Function dfdx::tensor_ops::softmax

source ·
pub fn softmax<Ax: Axes, S, E: Dtype, D: Device<E>, T: Tape<E, D>>(
    t: Tensor<S, E, D, T>
) -> Tensor<S, E, D, T>where
    S: ReduceShape<Ax> + Shape,
Expand description

Computes the softmax function across Ax.

Equivalent to exp(log_softmax(t)).

Pytorch equivalent: t.softmax(Axes)

Example:

let t: Tensor<Rank3<2, 3, 5>, f32, _> = dev.zeros();
let _ = t.softmax::<Axis<2>>();