Function dfdx::tensor_ops::log_softmax

source ·
pub fn log_softmax<Ax: Axes, S, E: Dtype, D: Device<E>, T: Tape<E, D>>(
    t: Tensor<S, E, D, T>
) -> Tensor<S, E, D, T>where
    S: ReduceShape<Ax> + Shape,
Expand description

log(softmax(t)) in numerically stable way across Ax. Does t - logsumexp(t) under the hood.

Pytorch equivalent: t.log_softmax(Ax)

Example:

let t: Tensor<Rank3<2, 3, 5>, f32, _> = dev.zeros();
let _ = t.log_softmax::<Axis<2>>();

Using multi axis log_softmax:

let _ = t.log_softmax::<Axes2<0, 2>>();