use std::mem::MaybeUninit;
use rten_simd::functional::{simd_apply, simd_map};
use rten_simd::ops::{FloatOps, NumOps};
use rten_simd::span::SrcDest;
use rten_simd::{Isa, Simd, SimdIterable, SimdOp, SimdUnaryOp};
use crate::exp::ReducedRangeExp;
pub struct Softmax<'src, 'dst> {
src_dest: SrcDest<'src, 'dst, f32>,
flush_nans_to_zero: bool,
}
impl<'src, 'dst> Softmax<'src, 'dst> {
#[track_caller]
pub fn new(input: &'src [f32], output: &'dst mut [MaybeUninit<f32>]) -> Self {
Softmax {
src_dest: (input, output).into(),
flush_nans_to_zero: false,
}
}
pub fn new_mut(input: &'dst mut [f32]) -> Self
where
'dst: 'src,
{
Softmax {
src_dest: input.into(),
flush_nans_to_zero: false,
}
}
pub fn flush_nans_to_zero(mut self, flush: bool) -> Self {
self.flush_nans_to_zero = flush;
self
}
}
impl<'dst> SimdOp for Softmax<'_, 'dst> {
type Output = &'dst mut [f32];
#[inline(always)]
fn eval<I: Isa>(self, isa: I) -> Self::Output {
let ops = isa.f32();
let max_val = self.src_dest.src().simd_iter(ops).fold_unroll::<4>(
ops.splat(f32::MIN),
#[inline(always)]
|max, x| ops.max(max, x),
#[inline(always)]
|max, x| ops.max(max, x),
);
let max_val = max_val
.to_array()
.into_iter()
.fold(f32::MIN, |max, x| max.max(x));
let (dest, exp_sum) = exp_sum_minus_max(isa, self.src_dest, max_val);
let exp_sum = ops.splat(exp_sum);
let inv_exp_sum = ops.reciprocal(exp_sum);
const UNROLL: usize = 2;
let zero = ops.zero();
if self.flush_nans_to_zero {
simd_apply::<_, _, _, UNROLL>(
ops,
dest,
#[inline(always)]
|x| {
let y = ops.mul(x, inv_exp_sum);
let not_nan = ops.eq(y, y);
ops.select(y, zero, not_nan)
},
);
} else {
simd_apply::<_, _, _, UNROLL>(
ops,
dest,
#[inline(always)]
|x| ops.mul(x, inv_exp_sum),
);
}
dest
}
}
#[inline(always)]
fn exp_sum_minus_max<'dst, I: Isa>(
isa: I,
src_dest: SrcDest<'_, 'dst, f32>,
max_val: f32,
) -> (&'dst mut [f32], f32) {
let ops = isa.f32();
let max_val = ops.splat(max_val);
let mut prev_exp_sum = ops.zero();
let mut exp_sum = ops.zero();
let dest = simd_map(
ops,
src_dest,
#[inline(always)]
|x| {
let y = ReducedRangeExp::apply(isa, ops.sub(x, max_val));
prev_exp_sum = exp_sum;
exp_sum = ops.add(exp_sum, y);
y
},
);
let remainder = dest.len() % ops.len();
if remainder != 0 {
let remainder_mask = ops.first_n_mask(remainder);
exp_sum = ops.select(exp_sum, prev_exp_sum, remainder_mask);
}
let exp_sum = exp_sum.to_array().into_iter().sum();
(dest, exp_sum)
}
#[cfg(test)]
mod tests {
use rten_simd::SimdOp;
use super::Softmax;
use crate::testing::{AsUninit, benchmark_op, check_f32s_are_equal_ulps, triples};
fn reference_softmax(xs: &[f32], ys: &mut [f32]) {
let max = xs.iter().copied().fold(f32::MIN, |max, x| max.max(x));
let mut exp_sum = 0.;
for (x, y) in xs.iter().zip(ys.iter_mut()) {
*y = (*x - max).exp();
exp_sum += *y;
}
for el in ys.iter_mut() {
*el /= exp_sum;
}
}
#[test]
fn test_softmax() {
let input = vec![0.1634, 0.8647, 0.6401, 0.8265, 0.0560, 0.2304];
let expected = &([
0.11715934, 0.23623686, 0.18871443, 0.2273828, 0.10522857, 0.12527795,
]);
let mut actual = vec![0.; input.len()];
Softmax::new(&input, actual.as_mut_slice().as_uninit()).dispatch();
check_f32s_are_equal_ulps(triples(&input, &actual, expected), 1. );
for len in 1..20 {
let input: Vec<f32> = (0..len).map(|x| x as f32 + 0.1).collect();
let mut expected = vec![0.; input.len()];
reference_softmax(&input, &mut expected);
let mut actual = vec![0.; input.len()];
Softmax::new(&input, actual.as_mut_slice().as_uninit()).dispatch();
check_f32s_are_equal_ulps(triples(&input, &actual, &expected), 3. );
}
}
#[test]
fn test_softmax_flush_nans_to_zero() {
let mut input = [f32::NEG_INFINITY; 3];
Softmax::new_mut(&mut input).dispatch();
assert!(input.iter().all(|x| x.is_nan()));
let mut input = [f32::NEG_INFINITY; 3];
Softmax::new_mut(&mut input)
.flush_nans_to_zero(true)
.dispatch();
assert_eq!(input, [0.; 3]);
}
#[test]
#[ignore]
fn bench_softmax() {
benchmark_op(reference_softmax, |src, dest| {
Softmax::new(src, dest).dispatch();
});
}
}