provable_contracts/kernels/
softmax.rs1#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9pub fn softmax_scalar(input: &[f32], output: &mut [f32]) {
21 assert_eq!(input.len(), output.len(), "input/output length mismatch");
22 assert!(!input.is_empty(), "softmax requires non-empty input");
23
24 let mut max_val = input[0];
26 for &x in &input[1..] {
27 if x > max_val {
28 max_val = x;
29 }
30 }
31
32 for (i, &x) in input.iter().enumerate() {
34 output[i] = (x - max_val).exp();
35 }
36
37 let mut sum = 0.0_f32;
39 for &e in output.iter() {
40 sum += e;
41 }
42
43 let inv_sum = 1.0 / sum;
45 for o in output.iter_mut() {
46 *o *= inv_sum;
47 }
48}
49
50#[cfg(target_arch = "x86_64")]
68#[target_feature(enable = "avx2")]
69pub unsafe fn softmax_avx2(input: &[f32], output: &mut [f32]) {
70 assert_eq!(input.len(), output.len(), "input/output length mismatch");
71 let n = input.len();
72 assert!(n > 0, "softmax requires non-empty input");
73
74 let chunks = n / 8;
75 let remainder = n % 8;
76
77 unsafe {
79 let mut max_vec = _mm256_set1_ps(f32::NEG_INFINITY);
81 for i in 0..chunks {
82 let v = _mm256_loadu_ps(input.as_ptr().add(i * 8));
83 max_vec = _mm256_max_ps(max_vec, v);
84 }
85
86 let mut max_val = f32::NEG_INFINITY;
88 let mut tmp = [0.0_f32; 8];
89 _mm256_storeu_ps(tmp.as_mut_ptr(), max_vec);
90 for &v in &tmp {
91 if v > max_val {
92 max_val = v;
93 }
94 }
95 for i in (chunks * 8)..n {
97 if input[i] > max_val {
98 max_val = input[i];
99 }
100 }
101
102 for i in 0..n {
104 output[i] = (input[i] - max_val).exp();
105 }
106
107 let mut sum_vec = _mm256_setzero_ps();
109 for i in 0..chunks {
110 let v = _mm256_loadu_ps(output.as_ptr().add(i * 8));
111 sum_vec = _mm256_add_ps(sum_vec, v);
112 }
113 _mm256_storeu_ps(tmp.as_mut_ptr(), sum_vec);
114 let mut sum = 0.0_f32;
115 for &v in &tmp {
116 sum += v;
117 }
118 for i in (chunks * 8)..n {
119 sum += output[i];
120 }
121
122 let inv_sum = 1.0 / sum;
124 let inv_vec = _mm256_set1_ps(inv_sum);
125 for i in 0..chunks {
126 let v = _mm256_loadu_ps(output.as_ptr().add(i * 8));
127 let r = _mm256_mul_ps(v, inv_vec);
128 _mm256_storeu_ps(output.as_mut_ptr().add(i * 8), r);
129 }
130 for i in (chunks * 8)..(chunks * 8 + remainder) {
131 output[i] *= inv_sum;
132 }
133 }
134}
135
136include!("softmax_ptx.rs");
137
138#[cfg(test)]
143mod tests {
144 use super::super::ulp::assert_ulp_eq;
145 use super::*;
146 use proptest::prelude::*;
147
148 #[test]
151 fn test_softmax_uniform() {
152 let input = [1.0_f32, 1.0, 1.0];
153 let mut output = [0.0_f32; 3];
154 softmax_scalar(&input, &mut output);
155 let expected = 1.0 / 3.0;
156 for &o in &output {
157 assert!((o - expected).abs() < 1e-6, "expected ~{expected}, got {o}");
158 }
159 }
160
161 #[test]
162 fn test_softmax_two_equal() {
163 let input = [0.0_f32, 0.0];
164 let mut output = [0.0_f32; 2];
165 softmax_scalar(&input, &mut output);
166 for &o in &output {
167 assert!((o - 0.5).abs() < 1e-6, "expected 0.5, got {o}");
168 }
169 }
170
171 #[test]
172 fn test_softmax_numerical_stability() {
173 let input = [1000.0_f32, 0.0, 0.0];
175 let mut output = [0.0_f32; 3];
176 softmax_scalar(&input, &mut output);
177 assert!(output[0].is_finite(), "output[0] must be finite");
178 assert!(output[1].is_finite(), "output[1] must be finite");
179 assert!(output[2].is_finite(), "output[2] must be finite");
180 assert!((output[0] - 1.0).abs() < 1e-6);
182 }
183
184 #[test]
185 fn test_softmax_single_element() {
186 let input = [42.0_f32];
187 let mut output = [0.0_f32; 1];
188 softmax_scalar(&input, &mut output);
189 assert!(
190 (output[0] - 1.0).abs() < 1e-7,
191 "softmax of single element must be 1.0"
192 );
193 }
194
195 #[test]
196 #[should_panic(expected = "input/output length mismatch")]
197 fn test_softmax_length_mismatch() {
198 let input = [1.0_f32, 2.0];
199 let mut output = [0.0_f32; 3];
200 softmax_scalar(&input, &mut output);
201 }
202
203 #[test]
204 #[should_panic(expected = "softmax requires non-empty input")]
205 fn test_softmax_empty_input() {
206 let input: [f32; 0] = [];
207 let mut output: [f32; 0] = [];
208 softmax_scalar(&input, &mut output);
209 }
210
211 proptest! {
214 #[test]
215 fn prop_softmax_sums_to_one(
216 v in proptest::collection::vec(-100.0_f32..100.0, 1..64)
217 ) {
218 let mut out = vec![0.0_f32; v.len()];
219 softmax_scalar(&v, &mut out);
220 let sum: f32 = out.iter().sum();
221 prop_assert!(
222 (sum - 1.0).abs() < 1e-5,
223 "softmax sum = {sum}, expected ~1.0"
224 );
225 }
226
227 #[test]
228 fn prop_softmax_outputs_in_unit_interval(
229 v in proptest::collection::vec(-100.0_f32..100.0, 1..64)
230 ) {
231 let mut out = vec![0.0_f32; v.len()];
232 softmax_scalar(&v, &mut out);
233 for (i, &o) in out.iter().enumerate() {
234 prop_assert!(
235 (0.0..=1.0).contains(&o),
236 "output[{i}] = {o} not in [0,1]"
237 );
238 }
239 }
240
241 #[test]
242 fn prop_softmax_order_preservation(
243 v in proptest::collection::vec(-50.0_f32..50.0, 2..32)
244 ) {
245 let mut out = vec![0.0_f32; v.len()];
246 softmax_scalar(&v, &mut out);
247 for i in 0..v.len() {
248 for j in (i + 1)..v.len() {
249 if v[i] > v[j] {
250 prop_assert!(
251 out[i] >= out[j],
252 "order violated: v[{i}]={} > v[{j}]={} but out[{i}]={} < out[{j}]={}",
253 v[i], v[j], out[i], out[j]
254 );
255 }
256 }
257 }
258 }
259
260 #[test]
261 fn prop_softmax_translation_invariance(
262 v in proptest::collection::vec(-50.0_f32..50.0, 2..32),
263 c in -50.0_f32..50.0
264 ) {
265 let mut out1 = vec![0.0_f32; v.len()];
266 softmax_scalar(&v, &mut out1);
267
268 let shifted: Vec<f32> = v.iter().map(|&x| x + c).collect();
269 let mut out2 = vec![0.0_f32; v.len()];
270 softmax_scalar(&shifted, &mut out2);
271
272 for i in 0..v.len() {
273 prop_assert!(
274 (out1[i] - out2[i]).abs() < 1e-5,
275 "translation invariance violated at {i}: {} vs {}",
276 out1[i], out2[i]
277 );
278 }
279 }
280 }
281
282 #[cfg(target_arch = "x86_64")]
285 #[test]
286 fn test_softmax_avx2_basic() {
287 if !is_x86_feature_detected!("avx2") {
288 return;
289 }
290 let input = [
291 1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
292 16.0,
293 ];
294 let mut scalar_out = [0.0_f32; 16];
295 let mut avx2_out = [0.0_f32; 16];
296 softmax_scalar(&input, &mut scalar_out);
297 unsafe { softmax_avx2(&input, &mut avx2_out) };
298 assert_ulp_eq(&scalar_out, &avx2_out, 8);
299 }
300
301 #[cfg(target_arch = "x86_64")]
302 #[test]
303 fn test_softmax_avx2_non_multiple_of_8() {
304 if !is_x86_feature_detected!("avx2") {
305 return;
306 }
307 let input = [1.0_f32, 2.0, 3.0, 4.0, 5.0];
308 let mut scalar_out = [0.0_f32; 5];
309 let mut avx2_out = [0.0_f32; 5];
310 softmax_scalar(&input, &mut scalar_out);
311 unsafe { softmax_avx2(&input, &mut avx2_out) };
312 assert_ulp_eq(&scalar_out, &avx2_out, 8);
313 }
314
315 #[cfg(target_arch = "x86_64")]
316 proptest! {
317 #[test]
318 fn prop_softmax_avx2_parity(
319 v in proptest::collection::vec(-100.0_f32..100.0, 1..64)
320 ) {
321 if !is_x86_feature_detected!("avx2") {
322 return Ok(());
323 }
324 let mut scalar_out = vec![0.0_f32; v.len()];
325 let mut avx2_out = vec![0.0_f32; v.len()];
326 softmax_scalar(&v, &mut scalar_out);
327 unsafe { softmax_avx2(&v, &mut avx2_out) };
328 assert_ulp_eq(&scalar_out, &avx2_out, 8);
329 }
330 }
331
332 #[test]
335 fn test_softmax_ptx_version() {
336 let ptx = softmax_ptx();
337 assert!(ptx.contains(".version 8.5"), "missing PTX version");
338 }
339
340 #[test]
341 fn test_softmax_ptx_target() {
342 let ptx = softmax_ptx();
343 assert!(ptx.contains(".target sm_90"), "missing PTX target");
344 }
345
346 #[test]
347 fn test_softmax_ptx_entry() {
348 let ptx = softmax_ptx();
349 assert!(ptx.contains(".entry softmax_kernel"), "missing entry point");
350 }
351
352 #[test]
353 fn test_softmax_ptx_ret() {
354 let ptx = softmax_ptx();
355 assert!(ptx.contains("ret;"), "missing ret instruction");
356 }
357
358 #[test]
359 fn test_softmax_ptx_shared_memory() {
360 let ptx = softmax_ptx();
361 assert!(ptx.contains(".shared"), "missing shared memory declaration");
362 }
363
364 #[test]
365 fn test_softmax_ptx_warp_shuffle() {
366 let ptx = softmax_ptx();
367 assert!(
368 ptx.contains("shfl.sync"),
369 "missing warp shuffle instructions"
370 );
371 }
372
373 #[test]
374 fn test_softmax_ptx_bar_sync() {
375 let ptx = softmax_ptx();
376 assert!(
377 ptx.contains("bar.sync"),
378 "missing bar.sync for block synchronization"
379 );
380 }
381
382 #[test]
383 fn test_softmax_ptx_balanced_braces() {
384 let ptx = softmax_ptx();
385 let open = ptx.matches('{').count();
386 let close = ptx.matches('}').count();
387 assert_eq!(
388 open, close,
389 "unbalanced braces: {open} open vs {close} close"
390 );
391 }
392}