Skip to main content

trueno/blis/
attention.rs

1//! Fused single-query attention for autoregressive decode.
2//!
3//! Computes: output = softmax(Q @ K^T / sqrt(head_dim)) @ V
4//! in a single pass over the KV cache without materializing the
5//! (1, seq_len) attention scores to memory.
6//!
7//! Uses online softmax (Milakov & Gimelshein, FlashAttention [64] Algorithm 1):
8//! For each block of K/V rows:
9//!   1. Compute partial scores = Q · K_block^T / sqrt(D)
10//!   2. Update running max and running sum
11//!   3. Rescale previous output accumulator
12//!   4. Accumulate exp(scores - max) @ V_block into output
13//!
14//! Contract: contracts/cgp/cgp-flash-attn-cpu-v1.yaml
15//! FALSIFY: FALSIFY-FLASH-ATTN-001 through 004
16
17/// Fused decode attention: output = softmax(Q @ K^T / sqrt(D)) @ V.
18///
19/// No heap allocation. Scores stay in a stack buffer (block_size elements).
20/// AVX2 GEMV for dot products, scalar exp for transcendentals.
21///
22/// # Arguments
23/// - `q`: query vector, length `head_dim`
24/// - `k_cache`: key cache, row-major (seq_len × head_dim)
25/// - `v_cache`: value cache, row-major (seq_len × head_dim)
26/// - `head_dim`: dimension D
27/// - `seq_len`: number of cached K/V rows
28/// - `output`: result buffer, length `head_dim` (will be overwritten)
29pub fn fused_attention_decode(
30    q: &[f32],
31    k_cache: &[f32],
32    v_cache: &[f32],
33    head_dim: usize,
34    seq_len: usize,
35    output: &mut [f32],
36) {
37    contract_pre_attention!();
38    assert_eq!(q.len(), head_dim);
39    assert_eq!(k_cache.len(), seq_len * head_dim);
40    assert_eq!(v_cache.len(), seq_len * head_dim);
41    assert_eq!(output.len(), head_dim);
42
43    if seq_len == 0 {
44        output.fill(0.0);
45        contract_post_attention!(output);
46        return;
47    }
48
49    #[cfg(target_arch = "x86_64")]
50    if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
51        // SAFETY: AVX2+FMA verified. Slice lengths checked by asserts above.
52        unsafe {
53            fused_attention_decode_avx2(q, k_cache, v_cache, head_dim, seq_len, output);
54        }
55        contract_post_attention!(output);
56        return;
57    }
58
59    fused_attention_decode_scalar(q, k_cache, v_cache, head_dim, seq_len, output);
60    contract_post_attention!(output);
61}
62
63/// Scalar fallback for non-x86 or non-AVX2 platforms.
64fn fused_attention_decode_scalar(
65    q: &[f32],
66    k_cache: &[f32],
67    v_cache: &[f32],
68    head_dim: usize,
69    seq_len: usize,
70    output: &mut [f32],
71) {
72    let scale = 1.0 / (head_dim as f32).sqrt();
73    let mut running_max = f32::NEG_INFINITY;
74    let mut running_sum = 0.0f32;
75    output.fill(0.0);
76
77    for s in 0..seq_len {
78        let k_row = &k_cache[s * head_dim..(s + 1) * head_dim];
79        let mut dot = 0.0f32;
80        for d in 0..head_dim {
81            dot += q[d] * k_row[d];
82        }
83        let score = dot * scale;
84
85        let new_max = running_max.max(score);
86        if running_max != f32::NEG_INFINITY {
87            let correction = (running_max - new_max).exp();
88            running_sum *= correction;
89            for val in output.iter_mut() {
90                *val *= correction;
91            }
92        }
93
94        let w = (score - new_max).exp();
95        running_sum += w;
96
97        let v_row = &v_cache[s * head_dim..(s + 1) * head_dim];
98        for d in 0..head_dim {
99            output[d] += w * v_row[d];
100        }
101        running_max = new_max;
102    }
103
104    if running_sum > 0.0 {
105        let inv_sum = 1.0 / running_sum;
106        for val in output.iter_mut() {
107            *val *= inv_sum;
108        }
109    }
110}
111
112/// AVX2 fused attention: SIMD dot product, SIMD V accumulation, SIMD rescale.
113///
114/// Three SIMD-accelerated hot paths:
115/// 1. Q·K dot product: 4 ymm accumulators × 8 f32 = 32-wide, hadd reduction
116/// 2. Output rescale (correction *= exp(...)): broadcast + vfmadd
117/// 3. w * V accumulation: broadcast weight, vfmadd per 8 elements
118///
119/// Uses AVX2 (not AVX-512) because attention is bandwidth-bound [60][61]:
120/// Zen 4 throttles clock during 512-bit ops, and GEMV-class workloads
121/// cannot compensate with wider SIMD.
122#[cfg(target_arch = "x86_64")]
123#[target_feature(enable = "avx2", enable = "fma")]
124unsafe fn fused_attention_decode_avx2(
125    q: &[f32],
126    k_cache: &[f32],
127    v_cache: &[f32],
128    head_dim: usize,
129    seq_len: usize,
130    output: &mut [f32],
131) {
132    unsafe {
133        use std::arch::x86_64::*;
134
135        let scale = 1.0 / (head_dim as f32).sqrt();
136        let d8 = head_dim / 8 * 8;
137
138        let mut running_max = f32::NEG_INFINITY;
139        let mut running_sum = 0.0f32;
140        output.fill(0.0);
141
142        // Process one K/V row per iteration (online softmax, no blocking needed
143        // since we SIMD the inner dim, not the seq_len dim).
144        for s in 0..seq_len {
145            let k_ptr = k_cache.as_ptr().add(s * head_dim);
146            let q_ptr = q.as_ptr();
147
148            // SIMD dot product: Q · K[s] with 4 ymm accumulators
149            let mut dot0 = _mm256_setzero_ps();
150            let mut dot1 = _mm256_setzero_ps();
151            let mut dot2 = _mm256_setzero_ps();
152            let mut dot3 = _mm256_setzero_ps();
153
154            let mut j = 0;
155            let d32 = head_dim / 32 * 32;
156            while j < d32 {
157                dot0 = _mm256_fmadd_ps(
158                    _mm256_loadu_ps(q_ptr.add(j)),
159                    _mm256_loadu_ps(k_ptr.add(j)),
160                    dot0,
161                );
162                dot1 = _mm256_fmadd_ps(
163                    _mm256_loadu_ps(q_ptr.add(j + 8)),
164                    _mm256_loadu_ps(k_ptr.add(j + 8)),
165                    dot1,
166                );
167                dot2 = _mm256_fmadd_ps(
168                    _mm256_loadu_ps(q_ptr.add(j + 16)),
169                    _mm256_loadu_ps(k_ptr.add(j + 16)),
170                    dot2,
171                );
172                dot3 = _mm256_fmadd_ps(
173                    _mm256_loadu_ps(q_ptr.add(j + 24)),
174                    _mm256_loadu_ps(k_ptr.add(j + 24)),
175                    dot3,
176                );
177                j += 32;
178            }
179            while j < d8 {
180                dot0 = _mm256_fmadd_ps(
181                    _mm256_loadu_ps(q_ptr.add(j)),
182                    _mm256_loadu_ps(k_ptr.add(j)),
183                    dot0,
184                );
185                j += 8;
186            }
187
188            // Horizontal sum: dot0+dot1+dot2+dot3 → scalar
189            dot0 = _mm256_add_ps(_mm256_add_ps(dot0, dot1), _mm256_add_ps(dot2, dot3));
190            // 256-bit → 128-bit: add high and low halves
191            let hi = _mm256_extractf128_ps(dot0, 1);
192            let lo = _mm256_castps256_ps128(dot0);
193            let sum128 = _mm_add_ps(lo, hi);
194            // 128-bit → scalar: hadd twice
195            let sum64 = _mm_hadd_ps(sum128, sum128);
196            let sum32 = _mm_hadd_ps(sum64, sum64);
197            let mut dot_scalar = _mm_cvtss_f32(sum32);
198
199            // Scalar remainder
200            while j < head_dim {
201                dot_scalar += *q.get_unchecked(j) * *k_cache.get_unchecked(s * head_dim + j);
202                j += 1;
203            }
204
205            let score = dot_scalar * scale;
206
207            // Online softmax update
208            let new_max = running_max.max(score);
209            if running_max != f32::NEG_INFINITY {
210                let correction = (running_max - new_max).exp();
211                running_sum *= correction;
212
213                // SIMD rescale output: output[d] *= correction
214                let corr_v = _mm256_set1_ps(correction);
215                let out_ptr = output.as_mut_ptr();
216                let mut d = 0;
217                while d < d8 {
218                    let ov = _mm256_loadu_ps(out_ptr.add(d));
219                    _mm256_storeu_ps(out_ptr.add(d), _mm256_mul_ps(ov, corr_v));
220                    d += 8;
221                }
222                while d < head_dim {
223                    *output.get_unchecked_mut(d) *= correction;
224                    d += 1;
225                }
226            }
227
228            let w = (score - new_max).exp();
229            running_sum += w;
230
231            // SIMD V accumulation: output[d] += w * V[s][d]
232            let w_v = _mm256_set1_ps(w);
233            let v_ptr = v_cache.as_ptr().add(s * head_dim);
234            let out_ptr = output.as_mut_ptr();
235            let mut d = 0;
236            while d < d8 {
237                let ov = _mm256_loadu_ps(out_ptr.add(d));
238                let vv = _mm256_loadu_ps(v_ptr.add(d));
239                _mm256_storeu_ps(out_ptr.add(d), _mm256_fmadd_ps(w_v, vv, ov));
240                d += 8;
241            }
242            while d < head_dim {
243                *output.get_unchecked_mut(d) += w * *v_cache.get_unchecked(s * head_dim + d);
244                d += 1;
245            }
246
247            running_max = new_max;
248        }
249
250        // Final normalization: output /= running_sum
251        if running_sum > 0.0 {
252            let inv_v = _mm256_set1_ps(1.0 / running_sum);
253            let out_ptr = output.as_mut_ptr();
254            let mut d = 0;
255            while d < d8 {
256                let ov = _mm256_loadu_ps(out_ptr.add(d));
257                _mm256_storeu_ps(out_ptr.add(d), _mm256_mul_ps(ov, inv_v));
258                d += 8;
259            }
260            while d < head_dim {
261                *output.get_unchecked_mut(d) /= running_sum;
262                d += 1;
263            }
264        }
265    } // unsafe
266}
267
268/// Unfused reference: separate Q@K^T, softmax, scores@V for validation.
269#[cfg(test)]
270fn unfused_attention_decode_reference(
271    q: &[f32],
272    k_cache: &[f32],
273    v_cache: &[f32],
274    head_dim: usize,
275    seq_len: usize,
276    output: &mut [f32],
277) {
278    let scale = 1.0 / (head_dim as f32).sqrt();
279
280    // Q @ K^T → scores
281    let mut scores = vec![0.0f32; seq_len];
282    for s in 0..seq_len {
283        let k_row = &k_cache[s * head_dim..(s + 1) * head_dim];
284        let mut dot = 0.0f32;
285        for d in 0..head_dim {
286            dot += q[d] * k_row[d];
287        }
288        scores[s] = dot * scale;
289    }
290
291    // softmax(scores)
292    let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
293    let mut sum = 0.0f32;
294    for s in scores.iter_mut() {
295        *s = (*s - max_score).exp();
296        sum += *s;
297    }
298    for s in scores.iter_mut() {
299        *s /= sum;
300    }
301
302    // scores @ V → output
303    output.fill(0.0);
304    for s in 0..seq_len {
305        let v_row = &v_cache[s * head_dim..(s + 1) * head_dim];
306        let w = scores[s];
307        for d in 0..head_dim {
308            output[d] += w * v_row[d];
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    fn gen_data(head_dim: usize, seq_len: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
318        let q: Vec<f32> = (0..head_dim).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
319        let k: Vec<f32> =
320            (0..seq_len * head_dim).map(|i| ((i * 13 + 7) % 100) as f32 / 100.0 - 0.5).collect();
321        let v: Vec<f32> =
322            (0..seq_len * head_dim).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0 - 0.5).collect();
323        (q, k, v)
324    }
325
326    /// FALSIFY-FLASH-ATTN-001: Fused matches unfused reference.
327    #[test]
328    fn test_fused_matches_reference() {
329        for &(d, s) in &[(128, 64), (128, 512), (128, 1024), (64, 256)] {
330            let (q, k, v) = gen_data(d, s);
331            let mut out_fused = vec![0.0f32; d];
332            let mut out_ref = vec![0.0f32; d];
333
334            fused_attention_decode(&q, &k, &v, d, s, &mut out_fused);
335            unfused_attention_decode_reference(&q, &k, &v, d, s, &mut out_ref);
336
337            let max_diff = out_fused
338                .iter()
339                .zip(out_ref.iter())
340                .map(|(a, b)| (a - b).abs())
341                .fold(0.0f32, f32::max);
342
343            assert!(max_diff < 1e-4, "FALSIFY-FLASH-ATTN-001: d={d} s={s} max_diff={max_diff}");
344        }
345    }
346
347    /// FALSIFY-FLASH-ATTN-004: softmax weights sum to 1.0.
348    #[test]
349    fn test_softmax_sums_to_one() {
350        let d = 128;
351        let s = 512;
352        let (q, k, v) = gen_data(d, s);
353        let scale = 1.0 / (d as f32).sqrt();
354
355        // Compute scores via fused path's logic
356        let mut running_max = f32::NEG_INFINITY;
357        let mut running_sum = 0.0f32;
358
359        for i in 0..s {
360            let k_row = &k[i * d..(i + 1) * d];
361            let dot: f32 = q.iter().zip(k_row.iter()).map(|(a, b)| a * b).sum();
362            let score = dot * scale;
363            let new_max = running_max.max(score);
364            if running_max != f32::NEG_INFINITY {
365                running_sum *= (running_max - new_max).exp();
366            }
367            running_sum += (score - new_max).exp();
368            running_max = new_max;
369        }
370
371        // Sum should be positive and normalization should yield ~1.0
372        assert!(running_sum > 0.0);
373
374        // Verify via unfused reference
375        let mut out = vec![0.0f32; d];
376        fused_attention_decode(&q, &k, &v, d, s, &mut out);
377        // Output should be bounded (not NaN or Inf)
378        assert!(out.iter().all(|x| x.is_finite()), "FALSIFY-FLASH-ATTN-004: NaN/Inf in output");
379    }
380
381    /// FALSIFY-FLASH-ATTN-001b: Edge case — seq_len=1.
382    #[test]
383    fn test_fused_seq_len_one() {
384        let d = 128;
385        let (q, k, v) = gen_data(d, 1);
386        let mut out_fused = vec![0.0f32; d];
387        let mut out_ref = vec![0.0f32; d];
388
389        fused_attention_decode(&q, &k, &v, d, 1, &mut out_fused);
390        unfused_attention_decode_reference(&q, &k, &v, d, 1, &mut out_ref);
391
392        // With seq_len=1, softmax weight is 1.0, output = V[0]
393        let max_diff =
394            out_fused.iter().zip(out_ref.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
395        assert!(max_diff < 1e-6, "seq_len=1: max_diff={max_diff}");
396    }
397
398    /// FALSIFY-FLASH-ATTN-001c: Edge case — seq_len=0.
399    #[test]
400    fn test_fused_seq_len_zero() {
401        let d = 128;
402        let q = vec![1.0f32; d];
403        let mut out = vec![99.0f32; d];
404        fused_attention_decode(&q, &[], &[], d, 0, &mut out);
405        assert!(out.iter().all(|&x| x == 0.0), "seq_len=0 should zero output");
406    }
407
408    /// Benchmark helper: measure fused vs unfused time.
409    #[test]
410    fn test_fused_perf_smoke() {
411        let d = 128;
412        let s = 512;
413        let (q, k, v) = gen_data(d, s);
414        let mut out = vec![0.0f32; d];
415
416        // Just verify it runs without panic at benchmark-representative size
417        fused_attention_decode(&q, &k, &v, d, s, &mut out);
418        assert!(out.iter().any(|&x| x != 0.0), "Output should be non-zero");
419    }
420}