Function dfdx::tensor_ops::log_softmax
source · pub fn log_softmax<Ax: Axes, S, E: Dtype, D: Device<E>, T: Tape<E, D>>(
t: Tensor<S, E, D, T>
) -> Tensor<S, E, D, T>where
S: ReduceShape<Ax> + Shape,
Expand description
log(softmax(t))
in numerically stable way across Ax
. Does t - logsumexp(t)
under the hood.
Pytorch equivalent: t.log_softmax(Ax)
Example:
let t: Tensor<Rank3<2, 3, 5>, f32, _> = dev.zeros();
let _ = t.log_softmax::<Axis<2>>();
Using multi axis log_softmax:
let _ = t.log_softmax::<Axes2<0, 2>>();