1use super::{Backend, ComputeOp};
22use crate::error::TruenoError;
23
24#[derive(Debug, Clone)]
40pub struct AttentionOp {
41 pub seq_len: usize,
43 pub kv_seq_len: usize,
45 pub head_dim: usize,
47 pub scale: f32,
49}
50
51impl AttentionOp {
52 #[must_use]
60 pub fn new(seq_len: usize, kv_seq_len: usize, head_dim: usize) -> Self {
61 Self { seq_len, kv_seq_len, head_dim, scale: 1.0 / (head_dim as f32).sqrt() }
62 }
63
64 #[must_use]
66 pub fn self_attention(seq_len: usize, head_dim: usize) -> Self {
67 Self::new(seq_len, seq_len, head_dim)
68 }
69
70 #[inline]
74 pub(crate) fn simd_dot(a: &[f32], b: &[f32]) -> f32 {
75 debug_assert_eq!(a.len(), b.len());
76
77 #[cfg(target_arch = "x86_64")]
79 {
80 if is_x86_feature_detected!("avx2") {
81 return unsafe { Self::avx2_dot(a, b) };
83 }
84 }
85
86 let mut sum0 = 0.0f32;
88 let mut sum1 = 0.0f32;
89 let mut sum2 = 0.0f32;
90 let mut sum3 = 0.0f32;
91
92 let chunks = a.len() / 4;
93 for i in 0..chunks {
94 let base = i * 4;
95 sum0 += a[base] * b[base];
96 sum1 += a[base + 1] * b[base + 1];
97 sum2 += a[base + 2] * b[base + 2];
98 sum3 += a[base + 3] * b[base + 3];
99 }
100
101 for i in (chunks * 4)..a.len() {
103 sum0 += a[i] * b[i];
104 }
105
106 sum0 + sum1 + sum2 + sum3
107 }
108
109 #[cfg(target_arch = "x86_64")]
111 #[target_feature(enable = "avx2", enable = "fma")]
112 unsafe fn avx2_dot(a: &[f32], b: &[f32]) -> f32 {
114 unsafe {
115 use std::arch::x86_64::*;
116
117 let mut sum = _mm256_setzero_ps();
118 let chunks = a.len() / 8;
119
120 for i in 0..chunks {
121 let base = i * 8;
122 let va = _mm256_loadu_ps(a.as_ptr().add(base));
123 let vb = _mm256_loadu_ps(b.as_ptr().add(base));
124 sum = _mm256_fmadd_ps(va, vb, sum);
125 }
126
127 let high = _mm256_extractf128_ps(sum, 1);
129 let low = _mm256_castps256_ps128(sum);
130 let sum128 = _mm_add_ps(high, low);
131 let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
132 let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
133 let mut result = _mm_cvtss_f32(sum32);
134
135 for i in (chunks * 8)..a.len() {
137 result += a[i] * b[i];
138 }
139
140 result
141 }
142 }
143
144 #[inline]
147 pub(crate) fn simd_axpy(alpha: f32, x: &[f32], out: &mut [f32]) {
148 debug_assert_eq!(x.len(), out.len());
149
150 #[cfg(target_arch = "x86_64")]
151 {
152 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
153 unsafe {
155 Self::avx2_axpy(alpha, x, out);
156 }
157 return;
158 }
159 }
160
161 for (o, &xi) in out.iter_mut().zip(x.iter()) {
163 *o += alpha * xi;
164 }
165 }
166
167 #[cfg(target_arch = "x86_64")]
169 #[target_feature(enable = "avx2", enable = "fma")]
170 unsafe fn avx2_axpy(alpha: f32, x: &[f32], out: &mut [f32]) {
171 unsafe {
172 use std::arch::x86_64::*;
173
174 let alpha_v = _mm256_set1_ps(alpha);
175 let n = x.len();
176 let n8 = n / 8 * 8;
177
178 let mut i = 0;
179 while i < n8 {
180 let xv = _mm256_loadu_ps(x.as_ptr().add(i));
181 let ov = _mm256_loadu_ps(out.as_ptr().add(i));
182 let r = _mm256_fmadd_ps(alpha_v, xv, ov);
183 _mm256_storeu_ps(out.as_mut_ptr().add(i), r);
184 i += 8;
185 }
186 while i < n {
188 *out.get_unchecked_mut(i) += alpha * *x.get_unchecked(i);
189 i += 1;
190 }
191 }
192 }
193
194 #[inline]
199 pub(crate) fn simd_softmax_row(scores: &mut [f32]) {
200 if scores.is_empty() {
201 return;
202 }
203
204 #[cfg(target_arch = "x86_64")]
205 if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
206 unsafe {
208 Self::avx2_softmax_row(scores);
209 }
210 return;
211 }
212
213 Self::scalar_softmax_row(scores);
215 }
216
217 fn scalar_softmax_row(scores: &mut [f32]) {
219 let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
220 let mut sum = 0.0f32;
221 for s in scores.iter_mut() {
222 *s = (*s - max).exp();
223 sum += *s;
224 }
225 let inv_sum = 1.0 / sum.max(f32::EPSILON);
226 for s in scores.iter_mut() {
227 *s *= inv_sum;
228 }
229 }
230
231 #[cfg(target_arch = "x86_64")]
233 #[target_feature(enable = "avx2", enable = "fma")]
234 unsafe fn avx2_softmax_row(scores: &mut [f32]) {
235 unsafe {
236 use std::arch::x86_64::*;
237
238 let n = scores.len();
239 let n8 = n / 8 * 8;
240
241 let mut max_v = _mm256_set1_ps(f32::NEG_INFINITY);
243 let mut i = 0;
244 while i < n8 {
245 let v = _mm256_loadu_ps(scores.as_ptr().add(i));
246 max_v = _mm256_max_ps(max_v, v);
247 i += 8;
248 }
249 let hi = _mm256_permute2f128_ps(max_v, max_v, 1);
251 max_v = _mm256_max_ps(max_v, hi);
252 let shuf = _mm256_shuffle_ps(max_v, max_v, 0b01_00_11_10);
253 max_v = _mm256_max_ps(max_v, shuf);
254 let shuf2 = _mm256_shuffle_ps(max_v, max_v, 0b10_11_00_01);
255 max_v = _mm256_max_ps(max_v, shuf2);
256 let mut max_val = _mm_cvtss_f32(_mm256_castps256_ps128(max_v));
257 for j in n8..n {
258 max_val = max_val.max(scores[j]);
259 }
260
261 let max_broadcast = _mm256_set1_ps(max_val);
263 let mut sum_v = _mm256_setzero_ps();
264 i = 0;
265 while i < n8 {
266 let x = _mm256_sub_ps(_mm256_loadu_ps(scores.as_ptr().add(i)), max_broadcast);
267 let e = crate::blis::softmax::fast_exp_avx2(x);
268 _mm256_storeu_ps(scores.as_mut_ptr().add(i), e);
269 sum_v = _mm256_add_ps(sum_v, e);
270 i += 8;
271 }
272 let hi = _mm256_permute2f128_ps(sum_v, sum_v, 1);
274 sum_v = _mm256_add_ps(sum_v, hi);
275 let shuf = _mm256_shuffle_ps(sum_v, sum_v, 0b01_00_11_10);
276 sum_v = _mm256_add_ps(sum_v, shuf);
277 let shuf2 = _mm256_shuffle_ps(sum_v, sum_v, 0b10_11_00_01);
278 sum_v = _mm256_add_ps(sum_v, shuf2);
279 let mut sum_val = _mm_cvtss_f32(_mm256_castps256_ps128(sum_v));
280 for j in n8..n {
282 let e = (scores[j] - max_val).exp();
283 scores[j] = e;
284 sum_val += e;
285 }
286
287 let inv_sum = 1.0 / sum_val.max(f32::EPSILON);
289 let inv_v = _mm256_set1_ps(inv_sum);
290 i = 0;
291 while i < n8 {
292 let v = _mm256_loadu_ps(scores.as_ptr().add(i));
293 _mm256_storeu_ps(scores.as_mut_ptr().add(i), _mm256_mul_ps(v, inv_v));
294 i += 8;
295 }
296 for j in n8..n {
297 scores[j] *= inv_sum;
298 }
299 } }
301}
302
303impl ComputeOp for AttentionOp {
304 type Input = (Vec<f32>, Vec<f32>, Vec<f32>);
309 type Output = Vec<f32>;
311
312 fn name(&self) -> &'static str {
313 "attention"
314 }
315
316 fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
317 let (q, k, v) = input;
318
319 let expected_q = self.seq_len * self.head_dim;
321 let expected_kv = self.kv_seq_len * self.head_dim;
322
323 if q.len() != expected_q {
324 return Err(TruenoError::SizeMismatch { expected: expected_q, actual: q.len() });
325 }
326 if k.len() != expected_kv || v.len() != expected_kv {
327 return Err(TruenoError::SizeMismatch { expected: expected_kv, actual: k.len() });
328 }
329
330 let mut output: Vec<f32> = Vec::with_capacity(expected_q);
333 unsafe {
335 output.set_len(expected_q);
336 }
337 let mut scores: Vec<f32> = Vec::with_capacity(self.kv_seq_len);
338 unsafe {
340 scores.set_len(self.kv_seq_len);
341 }
342
343 for qi in 0..self.seq_len {
345 let q_row = &q[qi * self.head_dim..(qi + 1) * self.head_dim];
346
347 for ki in 0..self.kv_seq_len {
349 let k_row = &k[ki * self.head_dim..(ki + 1) * self.head_dim];
350 scores[ki] = Self::simd_dot(q_row, k_row) * self.scale;
351 }
352
353 Self::simd_softmax_row(&mut scores);
355
356 let out_row = &mut output[qi * self.head_dim..(qi + 1) * self.head_dim];
358 out_row.fill(0.0);
359
360 for ki in 0..self.kv_seq_len {
361 let v_row = &v[ki * self.head_dim..(ki + 1) * self.head_dim];
362 let weight = scores[ki];
363
364 Self::simd_axpy(weight, v_row, out_row);
366 }
367 }
368
369 Ok(output)
370 }
371
372 fn tokens(&self, _input: &Self::Input) -> usize {
373 self.seq_len * self.head_dim
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381
382 fn assert_dot(a: &[f32], b: &[f32], expected: f32) {
384 let dot = AttentionOp::simd_dot(a, b);
385 assert!((dot - expected).abs() < 1e-3, "dot={dot}, expected={expected}");
386 }
387
388 fn assert_dot_iota(n: usize) {
390 let a: Vec<f32> = (1..=n).map(|x| x as f32).collect();
391 let b = vec![1.0f32; n];
392 let expected = (n * (n + 1)) / 2;
393 assert_dot(&a, &b, expected as f32);
394 }
395
396 fn assert_softmax_normalized(values: &[f32]) {
398 let mut scores = values.to_vec();
399 AttentionOp::simd_softmax_row(&mut scores);
400 let sum: f32 = scores.iter().sum();
401 assert!((sum - 1.0).abs() < 1e-5, "softmax sum={sum}");
402 }
403
404 fn assert_attention_ok(
406 op: &AttentionOp,
407 q: Vec<f32>,
408 k: Vec<f32>,
409 v: Vec<f32>,
410 expected_len: usize,
411 ) -> Vec<f32> {
412 let output = op.execute((q, k, v), Backend::Scalar).unwrap();
413 assert_eq!(output.len(), expected_len);
414 for val in &output {
415 assert!(val.is_finite());
416 }
417 output
418 }
419
420 #[test]
421 fn test_attention_basic() {
422 let op = AttentionOp::self_attention(2, 4); let q = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let k = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let output = op.execute((q, k, v), Backend::Scalar).unwrap();
430
431 assert_eq!(output.len(), 8);
432 }
434
435 #[test]
436 fn test_attention_dimension_mismatch_q() {
437 let op = AttentionOp::self_attention(2, 4);
438 let q = vec![1.0; 4]; let k = vec![1.0; 8];
440 let v = vec![1.0; 8];
441
442 let result = op.execute((q, k, v), Backend::Scalar);
443 assert!(result.is_err());
444 }
445
446 #[test]
447 fn test_attention_dimension_mismatch_kv() {
448 let op = AttentionOp::self_attention(2, 4);
449 let q = vec![1.0; 8];
450 let k = vec![1.0; 4]; let v = vec![1.0; 8];
452
453 let result = op.execute((q, k, v), Backend::Scalar);
454 assert!(result.is_err());
455 }
456
457 #[test]
458 fn test_attention_cross_attention() {
459 let op = AttentionOp::new(1, 4, 8); let q = vec![1.0; 8]; let k = vec![1.0; 32]; let v = vec![1.0; 32]; let output = op.execute((q, k, v), Backend::Scalar).unwrap();
467 assert_eq!(output.len(), 8);
468 }
469
470 #[test]
471 fn test_attention_tokens() {
472 let op = AttentionOp::self_attention(16, 64);
473 let input = (vec![], vec![], vec![]);
474 assert_eq!(op.tokens(&input), 1024);
476 }
477
478 #[test]
479 fn test_simd_softmax_row_empty() {
480 let mut scores: Vec<f32> = vec![];
481 AttentionOp::simd_softmax_row(&mut scores);
482 assert!(scores.is_empty());
483 }
484
485 #[test]
486 fn test_simd_softmax_row_single() {
487 let mut scores = vec![5.0];
488 AttentionOp::simd_softmax_row(&mut scores);
489 assert!((scores[0] - 1.0).abs() < 1e-6);
490 }
491
492 #[test]
493 fn test_simd_softmax_row_uniform() {
494 let mut scores = vec![1.0, 1.0, 1.0, 1.0];
495 AttentionOp::simd_softmax_row(&mut scores);
496
497 for s in &scores {
499 assert!((s - 0.25).abs() < 1e-6);
500 }
501 }
502
503 #[test]
504 fn test_simd_softmax_row_sum_to_one() {
505 assert_softmax_normalized(&[1.0, 2.0, 3.0, 4.0, 5.0]);
506 }
507
508 #[test]
509 fn test_simd_dot_basic() {
510 assert_dot(&[1.0, 2.0, 3.0, 4.0], &[1.0, 1.0, 1.0, 1.0], 10.0);
511 }
512
513 #[test]
514 fn test_simd_dot_unaligned() {
515 assert_dot(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0; 5], 30.0);
516 }
517
518 #[test]
523 fn test_attention_op_fields() {
524 let op = AttentionOp::new(4, 8, 64);
525 assert_eq!(op.seq_len, 4);
526 assert_eq!(op.kv_seq_len, 8);
527 assert_eq!(op.head_dim, 64);
528 assert!((op.scale - 0.125).abs() < 1e-6);
530 }
531
532 #[test]
533 fn test_attention_self_attention_fields() {
534 let op = AttentionOp::self_attention(16, 32);
535 assert_eq!(op.seq_len, 16);
536 assert_eq!(op.kv_seq_len, 16); assert_eq!(op.head_dim, 32);
538 }
539
540 #[test]
541 fn test_attention_name() {
542 let op = AttentionOp::self_attention(1, 4);
543 assert_eq!(op.name(), "attention");
544 }
545
546 #[test]
547 fn test_attention_v_size_mismatch() {
548 let op = AttentionOp::self_attention(2, 4);
549 let q = vec![1.0; 8];
550 let k = vec![1.0; 8];
551 let v = vec![1.0; 4]; let result = op.execute((q, k, v), Backend::Scalar);
554 assert!(result.is_err());
555 }
556
557 #[test]
558 fn test_attention_single_position() {
559 let op = AttentionOp::self_attention(1, 4);
561 let q = vec![1.0, 0.0, 0.0, 0.0];
562 let k = vec![1.0, 0.0, 0.0, 0.0];
563 let v = vec![2.0, 3.0, 4.0, 5.0];
564
565 let output = op.execute((q, k, v), Backend::Scalar).unwrap();
566 assert_eq!(output.len(), 4);
567 assert!((output[0] - 2.0).abs() < 1e-5);
570 assert!((output[1] - 3.0).abs() < 1e-5);
571 assert!((output[2] - 4.0).abs() < 1e-5);
572 assert!((output[3] - 5.0).abs() < 1e-5);
573 }
574
575 #[test]
576 fn test_attention_uniform_scores() {
577 let op = AttentionOp::new(1, 2, 2);
580 let head_dim = 2;
581
582 let q = vec![1.0, 1.0]; let k = vec![1.0, 1.0, 1.0, 1.0]; let v = vec![1.0, 0.0, 0.0, 1.0]; let output = op.execute((q, k, v), Backend::Scalar).unwrap();
587 assert_eq!(output.len(), head_dim);
588 assert!((output[0] - 0.5).abs() < 1e-5);
591 assert!((output[1] - 0.5).abs() < 1e-5);
592 }
593
594 #[test]
595 fn test_simd_dot_exact_multiple_of_four() {
596 assert_dot_iota(8); }
598
599 #[test]
600 fn test_simd_dot_single_element() {
601 assert_dot(&[3.0], &[4.0], 12.0);
602 }
603
604 #[test]
605 fn test_simd_dot_two_elements() {
606 assert_dot(&[2.0, 3.0], &[4.0, 5.0], 23.0);
607 }
608
609 #[test]
610 fn test_simd_dot_three_elements() {
611 assert_dot(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], 32.0);
612 }
613
614 #[test]
615 fn test_simd_dot_large_avx2_aligned() {
616 assert_dot_iota(16); }
618
619 #[test]
620 fn test_simd_dot_large_avx2_remainder() {
621 assert_dot_iota(19); }
623
624 #[test]
625 fn test_simd_dot_zeros() {
626 assert_dot(&[0.0; 16], &[1.0; 16], 0.0);
627 }
628
629 #[test]
630 fn test_simd_dot_negative_values() {
631 assert_dot(&[-1.0, -2.0, -3.0, -4.0], &[1.0; 4], -10.0);
632 }
633
634 #[test]
635 fn test_simd_softmax_row_large_values() {
636 assert_softmax_normalized(&[1000.0, 1001.0, 1002.0]);
637 }
638
639 #[test]
640 fn test_simd_softmax_row_negative_values() {
641 assert_softmax_normalized(&[-10.0, -20.0, -5.0]);
642 }
643
644 #[test]
645 fn test_attention_clone() {
646 let op = AttentionOp::new(4, 8, 64);
647 let cloned = op.clone();
648 assert_eq!(cloned.seq_len, 4);
649 assert_eq!(cloned.kv_seq_len, 8);
650 assert_eq!(cloned.head_dim, 64);
651 assert!((cloned.scale - op.scale).abs() < 1e-10);
652 }
653
654 #[test]
655 fn test_attention_multi_query_rows() {
656 let op = AttentionOp::new(3, 2, 2);
657 let q = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
658 let k = vec![1.0, 0.0, 0.0, 1.0];
659 let v = vec![10.0, 20.0, 30.0, 40.0];
660 assert_attention_ok(&op, q, k, v, 6);
661 }
662
663 #[test]
664 fn test_attention_tokens_cross_attention() {
665 let op = AttentionOp::new(1, 100, 64);
666 assert_eq!(op.tokens(&(vec![], vec![], vec![])), 64);
667 }
668
669 #[test]
672 fn test_simd_dot_avx2_remainders() {
673 for n in [9, 10, 15, 24, 5, 6, 7] {
675 assert_dot_iota(n);
676 }
677 }
678
679 #[test]
680 fn test_simd_dot_large_64_elements() {
681 assert_dot_iota(64); }
683
684 #[test]
685 fn test_simd_dot_orthogonal() {
686 let mut a = vec![0.0; 9];
687 let mut b = vec![0.0; 9];
688 a[0] = 1.0;
689 b[1] = 1.0;
690 assert_dot(&a, &b, 0.0);
691 }
692
693 #[test]
694 fn test_attention_execute_non_aligned_head_dim() {
695 let op = AttentionOp::self_attention(2, 9);
696 let output = assert_attention_ok(&op, vec![1.0; 18], vec![1.0; 18], vec![1.0; 18], 18);
697 for val in &output {
699 assert!((val - 1.0).abs() < 1e-4);
700 }
701 }
702
703 #[test]
704 fn test_attention_execute_head_dim_17() {
705 let op = AttentionOp::new(1, 3, 17);
706 let q: Vec<f32> = (0..17).map(|i| (i as f32) * 0.1).collect();
707 let k: Vec<f32> = (0..51).map(|i| ((i % 5) as f32) * 0.2).collect();
708 let v: Vec<f32> = (0..51).map(|i| (i as f32) * 0.01).collect();
709 assert_attention_ok(&op, q, k, v, 17);
710 }
711
712 fn assert_dot_scalar_ref(n: usize) {
719 let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.3 + 1.0).collect();
720 let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.7 - 0.5).collect();
721 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
722 let result = AttentionOp::simd_dot(&a, &b);
723 assert!(
724 (result - expected).abs() < 1e-2 * expected.abs().max(1.0),
725 "n={n}: dot={result}, expected={expected}"
726 );
727 }
728
729 #[test]
730 fn test_simd_dot_avx2_remainder_0() {
731 assert_dot_scalar_ref(32);
733 }
734
735 #[test]
736 fn test_simd_dot_avx2_remainder_1() {
737 assert_dot_scalar_ref(33);
739 }
740
741 #[test]
742 fn test_simd_dot_avx2_remainder_2() {
743 assert_dot_scalar_ref(34);
745 }
746
747 #[test]
748 fn test_simd_dot_avx2_remainder_3() {
749 assert_dot_scalar_ref(35);
751 }
752
753 #[test]
754 fn test_simd_dot_avx2_remainder_4() {
755 assert_dot_scalar_ref(36);
757 }
758
759 #[test]
760 fn test_simd_dot_avx2_remainder_5() {
761 assert_dot_scalar_ref(37);
763 }
764
765 #[test]
766 fn test_simd_dot_avx2_remainder_6() {
767 assert_dot_scalar_ref(38);
769 }
770
771 #[test]
772 fn test_simd_dot_avx2_remainder_7() {
773 assert_dot_scalar_ref(39);
775 }
776
777 #[test]
778 fn test_simd_dot_large_128() {
779 assert_dot_scalar_ref(128);
781 }
782
783 #[test]
784 fn test_simd_dot_large_1024() {
785 assert_dot_scalar_ref(1024);
787 }
788
789 #[test]
790 fn test_simd_dot_large_1024_plus_5() {
791 assert_dot_scalar_ref(1029);
793 }
794
795 #[test]
796 fn test_simd_dot_known_identity() {
797 let n = 64;
799 let a: Vec<f32> = {
800 let mut v = vec![0.0; n];
801 v[0] = 1.0;
802 v
803 };
804 let b = a.clone();
805 let result = AttentionOp::simd_dot(&a, &b);
806 assert!((result - 1.0).abs() < 1e-6, "identity dot = {result}");
807 }
808
809 #[test]
810 fn test_simd_dot_alternating_signs() {
811 let n = 64;
813 let a: Vec<f32> = (0..n).map(|i| if i % 2 == 0 { 1.0 } else { -1.0 }).collect();
814 let b = vec![1.0; n];
815 let result = AttentionOp::simd_dot(&a, &b);
816 assert!((result).abs() < 1e-5, "alternating dot = {result}");
817 }
818
819 #[test]
820 fn test_simd_dot_large_values() {
821 let a = vec![1000.0; 16];
823 let b = vec![1000.0; 16];
824 let expected = 1000.0 * 1000.0 * 16.0;
825 let result = AttentionOp::simd_dot(&a, &b);
826 assert!((result - expected).abs() < 1.0, "large dot = {result}, expected = {expected}");
827 }
828
829 #[test]
830 fn test_simd_dot_mixed_positive_negative() {
831 let a = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
833 let b = vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
834 let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
835 let result = AttentionOp::simd_dot(&a, &b);
836 assert!((result - expected).abs() < 1e-3, "mixed dot = {result}, expected = {expected}");
837 }
838
839 #[test]
840 fn test_simd_dot_very_small_values() {
841 let a = vec![1e-10; 16];
842 let b = vec![1e-10; 16];
843 let expected = 1e-20 * 16.0;
844 let result = AttentionOp::simd_dot(&a, &b);
845 assert!((result - expected).abs() < 1e-24, "small dot = {result}, expected = {expected}");
846 }
847
848 #[test]
853 fn test_attention_head_dim_64_multi_seq() {
854 let op = AttentionOp::self_attention(4, 64);
856 let q = vec![0.1; 4 * 64];
857 let k = vec![0.1; 4 * 64];
858 let v = vec![1.0; 4 * 64];
859 let output = assert_attention_ok(&op, q, k, v, 4 * 64);
860 for val in &output {
862 assert!((val - 1.0).abs() < 1e-4, "expected ~1.0, got {val}");
863 }
864 }
865
866 #[test]
867 fn test_attention_head_dim_128() {
868 let op = AttentionOp::new(2, 3, 128);
870 let q: Vec<f32> = (0..2 * 128).map(|i| (i as f32) * 0.001).collect();
871 let k: Vec<f32> = (0..3 * 128).map(|i| ((i % 7) as f32) * 0.01).collect();
872 let v: Vec<f32> = (0..3 * 128).map(|i| (i as f32) * 0.005).collect();
873 assert_attention_ok(&op, q, k, v, 2 * 128);
874 }
875
876 #[test]
877 fn test_attention_head_dim_33() {
878 let op = AttentionOp::new(2, 2, 33);
880 let q = vec![0.5; 2 * 33];
881 let k = vec![0.5; 2 * 33];
882 let v = vec![2.0; 2 * 33];
883 let output = assert_attention_ok(&op, q, k, v, 2 * 33);
884 for val in &output {
885 assert!((val - 2.0).abs() < 1e-4, "expected ~2.0, got {val}");
886 }
887 }
888
889 #[test]
890 fn test_attention_head_dim_7() {
891 let op = AttentionOp::self_attention(2, 7);
893 let q = vec![1.0; 2 * 7];
894 let k = vec![1.0; 2 * 7];
895 let v = vec![3.0; 2 * 7];
896 let output = assert_attention_ok(&op, q, k, v, 2 * 7);
897 for val in &output {
898 assert!((val - 3.0).abs() < 1e-4, "expected ~3.0, got {val}");
899 }
900 }
901
902 #[test]
921 fn falsify_att_001_weight_normalization() {
922 let test_rows: Vec<Vec<f32>> = vec![
923 vec![1.0, 2.0, 3.0, 4.0],
924 vec![-5.0, 0.0, 5.0, 10.0],
925 vec![1000.0, 1001.0, 1002.0],
926 vec![1e-7, 1e-7, 1e-7],
927 vec![0.0; 8],
928 vec![-100.0, 100.0],
929 ];
930
931 for values in &test_rows {
932 let mut scores = values.clone();
933 AttentionOp::simd_softmax_row(&mut scores);
934 let sum: f32 = scores.iter().sum();
935 assert!(
936 (sum - 1.0).abs() < 1e-5,
937 "FALSIFIED ATT-001: softmax row sum = {sum}, expected 1.0 for input {values:?}"
938 );
939 }
940 }
941
942 #[test]
946 fn falsify_att_002_output_convexity() {
947 let seq_len = 2;
948 let kv_seq_len = 3;
949 let head_dim = 4;
950 let op = AttentionOp::new(seq_len, kv_seq_len, head_dim);
951
952 let q = vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5];
953 let k = vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9];
954 let v = vec![2.0, -3.0, 5.0, 1.0, -1.0, 4.0, -2.0, 7.0, 3.0, 0.0, -4.0, 6.0];
955
956 let output = op.execute((q, k, v.clone()), Backend::Scalar).unwrap();
957
958 for qi in 0..seq_len {
959 for d in 0..head_dim {
960 let out_val = output[qi * head_dim + d];
961
962 let v_col_min =
963 (0..kv_seq_len).map(|ki| v[ki * head_dim + d]).fold(f32::INFINITY, f32::min);
964 let v_col_max = (0..kv_seq_len)
965 .map(|ki| v[ki * head_dim + d])
966 .fold(f32::NEG_INFINITY, f32::max);
967
968 assert!(
969 out_val >= v_col_min - 1e-5 && out_val <= v_col_max + 1e-5,
970 "FALSIFIED ATT-002: output[{qi}][{d}] = {out_val} outside V column [{v_col_min}, {v_col_max}]"
971 );
972 }
973 }
974 }
975
976 #[test]
980 fn falsify_att_003_scaling_factor() {
981 for d_k in [4, 8, 16, 32, 64, 128] {
982 let op = AttentionOp::self_attention(1, d_k);
983 let expected = 1.0 / (d_k as f32).sqrt();
984 assert!(
985 (op.scale - expected).abs() < 1e-6,
986 "FALSIFIED ATT-003: scale = {}, expected 1/√{d_k} = {expected}",
987 op.scale
988 );
989 if d_k > 1 {
991 let wrong = 1.0 / d_k as f32;
992 assert!(
993 (op.scale - wrong).abs() > 1e-6,
994 "FALSIFIED ATT-003: scale matches wrong 1/{d_k} = {wrong}",
995 );
996 }
997 }
998 }
999
1000 #[test]
1006 fn falsify_att_005_weights_bounded() {
1007 let test_rows: Vec<Vec<f32>> = vec![
1009 vec![1.0, 2.0, 3.0, 4.0, 5.0],
1010 vec![-5.0, 0.0, 5.0],
1011 vec![0.0, 0.0, 0.0, 0.0],
1012 vec![1e-10, 1e-10],
1013 vec![-10.0, -10.0, -10.0],
1014 vec![20.0, 20.5, 21.0],
1015 ];
1016
1017 for values in &test_rows {
1018 let mut scores = values.clone();
1019 AttentionOp::simd_softmax_row(&mut scores);
1020 for (j, &w) in scores.iter().enumerate() {
1021 assert!(
1022 w > 0.0,
1023 "FALSIFIED ATT-005: weight[{j}] = {w} not > 0 for input {values:?}"
1024 );
1025 assert!(
1026 w < 1.0,
1027 "FALSIFIED ATT-005: weight[{j}] = {w} not < 1 for input {values:?} (m >= 2)"
1028 );
1029 }
1030 }
1031 }
1032
1033 #[test]
1037 fn falsify_att_002b_uniform_v_identity() {
1038 let op = AttentionOp::new(2, 4, 8);
1039 let q: Vec<f32> = (0..16).map(|i| (i as f32) * 0.37).collect();
1040 let k: Vec<f32> = (0..32).map(|i| (i as f32) * 0.13).collect();
1041 let v_row = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
1043 let v: Vec<f32> = v_row.iter().copied().cycle().take(32).collect();
1044
1045 let output = op.execute((q, k, v), Backend::Scalar).unwrap();
1046
1047 for qi in 0..2 {
1048 for d in 0..8 {
1049 let diff = (output[qi * 8 + d] - v_row[d]).abs();
1050 assert!(
1051 diff < 1e-5,
1052 "FALSIFIED ATT-002: uniform V output[{qi}][{d}] = {}, expected {}",
1053 output[qi * 8 + d],
1054 v_row[d]
1055 );
1056 }
1057 }
1058 }
1059}