1use std::mem::MaybeUninit;
2
3use rten_simd::functional::{simd_apply, simd_map};
4use rten_simd::ops::{FloatOps, NumOps};
5use rten_simd::span::SrcDest;
6use rten_simd::{Isa, Simd, SimdIterable, SimdOp, SimdUnaryOp};
7
8use crate::exp::ReducedRangeExp;
9
10pub struct Softmax<'src, 'dst> {
18 src_dest: SrcDest<'src, 'dst, f32>,
19}
20
21impl<'src, 'dst> Softmax<'src, 'dst> {
22 pub fn new(input: &'src [f32], output: &'dst mut [MaybeUninit<f32>]) -> Self {
25 Softmax {
26 src_dest: (input, output).into(),
27 }
28 }
29
30 pub fn new_mut(input: &'dst mut [f32]) -> Self
32 where
33 'dst: 'src,
34 {
35 Softmax {
36 src_dest: input.into(),
37 }
38 }
39}
40
41impl<'dst> SimdOp for Softmax<'_, 'dst> {
42 type Output = &'dst mut [f32];
44
45 #[inline(always)]
46 fn eval<I: Isa>(self, isa: I) -> Self::Output {
47 let ops = isa.f32();
48
49 let max_val = self.src_dest.src().simd_iter(ops).fold_unroll::<4>(
50 ops.splat(f32::MIN),
51 #[inline(always)]
52 |max, x| ops.max(max, x),
53 #[inline(always)]
54 |max, x| ops.max(max, x),
55 );
56 let max_val = max_val
57 .to_array()
58 .into_iter()
59 .fold(f32::MIN, |max, x| max.max(x));
60
61 let (dest, exp_sum) = exp_sum_minus_max(isa, self.src_dest, max_val);
63
64 let exp_sum = ops.splat(exp_sum);
66 let inv_exp_sum = ops.reciprocal(exp_sum);
67 const UNROLL: usize = 2;
68 simd_apply::<_, _, _, UNROLL>(
69 ops,
70 dest,
71 #[inline(always)]
72 |x| ops.mul(x, inv_exp_sum),
73 );
74
75 dest
76 }
77}
78
79#[inline(always)]
81fn exp_sum_minus_max<'dst, I: Isa>(
82 isa: I,
83 src_dest: SrcDest<'_, 'dst, f32>,
84 max_val: f32,
85) -> (&'dst mut [f32], f32) {
86 let ops = isa.f32();
87
88 let max_val = ops.splat(max_val);
89
90 let mut prev_exp_sum = ops.zero();
92 let mut exp_sum = ops.zero();
93 let dest = simd_map(
94 ops,
95 src_dest,
96 #[inline(always)]
97 |x| {
98 let y = ReducedRangeExp::apply(isa, ops.sub(x, max_val));
100 prev_exp_sum = exp_sum;
101 exp_sum = ops.add(exp_sum, y);
102 y
103 },
104 );
105
106 let remainder = dest.len() % ops.len();
108 if remainder != 0 {
109 let remainder_mask = ops.first_n_mask(remainder);
110 exp_sum = ops.select(exp_sum, prev_exp_sum, remainder_mask);
111 }
112 let exp_sum = exp_sum.to_array().into_iter().sum();
113
114 (dest, exp_sum)
115}
116
117#[cfg(test)]
118mod tests {
119 use rten_simd::SimdOp;
120
121 use super::Softmax;
122 use crate::testing::{AsUninit, benchmark_op, check_f32s_are_equal_ulps, triples};
123
124 fn reference_softmax(xs: &[f32], ys: &mut [f32]) {
125 let max = xs.iter().copied().fold(f32::MIN, |max, x| max.max(x));
126
127 let mut exp_sum = 0.;
128 for (x, y) in xs.iter().zip(ys.iter_mut()) {
129 *y = (*x - max).exp();
130 exp_sum += *y;
131 }
132
133 for el in ys.iter_mut() {
134 *el /= exp_sum;
135 }
136 }
137
138 #[test]
139 fn test_softmax() {
140 let input = vec![0.1634, 0.8647, 0.6401, 0.8265, 0.0560, 0.2304];
142 let expected = &([
143 0.11715934, 0.23623686, 0.18871443, 0.2273828, 0.10522857, 0.12527795,
144 ]);
145 let mut actual = vec![0.; input.len()];
146
147 Softmax::new(&input, actual.as_mut_slice().as_uninit()).dispatch();
148 check_f32s_are_equal_ulps(triples(&input, &actual, expected), 1. );
149
150 for len in 1..20 {
152 let input: Vec<f32> = (0..len).map(|x| x as f32 + 0.1).collect();
153 let mut expected = vec![0.; input.len()];
154 reference_softmax(&input, &mut expected);
155
156 let mut actual = vec![0.; input.len()];
157 Softmax::new(&input, actual.as_mut_slice().as_uninit()).dispatch();
158
159 check_f32s_are_equal_ulps(triples(&input, &actual, &expected), 3. );
160 }
161 }
162
163 #[test]
164 #[ignore]
165 fn bench_softmax() {
166 benchmark_op(reference_softmax, |src, dest| {
167 Softmax::new(src, dest).dispatch();
168 });
169 }
170}