ferrum_testkit/op_diff/
flash_attention.rs1use super::{random_vec, OpUnderTest, Output};
7use ferrum_kernels::backend::AttnConfig;
8
9pub struct FlashAttentionOp {
10 pub batch: usize,
11 pub q_len: usize,
12 pub kv_len: usize,
13 pub num_heads: usize,
14 pub num_kv_heads: usize,
15 pub head_dim: usize,
16}
17
18impl FlashAttentionOp {
19 fn q_elems(&self) -> usize {
20 self.batch * self.q_len * self.num_heads * self.head_dim
21 }
22 fn kv_elems(&self) -> usize {
23 self.batch * self.kv_len * self.num_kv_heads * self.head_dim
24 }
25 fn cfg(&self) -> AttnConfig {
26 AttnConfig {
27 num_heads: self.num_heads,
28 num_kv_heads: self.num_kv_heads,
29 head_dim: self.head_dim,
30 causal: true,
31 scale: 1.0 / (self.head_dim as f32).sqrt(),
32 kv_seq_stride: 0,
33 sliding_window: 0,
34 }
35 }
36}
37
38macro_rules! run_backend {
39 ($B:ty, $self:expr, $seed:expr) => {{
40 use ferrum_kernels::backend::Backend;
41 let q = random_vec($self.q_elems(), -1.0, 1.0, $seed);
42 let k = random_vec($self.kv_elems(), -1.0, 1.0, $seed.wrapping_add(1));
43 let v = random_vec($self.kv_elems(), -1.0, 1.0, $seed.wrapping_add(2));
44 let mut ctx = <$B>::new_context();
45 let qb = <$B>::from_slice(&q);
46 let kb = <$B>::from_slice(&k);
47 let vb = <$B>::from_slice(&v);
48 let mut out = <$B>::alloc($self.q_elems());
49 <$B>::flash_attention(
50 &mut ctx,
51 &qb,
52 &kb,
53 &vb,
54 &mut out,
55 $self.batch,
56 $self.q_len,
57 $self.kv_len,
58 0,
59 &$self.cfg(),
60 );
61 <$B>::sync(&mut ctx);
62 <$B>::to_vec(&out, $self.q_elems())
63 }};
64}
65
66impl OpUnderTest for FlashAttentionOp {
67 fn name(&self) -> &str {
68 "flash_attention"
69 }
70
71 fn run_cpu(&self, seed: u64) -> Output {
72 run_backend!(ferrum_kernels::backend::cpu::CpuBackend, self, seed)
73 }
74
75 #[cfg(all(target_os = "macos", feature = "metal"))]
76 fn run_metal(&self, seed: u64) -> Output {
77 run_backend!(ferrum_kernels::backend::metal::MetalBackend, self, seed)
78 }
79
80 #[cfg(feature = "cuda")]
81 fn run_cuda(&self, seed: u64) -> Output {
82 run_backend!(ferrum_kernels::backend::cuda::CudaBackend, self, seed)
83 }
84}