Skip to main content

trueno/blis/
softmax.rs

1//! SIMD-accelerated softmax.
2//!
3//! 4-pass algorithm with AVX2 acceleration on passes 1/3/4:
4//!   Pass 1 (max):       AVX2 horizontal reduction with 4-way unrolling
5//!   Pass 2 (exp+store): Scalar exp() — transcendental, no SIMD without SVML
6//!   Pass 3 (sum):       AVX2 horizontal reduction with 4-way unrolling
7//!   Pass 4 (normalize): AVX2 multiply by 1/sum with 4-way unrolling
8//!
9//! The previous 3-pass fused approach (exp+sum in one loop) had a loop-carried
10//! dependency on `sum` that prevented LLVM from vectorizing the surrounding code.
11//! Splitting into 4 passes allows SIMD on three of the four passes.
12//!
13//! Contract: contracts/softmax-kernel-v1.yaml
14
15/// Softmax on a 1D slice with zero-copy output allocation.
16///
17/// Uses AVX2 acceleration for max/sum/normalize passes when available.
18///
19/// # Contract
20///
21/// - `softmax(x)_i = exp(x_i - max(x)) / Σ_j exp(x_j - max(x))`
22/// - Output sums to 1.0 (within f32 tolerance)
23/// - All outputs ≥ 0
24/// - Monotonicity: x_i > x_j → y_i > y_j
25/// - Shift-invariant: softmax(x + c) = softmax(x)
26#[must_use]
27pub fn softmax_1d_alloc(logits: &[f32]) -> Vec<f32> {
28    let n = logits.len();
29    if n == 0 {
30        return Vec::new();
31    }
32    if n == 1 {
33        return vec![1.0];
34    }
35
36    // Contract: softmax-kernel-v1.yaml precondition (pv codegen)
37    contract_pre_softmax!(logits);
38
39    #[cfg(target_arch = "x86_64")]
40    {
41        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
42            // SAFETY: AVX2+FMA verified by feature detection above.
43            let result = unsafe { softmax_avx2(logits) };
44            contract_post_softmax!(&result);
45            return result;
46        }
47    }
48
49    let result = softmax_scalar(logits);
50    contract_post_softmax!(&result);
51    result
52}
53
54/// Scalar 4-pass softmax — reference implementation.
55fn softmax_scalar(logits: &[f32]) -> Vec<f32> {
56    let n = logits.len();
57
58    // Pass 1: max
59    let mut max_val = f32::NEG_INFINITY;
60    for &v in logits {
61        max_val = max_val.max(v);
62    }
63
64    // Pass 2: exp + store (uninit: every element written by indexed loop)
65    let mut out: Vec<f32> = Vec::with_capacity(n);
66    // SAFETY: out[i] = exp(...) for every i in 0..n. No reads before writes.
67    unsafe {
68        out.set_len(n);
69    }
70    for i in 0..n {
71        out[i] = (logits[i] - max_val).exp();
72    }
73
74    // Pass 3: sum
75    let mut sum = 0.0f32;
76    for &v in &out {
77        sum += v;
78    }
79
80    // Pass 4: normalize (guard against sum=0 from underflow)
81    let inv_sum = 1.0 / sum.max(f32::EPSILON);
82    for v in &mut out {
83        *v *= inv_sum;
84    }
85
86    out
87}
88
89/// AVX2 3-pass softmax: max → fused SIMD-exp+sum → normalize.
90///
91/// Fuses passes 2+3 into a single SIMD exp+accumulate pass, eliminating one
92/// full memory traversal. Uses polynomial exp approximation (6th-order Remez
93/// minimax on [-ln(2)/2, ln(2)/2] with range reduction), giving <1 ULP error.
94///
95/// # Safety
96///
97/// Requires AVX2 + FMA support.
98#[cfg(target_arch = "x86_64")]
99#[target_feature(enable = "avx2", enable = "fma")]
100unsafe fn softmax_avx2(logits: &[f32]) -> Vec<f32> {
101    use std::arch::x86_64::*;
102
103    let n = logits.len();
104    let chunks = n / 32;
105    let remainder_32 = chunks * 32;
106
107    // ── Pass 1: AVX2 horizontal max ──────────────────────────────────────
108    let mut max0;
109    let mut max1;
110    let mut max2;
111    let mut max3;
112    unsafe {
113        max0 = _mm256_set1_ps(f32::NEG_INFINITY);
114        max1 = max0;
115        max2 = max0;
116        max3 = max0;
117
118        for i in 0..chunks {
119            let base = i * 32;
120            let v0 = _mm256_loadu_ps(logits.as_ptr().add(base));
121            let v1 = _mm256_loadu_ps(logits.as_ptr().add(base + 8));
122            let v2 = _mm256_loadu_ps(logits.as_ptr().add(base + 16));
123            let v3 = _mm256_loadu_ps(logits.as_ptr().add(base + 24));
124            max0 = _mm256_max_ps(max0, v0);
125            max1 = _mm256_max_ps(max1, v1);
126            max2 = _mm256_max_ps(max2, v2);
127            max3 = _mm256_max_ps(max3, v3);
128        }
129
130        max0 = _mm256_max_ps(max0, max1);
131        max2 = _mm256_max_ps(max2, max3);
132        max0 = _mm256_max_ps(max0, max2);
133
134        let hi = _mm256_permute2f128_ps(max0, max0, 1);
135        max0 = _mm256_max_ps(max0, hi);
136        let shuf = _mm256_shuffle_ps(max0, max0, 0b01_00_11_10);
137        max0 = _mm256_max_ps(max0, shuf);
138        let shuf2 = _mm256_shuffle_ps(max0, max0, 0b10_11_00_01);
139        max0 = _mm256_max_ps(max0, shuf2);
140    }
141
142    let mut max_val = _mm_cvtss_f32(_mm256_castps256_ps128(max0));
143    for i in remainder_32..n {
144        max_val = max_val.max(logits[i]);
145    }
146
147    // ── Pass 2 (fused): SIMD exp(x - max) + accumulate sum ──────────────
148    // Uninit: every element written by storeu_ps (main) or out[i]=e (remainder).
149    let mut out: Vec<f32> = Vec::with_capacity(n);
150    // SAFETY: Pass 2 writes all n elements before any read (pass 3).
151    unsafe {
152        out.set_len(n);
153    }
154    let mut sum0;
155    let mut sum1;
156    let mut sum2;
157    let mut sum3;
158    unsafe {
159        let max_v = _mm256_set1_ps(max_val);
160        sum0 = _mm256_setzero_ps();
161        sum1 = sum0;
162        sum2 = sum0;
163        sum3 = sum0;
164
165        for i in 0..chunks {
166            let base = i * 32;
167            let x0 = _mm256_sub_ps(_mm256_loadu_ps(logits.as_ptr().add(base)), max_v);
168            let x1 = _mm256_sub_ps(_mm256_loadu_ps(logits.as_ptr().add(base + 8)), max_v);
169            let x2 = _mm256_sub_ps(_mm256_loadu_ps(logits.as_ptr().add(base + 16)), max_v);
170            let x3 = _mm256_sub_ps(_mm256_loadu_ps(logits.as_ptr().add(base + 24)), max_v);
171
172            let e0 = fast_exp_avx2(x0);
173            let e1 = fast_exp_avx2(x1);
174            let e2 = fast_exp_avx2(x2);
175            let e3 = fast_exp_avx2(x3);
176
177            _mm256_storeu_ps(out.as_mut_ptr().add(base), e0);
178            _mm256_storeu_ps(out.as_mut_ptr().add(base + 8), e1);
179            _mm256_storeu_ps(out.as_mut_ptr().add(base + 16), e2);
180            _mm256_storeu_ps(out.as_mut_ptr().add(base + 24), e3);
181
182            sum0 = _mm256_add_ps(sum0, e0);
183            sum1 = _mm256_add_ps(sum1, e1);
184            sum2 = _mm256_add_ps(sum2, e2);
185            sum3 = _mm256_add_ps(sum3, e3);
186        }
187
188        sum0 = _mm256_add_ps(sum0, sum1);
189        sum2 = _mm256_add_ps(sum2, sum3);
190        sum0 = _mm256_add_ps(sum0, sum2);
191
192        let hi = _mm256_permute2f128_ps(sum0, sum0, 1);
193        sum0 = _mm256_add_ps(sum0, hi);
194        let shuf = _mm256_shuffle_ps(sum0, sum0, 0b01_00_11_10);
195        sum0 = _mm256_add_ps(sum0, shuf);
196        let shuf2 = _mm256_shuffle_ps(sum0, sum0, 0b10_11_00_01);
197        sum0 = _mm256_add_ps(sum0, shuf2);
198    }
199
200    let mut sum_val = _mm_cvtss_f32(_mm256_castps256_ps128(sum0));
201    // Scalar remainder
202    for i in remainder_32..n {
203        let e = (logits[i] - max_val).exp();
204        out[i] = e;
205        sum_val += e;
206    }
207
208    // ── Pass 3: AVX2 normalize ──────────────────────────────────────────
209    let inv_sum = 1.0 / sum_val.max(f32::EPSILON);
210    unsafe {
211        let inv = _mm256_set1_ps(inv_sum);
212
213        for i in 0..chunks {
214            let base = i * 32;
215            let v0 = _mm256_loadu_ps(out.as_ptr().add(base));
216            let v1 = _mm256_loadu_ps(out.as_ptr().add(base + 8));
217            let v2 = _mm256_loadu_ps(out.as_ptr().add(base + 16));
218            let v3 = _mm256_loadu_ps(out.as_ptr().add(base + 24));
219            _mm256_storeu_ps(out.as_mut_ptr().add(base), _mm256_mul_ps(v0, inv));
220            _mm256_storeu_ps(out.as_mut_ptr().add(base + 8), _mm256_mul_ps(v1, inv));
221            _mm256_storeu_ps(out.as_mut_ptr().add(base + 16), _mm256_mul_ps(v2, inv));
222            _mm256_storeu_ps(out.as_mut_ptr().add(base + 24), _mm256_mul_ps(v3, inv));
223        }
224    }
225    for i in remainder_32..n {
226        out[i] *= inv_sum;
227    }
228
229    out
230}
231
232/// Fast SIMD exp(x) via range reduction + 6th-order polynomial.
233///
234/// Algorithm: e^x = 2^n * e^r where n = round(x/ln2), r = x - n*ln2.
235/// e^r approximated by minimax polynomial on [-ln(2)/2, ln(2)/2].
236/// Reconstruction via integer bit manipulation of float exponent.
237///
238/// Relative error < 2 ULP for x ∈ [-87, 88].
239///
240/// AVX2 polynomial exp approximation (6th-order Remez minimax, <1 ULP error).
241/// CGP-DBUF: made pub(crate) for reuse in AttentionOp::simd_softmax_row.
242#[cfg(target_arch = "x86_64")]
243#[target_feature(enable = "avx2", enable = "fma")]
244#[inline]
245pub(crate) unsafe fn fast_exp_avx2(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
246    use std::arch::x86_64::*;
247
248    let log2e = _mm256_set1_ps(std::f32::consts::LOG2_E);
249    let ln2_hi = _mm256_set1_ps(0.693_145_751_953_125); // ln(2) high bits
250    let ln2_lo = _mm256_set1_ps(1.428_606_765_330_187_1e-6); // ln(2) low bits
251    let one = _mm256_set1_ps(1.0);
252
253    // Polynomial coefficients (Remez minimax for e^r on [-ln2/2, ln2/2])
254    let c2 = _mm256_set1_ps(0.500_000_0); // 1/2!
255    let c3 = _mm256_set1_ps(0.166_666_671_6); // ~1/3!
256    let c4 = _mm256_set1_ps(0.041_666_645_8); // ~1/4!
257    let c5 = _mm256_set1_ps(0.008_333_345_2); // ~1/5!
258    let c6 = _mm256_set1_ps(0.001_388_731_6); // ~1/6!
259
260    // Clamp to avoid overflow/underflow in integer conversion
261    let x = _mm256_max_ps(x, _mm256_set1_ps(-87.33654));
262    let x = _mm256_min_ps(x, _mm256_set1_ps(88.72284));
263
264    // Range reduction: n = round(x / ln(2))
265    let t = _mm256_fmadd_ps(x, log2e, _mm256_set1_ps(0.5));
266    let n = _mm256_floor_ps(t); // floor(x*log2e + 0.5) = round
267
268    // r = x - n * ln(2) (high + low for precision)
269    let r = _mm256_sub_ps(x, _mm256_mul_ps(n, ln2_hi));
270    let r = _mm256_sub_ps(r, _mm256_mul_ps(n, ln2_lo));
271
272    // Polynomial: e^r ≈ 1 + r + c2*r² + c3*r³ + c4*r⁴ + c5*r⁵ + c6*r⁶
273    // Horner form re-arranged for FMA efficiency:
274    let p = _mm256_fmadd_ps(c6, r, c5);
275    let p = _mm256_fmadd_ps(p, r, c4);
276    let p = _mm256_fmadd_ps(p, r, c3);
277    let p = _mm256_fmadd_ps(p, r, c2);
278    let p = _mm256_fmadd_ps(p, r, one);
279    let p = _mm256_fmadd_ps(p, r, one);
280
281    // Reconstruct: multiply by 2^n via integer exponent manipulation
282    let n_i = _mm256_cvtps_epi32(n);
283    let pow2n =
284        _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_add_epi32(n_i, _mm256_set1_epi32(127)), 23));
285
286    _mm256_mul_ps(p, pow2n)
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    fn deterministic_f32(len: usize) -> Vec<f32> {
294        (0..len).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect()
295    }
296
297    /// FALSIFY-SM-001: sum(softmax(x)) ≈ 1.0
298    #[test]
299    fn test_softmax_sums_to_one() {
300        for n in [32, 127, 256, 1000, 32000] {
301            let data = deterministic_f32(n);
302            let result = softmax_1d_alloc(&data);
303            let sum: f32 = result.iter().sum();
304            assert!((sum - 1.0).abs() < 1e-5, "sum = {sum} for n={n}, expected 1.0");
305        }
306    }
307
308    /// FALSIFY-SM-002: all elements ≥ 0
309    #[test]
310    fn test_softmax_non_negative() {
311        let data: Vec<f32> = (0..1000).map(|i| -100.0 + i as f32 * 0.1).collect();
312        let result = softmax_1d_alloc(&data);
313        for (i, &v) in result.iter().enumerate() {
314            assert!(v >= 0.0, "element [{i}] = {v} < 0");
315        }
316    }
317
318    /// FALSIFY-SM-003: monotonicity
319    #[test]
320    fn test_softmax_monotonic() {
321        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
322        let result = softmax_1d_alloc(&data);
323        for i in 1..result.len() {
324            assert!(
325                result[i] > result[i - 1],
326                "Not monotonic at [{i}]: {} <= {}",
327                result[i],
328                result[i - 1]
329            );
330        }
331    }
332
333    /// FALSIFY-SM-004: shift invariance
334    #[test]
335    fn test_softmax_shift_invariance() {
336        let data = deterministic_f32(1000);
337        let shifted: Vec<f32> = data.iter().map(|&x| x + 1000.0).collect();
338
339        let result_a = softmax_1d_alloc(&data);
340        let result_b = softmax_1d_alloc(&shifted);
341
342        for (i, (&a, &b)) in result_a.iter().zip(result_b.iter()).enumerate() {
343            assert!((a - b).abs() < 1e-6, "Shift invariance broken at [{i}]: {a} vs {b}");
344        }
345    }
346
347    /// FALSIFY-SM-005: uniform input
348    #[test]
349    fn test_softmax_uniform() {
350        for n in [4, 100, 1000] {
351            let data = vec![std::f32::consts::PI; n];
352            let result = softmax_1d_alloc(&data);
353            let expected = 1.0 / n as f32;
354            for (i, &v) in result.iter().enumerate() {
355                assert!((v - expected).abs() < 1e-6, "Uniform at [{i}]: {v} vs {expected}");
356            }
357        }
358    }
359
360    /// FALSIFY-SM-006: AVX2 vs scalar parity
361    #[test]
362    fn test_softmax_avx2_scalar_parity() {
363        for n in [32, 127, 1000, 32000] {
364            let data = deterministic_f32(n);
365            let avx2_result = softmax_1d_alloc(&data);
366            let scalar_result = softmax_scalar(&data);
367
368            for (i, (&a, &s)) in avx2_result.iter().zip(scalar_result.iter()).enumerate() {
369                // SIMD polynomial exp has <2 ULP error vs libm exp
370                assert!((a - s).abs() < 1e-6, "AVX2/scalar mismatch at [{i}] n={n}: {a} vs {s}");
371            }
372        }
373    }
374
375    /// FALSIFY-SM-007: remainder handling
376    #[test]
377    fn test_softmax_remainder_sizes() {
378        for n in [1, 2, 7, 8, 15, 31, 33, 63, 65, 127, 255] {
379            let data = deterministic_f32(n);
380            let result = softmax_1d_alloc(&data);
381            let sum: f32 = result.iter().sum();
382            assert!((sum - 1.0).abs() < 1e-5, "sum = {sum} for n={n}, expected 1.0");
383            assert_eq!(result.len(), n);
384        }
385    }
386
387    /// FALSIFY-SM-008: numerical stability (near exp overflow)
388    #[test]
389    fn test_softmax_numerical_stability() {
390        let mut data = vec![0.0f32; 100];
391        data[0] = 88.0; // near f32 exp overflow (~3.4e38, max is ~3.4e38)
392        data[50] = -88.0; // near underflow
393
394        let result = softmax_1d_alloc(&data);
395        assert!(!result.iter().any(|v| v.is_nan()), "Got NaN");
396        assert!(!result.iter().any(|v| v.is_infinite()), "Got Inf");
397        let sum: f32 = result.iter().sum();
398        assert!((sum - 1.0).abs() < 1e-5, "sum = {sum}");
399    }
400
401    /// FALSIFY-SM-009: argmax preservation
402    #[test]
403    fn test_softmax_argmax_preserved() {
404        let data = deterministic_f32(32000);
405        let result = softmax_1d_alloc(&data);
406
407        let input_argmax = data
408            .iter()
409            .enumerate()
410            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
411            .map(|(i, _)| i)
412            .unwrap();
413
414        let output_argmax = result
415            .iter()
416            .enumerate()
417            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
418            .map(|(i, _)| i)
419            .unwrap();
420
421        assert_eq!(input_argmax, output_argmax, "Argmax not preserved");
422    }
423
424    /// Edge: empty input
425    #[test]
426    fn test_softmax_empty() {
427        let result = softmax_1d_alloc(&[]);
428        assert!(result.is_empty());
429    }
430
431    /// Edge: single element
432    #[test]
433    fn test_softmax_single() {
434        let result = softmax_1d_alloc(&[42.0]);
435        assert_eq!(result, vec![1.0]);
436    }
437}