Skip to main content

trueno/brick/
attention.rs

1//! SIMD-Optimized Attention Operation (PMAT-017)
2//!
3//! This module contains the scaled dot-product attention operation
4//! with SIMD optimization for CPU inference.
5//!
6//! # Algorithm
7//!
8//! Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
9//!
10//! # SIMD Optimizations
11//!
12//! - Q @ K^T: Batched dot products with AVX2/AVX-512/FMA
13//! - Softmax: Row-wise numerically stable implementation
14//! - Scores @ V: SIMD-friendly weighted accumulation
15//!
16//! # Performance Target
17//!
18//! Close the 1.66x gap in CPU inference (25.4 → 42 tok/s) by replacing
19//! scalar triple-nested loops with SIMD operations.
20
21use super::{Backend, ComputeOp};
22use crate::error::TruenoError;
23
24/// Scaled dot-product attention operation.
25///
26/// Computes: Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V
27///
28/// # SIMD Optimization (PMAT-017)
29///
30/// Uses trueno's SIMD backends for:
31/// - Q @ K^T: Batched dot products with AVX2/AVX-512
32/// - Softmax: Row-wise numerically stable softmax
33/// - Scores @ V: Batched weighted sums
34///
35/// # Performance Target
36///
37/// Close the 1.66x gap in CPU inference (25.4 → 42 tok/s) by replacing
38/// scalar triple-nested loops with SIMD operations.
39#[derive(Debug, Clone)]
40pub struct AttentionOp {
41    /// Sequence length (Q rows)
42    pub seq_len: usize,
43    /// Key/Value sequence length (may differ for cross-attention)
44    pub kv_seq_len: usize,
45    /// Head dimension
46    pub head_dim: usize,
47    /// Scale factor (1/sqrt(head_dim))
48    pub scale: f32,
49}
50
51impl AttentionOp {
52    /// Create a new attention operation.
53    ///
54    /// # Arguments
55    ///
56    /// * `seq_len` - Query sequence length
57    /// * `kv_seq_len` - Key/Value sequence length
58    /// * `head_dim` - Dimension per head
59    #[must_use]
60    pub fn new(seq_len: usize, kv_seq_len: usize, head_dim: usize) -> Self {
61        Self { seq_len, kv_seq_len, head_dim, scale: 1.0 / (head_dim as f32).sqrt() }
62    }
63
64    /// Create for self-attention (seq_len == kv_seq_len).
65    #[must_use]
66    pub fn self_attention(seq_len: usize, head_dim: usize) -> Self {
67        Self::new(seq_len, seq_len, head_dim)
68    }
69
70    /// SIMD-optimized dot product for attention scores.
71    ///
72    /// Computes Q[i] · K[j] using SIMD when available.
73    #[inline]
74    pub(crate) fn simd_dot(a: &[f32], b: &[f32]) -> f32 {
75        debug_assert_eq!(a.len(), b.len());
76
77        // Use architecture-specific SIMD
78        #[cfg(target_arch = "x86_64")]
79        {
80            if is_x86_feature_detected!("avx2") {
81                // SAFETY: preconditions verified by caller
82                return unsafe { Self::avx2_dot(a, b) };
83            }
84        }
85
86        // Scalar fallback with manual unrolling for better vectorization
87        let mut sum0 = 0.0f32;
88        let mut sum1 = 0.0f32;
89        let mut sum2 = 0.0f32;
90        let mut sum3 = 0.0f32;
91
92        let chunks = a.len() / 4;
93        for i in 0..chunks {
94            let base = i * 4;
95            sum0 += a[base] * b[base];
96            sum1 += a[base + 1] * b[base + 1];
97            sum2 += a[base + 2] * b[base + 2];
98            sum3 += a[base + 3] * b[base + 3];
99        }
100
101        // Handle remainder
102        for i in (chunks * 4)..a.len() {
103            sum0 += a[i] * b[i];
104        }
105
106        sum0 + sum1 + sum2 + sum3
107    }
108
109    /// AVX2-optimized dot product.
110    #[cfg(target_arch = "x86_64")]
111    #[target_feature(enable = "avx2", enable = "fma")]
112    // SAFETY: caller verifies AVX2 support, input slices meet alignment/length requirements
113    unsafe fn avx2_dot(a: &[f32], b: &[f32]) -> f32 {
114        unsafe {
115            use std::arch::x86_64::*;
116
117            let mut sum = _mm256_setzero_ps();
118            let chunks = a.len() / 8;
119
120            for i in 0..chunks {
121                let base = i * 8;
122                let va = _mm256_loadu_ps(a.as_ptr().add(base));
123                let vb = _mm256_loadu_ps(b.as_ptr().add(base));
124                sum = _mm256_fmadd_ps(va, vb, sum);
125            }
126
127            // Horizontal sum
128            let high = _mm256_extractf128_ps(sum, 1);
129            let low = _mm256_castps256_ps128(sum);
130            let sum128 = _mm_add_ps(high, low);
131            let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
132            let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
133            let mut result = _mm_cvtss_f32(sum32);
134
135            // Handle remainder
136            for i in (chunks * 8)..a.len() {
137                result += a[i] * b[i];
138            }
139
140            result
141        }
142    }
143
144    /// SIMD axpy: out[i] += alpha * x[i] for all i.
145    /// Used in attention weighted sum: out += weight * v_row.
146    #[inline]
147    pub(crate) fn simd_axpy(alpha: f32, x: &[f32], out: &mut [f32]) {
148        debug_assert_eq!(x.len(), out.len());
149
150        #[cfg(target_arch = "x86_64")]
151        {
152            if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
153                // SAFETY: AVX2+FMA verified, slices same length
154                unsafe {
155                    Self::avx2_axpy(alpha, x, out);
156                }
157                return;
158            }
159        }
160
161        // Scalar fallback
162        for (o, &xi) in out.iter_mut().zip(x.iter()) {
163            *o += alpha * xi;
164        }
165    }
166
167    /// AVX2-optimized axpy: out[i] += alpha * x[i].
168    #[cfg(target_arch = "x86_64")]
169    #[target_feature(enable = "avx2", enable = "fma")]
170    unsafe fn avx2_axpy(alpha: f32, x: &[f32], out: &mut [f32]) {
171        unsafe {
172            use std::arch::x86_64::*;
173
174            let alpha_v = _mm256_set1_ps(alpha);
175            let n = x.len();
176            let n8 = n / 8 * 8;
177
178            let mut i = 0;
179            while i < n8 {
180                let xv = _mm256_loadu_ps(x.as_ptr().add(i));
181                let ov = _mm256_loadu_ps(out.as_ptr().add(i));
182                let r = _mm256_fmadd_ps(alpha_v, xv, ov);
183                _mm256_storeu_ps(out.as_mut_ptr().add(i), r);
184                i += 8;
185            }
186            // Scalar remainder
187            while i < n {
188                *out.get_unchecked_mut(i) += alpha * *x.get_unchecked(i);
189                i += 1;
190            }
191        }
192    }
193
194    /// Row-wise softmax with AVX2 SIMD exp + horizontal ops.
195    /// CGP-DBUF: replaced scalar exp() loop with AVX2 polynomial fast_exp
196    /// from blis::softmax (6th-order Remez minimax, <1 ULP error).
197    /// For seq_len=512: 64 AVX2 iterations vs 512 scalar exp() calls.
198    #[inline]
199    pub(crate) fn simd_softmax_row(scores: &mut [f32]) {
200        if scores.is_empty() {
201            return;
202        }
203
204        #[cfg(target_arch = "x86_64")]
205        if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
206            // SAFETY: AVX2+FMA verified. scores is valid mutable slice.
207            unsafe {
208                Self::avx2_softmax_row(scores);
209            }
210            return;
211        }
212
213        // Scalar fallback
214        Self::scalar_softmax_row(scores);
215    }
216
217    /// Scalar softmax fallback.
218    fn scalar_softmax_row(scores: &mut [f32]) {
219        let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
220        let mut sum = 0.0f32;
221        for s in scores.iter_mut() {
222            *s = (*s - max).exp();
223            sum += *s;
224        }
225        let inv_sum = 1.0 / sum.max(f32::EPSILON);
226        for s in scores.iter_mut() {
227            *s *= inv_sum;
228        }
229    }
230
231    /// AVX2 softmax: fused exp+sum in one pass, then SIMD normalize.
232    #[cfg(target_arch = "x86_64")]
233    #[target_feature(enable = "avx2", enable = "fma")]
234    unsafe fn avx2_softmax_row(scores: &mut [f32]) {
235        unsafe {
236            use std::arch::x86_64::*;
237
238            let n = scores.len();
239            let n8 = n / 8 * 8;
240
241            // Pass 1: find max (AVX2 horizontal max)
242            let mut max_v = _mm256_set1_ps(f32::NEG_INFINITY);
243            let mut i = 0;
244            while i < n8 {
245                let v = _mm256_loadu_ps(scores.as_ptr().add(i));
246                max_v = _mm256_max_ps(max_v, v);
247                i += 8;
248            }
249            // Horizontal reduce max_v
250            let hi = _mm256_permute2f128_ps(max_v, max_v, 1);
251            max_v = _mm256_max_ps(max_v, hi);
252            let shuf = _mm256_shuffle_ps(max_v, max_v, 0b01_00_11_10);
253            max_v = _mm256_max_ps(max_v, shuf);
254            let shuf2 = _mm256_shuffle_ps(max_v, max_v, 0b10_11_00_01);
255            max_v = _mm256_max_ps(max_v, shuf2);
256            let mut max_val = _mm_cvtss_f32(_mm256_castps256_ps128(max_v));
257            for j in n8..n {
258                max_val = max_val.max(scores[j]);
259            }
260
261            // Pass 2: fused exp(x-max) + sum (using fast_exp_avx2)
262            let max_broadcast = _mm256_set1_ps(max_val);
263            let mut sum_v = _mm256_setzero_ps();
264            i = 0;
265            while i < n8 {
266                let x = _mm256_sub_ps(_mm256_loadu_ps(scores.as_ptr().add(i)), max_broadcast);
267                let e = crate::blis::softmax::fast_exp_avx2(x);
268                _mm256_storeu_ps(scores.as_mut_ptr().add(i), e);
269                sum_v = _mm256_add_ps(sum_v, e);
270                i += 8;
271            }
272            // Horizontal reduce sum_v
273            let hi = _mm256_permute2f128_ps(sum_v, sum_v, 1);
274            sum_v = _mm256_add_ps(sum_v, hi);
275            let shuf = _mm256_shuffle_ps(sum_v, sum_v, 0b01_00_11_10);
276            sum_v = _mm256_add_ps(sum_v, shuf);
277            let shuf2 = _mm256_shuffle_ps(sum_v, sum_v, 0b10_11_00_01);
278            sum_v = _mm256_add_ps(sum_v, shuf2);
279            let mut sum_val = _mm_cvtss_f32(_mm256_castps256_ps128(sum_v));
280            // Scalar remainder for exp+sum
281            for j in n8..n {
282                let e = (scores[j] - max_val).exp();
283                scores[j] = e;
284                sum_val += e;
285            }
286
287            // Pass 3: normalize (SIMD multiply by 1/sum)
288            let inv_sum = 1.0 / sum_val.max(f32::EPSILON);
289            let inv_v = _mm256_set1_ps(inv_sum);
290            i = 0;
291            while i < n8 {
292                let v = _mm256_loadu_ps(scores.as_ptr().add(i));
293                _mm256_storeu_ps(scores.as_mut_ptr().add(i), _mm256_mul_ps(v, inv_v));
294                i += 8;
295            }
296            for j in n8..n {
297                scores[j] *= inv_sum;
298            }
299        } // unsafe
300    }
301}
302
303impl ComputeOp for AttentionOp {
304    /// Input: (Q, K, V) tensors as flat vectors
305    /// Q: [seq_len * head_dim]
306    /// K: [kv_seq_len * head_dim]
307    /// V: [kv_seq_len * head_dim]
308    type Input = (Vec<f32>, Vec<f32>, Vec<f32>);
309    /// Output: attention output [seq_len * head_dim]
310    type Output = Vec<f32>;
311
312    fn name(&self) -> &'static str {
313        "attention"
314    }
315
316    fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
317        let (q, k, v) = input;
318
319        // Validate dimensions
320        let expected_q = self.seq_len * self.head_dim;
321        let expected_kv = self.kv_seq_len * self.head_dim;
322
323        if q.len() != expected_q {
324            return Err(TruenoError::SizeMismatch { expected: expected_q, actual: q.len() });
325        }
326        if k.len() != expected_kv || v.len() != expected_kv {
327            return Err(TruenoError::SizeMismatch { expected: expected_kv, actual: k.len() });
328        }
329
330        // Uninit: output is zeroed per-row via out_row.fill(0.0) before accumulation.
331        // scores is SET via scores[ki] = simd_dot(...) before softmax reads.
332        let mut output: Vec<f32> = Vec::with_capacity(expected_q);
333        // SAFETY: Each qi loop iteration calls out_row.fill(0.0) before accumulating.
334        unsafe {
335            output.set_len(expected_q);
336        }
337        let mut scores: Vec<f32> = Vec::with_capacity(self.kv_seq_len);
338        // SAFETY: scores[ki] = ... (SET) for all ki before simd_softmax_row reads.
339        unsafe {
340            scores.set_len(self.kv_seq_len);
341        }
342
343        // For each query position
344        for qi in 0..self.seq_len {
345            let q_row = &q[qi * self.head_dim..(qi + 1) * self.head_dim];
346
347            // Compute Q[qi] · K[ki] for all ki (SIMD dot products)
348            for ki in 0..self.kv_seq_len {
349                let k_row = &k[ki * self.head_dim..(ki + 1) * self.head_dim];
350                scores[ki] = Self::simd_dot(q_row, k_row) * self.scale;
351            }
352
353            // Softmax over scores
354            Self::simd_softmax_row(&mut scores);
355
356            // Compute weighted sum: output[qi] = sum(scores[ki] * V[ki])
357            let out_row = &mut output[qi * self.head_dim..(qi + 1) * self.head_dim];
358            out_row.fill(0.0);
359
360            for ki in 0..self.kv_seq_len {
361                let v_row = &v[ki * self.head_dim..(ki + 1) * self.head_dim];
362                let weight = scores[ki];
363
364                // CGP-DBUF: AVX2 broadcast-multiply-add (was scalar loop)
365                Self::simd_axpy(weight, v_row, out_row);
366            }
367        }
368
369        Ok(output)
370    }
371
372    fn tokens(&self, _input: &Self::Input) -> usize {
373        // Output tokens = seq_len * head_dim
374        self.seq_len * self.head_dim
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381
382    /// Assert simd_dot of two slices equals expected within tolerance.
383    fn assert_dot(a: &[f32], b: &[f32], expected: f32) {
384        let dot = AttentionOp::simd_dot(a, b);
385        assert!((dot - expected).abs() < 1e-3, "dot={dot}, expected={expected}");
386    }
387
388    /// Assert simd_dot of [1..=n] · [1.0; n] equals n*(n+1)/2.
389    fn assert_dot_iota(n: usize) {
390        let a: Vec<f32> = (1..=n).map(|x| x as f32).collect();
391        let b = vec![1.0f32; n];
392        let expected = (n * (n + 1)) / 2;
393        assert_dot(&a, &b, expected as f32);
394    }
395
396    /// Assert softmax normalizes scores to sum=1.
397    fn assert_softmax_normalized(values: &[f32]) {
398        let mut scores = values.to_vec();
399        AttentionOp::simd_softmax_row(&mut scores);
400        let sum: f32 = scores.iter().sum();
401        assert!((sum - 1.0).abs() < 1e-5, "softmax sum={sum}");
402    }
403
404    /// Execute attention and assert output length and finiteness.
405    fn assert_attention_ok(
406        op: &AttentionOp,
407        q: Vec<f32>,
408        k: Vec<f32>,
409        v: Vec<f32>,
410        expected_len: usize,
411    ) -> Vec<f32> {
412        let output = op.execute((q, k, v), Backend::Scalar).unwrap();
413        assert_eq!(output.len(), expected_len);
414        for val in &output {
415            assert!(val.is_finite());
416        }
417        output
418    }
419
420    #[test]
421    fn test_attention_basic() {
422        let op = AttentionOp::self_attention(2, 4); // seq=2, head_dim=4
423
424        // Simple identity-like setup
425        let q = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // 2x4
426        let k = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; // 2x4
427        let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // 2x4
428
429        let output = op.execute((q, k, v), Backend::Scalar).unwrap();
430
431        assert_eq!(output.len(), 8);
432        // Output should be weighted combination of V rows
433    }
434
435    #[test]
436    fn test_attention_dimension_mismatch_q() {
437        let op = AttentionOp::self_attention(2, 4);
438        let q = vec![1.0; 4]; // Wrong size - should be 8
439        let k = vec![1.0; 8];
440        let v = vec![1.0; 8];
441
442        let result = op.execute((q, k, v), Backend::Scalar);
443        assert!(result.is_err());
444    }
445
446    #[test]
447    fn test_attention_dimension_mismatch_kv() {
448        let op = AttentionOp::self_attention(2, 4);
449        let q = vec![1.0; 8];
450        let k = vec![1.0; 4]; // Wrong size - should be 8
451        let v = vec![1.0; 8];
452
453        let result = op.execute((q, k, v), Backend::Scalar);
454        assert!(result.is_err());
455    }
456
457    #[test]
458    fn test_attention_cross_attention() {
459        // Cross-attention: Q from decoder (seq=1), K/V from encoder (seq=4)
460        let op = AttentionOp::new(1, 4, 8); // q_seq=1, kv_seq=4, head_dim=8
461
462        let q = vec![1.0; 8]; // 1 x 8
463        let k = vec![1.0; 32]; // 4 x 8
464        let v = vec![1.0; 32]; // 4 x 8
465
466        let output = op.execute((q, k, v), Backend::Scalar).unwrap();
467        assert_eq!(output.len(), 8);
468    }
469
470    #[test]
471    fn test_attention_tokens() {
472        let op = AttentionOp::self_attention(16, 64);
473        let input = (vec![], vec![], vec![]);
474        // tokens = seq_len * head_dim = 16 * 64 = 1024
475        assert_eq!(op.tokens(&input), 1024);
476    }
477
478    #[test]
479    fn test_simd_softmax_row_empty() {
480        let mut scores: Vec<f32> = vec![];
481        AttentionOp::simd_softmax_row(&mut scores);
482        assert!(scores.is_empty());
483    }
484
485    #[test]
486    fn test_simd_softmax_row_single() {
487        let mut scores = vec![5.0];
488        AttentionOp::simd_softmax_row(&mut scores);
489        assert!((scores[0] - 1.0).abs() < 1e-6);
490    }
491
492    #[test]
493    fn test_simd_softmax_row_uniform() {
494        let mut scores = vec![1.0, 1.0, 1.0, 1.0];
495        AttentionOp::simd_softmax_row(&mut scores);
496
497        // All equal inputs → uniform distribution
498        for s in &scores {
499            assert!((s - 0.25).abs() < 1e-6);
500        }
501    }
502
503    #[test]
504    fn test_simd_softmax_row_sum_to_one() {
505        assert_softmax_normalized(&[1.0, 2.0, 3.0, 4.0, 5.0]);
506    }
507
508    #[test]
509    fn test_simd_dot_basic() {
510        assert_dot(&[1.0, 2.0, 3.0, 4.0], &[1.0, 1.0, 1.0, 1.0], 10.0);
511    }
512
513    #[test]
514    fn test_simd_dot_unaligned() {
515        assert_dot(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0; 5], 30.0);
516    }
517
518    // =========================================================================
519    // Additional Coverage Tests
520    // =========================================================================
521
522    #[test]
523    fn test_attention_op_fields() {
524        let op = AttentionOp::new(4, 8, 64);
525        assert_eq!(op.seq_len, 4);
526        assert_eq!(op.kv_seq_len, 8);
527        assert_eq!(op.head_dim, 64);
528        // scale = 1/sqrt(64) = 1/8 = 0.125
529        assert!((op.scale - 0.125).abs() < 1e-6);
530    }
531
532    #[test]
533    fn test_attention_self_attention_fields() {
534        let op = AttentionOp::self_attention(16, 32);
535        assert_eq!(op.seq_len, 16);
536        assert_eq!(op.kv_seq_len, 16); // Self-attention: same lengths
537        assert_eq!(op.head_dim, 32);
538    }
539
540    #[test]
541    fn test_attention_name() {
542        let op = AttentionOp::self_attention(1, 4);
543        assert_eq!(op.name(), "attention");
544    }
545
546    #[test]
547    fn test_attention_v_size_mismatch() {
548        let op = AttentionOp::self_attention(2, 4);
549        let q = vec![1.0; 8];
550        let k = vec![1.0; 8];
551        let v = vec![1.0; 4]; // Wrong: should be 8
552
553        let result = op.execute((q, k, v), Backend::Scalar);
554        assert!(result.is_err());
555    }
556
557    #[test]
558    fn test_attention_single_position() {
559        // seq=1, kv=1, head_dim=4
560        let op = AttentionOp::self_attention(1, 4);
561        let q = vec![1.0, 0.0, 0.0, 0.0];
562        let k = vec![1.0, 0.0, 0.0, 0.0];
563        let v = vec![2.0, 3.0, 4.0, 5.0];
564
565        let output = op.execute((q, k, v), Backend::Scalar).unwrap();
566        assert_eq!(output.len(), 4);
567        // With single position, softmax of single score is 1.0
568        // Output = 1.0 * V = V
569        assert!((output[0] - 2.0).abs() < 1e-5);
570        assert!((output[1] - 3.0).abs() < 1e-5);
571        assert!((output[2] - 4.0).abs() < 1e-5);
572        assert!((output[3] - 5.0).abs() < 1e-5);
573    }
574
575    #[test]
576    fn test_attention_uniform_scores() {
577        // If Q and K are identical for all positions, scores are equal
578        // Output should be average of V rows
579        let op = AttentionOp::new(1, 2, 2);
580        let head_dim = 2;
581
582        let q = vec![1.0, 1.0]; // 1x2
583        let k = vec![1.0, 1.0, 1.0, 1.0]; // 2x2, both identical
584        let v = vec![1.0, 0.0, 0.0, 1.0]; // 2x2
585
586        let output = op.execute((q, k, v), Backend::Scalar).unwrap();
587        assert_eq!(output.len(), head_dim);
588        // Scores are equal => softmax gives [0.5, 0.5]
589        // Output = 0.5 * [1, 0] + 0.5 * [0, 1] = [0.5, 0.5]
590        assert!((output[0] - 0.5).abs() < 1e-5);
591        assert!((output[1] - 0.5).abs() < 1e-5);
592    }
593
594    #[test]
595    fn test_simd_dot_exact_multiple_of_four() {
596        assert_dot_iota(8); // sum(1..=8) = 36
597    }
598
599    #[test]
600    fn test_simd_dot_single_element() {
601        assert_dot(&[3.0], &[4.0], 12.0);
602    }
603
604    #[test]
605    fn test_simd_dot_two_elements() {
606        assert_dot(&[2.0, 3.0], &[4.0, 5.0], 23.0);
607    }
608
609    #[test]
610    fn test_simd_dot_three_elements() {
611        assert_dot(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], 32.0);
612    }
613
614    #[test]
615    fn test_simd_dot_large_avx2_aligned() {
616        assert_dot_iota(16); // sum(1..=16) = 136
617    }
618
619    #[test]
620    fn test_simd_dot_large_avx2_remainder() {
621        assert_dot_iota(19); // sum(1..=19) = 190
622    }
623
624    #[test]
625    fn test_simd_dot_zeros() {
626        assert_dot(&[0.0; 16], &[1.0; 16], 0.0);
627    }
628
629    #[test]
630    fn test_simd_dot_negative_values() {
631        assert_dot(&[-1.0, -2.0, -3.0, -4.0], &[1.0; 4], -10.0);
632    }
633
634    #[test]
635    fn test_simd_softmax_row_large_values() {
636        assert_softmax_normalized(&[1000.0, 1001.0, 1002.0]);
637    }
638
639    #[test]
640    fn test_simd_softmax_row_negative_values() {
641        assert_softmax_normalized(&[-10.0, -20.0, -5.0]);
642    }
643
644    #[test]
645    fn test_attention_clone() {
646        let op = AttentionOp::new(4, 8, 64);
647        let cloned = op.clone();
648        assert_eq!(cloned.seq_len, 4);
649        assert_eq!(cloned.kv_seq_len, 8);
650        assert_eq!(cloned.head_dim, 64);
651        assert!((cloned.scale - op.scale).abs() < 1e-10);
652    }
653
654    #[test]
655    fn test_attention_multi_query_rows() {
656        let op = AttentionOp::new(3, 2, 2);
657        let q = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
658        let k = vec![1.0, 0.0, 0.0, 1.0];
659        let v = vec![10.0, 20.0, 30.0, 40.0];
660        assert_attention_ok(&op, q, k, v, 6);
661    }
662
663    #[test]
664    fn test_attention_tokens_cross_attention() {
665        let op = AttentionOp::new(1, 100, 64);
666        assert_eq!(op.tokens(&(vec![], vec![], vec![])), 64);
667    }
668
669    // simd_dot coverage — AVX2 remainder and various non-aligned sizes
670
671    #[test]
672    fn test_simd_dot_avx2_remainders() {
673        // Test various sizes: 1 chunk+1, +2, +7, 3 chunks exact, sub-chunk
674        for n in [9, 10, 15, 24, 5, 6, 7] {
675            assert_dot_iota(n);
676        }
677    }
678
679    #[test]
680    fn test_simd_dot_large_64_elements() {
681        assert_dot_iota(64); // sum(1..=64) = 2080
682    }
683
684    #[test]
685    fn test_simd_dot_orthogonal() {
686        let mut a = vec![0.0; 9];
687        let mut b = vec![0.0; 9];
688        a[0] = 1.0;
689        b[1] = 1.0;
690        assert_dot(&a, &b, 0.0);
691    }
692
693    #[test]
694    fn test_attention_execute_non_aligned_head_dim() {
695        let op = AttentionOp::self_attention(2, 9);
696        let output = assert_attention_ok(&op, vec![1.0; 18], vec![1.0; 18], vec![1.0; 18], 18);
697        // Uniform Q/K → uniform softmax → output = mean of V rows = 1.0
698        for val in &output {
699            assert!((val - 1.0).abs() < 1e-4);
700        }
701    }
702
703    #[test]
704    fn test_attention_execute_head_dim_17() {
705        let op = AttentionOp::new(1, 3, 17);
706        let q: Vec<f32> = (0..17).map(|i| (i as f32) * 0.1).collect();
707        let k: Vec<f32> = (0..51).map(|i| ((i % 5) as f32) * 0.2).collect();
708        let v: Vec<f32> = (0..51).map(|i| (i as f32) * 0.01).collect();
709        assert_attention_ok(&op, q, k, v, 17);
710    }
711
712    // =========================================================================
713    // simd_dot coverage: AVX2 path with every remainder size (Refs CB-130)
714    // =========================================================================
715
716    /// Verify simd_dot with vectors of exactly `n` elements where each element
717    /// is a known value, checking against a scalar reference implementation.
718    fn assert_dot_scalar_ref(n: usize) {
719        let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.3 + 1.0).collect();
720        let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.7 - 0.5).collect();
721        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
722        let result = AttentionOp::simd_dot(&a, &b);
723        assert!(
724            (result - expected).abs() < 1e-2 * expected.abs().max(1.0),
725            "n={n}: dot={result}, expected={expected}"
726        );
727    }
728
729    #[test]
730    fn test_simd_dot_avx2_remainder_0() {
731        // 32 elements: exactly 4 AVX2 chunks, 0 remainder
732        assert_dot_scalar_ref(32);
733    }
734
735    #[test]
736    fn test_simd_dot_avx2_remainder_1() {
737        // 33 elements: 4 AVX2 chunks + 1 remainder
738        assert_dot_scalar_ref(33);
739    }
740
741    #[test]
742    fn test_simd_dot_avx2_remainder_2() {
743        // 34 elements: 4 AVX2 chunks + 2 remainder
744        assert_dot_scalar_ref(34);
745    }
746
747    #[test]
748    fn test_simd_dot_avx2_remainder_3() {
749        // 35 elements: 4 AVX2 chunks + 3 remainder
750        assert_dot_scalar_ref(35);
751    }
752
753    #[test]
754    fn test_simd_dot_avx2_remainder_4() {
755        // 36 elements: 4 AVX2 chunks + 4 remainder
756        assert_dot_scalar_ref(36);
757    }
758
759    #[test]
760    fn test_simd_dot_avx2_remainder_5() {
761        // 37 elements: 4 AVX2 chunks + 5 remainder
762        assert_dot_scalar_ref(37);
763    }
764
765    #[test]
766    fn test_simd_dot_avx2_remainder_6() {
767        // 38 elements: 4 AVX2 chunks + 6 remainder
768        assert_dot_scalar_ref(38);
769    }
770
771    #[test]
772    fn test_simd_dot_avx2_remainder_7() {
773        // 39 elements: 4 AVX2 chunks + 7 remainder
774        assert_dot_scalar_ref(39);
775    }
776
777    #[test]
778    fn test_simd_dot_large_128() {
779        // 128 elements: 16 AVX2 chunks, exercises sustained SIMD loop
780        assert_dot_scalar_ref(128);
781    }
782
783    #[test]
784    fn test_simd_dot_large_1024() {
785        // 1024 elements: 128 AVX2 chunks, large vector stress test
786        assert_dot_scalar_ref(1024);
787    }
788
789    #[test]
790    fn test_simd_dot_large_1024_plus_5() {
791        // 1029 elements: 128 AVX2 chunks + 5 remainder, large + non-aligned
792        assert_dot_scalar_ref(1029);
793    }
794
795    #[test]
796    fn test_simd_dot_known_identity() {
797        // Unit vector dot product = 1.0
798        let n = 64;
799        let a: Vec<f32> = {
800            let mut v = vec![0.0; n];
801            v[0] = 1.0;
802            v
803        };
804        let b = a.clone();
805        let result = AttentionOp::simd_dot(&a, &b);
806        assert!((result - 1.0).abs() < 1e-6, "identity dot = {result}");
807    }
808
809    #[test]
810    fn test_simd_dot_alternating_signs() {
811        // Alternating +1/-1 should cancel to 0 for even length
812        let n = 64;
813        let a: Vec<f32> = (0..n).map(|i| if i % 2 == 0 { 1.0 } else { -1.0 }).collect();
814        let b = vec![1.0; n];
815        let result = AttentionOp::simd_dot(&a, &b);
816        assert!((result).abs() < 1e-5, "alternating dot = {result}");
817    }
818
819    #[test]
820    fn test_simd_dot_large_values() {
821        // Large values should still compute correctly
822        let a = vec![1000.0; 16];
823        let b = vec![1000.0; 16];
824        let expected = 1000.0 * 1000.0 * 16.0;
825        let result = AttentionOp::simd_dot(&a, &b);
826        assert!((result - expected).abs() < 1.0, "large dot = {result}, expected = {expected}");
827    }
828
829    #[test]
830    fn test_simd_dot_mixed_positive_negative() {
831        // 10 elements: 1 AVX2 chunk (8) + 2 remainder, with mixed signs
832        let a = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
833        let b = vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
834        let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
835        let result = AttentionOp::simd_dot(&a, &b);
836        assert!((result - expected).abs() < 1e-3, "mixed dot = {result}, expected = {expected}");
837    }
838
839    #[test]
840    fn test_simd_dot_very_small_values() {
841        let a = vec![1e-10; 16];
842        let b = vec![1e-10; 16];
843        let expected = 1e-20 * 16.0;
844        let result = AttentionOp::simd_dot(&a, &b);
845        assert!((result - expected).abs() < 1e-24, "small dot = {result}, expected = {expected}");
846    }
847
848    // =========================================================================
849    // Full attention execute path with head_dim sizes that stress simd_dot
850    // =========================================================================
851
852    #[test]
853    fn test_attention_head_dim_64_multi_seq() {
854        // head_dim=64 (8 AVX2 chunks exactly): realistic transformer config
855        let op = AttentionOp::self_attention(4, 64);
856        let q = vec![0.1; 4 * 64];
857        let k = vec![0.1; 4 * 64];
858        let v = vec![1.0; 4 * 64];
859        let output = assert_attention_ok(&op, q, k, v, 4 * 64);
860        // Uniform input => uniform softmax => output = mean of V rows = 1.0
861        for val in &output {
862            assert!((val - 1.0).abs() < 1e-4, "expected ~1.0, got {val}");
863        }
864    }
865
866    #[test]
867    fn test_attention_head_dim_128() {
868        // head_dim=128 (16 AVX2 chunks): large head dimension
869        let op = AttentionOp::new(2, 3, 128);
870        let q: Vec<f32> = (0..2 * 128).map(|i| (i as f32) * 0.001).collect();
871        let k: Vec<f32> = (0..3 * 128).map(|i| ((i % 7) as f32) * 0.01).collect();
872        let v: Vec<f32> = (0..3 * 128).map(|i| (i as f32) * 0.005).collect();
873        assert_attention_ok(&op, q, k, v, 2 * 128);
874    }
875
876    #[test]
877    fn test_attention_head_dim_33() {
878        // head_dim=33: 4 AVX2 chunks + 1 remainder element in simd_dot
879        let op = AttentionOp::new(2, 2, 33);
880        let q = vec![0.5; 2 * 33];
881        let k = vec![0.5; 2 * 33];
882        let v = vec![2.0; 2 * 33];
883        let output = assert_attention_ok(&op, q, k, v, 2 * 33);
884        for val in &output {
885            assert!((val - 2.0).abs() < 1e-4, "expected ~2.0, got {val}");
886        }
887    }
888
889    #[test]
890    fn test_attention_head_dim_7() {
891        // head_dim=7: 0 AVX2 chunks, all remainder (exercises pure remainder path)
892        let op = AttentionOp::self_attention(2, 7);
893        let q = vec![1.0; 2 * 7];
894        let k = vec![1.0; 2 * 7];
895        let v = vec![3.0; 2 * 7];
896        let output = assert_attention_ok(&op, q, k, v, 2 * 7);
897        for val in &output {
898            assert!((val - 3.0).abs() < 1e-4, "expected ~3.0, got {val}");
899        }
900    }
901
902    // =========================================================================
903    // FALSIFY-ATT: attention-kernel-v1.yaml contract (trueno AttentionOp)
904    //
905    // Five-Whys (PMAT-354):
906    //   Why 1: trueno had 50+ attention unit tests but zero FALSIFY-ATT-* tests
907    //   Why 2: unit tests verify shapes/finiteness, not mathematical invariants
908    //   Why 3: no mapping from attention-kernel-v1.yaml to trueno test names
909    //   Why 4: trueno predates the provable-contracts YAML convention
910    //   Why 5: attention was "obviously correct" (standard formula)
911    //
912    // References:
913    //   - provable-contracts/contracts/attention-kernel-v1.yaml
914    //   - Vaswani et al. (2017) "Attention Is All You Need"
915    // =========================================================================
916
917    /// FALSIFY-ATT-001: Weight normalization — each softmax row sums to 1.0
918    ///
919    /// Contract: Σ_j softmax(QK^T/√d_k)_{ij} = 1 for all i
920    #[test]
921    fn falsify_att_001_weight_normalization() {
922        let test_rows: Vec<Vec<f32>> = vec![
923            vec![1.0, 2.0, 3.0, 4.0],
924            vec![-5.0, 0.0, 5.0, 10.0],
925            vec![1000.0, 1001.0, 1002.0],
926            vec![1e-7, 1e-7, 1e-7],
927            vec![0.0; 8],
928            vec![-100.0, 100.0],
929        ];
930
931        for values in &test_rows {
932            let mut scores = values.clone();
933            AttentionOp::simd_softmax_row(&mut scores);
934            let sum: f32 = scores.iter().sum();
935            assert!(
936                (sum - 1.0).abs() < 1e-5,
937                "FALSIFIED ATT-001: softmax row sum = {sum}, expected 1.0 for input {values:?}"
938            );
939        }
940    }
941
942    /// FALSIFY-ATT-002: Output convexity — output rows are convex combinations of V rows
943    ///
944    /// Contract: min_j(V[j][d]) ≤ output[i][d] ≤ max_j(V[j][d]) for all i, d
945    #[test]
946    fn falsify_att_002_output_convexity() {
947        let seq_len = 2;
948        let kv_seq_len = 3;
949        let head_dim = 4;
950        let op = AttentionOp::new(seq_len, kv_seq_len, head_dim);
951
952        let q = vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5];
953        let k = vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9];
954        let v = vec![2.0, -3.0, 5.0, 1.0, -1.0, 4.0, -2.0, 7.0, 3.0, 0.0, -4.0, 6.0];
955
956        let output = op.execute((q, k, v.clone()), Backend::Scalar).unwrap();
957
958        for qi in 0..seq_len {
959            for d in 0..head_dim {
960                let out_val = output[qi * head_dim + d];
961
962                let v_col_min =
963                    (0..kv_seq_len).map(|ki| v[ki * head_dim + d]).fold(f32::INFINITY, f32::min);
964                let v_col_max = (0..kv_seq_len)
965                    .map(|ki| v[ki * head_dim + d])
966                    .fold(f32::NEG_INFINITY, f32::max);
967
968                assert!(
969                    out_val >= v_col_min - 1e-5 && out_val <= v_col_max + 1e-5,
970                    "FALSIFIED ATT-002: output[{qi}][{d}] = {out_val} outside V column [{v_col_min}, {v_col_max}]"
971                );
972            }
973        }
974    }
975
976    /// FALSIFY-ATT-003: Scaling factor — uses 1/√d_k not 1/d_k
977    ///
978    /// Contract: scale = 1/√d_k
979    #[test]
980    fn falsify_att_003_scaling_factor() {
981        for d_k in [4, 8, 16, 32, 64, 128] {
982            let op = AttentionOp::self_attention(1, d_k);
983            let expected = 1.0 / (d_k as f32).sqrt();
984            assert!(
985                (op.scale - expected).abs() < 1e-6,
986                "FALSIFIED ATT-003: scale = {}, expected 1/√{d_k} = {expected}",
987                op.scale
988            );
989            // Verify it's NOT the wrong 1/d_k scaling
990            if d_k > 1 {
991                let wrong = 1.0 / d_k as f32;
992                assert!(
993                    (op.scale - wrong).abs() > 1e-6,
994                    "FALSIFIED ATT-003: scale matches wrong 1/{d_k} = {wrong}",
995                );
996            }
997        }
998    }
999
1000    /// FALSIFY-ATT-005: Weights bounded — all attention weights in [0, 1)
1001    ///
1002    /// Contract: 0 < attn_{ij} < 1 for all i,j in exact arithmetic.
1003    /// In f32, exp(-200) underflows to 0.0 so we test w >= 0 and w < 1.
1004    /// For moderate inputs (max gap < 80), strict w > 0 holds.
1005    #[test]
1006    fn falsify_att_005_weights_bounded() {
1007        // Moderate-range inputs where exp() doesn't underflow to 0
1008        let test_rows: Vec<Vec<f32>> = vec![
1009            vec![1.0, 2.0, 3.0, 4.0, 5.0],
1010            vec![-5.0, 0.0, 5.0],
1011            vec![0.0, 0.0, 0.0, 0.0],
1012            vec![1e-10, 1e-10],
1013            vec![-10.0, -10.0, -10.0],
1014            vec![20.0, 20.5, 21.0],
1015        ];
1016
1017        for values in &test_rows {
1018            let mut scores = values.clone();
1019            AttentionOp::simd_softmax_row(&mut scores);
1020            for (j, &w) in scores.iter().enumerate() {
1021                assert!(
1022                    w > 0.0,
1023                    "FALSIFIED ATT-005: weight[{j}] = {w} not > 0 for input {values:?}"
1024                );
1025                assert!(
1026                    w < 1.0,
1027                    "FALSIFIED ATT-005: weight[{j}] = {w} not < 1 for input {values:?} (m >= 2)"
1028                );
1029            }
1030        }
1031    }
1032
1033    /// FALSIFY-ATT-002b: Convexity with uniform V — output must equal V
1034    ///
1035    /// If all V rows are identical, output = V regardless of Q, K
1036    #[test]
1037    fn falsify_att_002b_uniform_v_identity() {
1038        let op = AttentionOp::new(2, 4, 8);
1039        let q: Vec<f32> = (0..16).map(|i| (i as f32) * 0.37).collect();
1040        let k: Vec<f32> = (0..32).map(|i| (i as f32) * 0.13).collect();
1041        // All 4 V rows are identical: [1, 2, 3, 4, 5, 6, 7, 8]
1042        let v_row = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1043        let v: Vec<f32> = v_row.iter().copied().cycle().take(32).collect();
1044
1045        let output = op.execute((q, k, v), Backend::Scalar).unwrap();
1046
1047        for qi in 0..2 {
1048            for d in 0..8 {
1049                let diff = (output[qi * 8 + d] - v_row[d]).abs();
1050                assert!(
1051                    diff < 1e-5,
1052                    "FALSIFIED ATT-002: uniform V output[{qi}][{d}] = {}, expected {}",
1053                    output[qi * 8 + d],
1054                    v_row[d]
1055                );
1056            }
1057        }
1058    }
1059}