use arrow::array::{Array, ArrowPrimitiveType, PrimitiveArray};
use arrow::buffer::Buffer;
use arrow::tensor::Tensor;
use arrow_ml_common::KernelError;
use arrow_ml_common::Result;
use num_traits::{Float, Zero};
use std::ops::AddAssign;
pub fn softmax<T>(array: &PrimitiveArray<T>) -> Result<PrimitiveArray<T>>
where
T: ArrowPrimitiveType,
T::Native: Float + AddAssign,
{
if array.null_count() > 0 {
return Err(KernelError::NullsNotSupported {
operation: "softmax",
});
}
if array.is_empty() {
return Err(KernelError::EmptyArray {
operation: "softmax",
});
}
let values = array.values();
let max_val = values
.iter()
.copied()
.fold(T::Native::neg_infinity(), |a, b| a.max(b));
let mut sum = T::Native::zero();
let exp_vals: Vec<T::Native> = values
.iter()
.map(|&x| {
let e = (x - max_val).exp();
sum += e;
e
})
.collect();
let result: Vec<T::Native> = exp_vals.into_iter().map(|e| e / sum).collect();
Ok(PrimitiveArray::from_iter_values(result))
}
pub fn softmax_tensor<T>(input: &Tensor<'_, T>, axis: i64) -> Result<Tensor<'static, T>>
where
T: ArrowPrimitiveType,
T::Native: Float + AddAssign,
{
let shape = input.shape().ok_or_else(|| {
KernelError::InvalidArgument("softmax_tensor: tensor has no shape".into())
})?;
let ndim = shape.len();
if ndim == 0 {
return Err(KernelError::InvalidArgument(
"softmax_tensor: tensor must be at least 1D".into(),
));
}
let axis = if axis < 0 { ndim as i64 + axis } else { axis };
if axis < 0 || axis >= ndim as i64 {
return Err(KernelError::InvalidArgument(format!(
"softmax_tensor: axis {} out of range for {}D tensor",
axis, ndim
)));
}
let axis = axis as usize;
let outer_size: usize = shape[..axis].iter().product();
let dim_size = shape[axis];
let inner_size: usize = shape[axis + 1..].iter().product();
let outer_size = if outer_size == 0 { 1 } else { outer_size };
let inner_size = if inner_size == 0 { 1 } else { inner_size };
let data: &[T::Native] = input.data().typed_data();
let mut out = data.to_vec();
for o in 0..outer_size {
for i in 0..inner_size {
let mut max_val = T::Native::neg_infinity();
for d in 0..dim_size {
let idx = o * dim_size * inner_size + d * inner_size + i;
if data[idx] > max_val {
max_val = data[idx];
}
}
let mut sum = T::Native::zero();
for d in 0..dim_size {
let idx = o * dim_size * inner_size + d * inner_size + i;
let e = (data[idx] - max_val).exp();
out[idx] = e;
sum += e;
}
for d in 0..dim_size {
let idx = o * dim_size * inner_size + d * inner_size + i;
out[idx] = out[idx] / sum;
}
}
}
let buf = Buffer::from_vec(out);
Tensor::new_row_major(buf, Some(shape.to_vec()), None).map_err(KernelError::from)
}
#[cfg(test)]
mod tests {
use super::*;
use arrow::array::Float32Array;
use arrow::buffer::ScalarBuffer;
use arrow::datatypes::Float32Type;
#[test]
fn test_softmax_uniform() {
let input = Float32Array::from(vec![1.0_f32, 1.0, 1.0, 1.0]);
let output = softmax(&input).unwrap();
for i in 0..4 {
assert!((output.value(i) - 0.25).abs() < 1e-6);
}
}
#[test]
fn test_softmax_sums_to_one() {
let input = Float32Array::from(vec![1.0_f32, 2.0, 3.0, 4.0]);
let output = softmax(&input).unwrap();
let sum: f32 = output.values().iter().sum();
assert!((sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_softmax_ordering() {
let input = Float32Array::from(vec![1.0_f32, 3.0, 2.0]);
let output = softmax(&input).unwrap();
assert!(output.value(1) > output.value(2));
assert!(output.value(2) > output.value(0));
}
#[test]
fn test_softmax_numerical_stability() {
let input = Float32Array::from(vec![1000.0_f32, 1001.0, 1002.0]);
let output = softmax(&input).unwrap();
let sum: f32 = output.values().iter().sum();
assert!((sum - 1.0).abs() < 1e-5);
assert!(output.value(2) > output.value(1));
assert!(output.value(1) > output.value(0));
}
#[test]
fn test_softmax_rejects_nulls() {
let input = Float32Array::from(vec![Some(1.0_f32), None, Some(3.0)]);
assert!(softmax(&input).is_err());
}
#[test]
fn test_softmax_rejects_empty() {
let input = Float32Array::from(Vec::<f32>::new());
assert!(softmax(&input).is_err());
}
fn make_f32(data: Vec<f32>, shape: Vec<usize>) -> Tensor<'static, Float32Type> {
let buffer = Buffer::from(ScalarBuffer::<f32>::from(data).into_inner());
Tensor::new_row_major(buffer, Some(shape), None).unwrap()
}
#[test]
fn test_softmax_tensor_2d_axis1() {
let input = make_f32(vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0], vec![2, 3]);
let out = softmax_tensor::<Float32Type>(&input, 1).unwrap();
assert_eq!(out.shape().unwrap(), &vec![2, 3]);
let data = out.data().typed_data::<f32>();
let row0_sum: f32 = data[0..3].iter().sum();
let row1_sum: f32 = data[3..6].iter().sum();
assert!((row0_sum - 1.0).abs() < 1e-6);
assert!((row1_sum - 1.0).abs() < 1e-6);
}
#[test]
fn test_softmax_tensor_2d_axis0() {
let input = make_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let out = softmax_tensor::<Float32Type>(&input, 0).unwrap();
let data = out.data().typed_data::<f32>();
for j in 0..3 {
let col_sum = data[j] + data[3 + j];
assert!((col_sum - 1.0).abs() < 1e-6);
}
}
#[test]
fn test_softmax_tensor_3d_attention() {
let input = make_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![1, 2, 3]);
let out = softmax_tensor::<Float32Type>(&input, -1).unwrap();
assert_eq!(out.shape().unwrap(), &vec![1, 2, 3]);
let data = out.data().typed_data::<f32>();
let sum0: f32 = data[0..3].iter().sum();
let sum1: f32 = data[3..6].iter().sum();
assert!((sum0 - 1.0).abs() < 1e-6);
assert!((sum1 - 1.0).abs() < 1e-6);
}
#[test]
fn test_softmax_tensor_negative_axis() {
let input = make_f32(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let out = softmax_tensor::<Float32Type>(&input, -1).unwrap(); let data = out.data().typed_data::<f32>();
let row0_sum: f32 = data[0..3].iter().sum();
assert!((row0_sum - 1.0).abs() < 1e-6);
}
}