use super::{Backend, ComputeOp};
use crate::error::TruenoError;
#[derive(Debug, Clone)]
pub struct AttentionOp {
pub seq_len: usize,
pub kv_seq_len: usize,
pub head_dim: usize,
pub scale: f32,
}
impl AttentionOp {
#[must_use]
pub fn new(seq_len: usize, kv_seq_len: usize, head_dim: usize) -> Self {
Self { seq_len, kv_seq_len, head_dim, scale: 1.0 / (head_dim as f32).sqrt() }
}
#[must_use]
pub fn self_attention(seq_len: usize, head_dim: usize) -> Self {
Self::new(seq_len, seq_len, head_dim)
}
#[inline]
pub(crate) fn simd_dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") {
return unsafe { Self::avx2_dot(a, b) };
}
}
let mut sum0 = 0.0f32;
let mut sum1 = 0.0f32;
let mut sum2 = 0.0f32;
let mut sum3 = 0.0f32;
let chunks = a.len() / 4;
for i in 0..chunks {
let base = i * 4;
sum0 += a[base] * b[base];
sum1 += a[base + 1] * b[base + 1];
sum2 += a[base + 2] * b[base + 2];
sum3 += a[base + 3] * b[base + 3];
}
for i in (chunks * 4)..a.len() {
sum0 += a[i] * b[i];
}
sum0 + sum1 + sum2 + sum3
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn avx2_dot(a: &[f32], b: &[f32]) -> f32 {
unsafe {
use std::arch::x86_64::*;
let mut sum = _mm256_setzero_ps();
let chunks = a.len() / 8;
for i in 0..chunks {
let base = i * 8;
let va = _mm256_loadu_ps(a.as_ptr().add(base));
let vb = _mm256_loadu_ps(b.as_ptr().add(base));
sum = _mm256_fmadd_ps(va, vb, sum);
}
let high = _mm256_extractf128_ps(sum, 1);
let low = _mm256_castps256_ps128(sum);
let sum128 = _mm_add_ps(high, low);
let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
let mut result = _mm_cvtss_f32(sum32);
for i in (chunks * 8)..a.len() {
result += a[i] * b[i];
}
result
}
}
#[inline]
pub(crate) fn simd_softmax_row(scores: &mut [f32]) {
if scores.is_empty() {
return;
}
let max = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for s in scores.iter_mut() {
*s = (*s - max).exp();
sum += *s;
}
let inv_sum = 1.0 / sum.max(f32::EPSILON);
for s in scores.iter_mut() {
*s *= inv_sum;
}
}
}
impl ComputeOp for AttentionOp {
type Input = (Vec<f32>, Vec<f32>, Vec<f32>);
type Output = Vec<f32>;
fn name(&self) -> &'static str {
"attention"
}
fn execute(&self, input: Self::Input, _backend: Backend) -> Result<Self::Output, TruenoError> {
let (q, k, v) = input;
let expected_q = self.seq_len * self.head_dim;
let expected_kv = self.kv_seq_len * self.head_dim;
if q.len() != expected_q {
return Err(TruenoError::SizeMismatch { expected: expected_q, actual: q.len() });
}
if k.len() != expected_kv || v.len() != expected_kv {
return Err(TruenoError::SizeMismatch { expected: expected_kv, actual: k.len() });
}
let mut output = vec![0.0f32; expected_q];
let mut scores = vec![0.0f32; self.kv_seq_len];
for qi in 0..self.seq_len {
let q_row = &q[qi * self.head_dim..(qi + 1) * self.head_dim];
for ki in 0..self.kv_seq_len {
let k_row = &k[ki * self.head_dim..(ki + 1) * self.head_dim];
scores[ki] = Self::simd_dot(q_row, k_row) * self.scale;
}
Self::simd_softmax_row(&mut scores);
let out_row = &mut output[qi * self.head_dim..(qi + 1) * self.head_dim];
out_row.fill(0.0);
for ki in 0..self.kv_seq_len {
let v_row = &v[ki * self.head_dim..(ki + 1) * self.head_dim];
let weight = scores[ki];
for (o, &vi) in out_row.iter_mut().zip(v_row.iter()) {
*o += weight * vi;
}
}
}
Ok(output)
}
fn tokens(&self, _input: &Self::Input) -> usize {
self.seq_len * self.head_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
fn assert_dot(a: &[f32], b: &[f32], expected: f32) {
let dot = AttentionOp::simd_dot(a, b);
assert!((dot - expected).abs() < 1e-3, "dot={dot}, expected={expected}");
}
fn assert_dot_iota(n: usize) {
let a: Vec<f32> = (1..=n).map(|x| x as f32).collect();
let b = vec![1.0f32; n];
let expected = (n * (n + 1)) / 2;
assert_dot(&a, &b, expected as f32);
}
fn assert_softmax_normalized(values: &[f32]) {
let mut scores = values.to_vec();
AttentionOp::simd_softmax_row(&mut scores);
let sum: f32 = scores.iter().sum();
assert!((sum - 1.0).abs() < 1e-5, "softmax sum={sum}");
}
fn assert_attention_ok(
op: &AttentionOp,
q: Vec<f32>,
k: Vec<f32>,
v: Vec<f32>,
expected_len: usize,
) -> Vec<f32> {
let output = op.execute((q, k, v), Backend::Scalar).unwrap();
assert_eq!(output.len(), expected_len);
for val in &output {
assert!(val.is_finite());
}
output
}
#[test]
fn test_attention_basic() {
let op = AttentionOp::self_attention(2, 4);
let q = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let k = vec![1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]; let v = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let output = op.execute((q, k, v), Backend::Scalar).unwrap();
assert_eq!(output.len(), 8);
}
#[test]
fn test_attention_dimension_mismatch_q() {
let op = AttentionOp::self_attention(2, 4);
let q = vec![1.0; 4]; let k = vec![1.0; 8];
let v = vec![1.0; 8];
let result = op.execute((q, k, v), Backend::Scalar);
assert!(result.is_err());
}
#[test]
fn test_attention_dimension_mismatch_kv() {
let op = AttentionOp::self_attention(2, 4);
let q = vec![1.0; 8];
let k = vec![1.0; 4]; let v = vec![1.0; 8];
let result = op.execute((q, k, v), Backend::Scalar);
assert!(result.is_err());
}
#[test]
fn test_attention_cross_attention() {
let op = AttentionOp::new(1, 4, 8);
let q = vec![1.0; 8]; let k = vec![1.0; 32]; let v = vec![1.0; 32];
let output = op.execute((q, k, v), Backend::Scalar).unwrap();
assert_eq!(output.len(), 8);
}
#[test]
fn test_attention_tokens() {
let op = AttentionOp::self_attention(16, 64);
let input = (vec![], vec![], vec![]);
assert_eq!(op.tokens(&input), 1024);
}
#[test]
fn test_simd_softmax_row_empty() {
let mut scores: Vec<f32> = vec![];
AttentionOp::simd_softmax_row(&mut scores);
assert!(scores.is_empty());
}
#[test]
fn test_simd_softmax_row_single() {
let mut scores = vec![5.0];
AttentionOp::simd_softmax_row(&mut scores);
assert!((scores[0] - 1.0).abs() < 1e-6);
}
#[test]
fn test_simd_softmax_row_uniform() {
let mut scores = vec![1.0, 1.0, 1.0, 1.0];
AttentionOp::simd_softmax_row(&mut scores);
for s in &scores {
assert!((s - 0.25).abs() < 1e-6);
}
}
#[test]
fn test_simd_softmax_row_sum_to_one() {
assert_softmax_normalized(&[1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_simd_dot_basic() {
assert_dot(&[1.0, 2.0, 3.0, 4.0], &[1.0, 1.0, 1.0, 1.0], 10.0);
}
#[test]
fn test_simd_dot_unaligned() {
assert_dot(&[1.0, 2.0, 3.0, 4.0, 5.0], &[2.0; 5], 30.0);
}
#[test]
fn test_attention_op_fields() {
let op = AttentionOp::new(4, 8, 64);
assert_eq!(op.seq_len, 4);
assert_eq!(op.kv_seq_len, 8);
assert_eq!(op.head_dim, 64);
assert!((op.scale - 0.125).abs() < 1e-6);
}
#[test]
fn test_attention_self_attention_fields() {
let op = AttentionOp::self_attention(16, 32);
assert_eq!(op.seq_len, 16);
assert_eq!(op.kv_seq_len, 16); assert_eq!(op.head_dim, 32);
}
#[test]
fn test_attention_name() {
let op = AttentionOp::self_attention(1, 4);
assert_eq!(op.name(), "attention");
}
#[test]
fn test_attention_v_size_mismatch() {
let op = AttentionOp::self_attention(2, 4);
let q = vec![1.0; 8];
let k = vec![1.0; 8];
let v = vec![1.0; 4];
let result = op.execute((q, k, v), Backend::Scalar);
assert!(result.is_err());
}
#[test]
fn test_attention_single_position() {
let op = AttentionOp::self_attention(1, 4);
let q = vec![1.0, 0.0, 0.0, 0.0];
let k = vec![1.0, 0.0, 0.0, 0.0];
let v = vec![2.0, 3.0, 4.0, 5.0];
let output = op.execute((q, k, v), Backend::Scalar).unwrap();
assert_eq!(output.len(), 4);
assert!((output[0] - 2.0).abs() < 1e-5);
assert!((output[1] - 3.0).abs() < 1e-5);
assert!((output[2] - 4.0).abs() < 1e-5);
assert!((output[3] - 5.0).abs() < 1e-5);
}
#[test]
fn test_attention_uniform_scores() {
let op = AttentionOp::new(1, 2, 2);
let head_dim = 2;
let q = vec![1.0, 1.0]; let k = vec![1.0, 1.0, 1.0, 1.0]; let v = vec![1.0, 0.0, 0.0, 1.0];
let output = op.execute((q, k, v), Backend::Scalar).unwrap();
assert_eq!(output.len(), head_dim);
assert!((output[0] - 0.5).abs() < 1e-5);
assert!((output[1] - 0.5).abs() < 1e-5);
}
#[test]
fn test_simd_dot_exact_multiple_of_four() {
assert_dot_iota(8); }
#[test]
fn test_simd_dot_single_element() {
assert_dot(&[3.0], &[4.0], 12.0);
}
#[test]
fn test_simd_dot_two_elements() {
assert_dot(&[2.0, 3.0], &[4.0, 5.0], 23.0);
}
#[test]
fn test_simd_dot_three_elements() {
assert_dot(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], 32.0);
}
#[test]
fn test_simd_dot_large_avx2_aligned() {
assert_dot_iota(16); }
#[test]
fn test_simd_dot_large_avx2_remainder() {
assert_dot_iota(19); }
#[test]
fn test_simd_dot_zeros() {
assert_dot(&[0.0; 16], &[1.0; 16], 0.0);
}
#[test]
fn test_simd_dot_negative_values() {
assert_dot(&[-1.0, -2.0, -3.0, -4.0], &[1.0; 4], -10.0);
}
#[test]
fn test_simd_softmax_row_large_values() {
assert_softmax_normalized(&[1000.0, 1001.0, 1002.0]);
}
#[test]
fn test_simd_softmax_row_negative_values() {
assert_softmax_normalized(&[-10.0, -20.0, -5.0]);
}
#[test]
fn test_attention_clone() {
let op = AttentionOp::new(4, 8, 64);
let cloned = op.clone();
assert_eq!(cloned.seq_len, 4);
assert_eq!(cloned.kv_seq_len, 8);
assert_eq!(cloned.head_dim, 64);
assert!((cloned.scale - op.scale).abs() < 1e-10);
}
#[test]
fn test_attention_multi_query_rows() {
let op = AttentionOp::new(3, 2, 2);
let q = vec![1.0, 0.0, 0.0, 1.0, 1.0, 1.0];
let k = vec![1.0, 0.0, 0.0, 1.0];
let v = vec![10.0, 20.0, 30.0, 40.0];
assert_attention_ok(&op, q, k, v, 6);
}
#[test]
fn test_attention_tokens_cross_attention() {
let op = AttentionOp::new(1, 100, 64);
assert_eq!(op.tokens(&(vec![], vec![], vec![])), 64);
}
#[test]
fn test_simd_dot_avx2_remainders() {
for n in [9, 10, 15, 24, 5, 6, 7] {
assert_dot_iota(n);
}
}
#[test]
fn test_simd_dot_large_64_elements() {
assert_dot_iota(64); }
#[test]
fn test_simd_dot_orthogonal() {
let mut a = vec![0.0; 9];
let mut b = vec![0.0; 9];
a[0] = 1.0;
b[1] = 1.0;
assert_dot(&a, &b, 0.0);
}
#[test]
fn test_attention_execute_non_aligned_head_dim() {
let op = AttentionOp::self_attention(2, 9);
let output = assert_attention_ok(&op, vec![1.0; 18], vec![1.0; 18], vec![1.0; 18], 18);
for val in &output {
assert!((val - 1.0).abs() < 1e-4);
}
}
#[test]
fn test_attention_execute_head_dim_17() {
let op = AttentionOp::new(1, 3, 17);
let q: Vec<f32> = (0..17).map(|i| (i as f32) * 0.1).collect();
let k: Vec<f32> = (0..51).map(|i| ((i % 5) as f32) * 0.2).collect();
let v: Vec<f32> = (0..51).map(|i| (i as f32) * 0.01).collect();
assert_attention_ok(&op, q, k, v, 17);
}
fn assert_dot_scalar_ref(n: usize) {
let a: Vec<f32> = (0..n).map(|i| (i as f32) * 0.3 + 1.0).collect();
let b: Vec<f32> = (0..n).map(|i| (i as f32) * 0.7 - 0.5).collect();
let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let result = AttentionOp::simd_dot(&a, &b);
assert!(
(result - expected).abs() < 1e-2 * expected.abs().max(1.0),
"n={n}: dot={result}, expected={expected}"
);
}
#[test]
fn test_simd_dot_avx2_remainder_0() {
assert_dot_scalar_ref(32);
}
#[test]
fn test_simd_dot_avx2_remainder_1() {
assert_dot_scalar_ref(33);
}
#[test]
fn test_simd_dot_avx2_remainder_2() {
assert_dot_scalar_ref(34);
}
#[test]
fn test_simd_dot_avx2_remainder_3() {
assert_dot_scalar_ref(35);
}
#[test]
fn test_simd_dot_avx2_remainder_4() {
assert_dot_scalar_ref(36);
}
#[test]
fn test_simd_dot_avx2_remainder_5() {
assert_dot_scalar_ref(37);
}
#[test]
fn test_simd_dot_avx2_remainder_6() {
assert_dot_scalar_ref(38);
}
#[test]
fn test_simd_dot_avx2_remainder_7() {
assert_dot_scalar_ref(39);
}
#[test]
fn test_simd_dot_large_128() {
assert_dot_scalar_ref(128);
}
#[test]
fn test_simd_dot_large_1024() {
assert_dot_scalar_ref(1024);
}
#[test]
fn test_simd_dot_large_1024_plus_5() {
assert_dot_scalar_ref(1029);
}
#[test]
fn test_simd_dot_known_identity() {
let n = 64;
let a: Vec<f32> = {
let mut v = vec![0.0; n];
v[0] = 1.0;
v
};
let b = a.clone();
let result = AttentionOp::simd_dot(&a, &b);
assert!((result - 1.0).abs() < 1e-6, "identity dot = {result}");
}
#[test]
fn test_simd_dot_alternating_signs() {
let n = 64;
let a: Vec<f32> = (0..n).map(|i| if i % 2 == 0 { 1.0 } else { -1.0 }).collect();
let b = vec![1.0; n];
let result = AttentionOp::simd_dot(&a, &b);
assert!((result).abs() < 1e-5, "alternating dot = {result}");
}
#[test]
fn test_simd_dot_large_values() {
let a = vec![1000.0; 16];
let b = vec![1000.0; 16];
let expected = 1000.0 * 1000.0 * 16.0;
let result = AttentionOp::simd_dot(&a, &b);
assert!((result - expected).abs() < 1.0, "large dot = {result}, expected = {expected}");
}
#[test]
fn test_simd_dot_mixed_positive_negative() {
let a = vec![1.0, -2.0, 3.0, -4.0, 5.0, -6.0, 7.0, -8.0, 9.0, -10.0];
let b = vec![10.0, 9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0];
let expected: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let result = AttentionOp::simd_dot(&a, &b);
assert!((result - expected).abs() < 1e-3, "mixed dot = {result}, expected = {expected}");
}
#[test]
fn test_simd_dot_very_small_values() {
let a = vec![1e-10; 16];
let b = vec![1e-10; 16];
let expected = 1e-20 * 16.0;
let result = AttentionOp::simd_dot(&a, &b);
assert!((result - expected).abs() < 1e-24, "small dot = {result}, expected = {expected}");
}
#[test]
fn test_attention_head_dim_64_multi_seq() {
let op = AttentionOp::self_attention(4, 64);
let q = vec![0.1; 4 * 64];
let k = vec![0.1; 4 * 64];
let v = vec![1.0; 4 * 64];
let output = assert_attention_ok(&op, q, k, v, 4 * 64);
for val in &output {
assert!((val - 1.0).abs() < 1e-4, "expected ~1.0, got {val}");
}
}
#[test]
fn test_attention_head_dim_128() {
let op = AttentionOp::new(2, 3, 128);
let q: Vec<f32> = (0..2 * 128).map(|i| (i as f32) * 0.001).collect();
let k: Vec<f32> = (0..3 * 128).map(|i| ((i % 7) as f32) * 0.01).collect();
let v: Vec<f32> = (0..3 * 128).map(|i| (i as f32) * 0.005).collect();
assert_attention_ok(&op, q, k, v, 2 * 128);
}
#[test]
fn test_attention_head_dim_33() {
let op = AttentionOp::new(2, 2, 33);
let q = vec![0.5; 2 * 33];
let k = vec![0.5; 2 * 33];
let v = vec![2.0; 2 * 33];
let output = assert_attention_ok(&op, q, k, v, 2 * 33);
for val in &output {
assert!((val - 2.0).abs() < 1e-4, "expected ~2.0, got {val}");
}
}
#[test]
fn test_attention_head_dim_7() {
let op = AttentionOp::self_attention(2, 7);
let q = vec![1.0; 2 * 7];
let k = vec![1.0; 2 * 7];
let v = vec![3.0; 2 * 7];
let output = assert_attention_ok(&op, q, k, v, 2 * 7);
for val in &output {
assert!((val - 3.0).abs() < 1e-4, "expected ~3.0, got {val}");
}
}
#[test]
fn falsify_att_001_weight_normalization() {
let test_rows: Vec<Vec<f32>> = vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![-5.0, 0.0, 5.0, 10.0],
vec![1000.0, 1001.0, 1002.0],
vec![1e-7, 1e-7, 1e-7],
vec![0.0; 8],
vec![-100.0, 100.0],
];
for values in &test_rows {
let mut scores = values.clone();
AttentionOp::simd_softmax_row(&mut scores);
let sum: f32 = scores.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"FALSIFIED ATT-001: softmax row sum = {sum}, expected 1.0 for input {values:?}"
);
}
}
#[test]
fn falsify_att_002_output_convexity() {
let seq_len = 2;
let kv_seq_len = 3;
let head_dim = 4;
let op = AttentionOp::new(seq_len, kv_seq_len, head_dim);
let q = vec![1.0, 0.5, -0.3, 0.8, -1.0, 0.2, 0.7, -0.5];
let k = vec![0.3, -0.7, 1.0, 0.2, -0.5, 0.8, 0.1, -0.3, 0.6, -0.1, 0.4, 0.9];
let v = vec![2.0, -3.0, 5.0, 1.0, -1.0, 4.0, -2.0, 7.0, 3.0, 0.0, -4.0, 6.0];
let output = op.execute((q, k, v.clone()), Backend::Scalar).unwrap();
for qi in 0..seq_len {
for d in 0..head_dim {
let out_val = output[qi * head_dim + d];
let v_col_min =
(0..kv_seq_len).map(|ki| v[ki * head_dim + d]).fold(f32::INFINITY, f32::min);
let v_col_max = (0..kv_seq_len)
.map(|ki| v[ki * head_dim + d])
.fold(f32::NEG_INFINITY, f32::max);
assert!(
out_val >= v_col_min - 1e-5 && out_val <= v_col_max + 1e-5,
"FALSIFIED ATT-002: output[{qi}][{d}] = {out_val} outside V column [{v_col_min}, {v_col_max}]"
);
}
}
}
#[test]
fn falsify_att_003_scaling_factor() {
for d_k in [4, 8, 16, 32, 64, 128] {
let op = AttentionOp::self_attention(1, d_k);
let expected = 1.0 / (d_k as f32).sqrt();
assert!(
(op.scale - expected).abs() < 1e-6,
"FALSIFIED ATT-003: scale = {}, expected 1/√{d_k} = {expected}",
op.scale
);
if d_k > 1 {
let wrong = 1.0 / d_k as f32;
assert!(
(op.scale - wrong).abs() > 1e-6,
"FALSIFIED ATT-003: scale matches wrong 1/{d_k} = {wrong}",
);
}
}
}
#[test]
fn falsify_att_005_weights_bounded() {
let test_rows: Vec<Vec<f32>> = vec![
vec![1.0, 2.0, 3.0, 4.0, 5.0],
vec![-5.0, 0.0, 5.0],
vec![0.0, 0.0, 0.0, 0.0],
vec![1e-10, 1e-10],
vec![-10.0, -10.0, -10.0],
vec![20.0, 20.5, 21.0],
];
for values in &test_rows {
let mut scores = values.clone();
AttentionOp::simd_softmax_row(&mut scores);
for (j, &w) in scores.iter().enumerate() {
assert!(
w > 0.0,
"FALSIFIED ATT-005: weight[{j}] = {w} not > 0 for input {values:?}"
);
assert!(
w < 1.0,
"FALSIFIED ATT-005: weight[{j}] = {w} not < 1 for input {values:?} (m >= 2)"
);
}
}
}
#[test]
fn falsify_att_002b_uniform_v_identity() {
let op = AttentionOp::new(2, 4, 8);
let q: Vec<f32> = (0..16).map(|i| (i as f32) * 0.37).collect();
let k: Vec<f32> = (0..32).map(|i| (i as f32) * 0.13).collect();
let v_row = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
let v: Vec<f32> = v_row.iter().copied().cycle().take(32).collect();
let output = op.execute((q, k, v), Backend::Scalar).unwrap();
for qi in 0..2 {
for d in 0..8 {
let diff = (output[qi * 8 + d] - v_row[d]).abs();
assert!(
diff < 1e-5,
"FALSIFIED ATT-002: uniform V output[{qi}][{d}] = {}, expected {}",
output[qi * 8 + d],
v_row[d]
);
}
}
}
}