use sapient_core::error::{Result, SapientError};
use sapient_core::Tensor;
fn normalise_axis(axis: i64, ndim: usize) -> usize {
if axis < 0 {
(ndim as i64 + axis) as usize
} else {
axis as usize
}
}
pub fn softmax(x: &Tensor, axis: i64) -> Result<Tensor> {
apply_softmax_impl(x, axis, false)
}
pub fn log_softmax(x: &Tensor, axis: i64) -> Result<Tensor> {
apply_softmax_impl(x, axis, true)
}
fn apply_softmax_impl(x: &Tensor, axis: i64, log_mode: bool) -> Result<Tensor> {
let shape = x.shape();
let ndim = shape.ndim();
let ax = normalise_axis(axis, ndim);
if ax >= ndim {
return Err(SapientError::internal(format!(
"softmax axis {axis} out of range for rank {ndim}"
)));
}
let data_cow = x.to_f32_cow();
let data = data_cow.as_ref();
let mut out = vec![0.0f32; data.len()];
let outer: usize = shape.dims()[..ax].iter().product();
let dim_size = shape.dims()[ax];
let inner: usize = shape.dims()[ax + 1..].iter().product();
for o in 0..outer {
for i in 0..inner {
let slice: Vec<f32> = (0..dim_size)
.map(|d| data[(o * dim_size + d) * inner + i])
.collect();
let mut max_v = slice.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
if max_v == f32::NEG_INFINITY {
max_v = 0.0;
}
let exps: Vec<f32> = slice.iter().map(|&v| (v - max_v).exp()).collect();
let mut sum_e: f32 = exps.iter().sum();
if sum_e == 0.0 {
sum_e = f32::EPSILON;
}
for d in 0..dim_size {
let idx = (o * dim_size + d) * inner + i;
out[idx] = if log_mode {
(slice[d] - max_v) - sum_e.ln()
} else {
exps[d] / sum_e
};
}
}
}
Tensor::from_f32(&out, shape.clone())
}
#[cfg(test)]
mod tests {
use super::*;
use sapient_core::Tensor;
#[test]
fn softmax_sums_to_one() {
let x = Tensor::from_f32(&[1.0, 2.0, 3.0, 4.0], vec![1, 4]).unwrap();
let y = softmax(&x, 1).unwrap();
let sum: f32 = y.as_f32_slice().iter().sum();
assert!((sum - 1.0).abs() < 1e-6, "sum = {sum}");
}
#[test]
fn softmax_stable_large() {
let x = Tensor::from_f32(&[1000.0, 1001.0, 1002.0], vec![1, 3]).unwrap();
let y = softmax(&x, 1).unwrap();
let d = y.as_f32_slice();
for &v in d {
assert!(v.is_finite(), "non-finite: {v}");
}
let sum: f32 = d.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "sum = {sum}");
}
#[test]
fn log_softmax_finite() {
let x = Tensor::from_f32(&[1.0, 2.0, 3.0], vec![1, 3]).unwrap();
let y = log_softmax(&x, 1).unwrap();
for &v in y.as_f32_slice() {
assert!(v.is_finite());
}
}
}