Skip to main content

provable_contracts/kernels/
softmax.rs

1//! Softmax kernel: numerically stable exponential normalization.
2//!
3//! Matches `softmax-kernel-v1.yaml`.
4//! Four phases: `find_max` -> `exp_subtract` -> `sum_exp` -> `normalize`.
5
6#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9// ────────────────────────────────────────────────────────────────────────────
10// Scalar implementation
11// ────────────────────────────────────────────────────────────────────────────
12
13/// Scalar reference implementation of numerically stable softmax.
14///
15/// Computes `softmax(x)_i = exp(x_i - max(x)) / sum_j exp(x_j - max(x))`.
16///
17/// # Panics
18///
19/// Panics if `input` and `output` have different lengths or if `input` is empty.
20pub fn softmax_scalar(input: &[f32], output: &mut [f32]) {
21    assert_eq!(input.len(), output.len(), "input/output length mismatch");
22    assert!(!input.is_empty(), "softmax requires non-empty input");
23
24    // Phase 1: find max for numerical stability
25    let mut max_val = input[0];
26    for &x in &input[1..] {
27        if x > max_val {
28            max_val = x;
29        }
30    }
31
32    // Phase 2: exp(x_i - max)
33    for (i, &x) in input.iter().enumerate() {
34        output[i] = (x - max_val).exp();
35    }
36
37    // Phase 3: sum of exponentials
38    let mut sum = 0.0_f32;
39    for &e in output.iter() {
40        sum += e;
41    }
42
43    // Phase 4: normalize
44    let inv_sum = 1.0 / sum;
45    for o in output.iter_mut() {
46        *o *= inv_sum;
47    }
48}
49
50// ────────────────────────────────────────────────────────────────────────────
51// AVX2 implementation
52// ────────────────────────────────────────────────────────────────────────────
53
54/// AVX2 SIMD implementation of numerically stable softmax.
55///
56/// Uses `_mm256_max_ps` for horizontal max reduction across 8-wide lanes,
57/// then scalar fallback for exp (no AVX2 exp intrinsic), and vectorized
58/// final division.
59///
60/// # Safety
61///
62/// Requires AVX2 support. Caller must verify with `is_x86_feature_detected!("avx2")`.
63///
64/// # Panics
65///
66/// Panics if `input` and `output` have different lengths or if `input` is empty.
67#[cfg(target_arch = "x86_64")]
68#[target_feature(enable = "avx2")]
69pub unsafe fn softmax_avx2(input: &[f32], output: &mut [f32]) {
70    assert_eq!(input.len(), output.len(), "input/output length mismatch");
71    let n = input.len();
72    assert!(n > 0, "softmax requires non-empty input");
73
74    let chunks = n / 8;
75    let remainder = n % 8;
76
77    // SAFETY: caller guarantees AVX2 is available; target_feature gate enforces it.
78    unsafe {
79        // Phase 1: find max using AVX2 horizontal reduction
80        let mut max_vec = _mm256_set1_ps(f32::NEG_INFINITY);
81        for i in 0..chunks {
82            let v = _mm256_loadu_ps(input.as_ptr().add(i * 8));
83            max_vec = _mm256_max_ps(max_vec, v);
84        }
85
86        // Horizontal max reduction of the 8-wide vector
87        let mut max_val = f32::NEG_INFINITY;
88        let mut tmp = [0.0_f32; 8];
89        _mm256_storeu_ps(tmp.as_mut_ptr(), max_vec);
90        for &v in &tmp {
91            if v > max_val {
92                max_val = v;
93            }
94        }
95        // Check remainder elements
96        for i in (chunks * 8)..n {
97            if input[i] > max_val {
98                max_val = input[i];
99            }
100        }
101
102        // Phase 2: exp(x_i - max) — scalar fallback (no AVX2 exp intrinsic)
103        for i in 0..n {
104            output[i] = (input[i] - max_val).exp();
105        }
106
107        // Phase 3: sum of exponentials using AVX2 accumulation
108        let mut sum_vec = _mm256_setzero_ps();
109        for i in 0..chunks {
110            let v = _mm256_loadu_ps(output.as_ptr().add(i * 8));
111            sum_vec = _mm256_add_ps(sum_vec, v);
112        }
113        _mm256_storeu_ps(tmp.as_mut_ptr(), sum_vec);
114        let mut sum = 0.0_f32;
115        for &v in &tmp {
116            sum += v;
117        }
118        for i in (chunks * 8)..n {
119            sum += output[i];
120        }
121
122        // Phase 4: normalize using AVX2 division
123        let inv_sum = 1.0 / sum;
124        let inv_vec = _mm256_set1_ps(inv_sum);
125        for i in 0..chunks {
126            let v = _mm256_loadu_ps(output.as_ptr().add(i * 8));
127            let r = _mm256_mul_ps(v, inv_vec);
128            _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), r);
129        }
130        for i in (chunks * 8)..(chunks * 8 + remainder) {
131            output[i] *= inv_sum;
132        }
133    }
134}
135
136include!("softmax_ptx.rs");
137
138// ────────────────────────────────────────────────────────────────────────────
139// Tests
140// ────────────────────────────────────────────────────────────────────────────
141
142#[cfg(test)]
143mod tests {
144    use super::super::ulp::assert_ulp_eq;
145    use super::*;
146    use proptest::prelude::*;
147
148    // ── Scalar known-answer tests ────────────────────────────────────────
149
150    #[test]
151    fn test_softmax_uniform() {
152        let input = [1.0_f32, 1.0, 1.0];
153        let mut output = [0.0_f32; 3];
154        softmax_scalar(&input, &mut output);
155        let expected = 1.0 / 3.0;
156        for &o in &output {
157            assert!((o - expected).abs() < 1e-6, "expected ~{expected}, got {o}");
158        }
159    }
160
161    #[test]
162    fn test_softmax_two_equal() {
163        let input = [0.0_f32, 0.0];
164        let mut output = [0.0_f32; 2];
165        softmax_scalar(&input, &mut output);
166        for &o in &output {
167            assert!((o - 0.5).abs() < 1e-6, "expected 0.5, got {o}");
168        }
169    }
170
171    #[test]
172    fn test_softmax_numerical_stability() {
173        // Large values should not overflow thanks to max-subtraction trick
174        let input = [1000.0_f32, 0.0, 0.0];
175        let mut output = [0.0_f32; 3];
176        softmax_scalar(&input, &mut output);
177        assert!(output[0].is_finite(), "output[0] must be finite");
178        assert!(output[1].is_finite(), "output[1] must be finite");
179        assert!(output[2].is_finite(), "output[2] must be finite");
180        // Dominant element should be close to 1.0
181        assert!((output[0] - 1.0).abs() < 1e-6);
182    }
183
184    #[test]
185    fn test_softmax_single_element() {
186        let input = [42.0_f32];
187        let mut output = [0.0_f32; 1];
188        softmax_scalar(&input, &mut output);
189        assert!(
190            (output[0] - 1.0).abs() < 1e-7,
191            "softmax of single element must be 1.0"
192        );
193    }
194
195    #[test]
196    #[should_panic(expected = "input/output length mismatch")]
197    fn test_softmax_length_mismatch() {
198        let input = [1.0_f32, 2.0];
199        let mut output = [0.0_f32; 3];
200        softmax_scalar(&input, &mut output);
201    }
202
203    #[test]
204    #[should_panic(expected = "softmax requires non-empty input")]
205    fn test_softmax_empty_input() {
206        let input: [f32; 0] = [];
207        let mut output: [f32; 0] = [];
208        softmax_scalar(&input, &mut output);
209    }
210
211    // ── Property-based tests ─────────────────────────────────────────────
212
213    proptest! {
214        #[test]
215        fn prop_softmax_sums_to_one(
216            v in proptest::collection::vec(-100.0_f32..100.0, 1..64)
217        ) {
218            let mut out = vec![0.0_f32; v.len()];
219            softmax_scalar(&v, &mut out);
220            let sum: f32 = out.iter().sum();
221            prop_assert!(
222                (sum - 1.0).abs() < 1e-5,
223                "softmax sum = {sum}, expected ~1.0"
224            );
225        }
226
227        #[test]
228        fn prop_softmax_outputs_in_unit_interval(
229            v in proptest::collection::vec(-100.0_f32..100.0, 1..64)
230        ) {
231            let mut out = vec![0.0_f32; v.len()];
232            softmax_scalar(&v, &mut out);
233            for (i, &o) in out.iter().enumerate() {
234                prop_assert!(
235                    (0.0..=1.0).contains(&o),
236                    "output[{i}] = {o} not in [0,1]"
237                );
238            }
239        }
240
241        #[test]
242        fn prop_softmax_order_preservation(
243            v in proptest::collection::vec(-50.0_f32..50.0, 2..32)
244        ) {
245            let mut out = vec![0.0_f32; v.len()];
246            softmax_scalar(&v, &mut out);
247            for i in 0..v.len() {
248                for j in (i + 1)..v.len() {
249                    if v[i] > v[j] {
250                        prop_assert!(
251                            out[i] >= out[j],
252                            "order violated: v[{i}]={} > v[{j}]={} but out[{i}]={} < out[{j}]={}",
253                            v[i], v[j], out[i], out[j]
254                        );
255                    }
256                }
257            }
258        }
259
260        #[test]
261        fn prop_softmax_translation_invariance(
262            v in proptest::collection::vec(-50.0_f32..50.0, 2..32),
263            c in -50.0_f32..50.0
264        ) {
265            let mut out1 = vec![0.0_f32; v.len()];
266            softmax_scalar(&v, &mut out1);
267
268            let shifted: Vec<f32> = v.iter().map(|&x| x + c).collect();
269            let mut out2 = vec![0.0_f32; v.len()];
270            softmax_scalar(&shifted, &mut out2);
271
272            for i in 0..v.len() {
273                prop_assert!(
274                    (out1[i] - out2[i]).abs() < 1e-5,
275                    "translation invariance violated at {i}: {} vs {}",
276                    out1[i], out2[i]
277                );
278            }
279        }
280    }
281
282    // ── AVX2 parity tests ────────────────────────────────────────────────
283
284    #[cfg(target_arch = "x86_64")]
285    #[test]
286    fn test_softmax_avx2_basic() {
287        if !is_x86_feature_detected!("avx2") {
288            return;
289        }
290        let input = [
291            1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
292            16.0,
293        ];
294        let mut scalar_out = [0.0_f32; 16];
295        let mut avx2_out = [0.0_f32; 16];
296        softmax_scalar(&input, &mut scalar_out);
297        unsafe { softmax_avx2(&input, &mut avx2_out) };
298        assert_ulp_eq(&scalar_out, &avx2_out, 8);
299    }
300
301    #[cfg(target_arch = "x86_64")]
302    #[test]
303    fn test_softmax_avx2_non_multiple_of_8() {
304        if !is_x86_feature_detected!("avx2") {
305            return;
306        }
307        let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
308        let mut scalar_out = [0.0_f32; 5];
309        let mut avx2_out = [0.0_f32; 5];
310        softmax_scalar(&input, &mut scalar_out);
311        unsafe { softmax_avx2(&input, &mut avx2_out) };
312        assert_ulp_eq(&scalar_out, &avx2_out, 8);
313    }
314
315    #[cfg(target_arch = "x86_64")]
316    proptest! {
317        #[test]
318        fn prop_softmax_avx2_parity(
319            v in proptest::collection::vec(-100.0_f32..100.0, 1..64)
320        ) {
321            if !is_x86_feature_detected!("avx2") {
322                return Ok(());
323            }
324            let mut scalar_out = vec![0.0_f32; v.len()];
325            let mut avx2_out = vec![0.0_f32; v.len()];
326            softmax_scalar(&v, &mut scalar_out);
327            unsafe { softmax_avx2(&v, &mut avx2_out) };
328            assert_ulp_eq(&scalar_out, &avx2_out, 8);
329        }
330    }
331
332    // ── PTX structural tests ─────────────────────────────────────────────
333
334    #[test]
335    fn test_softmax_ptx_version() {
336        let ptx = softmax_ptx();
337        assert!(ptx.contains(".version 8.5"), "missing PTX version");
338    }
339
340    #[test]
341    fn test_softmax_ptx_target() {
342        let ptx = softmax_ptx();
343        assert!(ptx.contains(".target sm_90"), "missing PTX target");
344    }
345
346    #[test]
347    fn test_softmax_ptx_entry() {
348        let ptx = softmax_ptx();
349        assert!(ptx.contains(".entry softmax_kernel"), "missing entry point");
350    }
351
352    #[test]
353    fn test_softmax_ptx_ret() {
354        let ptx = softmax_ptx();
355        assert!(ptx.contains("ret;"), "missing ret instruction");
356    }
357
358    #[test]
359    fn test_softmax_ptx_shared_memory() {
360        let ptx = softmax_ptx();
361        assert!(ptx.contains(".shared"), "missing shared memory declaration");
362    }
363
364    #[test]
365    fn test_softmax_ptx_warp_shuffle() {
366        let ptx = softmax_ptx();
367        assert!(
368            ptx.contains("shfl.sync"),
369            "missing warp shuffle instructions"
370        );
371    }
372
373    #[test]
374    fn test_softmax_ptx_bar_sync() {
375        let ptx = softmax_ptx();
376        assert!(
377            ptx.contains("bar.sync"),
378            "missing bar.sync for block synchronization"
379        );
380    }
381
382    #[test]
383    fn test_softmax_ptx_balanced_braces() {
384        let ptx = softmax_ptx();
385        let open = ptx.matches('{').count();
386        let close = ptx.matches('}').count();
387        assert_eq!(
388            open, close,
389            "unbalanced braces: {open} open vs {close} close"
390        );
391    }
392}