Skip to main content

ferrum_testkit/op_diff/
flash_attention.rs

1//! `flash_attention` op-diff harness — see `crate::op_diff`.
2//!
3//! Dense causal attention. CpuBackend implements the reference; Metal/CUDA
4//! run their flash kernels against the same Q/K/V.
5
6use super::{random_vec, OpUnderTest, Output};
7use ferrum_kernels::backend::AttnConfig;
8
9pub struct FlashAttentionOp {
10    pub batch: usize,
11    pub q_len: usize,
12    pub kv_len: usize,
13    pub num_heads: usize,
14    pub num_kv_heads: usize,
15    pub head_dim: usize,
16}
17
18impl FlashAttentionOp {
19    fn q_elems(&self) -> usize {
20        self.batch * self.q_len * self.num_heads * self.head_dim
21    }
22    fn kv_elems(&self) -> usize {
23        self.batch * self.kv_len * self.num_kv_heads * self.head_dim
24    }
25    fn cfg(&self) -> AttnConfig {
26        AttnConfig {
27            num_heads: self.num_heads,
28            num_kv_heads: self.num_kv_heads,
29            head_dim: self.head_dim,
30            causal: true,
31            scale: 1.0 / (self.head_dim as f32).sqrt(),
32            kv_seq_stride: 0,
33            sliding_window: 0,
34        }
35    }
36}
37
38macro_rules! run_backend {
39    ($B:ty, $self:expr, $seed:expr) => {{
40        use ferrum_kernels::backend::Backend;
41        let q = random_vec($self.q_elems(), -1.0, 1.0, $seed);
42        let k = random_vec($self.kv_elems(), -1.0, 1.0, $seed.wrapping_add(1));
43        let v = random_vec($self.kv_elems(), -1.0, 1.0, $seed.wrapping_add(2));
44        let mut ctx = <$B>::new_context();
45        let qb = <$B>::from_slice(&q);
46        let kb = <$B>::from_slice(&k);
47        let vb = <$B>::from_slice(&v);
48        let mut out = <$B>::alloc($self.q_elems());
49        <$B>::flash_attention(
50            &mut ctx,
51            &qb,
52            &kb,
53            &vb,
54            &mut out,
55            $self.batch,
56            $self.q_len,
57            $self.kv_len,
58            0,
59            &$self.cfg(),
60        );
61        <$B>::sync(&mut ctx);
62        <$B>::to_vec(&out, $self.q_elems())
63    }};
64}
65
66impl OpUnderTest for FlashAttentionOp {
67    fn name(&self) -> &str {
68        "flash_attention"
69    }
70
71    fn run_cpu(&self, seed: u64) -> Output {
72        run_backend!(ferrum_kernels::backend::cpu::CpuBackend, self, seed)
73    }
74
75    #[cfg(all(target_os = "macos", feature = "metal"))]
76    fn run_metal(&self, seed: u64) -> Output {
77        run_backend!(ferrum_kernels::backend::metal::MetalBackend, self, seed)
78    }
79
80    #[cfg(feature = "cuda")]
81    fn run_cuda(&self, seed: u64) -> Output {
82        run_backend!(ferrum_kernels::backend::cuda::CudaBackend, self, seed)
83    }
84}