pub fn fused_attention_decode(
q: &[f32],
k_cache: &[f32],
v_cache: &[f32],
head_dim: usize,
seq_len: usize,
output: &mut [f32],
) {
contract_pre_attention!();
assert_eq!(q.len(), head_dim);
assert_eq!(k_cache.len(), seq_len * head_dim);
assert_eq!(v_cache.len(), seq_len * head_dim);
assert_eq!(output.len(), head_dim);
if seq_len == 0 {
output.fill(0.0);
contract_post_attention!(output);
return;
}
#[cfg(target_arch = "x86_64")]
if std::arch::is_x86_feature_detected!("avx2") && std::arch::is_x86_feature_detected!("fma") {
unsafe {
fused_attention_decode_avx2(q, k_cache, v_cache, head_dim, seq_len, output);
}
contract_post_attention!(output);
return;
}
fused_attention_decode_scalar(q, k_cache, v_cache, head_dim, seq_len, output);
contract_post_attention!(output);
}
fn fused_attention_decode_scalar(
q: &[f32],
k_cache: &[f32],
v_cache: &[f32],
head_dim: usize,
seq_len: usize,
output: &mut [f32],
) {
let scale = 1.0 / (head_dim as f32).sqrt();
let mut running_max = f32::NEG_INFINITY;
let mut running_sum = 0.0f32;
output.fill(0.0);
for s in 0..seq_len {
let k_row = &k_cache[s * head_dim..(s + 1) * head_dim];
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q[d] * k_row[d];
}
let score = dot * scale;
let new_max = running_max.max(score);
if running_max != f32::NEG_INFINITY {
let correction = (running_max - new_max).exp();
running_sum *= correction;
for val in output.iter_mut() {
*val *= correction;
}
}
let w = (score - new_max).exp();
running_sum += w;
let v_row = &v_cache[s * head_dim..(s + 1) * head_dim];
for d in 0..head_dim {
output[d] += w * v_row[d];
}
running_max = new_max;
}
if running_sum > 0.0 {
let inv_sum = 1.0 / running_sum;
for val in output.iter_mut() {
*val *= inv_sum;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn fused_attention_decode_avx2(
q: &[f32],
k_cache: &[f32],
v_cache: &[f32],
head_dim: usize,
seq_len: usize,
output: &mut [f32],
) {
unsafe {
use std::arch::x86_64::*;
let scale = 1.0 / (head_dim as f32).sqrt();
let d8 = head_dim / 8 * 8;
let mut running_max = f32::NEG_INFINITY;
let mut running_sum = 0.0f32;
output.fill(0.0);
for s in 0..seq_len {
let k_ptr = k_cache.as_ptr().add(s * head_dim);
let q_ptr = q.as_ptr();
let mut dot0 = _mm256_setzero_ps();
let mut dot1 = _mm256_setzero_ps();
let mut dot2 = _mm256_setzero_ps();
let mut dot3 = _mm256_setzero_ps();
let mut j = 0;
let d32 = head_dim / 32 * 32;
while j < d32 {
dot0 = _mm256_fmadd_ps(
_mm256_loadu_ps(q_ptr.add(j)),
_mm256_loadu_ps(k_ptr.add(j)),
dot0,
);
dot1 = _mm256_fmadd_ps(
_mm256_loadu_ps(q_ptr.add(j + 8)),
_mm256_loadu_ps(k_ptr.add(j + 8)),
dot1,
);
dot2 = _mm256_fmadd_ps(
_mm256_loadu_ps(q_ptr.add(j + 16)),
_mm256_loadu_ps(k_ptr.add(j + 16)),
dot2,
);
dot3 = _mm256_fmadd_ps(
_mm256_loadu_ps(q_ptr.add(j + 24)),
_mm256_loadu_ps(k_ptr.add(j + 24)),
dot3,
);
j += 32;
}
while j < d8 {
dot0 = _mm256_fmadd_ps(
_mm256_loadu_ps(q_ptr.add(j)),
_mm256_loadu_ps(k_ptr.add(j)),
dot0,
);
j += 8;
}
dot0 = _mm256_add_ps(_mm256_add_ps(dot0, dot1), _mm256_add_ps(dot2, dot3));
let hi = _mm256_extractf128_ps(dot0, 1);
let lo = _mm256_castps256_ps128(dot0);
let sum128 = _mm_add_ps(lo, hi);
let sum64 = _mm_hadd_ps(sum128, sum128);
let sum32 = _mm_hadd_ps(sum64, sum64);
let mut dot_scalar = _mm_cvtss_f32(sum32);
while j < head_dim {
dot_scalar += *q.get_unchecked(j) * *k_cache.get_unchecked(s * head_dim + j);
j += 1;
}
let score = dot_scalar * scale;
let new_max = running_max.max(score);
if running_max != f32::NEG_INFINITY {
let correction = (running_max - new_max).exp();
running_sum *= correction;
let corr_v = _mm256_set1_ps(correction);
let out_ptr = output.as_mut_ptr();
let mut d = 0;
while d < d8 {
let ov = _mm256_loadu_ps(out_ptr.add(d));
_mm256_storeu_ps(out_ptr.add(d), _mm256_mul_ps(ov, corr_v));
d += 8;
}
while d < head_dim {
*output.get_unchecked_mut(d) *= correction;
d += 1;
}
}
let w = (score - new_max).exp();
running_sum += w;
let w_v = _mm256_set1_ps(w);
let v_ptr = v_cache.as_ptr().add(s * head_dim);
let out_ptr = output.as_mut_ptr();
let mut d = 0;
while d < d8 {
let ov = _mm256_loadu_ps(out_ptr.add(d));
let vv = _mm256_loadu_ps(v_ptr.add(d));
_mm256_storeu_ps(out_ptr.add(d), _mm256_fmadd_ps(w_v, vv, ov));
d += 8;
}
while d < head_dim {
*output.get_unchecked_mut(d) += w * *v_cache.get_unchecked(s * head_dim + d);
d += 1;
}
running_max = new_max;
}
if running_sum > 0.0 {
let inv_v = _mm256_set1_ps(1.0 / running_sum);
let out_ptr = output.as_mut_ptr();
let mut d = 0;
while d < d8 {
let ov = _mm256_loadu_ps(out_ptr.add(d));
_mm256_storeu_ps(out_ptr.add(d), _mm256_mul_ps(ov, inv_v));
d += 8;
}
while d < head_dim {
*output.get_unchecked_mut(d) /= running_sum;
d += 1;
}
}
} }
#[cfg(test)]
fn unfused_attention_decode_reference(
q: &[f32],
k_cache: &[f32],
v_cache: &[f32],
head_dim: usize,
seq_len: usize,
output: &mut [f32],
) {
let scale = 1.0 / (head_dim as f32).sqrt();
let mut scores = vec![0.0f32; seq_len];
for s in 0..seq_len {
let k_row = &k_cache[s * head_dim..(s + 1) * head_dim];
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q[d] * k_row[d];
}
scores[s] = dot * scale;
}
let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut sum = 0.0f32;
for s in scores.iter_mut() {
*s = (*s - max_score).exp();
sum += *s;
}
for s in scores.iter_mut() {
*s /= sum;
}
output.fill(0.0);
for s in 0..seq_len {
let v_row = &v_cache[s * head_dim..(s + 1) * head_dim];
let w = scores[s];
for d in 0..head_dim {
output[d] += w * v_row[d];
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn gen_data(head_dim: usize, seq_len: usize) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let q: Vec<f32> = (0..head_dim).map(|i| ((i * 7 + 3) % 100) as f32 / 100.0 - 0.5).collect();
let k: Vec<f32> =
(0..seq_len * head_dim).map(|i| ((i * 13 + 7) % 100) as f32 / 100.0 - 0.5).collect();
let v: Vec<f32> =
(0..seq_len * head_dim).map(|i| ((i * 11 + 5) % 100) as f32 / 100.0 - 0.5).collect();
(q, k, v)
}
#[test]
fn test_fused_matches_reference() {
for &(d, s) in &[(128, 64), (128, 512), (128, 1024), (64, 256)] {
let (q, k, v) = gen_data(d, s);
let mut out_fused = vec![0.0f32; d];
let mut out_ref = vec![0.0f32; d];
fused_attention_decode(&q, &k, &v, d, s, &mut out_fused);
unfused_attention_decode_reference(&q, &k, &v, d, s, &mut out_ref);
let max_diff = out_fused
.iter()
.zip(out_ref.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(max_diff < 1e-4, "FALSIFY-FLASH-ATTN-001: d={d} s={s} max_diff={max_diff}");
}
}
#[test]
fn test_softmax_sums_to_one() {
let d = 128;
let s = 512;
let (q, k, v) = gen_data(d, s);
let scale = 1.0 / (d as f32).sqrt();
let mut running_max = f32::NEG_INFINITY;
let mut running_sum = 0.0f32;
for i in 0..s {
let k_row = &k[i * d..(i + 1) * d];
let dot: f32 = q.iter().zip(k_row.iter()).map(|(a, b)| a * b).sum();
let score = dot * scale;
let new_max = running_max.max(score);
if running_max != f32::NEG_INFINITY {
running_sum *= (running_max - new_max).exp();
}
running_sum += (score - new_max).exp();
running_max = new_max;
}
assert!(running_sum > 0.0);
let mut out = vec![0.0f32; d];
fused_attention_decode(&q, &k, &v, d, s, &mut out);
assert!(out.iter().all(|x| x.is_finite()), "FALSIFY-FLASH-ATTN-004: NaN/Inf in output");
}
#[test]
fn test_fused_seq_len_one() {
let d = 128;
let (q, k, v) = gen_data(d, 1);
let mut out_fused = vec![0.0f32; d];
let mut out_ref = vec![0.0f32; d];
fused_attention_decode(&q, &k, &v, d, 1, &mut out_fused);
unfused_attention_decode_reference(&q, &k, &v, d, 1, &mut out_ref);
let max_diff =
out_fused.iter().zip(out_ref.iter()).map(|(a, b)| (a - b).abs()).fold(0.0f32, f32::max);
assert!(max_diff < 1e-6, "seq_len=1: max_diff={max_diff}");
}
#[test]
fn test_fused_seq_len_zero() {
let d = 128;
let q = vec![1.0f32; d];
let mut out = vec![99.0f32; d];
fused_attention_decode(&q, &[], &[], d, 0, &mut out);
assert!(out.iter().all(|&x| x == 0.0), "seq_len=0 should zero output");
}
#[test]
fn test_fused_perf_smoke() {
let d = 128;
let s = 512;
let (q, k, v) = gen_data(d, s);
let mut out = vec![0.0f32; d];
fused_attention_decode(&q, &k, &v, d, s, &mut out);
assert!(out.iter().any(|&x| x != 0.0), "Output should be non-zero");
}
}