use super::flash_helpers::{standard_attention_bwd, standard_attention_fwd};
use crate::error::{Error, Result};
use crate::ops::impl_generic::attention::multi_head_attention_impl;
use crate::ops::traits::{AttentionOps, FlashAttentionOps};
use numr::autograd::Var;
use numr::dtype::DType;
use numr::runtime::cpu::{CpuClient, CpuRuntime};
use numr::tensor::Tensor;
impl AttentionOps<CpuRuntime> for CpuClient {
fn multi_head_attention(
&self,
q: &Var<CpuRuntime>,
k: &Var<CpuRuntime>,
v: &Var<CpuRuntime>,
mask: Option<&Var<CpuRuntime>>,
num_heads: usize,
) -> Result<Var<CpuRuntime>> {
multi_head_attention_impl(self, q, k, v, mask, num_heads)
}
}
impl FlashAttentionOps<CpuRuntime> for CpuClient {
fn flash_attention_fwd(
&self,
q: &Tensor<CpuRuntime>,
k: &Tensor<CpuRuntime>,
v: &Tensor<CpuRuntime>,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
causal: bool,
window_size: usize,
kv_seq_len: Option<usize>,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
if let Some(seq_len) = kv_seq_len {
let k_narrow = k.narrow(2, 0, seq_len)?;
let v_narrow = v.narrow(2, 0, seq_len)?;
let k_c = k_narrow.contiguous();
let v_c = v_narrow.contiguous();
return self.flash_attention_fwd(
q,
&k_c,
&v_c,
num_heads,
num_kv_heads,
head_dim,
causal,
window_size,
None,
);
}
let seq_len_q = q.shape()[2];
if seq_len_q == 1
&& !causal
&& window_size == 0
&& q.dtype() == DType::F32
&& k.dtype() == DType::F32
&& v.dtype() == DType::F32
{
return super::decode_attention::fused_decode_attention(
q,
k,
v,
num_heads,
num_kv_heads,
head_dim,
);
}
let _ = head_dim; standard_attention_fwd(self, q, k, v, causal, num_heads, num_kv_heads, window_size)
}
fn flash_attention_fwd_fp8(
&self,
_q: &Tensor<CpuRuntime>,
_k: &Tensor<CpuRuntime>,
_v: &Tensor<CpuRuntime>,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_causal: bool,
_q_scale: f32,
_k_scale: f32,
_v_scale: f32,
_o_scale: f32,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
Err(Error::InvalidArgument {
arg: "dtype",
reason: "FP8 Flash Attention is not supported on CPU".into(),
})
}
fn flash_attention_bwd(
&self,
dout: &Tensor<CpuRuntime>,
q: &Tensor<CpuRuntime>,
k: &Tensor<CpuRuntime>,
v: &Tensor<CpuRuntime>,
output: &Tensor<CpuRuntime>,
lse: &Tensor<CpuRuntime>,
num_heads: usize,
num_kv_heads: usize,
_head_dim: usize,
causal: bool,
window_size: usize,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
standard_attention_bwd(
self,
dout,
q,
k,
v,
output,
lse,
causal,
num_heads,
num_kv_heads,
window_size,
)
}
fn flash_attention_bwd_fp8(
&self,
_dout: &Tensor<CpuRuntime>,
_q: &Tensor<CpuRuntime>,
_k: &Tensor<CpuRuntime>,
_v: &Tensor<CpuRuntime>,
_output: &Tensor<CpuRuntime>,
_lse: &Tensor<CpuRuntime>,
_num_heads: usize,
_num_kv_heads: usize,
_head_dim: usize,
_causal: bool,
_q_scale: f32,
_k_scale: f32,
_v_scale: f32,
_do_scale: f32,
_o_scale: f32,
_dq_scale: f32,
_dk_scale: f32,
_dv_scale: f32,
) -> Result<(Tensor<CpuRuntime>, Tensor<CpuRuntime>, Tensor<CpuRuntime>)> {
Err(Error::InvalidArgument {
arg: "dtype",
reason: "FP8 Flash Attention backward is not supported on CPU".into(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::ops::{BinaryOps, ReduceOps, UnaryOps};
fn rand_tensor(
shape: &[usize],
_client: &CpuClient,
device: &<CpuRuntime as numr::runtime::Runtime>::Device,
) -> Tensor<CpuRuntime> {
let n: usize = shape.iter().product();
let data: Vec<f32> = (0..n).map(|i| (i as f32 * 0.1).sin() * 0.5).collect();
Tensor::<CpuRuntime>::from_slice(&data, shape, device)
}
#[test]
fn test_flash_fwd_output_shape() {
let (client, device) = cpu_setup();
let (b, h, s, d) = (2, 4, 8, 16);
let q = rand_tensor(&[b, h, s, d], &client, &device);
let k = rand_tensor(&[b, h, s, d], &client, &device);
let v = rand_tensor(&[b, h, s, d], &client, &device);
let (out, lse) = client
.flash_attention_fwd(&q, &k, &v, h, h, d, false, 0, None)
.unwrap();
assert_eq!(out.shape(), &[b, h, s, d]);
assert_eq!(lse.shape(), &[b, h, s]);
}
#[test]
fn test_flash_fwd_causal() {
let (client, device) = cpu_setup();
let (b, h, s, d) = (1, 2, 6, 8);
let q = rand_tensor(&[b, h, s, d], &client, &device);
let k = rand_tensor(&[b, h, s, d], &client, &device);
let v = rand_tensor(&[b, h, s, d], &client, &device);
let (out_causal, _) = client
.flash_attention_fwd(&q, &k, &v, h, h, d, true, 0, None)
.unwrap();
let (out_full, _) = client
.flash_attention_fwd(&q, &k, &v, h, h, d, false, 0, None)
.unwrap();
let diff = client.sub(&out_causal, &out_full).unwrap();
let abs_diff = client.abs(&diff).unwrap();
let max_diff = client.max(&abs_diff, &[], false).unwrap();
let max_val = max_diff.to_vec::<f32>()[0];
assert!(
max_val > 1e-6,
"Causal and non-causal outputs should differ"
);
}
#[test]
fn test_flash_fwd_sliding_window() {
let (client, device) = cpu_setup();
let (b, h, s, d) = (1, 2, 12, 8);
let q = rand_tensor(&[b, h, s, d], &client, &device);
let k = rand_tensor(&[b, h, s, d], &client, &device);
let v = rand_tensor(&[b, h, s, d], &client, &device);
let (out_window, _) = client
.flash_attention_fwd(&q, &k, &v, h, h, d, false, 4, None)
.unwrap();
let (out_full, _) = client
.flash_attention_fwd(&q, &k, &v, h, h, d, false, 0, None)
.unwrap();
let ow = out_window.to_vec::<f32>();
let of = out_full.to_vec::<f32>();
let max_diff = ow
.iter()
.zip(of.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(
max_diff > 1e-6,
"Sliding window should differ from full attention"
);
}
#[test]
fn test_flash_fwd_gqa() {
let (client, device) = cpu_setup();
let (b, h, nkv, s, d) = (1, 8, 2, 4, 16);
let q = rand_tensor(&[b, h, s, d], &client, &device);
let k = rand_tensor(&[b, nkv, s, d], &client, &device);
let v = rand_tensor(&[b, nkv, s, d], &client, &device);
let (out, lse) = client
.flash_attention_fwd(&q, &k, &v, h, nkv, d, false, 0, None)
.unwrap();
assert_eq!(out.shape(), &[b, h, s, d]);
assert_eq!(lse.shape(), &[b, h, s]);
}
#[test]
fn test_flash_bwd_produces_gradients() {
let (client, device) = cpu_setup();
let (b, h, s, d) = (1, 2, 4, 8);
let q = rand_tensor(&[b, h, s, d], &client, &device);
let k = rand_tensor(&[b, h, s, d], &client, &device);
let v = rand_tensor(&[b, h, s, d], &client, &device);
let (out, lse) = client
.flash_attention_fwd(&q, &k, &v, h, h, d, false, 0, None)
.unwrap();
let dout = rand_tensor(&[b, h, s, d], &client, &device);
let (dq, dk, dv) = client
.flash_attention_bwd(&dout, &q, &k, &v, &out, &lse, h, h, d, false, 0)
.unwrap();
assert_eq!(dq.shape(), &[b, h, s, d]);
assert_eq!(dk.shape(), &[b, h, s, d]);
assert_eq!(dv.shape(), &[b, h, s, d]);
let dq_abs = client.abs(&dq).unwrap();
let dq_sum = client.sum(&dq_abs, &[], false).unwrap();
assert!(dq_sum.to_vec::<f32>()[0] > 1e-6);
}
#[test]
fn test_flash_bwd_causal_gradients() {
let (client, device) = cpu_setup();
let (b, h, s, d) = (1, 2, 4, 8);
let q = rand_tensor(&[b, h, s, d], &client, &device);
let k = rand_tensor(&[b, h, s, d], &client, &device);
let v = rand_tensor(&[b, h, s, d], &client, &device);
let (out, lse) = client
.flash_attention_fwd(&q, &k, &v, h, h, d, true, 0, None)
.unwrap();
let dout = rand_tensor(&[b, h, s, d], &client, &device);
let (dq, dk, dv) = client
.flash_attention_bwd(&dout, &q, &k, &v, &out, &lse, h, h, d, true, 0)
.unwrap();
assert_eq!(dq.shape(), &[b, h, s, d]);
assert_eq!(dk.shape(), &[b, h, s, d]);
assert_eq!(dv.shape(), &[b, h, s, d]);
}
#[test]
fn test_flash_bwd_gqa_gradient_shapes() {
let (client, device) = cpu_setup();
let (b, h, nkv, s, d) = (1, 8, 2, 4, 16);
let q = rand_tensor(&[b, h, s, d], &client, &device);
let k = rand_tensor(&[b, nkv, s, d], &client, &device);
let v = rand_tensor(&[b, nkv, s, d], &client, &device);
let (out, lse) = client
.flash_attention_fwd(&q, &k, &v, h, nkv, d, false, 0, None)
.unwrap();
let dout = rand_tensor(&[b, h, s, d], &client, &device);
let (dq, dk, dv) = client
.flash_attention_bwd(&dout, &q, &k, &v, &out, &lse, h, nkv, d, false, 0)
.unwrap();
assert_eq!(dq.shape(), &[b, h, s, d]);
assert_eq!(dk.shape(), &[b, nkv, s, d]);
assert_eq!(dv.shape(), &[b, nkv, s, d]);
}
#[test]
fn test_var_flash_attention_autograd() {
use crate::ops::autograd_attention::var_flash_attention;
use numr::autograd::Var;
let (_client, device) = cpu_setup();
let (b, h, s, d) = (1, 2, 4, 8);
let q_t = rand_tensor(&[b, h, s, d], &_client, &device);
let k_t = rand_tensor(&[b, h, s, d], &_client, &device);
let v_t = rand_tensor(&[b, h, s, d], &_client, &device);
let q = Var::new(q_t, true);
let k = Var::new(k_t, true);
let v = Var::new(v_t, true);
let out = var_flash_attention::<CpuRuntime>(&q, &k, &v, h, h, d, false, 0).unwrap();
assert_eq!(out.tensor().shape(), &[b, h, s, d]);
assert!(
out.grad_fn().is_some(),
"Output should have grad_fn when inputs require grad"
);
}
}