#[cfg(target_arch = "x86_64")]
use std::arch::x86_64::*;
use crate::error::Result;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::Tensor;
pub fn fused_decode_attention(
q: &Tensor<CpuRuntime>,
k: &Tensor<CpuRuntime>,
v: &Tensor<CpuRuntime>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
let q_shape = q.shape();
let k_shape = k.shape();
let batch = q_shape[0];
let seq_len_k = k_shape[2];
debug_assert_eq!(q_shape[1], num_heads);
debug_assert_eq!(q_shape[2], 1);
debug_assert_eq!(q_shape[3], head_dim);
debug_assert_eq!(k_shape[1], num_kv_heads);
debug_assert_eq!(k_shape[3], head_dim);
let scale = (head_dim as f64).sqrt().recip();
let kv_group_size = num_heads / num_kv_heads;
let q_data = unsafe { q.storage().as_host_slice::<f32>() };
let k_data = unsafe { k.storage().as_host_slice::<f32>() };
let v_data = unsafe { v.storage().as_host_slice::<f32>() };
let mut output = vec![0.0f32; batch * num_heads * head_dim];
let mut lse_data = vec![0.0f32; batch * num_heads];
let q_stride_b = num_heads * head_dim; let k_stride_b = num_kv_heads * seq_len_k * head_dim;
let k_stride_h = seq_len_k * head_dim;
let v_stride_b = k_stride_b;
let v_stride_h = k_stride_h;
let mut scores = vec![0.0f32; seq_len_k];
for b in 0..batch {
for h in 0..num_heads {
let kv_h = h / kv_group_size;
let q_offset = b * q_stride_b + h * head_dim;
let k_base = b * k_stride_b + kv_h * k_stride_h;
let v_base = b * v_stride_b + kv_h * v_stride_h;
let q_row = &q_data[q_offset..q_offset + head_dim];
let mut max_score = f32::NEG_INFINITY;
for j in 0..seq_len_k {
let k_row = &k_data[k_base + j * head_dim..k_base + j * head_dim + head_dim];
let score = (dot_f32_simd(q_row, k_row) as f64 * scale) as f32;
scores[j] = score;
if score > max_score {
max_score = score;
}
}
let mut sum_exp = 0.0f64;
for s in scores[..seq_len_k].iter_mut() {
let w = (*s - max_score).exp();
*s = w;
sum_exp += w as f64;
}
let out_offset = b * num_heads * head_dim + h * head_dim;
let out_row = &mut output[out_offset..out_offset + head_dim];
out_row.fill(0.0);
let inv_sum = (1.0f64 / sum_exp) as f32;
for j in 0..seq_len_k {
let w = scores[j] * inv_sum;
let v_row = &v_data[v_base + j * head_dim..v_base + j * head_dim + head_dim];
accumulate_weighted_simd(out_row, v_row, w);
}
lse_data[b * num_heads + h] = max_score + (sum_exp as f32).ln();
}
}
let out_tensor =
Tensor::<CpuRuntime>::from_slice(&output, &[batch, num_heads, 1, head_dim], q.device());
let lse_tensor =
Tensor::<CpuRuntime>::from_slice(&lse_data, &[batch, num_heads, 1], q.device());
Ok((out_tensor, lse_tensor))
}
#[inline]
fn dot_f32_simd(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(target_arch = "x86_64")]
{
let len = a.len();
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
return unsafe {
crate::quant::cpu::kernels::simd::dot_f32::dot_f32_avx2_fma(
a.as_ptr(),
b.as_ptr(),
len,
)
};
}
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[cfg(target_arch = "aarch64")]
unsafe {
crate::quant::cpu::kernels::simd::aarch64::dot_f32::dot_f32_neon(
a.as_ptr(),
b.as_ptr(),
a.len(),
)
}
#[cfg(not(any(target_arch = "x86_64", target_arch = "aarch64")))]
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[inline]
fn accumulate_weighted_simd(out: &mut [f32], v: &[f32], weight: f32) {
debug_assert_eq!(out.len(), v.len());
#[cfg(target_arch = "x86_64")]
{
if is_x86_feature_detected!("avx2") && is_x86_feature_detected!("fma") {
unsafe {
accumulate_weighted_avx2(out.as_mut_ptr(), v.as_ptr(), weight, out.len());
}
return;
}
}
for (o, &vi) in out.iter_mut().zip(v.iter()) {
*o += weight * vi;
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2", enable = "fma")]
unsafe fn accumulate_weighted_avx2(out: *mut f32, v: *const f32, weight: f32, len: usize) {
unsafe {
const LANES: usize = 8;
let chunks = len / LANES;
let remainder = len % LANES;
let w_vec = _mm256_set1_ps(weight);
for i in 0..chunks {
let offset = i * LANES;
let vo = _mm256_loadu_ps(out.add(offset));
let vv = _mm256_loadu_ps(v.add(offset));
let result = _mm256_fmadd_ps(w_vec, vv, vo);
_mm256_storeu_ps(out.add(offset), result);
}
for i in 0..remainder {
let offset = chunks * LANES + i;
*out.add(offset) += weight * *v.add(offset);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use numr::runtime::cpu::CpuDevice;
fn make_tensor(data: &[f32], shape: &[usize]) -> Tensor<CpuRuntime> {
let device = CpuDevice::new();
Tensor::<CpuRuntime>::from_slice(data, shape, &device)
}
#[test]
fn test_decode_attention_basic() {
let q_data: Vec<f32> = (0..8).map(|i| (i as f32 + 1.0) * 0.1).collect();
let k_data: Vec<f32> = (0..24).map(|i| (i as f32) * 0.05).collect();
let v_data: Vec<f32> = (0..24).map(|i| (i as f32 + 1.0) * 0.1).collect();
let q = make_tensor(&q_data, &[1, 2, 1, 4]);
let k = make_tensor(&k_data, &[1, 2, 3, 4]);
let v = make_tensor(&v_data, &[1, 2, 3, 4]);
let (out, _lse) = fused_decode_attention(&q, &k, &v, 2, 2, 4).unwrap();
assert_eq!(out.shape(), &[1, 2, 1, 4]);
let out_data = out.to_vec::<f32>();
let sum: f32 = out_data.iter().map(|x| x.abs()).sum();
assert!(sum > 0.0, "Output should be non-zero");
}
#[test]
fn test_decode_attention_gqa() {
let q_data: Vec<f32> = (0..16).map(|i| (i as f32 + 1.0) * 0.1).collect();
let k_data: Vec<f32> = (0..24).map(|i| (i as f32) * 0.05).collect();
let v_data: Vec<f32> = (0..24).map(|i| (i as f32 + 1.0) * 0.1).collect();
let q = make_tensor(&q_data, &[1, 4, 1, 4]);
let k = make_tensor(&k_data, &[1, 2, 3, 4]);
let v = make_tensor(&v_data, &[1, 2, 3, 4]);
let (out, _lse) = fused_decode_attention(&q, &k, &v, 4, 2, 4).unwrap();
assert_eq!(out.shape(), &[1, 4, 1, 4]);
let out_data = out.to_vec::<f32>();
let sum: f32 = out_data.iter().map(|x| x.abs()).sum();
assert!(sum > 0.0);
}
#[test]
fn test_decode_attention_matches_standard() {
use numr::ops::{ActivationOps, MatmulOps, ScalarOps};
use numr::runtime::cpu::CpuClient;
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
let num_heads = 2;
let num_kv_heads = 2;
let head_dim = 8;
let seq_len_k = 5;
let scale = (head_dim as f64).sqrt().recip();
let q_data: Vec<f32> = (0..num_heads * head_dim)
.map(|i| ((i as f32) * 0.3).sin())
.collect();
let k_data: Vec<f32> = (0..num_kv_heads * seq_len_k * head_dim)
.map(|i| ((i as f32) * 0.2).cos())
.collect();
let v_data: Vec<f32> = (0..num_kv_heads * seq_len_k * head_dim)
.map(|i| ((i as f32) * 0.1 + 0.5).sin())
.collect();
let q = Tensor::<CpuRuntime>::from_slice(&q_data, &[1, num_heads, 1, head_dim], &device);
let k = Tensor::<CpuRuntime>::from_slice(
&k_data,
&[1, num_kv_heads, seq_len_k, head_dim],
&device,
);
let v = Tensor::<CpuRuntime>::from_slice(
&v_data,
&[1, num_kv_heads, seq_len_k, head_dim],
&device,
);
let (fused_out, _) =
fused_decode_attention(&q, &k, &v, num_heads, num_kv_heads, head_dim).unwrap();
let k_t = k.transpose(-2isize, -1isize).unwrap().contiguous();
let scores = client.matmul(&q, &k_t).unwrap();
let scores = client.mul_scalar(&scores, scale).unwrap();
let weights = client.softmax(&scores, -1).unwrap();
let ref_out = client.matmul(&weights, &v).unwrap();
let fused_data = fused_out.to_vec::<f32>();
let ref_data = ref_out.to_vec::<f32>();
let mut max_diff = 0.0f32;
for (i, (&f, &r)) in fused_data.iter().zip(ref_data.iter()).enumerate() {
let diff = (f - r).abs();
if diff > max_diff {
max_diff = diff;
}
assert!(
diff < 1e-5,
"mismatch at {}: fused={}, ref={}, diff={}",
i,
f,
r,
diff
);
}
eprintln!("max diff (small test): {max_diff:.2e}");
}
#[test]
fn test_decode_attention_matches_standard_realistic() {
use numr::ops::{ActivationOps, MatmulOps, ScalarOps, ShapeOps};
use numr::runtime::cpu::CpuClient;
let device = CpuDevice::new();
let client = CpuClient::new(device.clone());
let num_heads = 32;
let num_kv_heads = 8;
let head_dim = 128;
let seq_len_k = 64;
let scale = (head_dim as f64).sqrt().recip();
let q_data: Vec<f32> = (0..num_heads * head_dim)
.map(|i| ((i as f32) * 0.037).sin())
.collect();
let k_data: Vec<f32> = (0..num_kv_heads * seq_len_k * head_dim)
.map(|i| ((i as f32) * 0.023).cos())
.collect();
let v_data: Vec<f32> = (0..num_kv_heads * seq_len_k * head_dim)
.map(|i| ((i as f32) * 0.011 + 0.5).sin())
.collect();
let q = Tensor::<CpuRuntime>::from_slice(&q_data, &[1, num_heads, 1, head_dim], &device);
let k = Tensor::<CpuRuntime>::from_slice(
&k_data,
&[1, num_kv_heads, seq_len_k, head_dim],
&device,
);
let v = Tensor::<CpuRuntime>::from_slice(
&v_data,
&[1, num_kv_heads, seq_len_k, head_dim],
&device,
);
let (fused_out, _) =
fused_decode_attention(&q, &k, &v, num_heads, num_kv_heads, head_dim).unwrap();
let repeats = num_heads / num_kv_heads;
let k_exp = client.repeat_interleave(&k, repeats, Some(1)).unwrap();
let v_exp = client.repeat_interleave(&v, repeats, Some(1)).unwrap();
let k_t = k_exp.transpose(-2isize, -1isize).unwrap().contiguous();
let scores = client.matmul(&q, &k_t).unwrap();
let scores = client.mul_scalar(&scores, scale).unwrap();
let weights = client.softmax(&scores, -1).unwrap();
let ref_out = client.matmul(&weights, &v_exp).unwrap();
let fused_data = fused_out.to_vec::<f32>();
let ref_data = ref_out.to_vec::<f32>();
let mut max_diff = 0.0f32;
for (i, (&f, &r)) in fused_data.iter().zip(ref_data.iter()).enumerate() {
let diff = (f - r).abs();
if diff > max_diff {
max_diff = diff;
}
assert!(
diff < 1e-5,
"mismatch at {}: fused={}, ref={}, diff={}",
i,
f,
r,
diff
);
}
eprintln!("max diff (realistic test): {max_diff:.2e}");
}
}