Function dfdx::tensor_ops::map
source · [−]pub fn map<T: Tensor<Dtype = f32>, F, DF>(t: T, f: F, df: DF) -> T where
F: 'static + FnMut(&f32) -> f32 + Copy,
DF: 'static + FnMut(&f32) -> f32 + Copy,
Expand description
Applies a function f
to every element of the Tensor. The derivative
df
must also be provided.
This is primarily used to implement standard functions such as relu(), exp(), etc. But users can also implement their own activations with this.
Examples:
let t = Tensor1D::new([-2.0, -1.0, 0.0, 1.0, 2.0]);
let r = map(t, |x| 2.0 * x, |x| 2.0);
assert_eq!(r.data(), &[-4.0, -2.0, 0.0, 2.0, 4.0]);