use crate::error::Result;
use crate::ops::traits::VarLenAttentionOps;
use numr::runtime::cpu::{CpuClient, CpuRuntime};
use numr::tensor::Tensor;
impl VarLenAttentionOps<CpuRuntime> for CpuClient {
fn varlen_attention_fwd(
&self,
q: &Tensor<CpuRuntime>,
k: &Tensor<CpuRuntime>,
v: &Tensor<CpuRuntime>,
cu_seqlens_q: &Tensor<CpuRuntime>,
cu_seqlens_k: &Tensor<CpuRuntime>,
batch_size: usize,
num_heads: usize,
_max_seqlen_q: usize,
_max_seqlen_k: usize,
head_dim: usize,
causal: bool,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
let total_tokens_q = q.shape()[0];
let device = q.device();
let q_data = q.to_vec::<f32>();
let k_data = k.to_vec::<f32>();
let v_data = v.to_vec::<f32>();
let cu_q = cu_seqlens_q.to_vec::<i32>();
let cu_k = cu_seqlens_k.to_vec::<i32>();
let mut out = vec![0.0f32; total_tokens_q * num_heads * head_dim];
let mut lse = vec![0.0f32; total_tokens_q * num_heads];
let scale = (head_dim as f32).sqrt().recip();
for b in 0..batch_size {
let sq_start = cu_q[b] as usize;
let sq_end = cu_q[b + 1] as usize;
let sk_start = cu_k[b] as usize;
let sk_end = cu_k[b + 1] as usize;
let seq_len_q = sq_end - sq_start;
let seq_len_k = sk_end - sk_start;
for h in 0..num_heads {
for qi in 0..seq_len_q {
let q_offset = ((sq_start + qi) * num_heads + h) * head_dim;
let mut max_score = f32::NEG_INFINITY;
let mut scores = Vec::with_capacity(seq_len_k);
for ki in 0..seq_len_k {
if causal && qi < ki {
scores.push(f32::NEG_INFINITY);
continue;
}
let k_offset = ((sk_start + ki) * num_heads + h) * head_dim;
let mut dot = 0.0f32;
for d in 0..head_dim {
dot += q_data[q_offset + d] * k_data[k_offset + d];
}
let s = dot * scale;
max_score = max_score.max(s);
scores.push(s);
}
let mut sum_exp = 0.0f32;
let mut exp_scores = Vec::with_capacity(seq_len_k);
for &s in &scores {
let e = (s - max_score).exp();
sum_exp += e;
exp_scores.push(e);
}
let o_offset = ((sq_start + qi) * num_heads + h) * head_dim;
for (ki, &exp_s) in exp_scores.iter().enumerate() {
let weight = exp_s / sum_exp;
let v_offset = ((sk_start + ki) * num_heads + h) * head_dim;
for d in 0..head_dim {
out[o_offset + d] += weight * v_data[v_offset + d];
}
}
lse[(sq_start + qi) * num_heads + h] = max_score + sum_exp.ln();
}
}
}
let output =
Tensor::<CpuRuntime>::from_slice(&out, &[total_tokens_q, num_heads, head_dim], device);
let lse_tensor =
Tensor::<CpuRuntime>::from_slice(&lse, &[total_tokens_q, num_heads], device);
Ok((output, lse_tensor))
}
fn varlen_attention_bwd(
&self,
dout: &Tensor<CpuRuntime>,
q: &Tensor<CpuRuntime>,
k: &Tensor<CpuRuntime>,
v: &Tensor<CpuRuntime>,
output: &Tensor<CpuRuntime>,
lse: &Tensor<CpuRuntime>,
cu_seqlens_q: &Tensor<CpuRuntime>,
cu_seqlens_k: &Tensor<CpuRuntime>,
batch_size: usize,
num_heads: usize,
_max_seqlen_q: usize,
_max_seqlen_k: usize,
head_dim: usize,
causal: bool,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
let total_tokens_q = q.shape()[0];
let total_tokens_k = k.shape()[0];
let device = q.device();
let q_data = q.to_vec::<f32>();
let k_data = k.to_vec::<f32>();
let v_data = v.to_vec::<f32>();
let o_data = output.to_vec::<f32>();
let lse_data = lse.to_vec::<f32>();
let do_data = dout.to_vec::<f32>();
let cu_q = cu_seqlens_q.to_vec::<i32>();
let cu_k = cu_seqlens_k.to_vec::<i32>();
let mut dq = vec![0.0f32; total_tokens_q * num_heads * head_dim];
let mut dk = vec![0.0f32; total_tokens_k * num_heads * head_dim];
let mut dv = vec![0.0f32; total_tokens_k * num_heads * head_dim];
let scale = (head_dim as f32).sqrt().recip();
for b in 0..batch_size {
let sq_start = cu_q[b] as usize;
let sq_end = cu_q[b + 1] as usize;
let sk_start = cu_k[b] as usize;
let sk_end = cu_k[b + 1] as usize;
let seq_len_q = sq_end - sq_start;
let seq_len_k = sk_end - sk_start;
for h in 0..num_heads {
for qi in 0..seq_len_q {
let q_off = ((sq_start + qi) * num_heads + h) * head_dim;
let l = lse_data[(sq_start + qi) * num_heads + h];
let o_off = ((sq_start + qi) * num_heads + h) * head_dim;
let mut d_val = 0.0f32;
for d in 0..head_dim {
d_val += do_data[o_off + d] * o_data[o_off + d];
}
for ki in 0..seq_len_k {
if causal && qi < ki {
continue;
}
let k_off = ((sk_start + ki) * num_heads + h) * head_dim;
let mut score = 0.0f32;
for d in 0..head_dim {
score += q_data[q_off + d] * k_data[k_off + d];
}
score *= scale;
let prob = (score - l).exp();
let v_off = ((sk_start + ki) * num_heads + h) * head_dim;
let mut grad_prob = 0.0f32;
for d in 0..head_dim {
grad_prob += v_data[v_off + d] * do_data[o_off + d];
}
let grad_score = prob * (grad_prob - d_val);
for d in 0..head_dim {
dq[q_off + d] += scale * grad_score * k_data[k_off + d];
dk[k_off + d] += scale * grad_score * q_data[q_off + d];
dv[v_off + d] += prob * do_data[o_off + d];
}
}
}
}
}
let dq_t =
Tensor::<CpuRuntime>::from_slice(&dq, &[total_tokens_q, num_heads, head_dim], device);
let dk_t =
Tensor::<CpuRuntime>::from_slice(&dk, &[total_tokens_k, num_heads, head_dim], device);
let dv_t =
Tensor::<CpuRuntime>::from_slice(&dv, &[total_tokens_k, num_heads, head_dim], device);
Ok((dq_t, dk_t, dv_t))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
#[test]
fn test_varlen_fwd_shape() {
let (client, dev) = cpu_setup();
let total_q = 5;
let num_heads = 2;
let head_dim = 4;
let q_data = vec![0.1f32; total_q * num_heads * head_dim];
let k_data = vec![0.1f32; total_q * num_heads * head_dim];
let v_data = vec![0.2f32; total_q * num_heads * head_dim];
let q = Tensor::<CpuRuntime>::from_slice(&q_data, &[total_q, num_heads, head_dim], &dev);
let k = Tensor::<CpuRuntime>::from_slice(&k_data, &[total_q, num_heads, head_dim], &dev);
let v = Tensor::<CpuRuntime>::from_slice(&v_data, &[total_q, num_heads, head_dim], &dev);
let cu_seqlens = vec![0i32, 3, 5];
let cu = Tensor::<CpuRuntime>::from_slice(&cu_seqlens, &[3], &dev);
let (out, lse) = client
.varlen_attention_fwd(&q, &k, &v, &cu, &cu, 2, num_heads, 3, 3, head_dim, false)
.unwrap();
assert_eq!(out.shape(), &[total_q, num_heads, head_dim]);
assert_eq!(lse.shape(), &[total_q, num_heads]);
}
#[test]
fn test_varlen_fwd_causal() {
let (client, dev) = cpu_setup();
let total_q = 4;
let num_heads = 1;
let head_dim = 2;
let q_data: Vec<f32> = (0..total_q * num_heads * head_dim)
.map(|i| (i as f32) * 0.1 + 0.1)
.collect();
let k_data = q_data.clone();
let v_data: Vec<f32> = (0..total_q * num_heads * head_dim)
.map(|i| (i as f32) * 0.05)
.collect();
let q = Tensor::<CpuRuntime>::from_slice(&q_data, &[total_q, num_heads, head_dim], &dev);
let k = Tensor::<CpuRuntime>::from_slice(&k_data, &[total_q, num_heads, head_dim], &dev);
let v = Tensor::<CpuRuntime>::from_slice(&v_data, &[total_q, num_heads, head_dim], &dev);
let cu_seqlens = vec![0i32, 4];
let cu = Tensor::<CpuRuntime>::from_slice(&cu_seqlens, &[2], &dev);
let (out_causal, _) = client
.varlen_attention_fwd(&q, &k, &v, &cu, &cu, 1, num_heads, 4, 4, head_dim, true)
.unwrap();
let (out_full, _) = client
.varlen_attention_fwd(&q, &k, &v, &cu, &cu, 1, num_heads, 4, 4, head_dim, false)
.unwrap();
let causal_data = out_causal.to_vec::<f32>();
let full_data = out_full.to_vec::<f32>();
let last_off = (total_q - 1) * num_heads * head_dim;
for d in 0..head_dim {
assert!(
(causal_data[last_off + d] - full_data[last_off + d]).abs() < 1e-5,
"Last token should match between causal and non-causal"
);
}
let second_off = num_heads * head_dim;
let differs = (0..head_dim)
.any(|d| (causal_data[second_off + d] - full_data[second_off + d]).abs() > 1e-6);
assert!(
differs,
"Middle tokens should differ between causal and non-causal"
);
}
#[test]
fn test_varlen_bwd_shapes() {
let (client, dev) = cpu_setup();
let total_q = 5;
let num_heads = 2;
let head_dim = 4;
let n = total_q * num_heads * head_dim;
let q_data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.3).sin()).collect();
let k_data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.7).cos()).collect();
let v_data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.5 + 1.0).sin()).collect();
let q = Tensor::<CpuRuntime>::from_slice(&q_data, &[total_q, num_heads, head_dim], &dev);
let k = Tensor::<CpuRuntime>::from_slice(&k_data, &[total_q, num_heads, head_dim], &dev);
let v = Tensor::<CpuRuntime>::from_slice(&v_data, &[total_q, num_heads, head_dim], &dev);
let cu_seqlens = vec![0i32, 3, 5];
let cu = Tensor::<CpuRuntime>::from_slice(&cu_seqlens, &[3], &dev);
let (out, lse) = client
.varlen_attention_fwd(&q, &k, &v, &cu, &cu, 2, num_heads, 3, 3, head_dim, false)
.unwrap();
let do_data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.2).cos() * 0.1).collect();
let dout =
Tensor::<CpuRuntime>::from_slice(&do_data, &[total_q, num_heads, head_dim], &dev);
let (dq, dk, dv) = client
.varlen_attention_bwd(
&dout, &q, &k, &v, &out, &lse, &cu, &cu, 2, num_heads, 3, 3, head_dim, false,
)
.unwrap();
assert_eq!(dq.shape(), &[total_q, num_heads, head_dim]);
assert_eq!(dk.shape(), &[total_q, num_heads, head_dim]);
assert_eq!(dv.shape(), &[total_q, num_heads, head_dim]);
let dq_data = dq.to_vec::<f32>();
let has_nonzero = dq_data.iter().any(|&x: &f32| x.abs() > 1e-10);
assert!(has_nonzero, "dQ should have non-zero gradients");
}
}