1#[must_use]
27pub fn softmax_1d_alloc(logits: &[f32]) -> Vec<f32> {
28 let n = logits.len();
29 if n == 0 {
30 return Vec::new();
31 }
32 if n == 1 {
33 return vec![1.0];
34 }
35
36 contract_pre_softmax!(logits);
38
39 #[cfg(target_arch = "x86_64")]
40 {
41 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
42 let result = unsafe { softmax_avx2(logits) };
44 contract_post_softmax!(&result);
45 return result;
46 }
47 }
48
49 let result = softmax_scalar(logits);
50 contract_post_softmax!(&result);
51 result
52}
53
54fn softmax_scalar(logits: &[f32]) -> Vec<f32> {
56 let n = logits.len();
57
58 let mut max_val = f32::NEG_INFINITY;
60 for &v in logits {
61 max_val = max_val.max(v);
62 }
63
64 let mut out: Vec<f32> = Vec::with_capacity(n);
66 unsafe {
68 out.set_len(n);
69 }
70 for i in 0..n {
71 out[i] = (logits[i] - max_val).exp();
72 }
73
74 let mut sum = 0.0f32;
76 for &v in &out {
77 sum += v;
78 }
79
80 let inv_sum = 1.0 / sum.max(f32::EPSILON);
82 for v in &mut out {
83 *v *= inv_sum;
84 }
85
86 out
87}
88
89#[cfg(target_arch = "x86_64")]
99#[target_feature(enable = "avx2", enable = "fma")]
100unsafe fn softmax_avx2(logits: &[f32]) -> Vec<f32> {
101 use std::arch::x86_64::*;
102
103 let n = logits.len();
104 let chunks = n / 32;
105 let remainder_32 = chunks * 32;
106
107 let mut max0;
109 let mut max1;
110 let mut max2;
111 let mut max3;
112 unsafe {
113 max0 = _mm256_set1_ps(f32::NEG_INFINITY);
114 max1 = max0;
115 max2 = max0;
116 max3 = max0;
117
118 for i in 0..chunks {
119 let base = i * 32;
120 let v0 = _mm256_loadu_ps(logits.as_ptr().add(base));
121 let v1 = _mm256_loadu_ps(logits.as_ptr().add(base + 8));
122 let v2 = _mm256_loadu_ps(logits.as_ptr().add(base + 16));
123 let v3 = _mm256_loadu_ps(logits.as_ptr().add(base + 24));
124 max0 = _mm256_max_ps(max0, v0);
125 max1 = _mm256_max_ps(max1, v1);
126 max2 = _mm256_max_ps(max2, v2);
127 max3 = _mm256_max_ps(max3, v3);
128 }
129
130 max0 = _mm256_max_ps(max0, max1);
131 max2 = _mm256_max_ps(max2, max3);
132 max0 = _mm256_max_ps(max0, max2);
133
134 let hi = _mm256_permute2f128_ps(max0, max0, 1);
135 max0 = _mm256_max_ps(max0, hi);
136 let shuf = _mm256_shuffle_ps(max0, max0, 0b01_00_11_10);
137 max0 = _mm256_max_ps(max0, shuf);
138 let shuf2 = _mm256_shuffle_ps(max0, max0, 0b10_11_00_01);
139 max0 = _mm256_max_ps(max0, shuf2);
140 }
141
142 let mut max_val = _mm_cvtss_f32(_mm256_castps256_ps128(max0));
143 for i in remainder_32..n {
144 max_val = max_val.max(logits[i]);
145 }
146
147 let mut out: Vec<f32> = Vec::with_capacity(n);
150 unsafe {
152 out.set_len(n);
153 }
154 let mut sum0;
155 let mut sum1;
156 let mut sum2;
157 let mut sum3;
158 unsafe {
159 let max_v = _mm256_set1_ps(max_val);
160 sum0 = _mm256_setzero_ps();
161 sum1 = sum0;
162 sum2 = sum0;
163 sum3 = sum0;
164
165 for i in 0..chunks {
166 let base = i * 32;
167 let x0 = _mm256_sub_ps(_mm256_loadu_ps(logits.as_ptr().add(base)), max_v);
168 let x1 = _mm256_sub_ps(_mm256_loadu_ps(logits.as_ptr().add(base + 8)), max_v);
169 let x2 = _mm256_sub_ps(_mm256_loadu_ps(logits.as_ptr().add(base + 16)), max_v);
170 let x3 = _mm256_sub_ps(_mm256_loadu_ps(logits.as_ptr().add(base + 24)), max_v);
171
172 let e0 = fast_exp_avx2(x0);
173 let e1 = fast_exp_avx2(x1);
174 let e2 = fast_exp_avx2(x2);
175 let e3 = fast_exp_avx2(x3);
176
177 _mm256_storeu_ps(out.as_mut_ptr().add(base), e0);
178 _mm256_storeu_ps(out.as_mut_ptr().add(base + 8), e1);
179 _mm256_storeu_ps(out.as_mut_ptr().add(base + 16), e2);
180 _mm256_storeu_ps(out.as_mut_ptr().add(base + 24), e3);
181
182 sum0 = _mm256_add_ps(sum0, e0);
183 sum1 = _mm256_add_ps(sum1, e1);
184 sum2 = _mm256_add_ps(sum2, e2);
185 sum3 = _mm256_add_ps(sum3, e3);
186 }
187
188 sum0 = _mm256_add_ps(sum0, sum1);
189 sum2 = _mm256_add_ps(sum2, sum3);
190 sum0 = _mm256_add_ps(sum0, sum2);
191
192 let hi = _mm256_permute2f128_ps(sum0, sum0, 1);
193 sum0 = _mm256_add_ps(sum0, hi);
194 let shuf = _mm256_shuffle_ps(sum0, sum0, 0b01_00_11_10);
195 sum0 = _mm256_add_ps(sum0, shuf);
196 let shuf2 = _mm256_shuffle_ps(sum0, sum0, 0b10_11_00_01);
197 sum0 = _mm256_add_ps(sum0, shuf2);
198 }
199
200 let mut sum_val = _mm_cvtss_f32(_mm256_castps256_ps128(sum0));
201 for i in remainder_32..n {
203 let e = (logits[i] - max_val).exp();
204 out[i] = e;
205 sum_val += e;
206 }
207
208 let inv_sum = 1.0 / sum_val.max(f32::EPSILON);
210 unsafe {
211 let inv = _mm256_set1_ps(inv_sum);
212
213 for i in 0..chunks {
214 let base = i * 32;
215 let v0 = _mm256_loadu_ps(out.as_ptr().add(base));
216 let v1 = _mm256_loadu_ps(out.as_ptr().add(base + 8));
217 let v2 = _mm256_loadu_ps(out.as_ptr().add(base + 16));
218 let v3 = _mm256_loadu_ps(out.as_ptr().add(base + 24));
219 _mm256_storeu_ps(out.as_mut_ptr().add(base), _mm256_mul_ps(v0, inv));
220 _mm256_storeu_ps(out.as_mut_ptr().add(base + 8), _mm256_mul_ps(v1, inv));
221 _mm256_storeu_ps(out.as_mut_ptr().add(base + 16), _mm256_mul_ps(v2, inv));
222 _mm256_storeu_ps(out.as_mut_ptr().add(base + 24), _mm256_mul_ps(v3, inv));
223 }
224 }
225 for i in remainder_32..n {
226 out[i] *= inv_sum;
227 }
228
229 out
230}
231
232#[cfg(target_arch = "x86_64")]
243#[target_feature(enable = "avx2", enable = "fma")]
244#[inline]
245pub(crate) unsafe fn fast_exp_avx2(x: std::arch::x86_64::__m256) -> std::arch::x86_64::__m256 {
246 use std::arch::x86_64::*;
247
248 let log2e = _mm256_set1_ps(std::f32::consts::LOG2_E);
249 let ln2_hi = _mm256_set1_ps(0.693_145_751_953_125); let ln2_lo = _mm256_set1_ps(1.428_606_765_330_187_1e-6); let one = _mm256_set1_ps(1.0);
252
253 let c2 = _mm256_set1_ps(0.500_000_0); let c3 = _mm256_set1_ps(0.166_666_671_6); let c4 = _mm256_set1_ps(0.041_666_645_8); let c5 = _mm256_set1_ps(0.008_333_345_2); let c6 = _mm256_set1_ps(0.001_388_731_6); let x = _mm256_max_ps(x, _mm256_set1_ps(-87.33654));
262 let x = _mm256_min_ps(x, _mm256_set1_ps(88.72284));
263
264 let t = _mm256_fmadd_ps(x, log2e, _mm256_set1_ps(0.5));
266 let n = _mm256_floor_ps(t); let r = _mm256_sub_ps(x, _mm256_mul_ps(n, ln2_hi));
270 let r = _mm256_sub_ps(r, _mm256_mul_ps(n, ln2_lo));
271
272 let p = _mm256_fmadd_ps(c6, r, c5);
275 let p = _mm256_fmadd_ps(p, r, c4);
276 let p = _mm256_fmadd_ps(p, r, c3);
277 let p = _mm256_fmadd_ps(p, r, c2);
278 let p = _mm256_fmadd_ps(p, r, one);
279 let p = _mm256_fmadd_ps(p, r, one);
280
281 let n_i = _mm256_cvtps_epi32(n);
283 let pow2n =
284 _mm256_castsi256_ps(_mm256_slli_epi32(_mm256_add_epi32(n_i, _mm256_set1_epi32(127)), 23));
285
286 _mm256_mul_ps(p, pow2n)
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 fn deterministic_f32(len: usize) -> Vec<f32> {
294 (0..len).map(|i| ((i * 17 + 31) % 1000) as f32 / 1000.0 - 0.5).collect()
295 }
296
297 #[test]
299 fn test_softmax_sums_to_one() {
300 for n in [32, 127, 256, 1000, 32000] {
301 let data = deterministic_f32(n);
302 let result = softmax_1d_alloc(&data);
303 let sum: f32 = result.iter().sum();
304 assert!((sum - 1.0).abs() < 1e-5, "sum = {sum} for n={n}, expected 1.0");
305 }
306 }
307
308 #[test]
310 fn test_softmax_non_negative() {
311 let data: Vec<f32> = (0..1000).map(|i| -100.0 + i as f32 * 0.1).collect();
312 let result = softmax_1d_alloc(&data);
313 for (i, &v) in result.iter().enumerate() {
314 assert!(v >= 0.0, "element [{i}] = {v} < 0");
315 }
316 }
317
318 #[test]
320 fn test_softmax_monotonic() {
321 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
322 let result = softmax_1d_alloc(&data);
323 for i in 1..result.len() {
324 assert!(
325 result[i] > result[i - 1],
326 "Not monotonic at [{i}]: {} <= {}",
327 result[i],
328 result[i - 1]
329 );
330 }
331 }
332
333 #[test]
335 fn test_softmax_shift_invariance() {
336 let data = deterministic_f32(1000);
337 let shifted: Vec<f32> = data.iter().map(|&x| x + 1000.0).collect();
338
339 let result_a = softmax_1d_alloc(&data);
340 let result_b = softmax_1d_alloc(&shifted);
341
342 for (i, (&a, &b)) in result_a.iter().zip(result_b.iter()).enumerate() {
343 assert!((a - b).abs() < 1e-6, "Shift invariance broken at [{i}]: {a} vs {b}");
344 }
345 }
346
347 #[test]
349 fn test_softmax_uniform() {
350 for n in [4, 100, 1000] {
351 let data = vec![std::f32::consts::PI; n];
352 let result = softmax_1d_alloc(&data);
353 let expected = 1.0 / n as f32;
354 for (i, &v) in result.iter().enumerate() {
355 assert!((v - expected).abs() < 1e-6, "Uniform at [{i}]: {v} vs {expected}");
356 }
357 }
358 }
359
360 #[test]
362 fn test_softmax_avx2_scalar_parity() {
363 for n in [32, 127, 1000, 32000] {
364 let data = deterministic_f32(n);
365 let avx2_result = softmax_1d_alloc(&data);
366 let scalar_result = softmax_scalar(&data);
367
368 for (i, (&a, &s)) in avx2_result.iter().zip(scalar_result.iter()).enumerate() {
369 assert!((a - s).abs() < 1e-6, "AVX2/scalar mismatch at [{i}] n={n}: {a} vs {s}");
371 }
372 }
373 }
374
375 #[test]
377 fn test_softmax_remainder_sizes() {
378 for n in [1, 2, 7, 8, 15, 31, 33, 63, 65, 127, 255] {
379 let data = deterministic_f32(n);
380 let result = softmax_1d_alloc(&data);
381 let sum: f32 = result.iter().sum();
382 assert!((sum - 1.0).abs() < 1e-5, "sum = {sum} for n={n}, expected 1.0");
383 assert_eq!(result.len(), n);
384 }
385 }
386
387 #[test]
389 fn test_softmax_numerical_stability() {
390 let mut data = vec![0.0f32; 100];
391 data[0] = 88.0; data[50] = -88.0; let result = softmax_1d_alloc(&data);
395 assert!(!result.iter().any(|v| v.is_nan()), "Got NaN");
396 assert!(!result.iter().any(|v| v.is_infinite()), "Got Inf");
397 let sum: f32 = result.iter().sum();
398 assert!((sum - 1.0).abs() < 1e-5, "sum = {sum}");
399 }
400
401 #[test]
403 fn test_softmax_argmax_preserved() {
404 let data = deterministic_f32(32000);
405 let result = softmax_1d_alloc(&data);
406
407 let input_argmax = data
408 .iter()
409 .enumerate()
410 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
411 .map(|(i, _)| i)
412 .unwrap();
413
414 let output_argmax = result
415 .iter()
416 .enumerate()
417 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
418 .map(|(i, _)| i)
419 .unwrap();
420
421 assert_eq!(input_argmax, output_argmax, "Argmax not preserved");
422 }
423
424 #[test]
426 fn test_softmax_empty() {
427 let result = softmax_1d_alloc(&[]);
428 assert!(result.is_empty());
429 }
430
431 #[test]
433 fn test_softmax_single() {
434 let result = softmax_1d_alloc(&[42.0]);
435 assert_eq!(result, vec![1.0]);
436 }
437}