1pub fn fused_attention_decode(
30 q: &[f32],
31 k_cache: &[f32],
32 v_cache: &[f32],
33 head_dim: usize,
34 seq_len: usize,
35 output: &mut [f32],
36) {
37 contract_pre_attention!();
38 assert_eq!(q.len(), head_dim);
39 assert_eq!(k_cache.len(), seq_len * head_dim);
40 assert_eq!(v_cache.len(), seq_len * head_dim);
41 assert_eq!(output.len(), head_dim);
42
43 if seq_len == 0 {
44 output.fill(0.0);
45 contract_post_attention!(output);
46 return;
47 }
48
49 #[cfg(target_arch = "x86_64")]
50 if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
51 unsafe {
53 fused_attention_decode_avx2(q, k_cache, v_cache, head_dim, seq_len, output);
54 }
55 contract_post_attention!(output);
56 return;
57 }
58
59 fused_attention_decode_scalar(q, k_cache, v_cache, head_dim, seq_len, output);
60 contract_post_attention!(output);
61}
62
63fn fused_attention_decode_scalar(
65 q: &[f32],
66 k_cache: &[f32],
67 v_cache: &[f32],
68 head_dim: usize,
69 seq_len: usize,
70 output: &mut [f32],
71) {
72 let scale = 1.0 / (head_dim as f32).sqrt();
73 let mut running_max = f32::NEG_INFINITY;
74 let mut running_sum = 0.0f32;
75 output.fill(0.0);
76
77 for s in 0..seq_len {
78 let k_row = &k_cache[s * head_dim..(s + 1) * head_dim];
79 let mut dot = 0.0f32;
80 for d in 0..head_dim {
81 dot += q[d] * k_row[d];
82 }
83 let score = dot * scale;
84
85 let new_max = running_max.max(score);
86 if running_max != f32::NEG_INFINITY {
87 let correction = (running_max - new_max).exp();
88 running_sum *= correction;
89 for val in output.iter_mut() {
90 *val *= correction;
91 }
92 }
93
94 let w = (score - new_max).exp();
95 running_sum += w;
96
97 let v_row = &v_cache[s * head_dim..(s + 1) * head_dim];
98 for d in 0..head_dim {
99 output[d] += w * v_row[d];
100 }
101 running_max = new_max;
102 }
103
104 if running_sum > 0.0 {
105 let inv_sum = 1.0 / running_sum;
106 for val in output.iter_mut() {
107 *val *= inv_sum;
108 }
109 }
110}
111
112#[cfg(target_arch = "x86_64")]
123#[target_feature(enable = "avx2", enable = "fma")]
124unsafe fn fused_attention_decode_avx2(
125 q: &[f32],
126 k_cache: &[f32],
127 v_cache: &[f32],
128 head_dim: usize,
129 seq_len: usize,
130 output: &mut [f32],
131) {
132 unsafe {
133 use std::arch::x86_64::*;
134
135 let scale = 1.0 / (head_dim as f32).sqrt();
136 let d8 = head_dim / 8 * 8;
137
138 let mut running_max = f32::NEG_INFINITY;
139 let mut running_sum = 0.0f32;
140 output.fill(0.0);
141
142 for s in 0..seq_len {
145 let k_ptr = k_cache.as_ptr().add(s * head_dim);
146 let q_ptr = q.as_ptr();
147
148 let mut dot0 = _mm256_setzero_ps();
150 let mut dot1 = _mm256_setzero_ps();
151 let mut dot2 = _mm256_setzero_ps();
152 let mut dot3 = _mm256_setzero_ps();
153
154 let mut j = 0;
155 let d32 = head_dim / 32 * 32;
156 while j < d32 {
157 dot0 = _mm256_fmadd_ps(
158 _mm256_loadu_ps(q_ptr.add(j)),
159 _mm256_loadu_ps(k_ptr.add(j)),
160 dot0,
161 );
162 dot1 = _mm256_fmadd_ps(
163 _mm256_loadu_ps(q_ptr.add(j + 8)),
164 _mm256_loadu_ps(k_ptr.add(j + 8)),
165 dot1,
166 );
167 dot2 = _mm256_fmadd_ps(
168 _mm256_loadu_ps(q_ptr.add(j + 16)),
169 _mm256_loadu_ps(k_ptr.add(j + 16)),
170 dot2,
171 );
172 dot3 = _mm256_fmadd_ps(
173 _mm256_loadu_ps(q_ptr.add(j + 24)),
174 _mm256_loadu_ps(k_ptr.add(j + 24)),
175 dot3,
176 );
177 j += 32;
178 }
179 while j < d8 {
180 dot0 = _mm256_fmadd_ps(
181 _mm256_loadu_ps(q_ptr.add(j)),
182 _mm256_loadu_ps(k_ptr.add(j)),
183 dot0,
184 );
185 j += 8;
186 }
187
188 dot0 = _mm256_add_ps(_mm256_add_ps(dot0, dot1), _mm256_add_ps(dot2, dot3));
190 let hi = _mm256_extractf128_ps(dot0, 1);
192 let lo = _mm256_castps256_ps128(dot0);
193 let sum128 = _mm_add_ps(lo, hi);
194 let sum64 = _mm_hadd_ps(sum128, sum128);
196 let sum32 = _mm_hadd_ps(sum64, sum64);
197 let mut dot_scalar = _mm_cvtss_f32(sum32);
198
199 while j < head_dim {
201 dot_scalar += *q.get_unchecked(j) * *k_cache.get_unchecked(s * head_dim + j);
202 j += 1;
203 }
204
205 let score = dot_scalar * scale;
206
207 let new_max = running_max.max(score);
209 if running_max != f32::NEG_INFINITY {
210 let correction = (running_max - new_max).exp();
211 running_sum *= correction;
212
213 let corr_v = _mm256_set1_ps(correction);
215 let out_ptr = output.as_mut_ptr();
216 let mut d = 0;
217 while d < d8 {
218 let ov = _mm256_loadu_ps(out_ptr.add(d));
219 _mm256_storeu_ps(out_ptr.add(d), _mm256_mul_ps(ov, corr_v));
220 d += 8;
221 }
222 while d < head_dim {
223 *output.get_unchecked_mut(d) *= correction;
224 d += 1;
225 }
226 }
227
228 let w = (score - new_max).exp();
229 running_sum += w;
230
231 let w_v = _mm256_set1_ps(w);
233 let v_ptr = v_cache.as_ptr().add(s * head_dim);
234 let out_ptr = output.as_mut_ptr();
235 let mut d = 0;
236 while d < d8 {
237 let ov = _mm256_loadu_ps(out_ptr.add(d));
238 let vv = _mm256_loadu_ps(v_ptr.add(d));
239 _mm256_storeu_ps(out_ptr.add(d), _mm256_fmadd_ps(w_v, vv, ov));
240 d += 8;
241 }
242 while d < head_dim {
243 *output.get_unchecked_mut(d) += w * *v_cache.get_unchecked(s * head_dim + d);
244 d += 1;
245 }
246
247 running_max = new_max;
248 }
249
250 if running_sum > 0.0 {
252 let inv_v = _mm256_set1_ps(1.0 / running_sum);
253 let out_ptr = output.as_mut_ptr();
254 let mut d = 0;
255 while d < d8 {
256 let ov = _mm256_loadu_ps(out_ptr.add(d));
257 _mm256_storeu_ps(out_ptr.add(d), _mm256_mul_ps(ov, inv_v));
258 d += 8;
259 }
260 while d < head_dim {
261 *output.get_unchecked_mut(d) /= running_sum;
262 d += 1;
263 }
264 }
265 } }
267
268#[cfg(test)]
270fn unfused_attention_decode_reference(
271 q: &[f32],
272 k_cache: &[f32],
273 v_cache: &[f32],
274 head_dim: usize,
275 seq_len: usize,
276 output: &mut [f32],
277) {
278 let scale = 1.0 / (head_dim as f32).sqrt();
279
280 let mut scores = vec![0.0f32; seq_len];
282 for s in 0..seq_len {
283 let k_row = &k_cache[s * head_dim..(s + 1) * head_dim];
284 let mut dot = 0.0f32;
285 for d in 0..head_dim {
286 dot += q[d] * k_row[d];
287 }
288 scores[s] = dot * scale;
289 }
290
291 let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
293 let mut sum = 0.0f32;
294 for s in scores.iter_mut() {
295 *s = (*s - max_score).exp();
296 sum += *s;
297 }
298 for s in scores.iter_mut() {
299 *s /= sum;
300 }
301
302 output.fill(0.0);
304 for s in 0..seq_len {
305 let v_row = &v_cache[s * head_dim..(s + 1) * head_dim];
306 let w = scores[s];
307 for d in 0..head_dim {
308 output[d] += w * v_row[d];
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 fn gen_data(head_dim: usize, seq_len: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
318 let q: Vec<f32> = (0..head_dim).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
319 let k: Vec<f32> =
320 (0..seq_len * head_dim).map(|i| ((i * 13 + 7) % 100) as f32 / 100.0 - 0.5).collect();
321 let v: Vec<f32> =
322 (0..seq_len * head_dim).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0 - 0.5).collect();
323 (q, k, v)
324 }
325
326 #[test]
328 fn test_fused_matches_reference() {
329 for &(d, s) in &[(128, 64), (128, 512), (128, 1024), (64, 256)] {
330 let (q, k, v) = gen_data(d, s);
331 let mut out_fused = vec![0.0f32; d];
332 let mut out_ref = vec![0.0f32; d];
333
334 fused_attention_decode(&q, &k, &v, d, s, &mut out_fused);
335 unfused_attention_decode_reference(&q, &k, &v, d, s, &mut out_ref);
336
337 let max_diff = out_fused
338 .iter()
339 .zip(out_ref.iter())
340 .map(|(a, b)| (a - b).abs())
341 .fold(0.0f32, f32::max);
342
343 assert!(max_diff < 1e-4, "FALSIFY-FLASH-ATTN-001: d={d} s={s} max_diff={max_diff}");
344 }
345 }
346
347 #[test]
349 fn test_softmax_sums_to_one() {
350 let d = 128;
351 let s = 512;
352 let (q, k, v) = gen_data(d, s);
353 let scale = 1.0 / (d as f32).sqrt();
354
355 let mut running_max = f32::NEG_INFINITY;
357 let mut running_sum = 0.0f32;
358
359 for i in 0..s {
360 let k_row = &k[i * d..(i + 1) * d];
361 let dot: f32 = q.iter().zip(k_row.iter()).map(|(a, b)| a * b).sum();
362 let score = dot * scale;
363 let new_max = running_max.max(score);
364 if running_max != f32::NEG_INFINITY {
365 running_sum *= (running_max - new_max).exp();
366 }
367 running_sum += (score - new_max).exp();
368 running_max = new_max;
369 }
370
371 assert!(running_sum > 0.0);
373
374 let mut out = vec![0.0f32; d];
376 fused_attention_decode(&q, &k, &v, d, s, &mut out);
377 assert!(out.iter().all(|x| x.is_finite()), "FALSIFY-FLASH-ATTN-004: NaN/Inf in output");
379 }
380
381 #[test]
383 fn test_fused_seq_len_one() {
384 let d = 128;
385 let (q, k, v) = gen_data(d, 1);
386 let mut out_fused = vec![0.0f32; d];
387 let mut out_ref = vec![0.0f32; d];
388
389 fused_attention_decode(&q, &k, &v, d, 1, &mut out_fused);
390 unfused_attention_decode_reference(&q, &k, &v, d, 1, &mut out_ref);
391
392 let max_diff =
394 out_fused.iter().zip(out_ref.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
395 assert!(max_diff < 1e-6, "seq_len=1: max_diff={max_diff}");
396 }
397
398 #[test]
400 fn test_fused_seq_len_zero() {
401 let d = 128;
402 let q = vec![1.0f32; d];
403 let mut out = vec![99.0f32; d];
404 fused_attention_decode(&q, &[], &[], d, 0, &mut out);
405 assert!(out.iter().all(|&x| x == 0.0), "seq_len=0 should zero output");
406 }
407
408 #[test]
410 fn test_fused_perf_smoke() {
411 let d = 128;
412 let s = 512;
413 let (q, k, v) = gen_data(d, s);
414 let mut out = vec![0.0f32; d];
415
416 fused_attention_decode(&q, &k, &v, d, s, &mut out);
418 assert!(out.iter().any(|&x| x != 0.0), "Output should be non-zero");
419 }
420}