1use std::mem::MaybeUninit;
2
3use rten_simd::functional::{simd_apply, simd_map};
4use rten_simd::ops::{FloatOps, NumOps};
5use rten_simd::span::SrcDest;
6use rten_simd::{Isa, Simd, SimdIterable, SimdOp, SimdUnaryOp};
7
8use crate::exp::ReducedRangeExp;
9
10pub struct Softmax<'src, 'dst> {
18    src_dest: SrcDest<'src, 'dst, f32>,
19}
20
21impl<'src, 'dst> Softmax<'src, 'dst> {
22    pub fn new(input: &'src [f32], output: &'dst mut [MaybeUninit<f32>]) -> Self {
25        Softmax {
26            src_dest: (input, output).into(),
27        }
28    }
29
30    pub fn new_mut(input: &'dst mut [f32]) -> Self
32    where
33        'dst: 'src,
34    {
35        Softmax {
36            src_dest: input.into(),
37        }
38    }
39}
40
41impl<'dst> SimdOp for Softmax<'_, 'dst> {
42    type Output = &'dst mut [f32];
44
45    #[inline(always)]
46    fn eval<I: Isa>(self, isa: I) -> Self::Output {
47        let ops = isa.f32();
48
49        let max_val = self.src_dest.src().simd_iter(ops).fold_unroll::<4>(
50            ops.splat(f32::MIN),
51            #[inline(always)]
52            |max, x| ops.max(max, x),
53            #[inline(always)]
54            |max, x| ops.max(max, x),
55        );
56        let max_val = max_val
57            .to_array()
58            .into_iter()
59            .fold(f32::MIN, |max, x| max.max(x));
60
61        let (dest, exp_sum) = exp_sum_minus_max(isa, self.src_dest, max_val);
63
64        let exp_sum = ops.splat(exp_sum);
66        let inv_exp_sum = ops.reciprocal(exp_sum);
67        const UNROLL: usize = 2;
68        simd_apply::<_, _, _, UNROLL>(
69            ops,
70            dest,
71            #[inline(always)]
72            |x| ops.mul(x, inv_exp_sum),
73        );
74
75        dest
76    }
77}
78
79#[inline(always)]
81fn exp_sum_minus_max<'dst, I: Isa>(
82    isa: I,
83    src_dest: SrcDest<'_, 'dst, f32>,
84    max_val: f32,
85) -> (&'dst mut [f32], f32) {
86    let ops = isa.f32();
87
88    let max_val = ops.splat(max_val);
89
90    let mut prev_exp_sum = ops.zero();
92    let mut exp_sum = ops.zero();
93    let dest = simd_map(
94        ops,
95        src_dest,
96        #[inline(always)]
97        |x| {
98            let y = ReducedRangeExp::apply(isa, ops.sub(x, max_val));
100            prev_exp_sum = exp_sum;
101            exp_sum = ops.add(exp_sum, y);
102            y
103        },
104    );
105
106    let remainder = dest.len() % ops.len();
108    if remainder != 0 {
109        let remainder_mask = ops.first_n_mask(remainder);
110        exp_sum = ops.select(exp_sum, prev_exp_sum, remainder_mask);
111    }
112    let exp_sum = exp_sum.to_array().into_iter().sum();
113
114    (dest, exp_sum)
115}
116
117#[cfg(test)]
118mod tests {
119    use rten_simd::SimdOp;
120
121    use super::Softmax;
122    use crate::testing::{AsUninit, benchmark_op, check_f32s_are_equal_ulps, triples};
123
124    fn reference_softmax(xs: &[f32], ys: &mut [f32]) {
125        let max = xs.iter().copied().fold(f32::MIN, |max, x| max.max(x));
126
127        let mut exp_sum = 0.;
128        for (x, y) in xs.iter().zip(ys.iter_mut()) {
129            *y = (*x - max).exp();
130            exp_sum += *y;
131        }
132
133        for el in ys.iter_mut() {
134            *el /= exp_sum;
135        }
136    }
137
138    #[test]
139    fn test_softmax() {
140        let input = vec![0.1634, 0.8647, 0.6401, 0.8265, 0.0560, 0.2304];
142        let expected = &([
143            0.11715934, 0.23623686, 0.18871443, 0.2273828, 0.10522857, 0.12527795,
144        ]);
145        let mut actual = vec![0.; input.len()];
146
147        Softmax::new(&input, actual.as_mut_slice().as_uninit()).dispatch();
148        check_f32s_are_equal_ulps(triples(&input, &actual, expected), 1. );
149
150        for len in 1..20 {
152            let input: Vec<f32> = (0..len).map(|x| x as f32 + 0.1).collect();
153            let mut expected = vec![0.; input.len()];
154            reference_softmax(&input, &mut expected);
155
156            let mut actual = vec![0.; input.len()];
157            Softmax::new(&input, actual.as_mut_slice().as_uninit()).dispatch();
158
159            check_f32s_are_equal_ulps(triples(&input, &actual, &expected), 3. );
160        }
161    }
162
163    #[test]
164    #[ignore]
165    fn bench_softmax() {
166        benchmark_op(reference_softmax, |src, dest| {
167            Softmax::new(src, dest).dispatch();
168        });
169    }
170}