pub fn softmax(logits: &ArrayD<f32>) -> Array2<f32>
Take the softmax of an array of shape batch_size * num_classes
batch_size * num_classes