Function dfdx::tensor_ops::softmax
source · pub fn softmax<Ax: Axes, S, E: Dtype, D: Device<E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>
) -> Tensor<S, E, D, T>where
S: ReduceShape<Ax> + Shape,
Expand description
Computes the softmax function across
Ax
.
Equivalent to exp(log_softmax(t))
.
Pytorch equivalent: t.softmax(Axes)
Example:
let t: Tensor<Rank3<2, 3, 5>, f32, _> = dev.zeros();
let _ = t.softmax::<Axis<2>>();