Skip to main content

ferrum_testkit/op_diff/
split_qkv.rs

1//! `split_qkv` op-diff harness — see `crate::op_diff`.
2//!
3//! Splits a fused `[tokens, q_dim + 2*kv_dim]` projection into separate q/k/v
4//! buffers. Pure slicing (exact); compared output is `[q, k, v]` concatenated.
5
6use super::{random_vec, OpUnderTest, Output};
7
8pub struct SplitQkvOp {
9    pub tokens: usize,
10    pub q_dim: usize,
11    pub kv_dim: usize,
12}
13
14impl SplitQkvOp {
15    fn fused(&self) -> usize {
16        self.tokens * (self.q_dim + 2 * self.kv_dim)
17    }
18}
19
20impl OpUnderTest for SplitQkvOp {
21    fn name(&self) -> &str {
22        "split_qkv"
23    }
24
25    fn run_cpu(&self, seed: u64) -> Output {
26        use ferrum_kernels::backend::cpu::CpuBackend;
27        use ferrum_kernels::backend::Backend;
28
29        let qkv = random_vec(self.fused(), -2.0, 2.0, seed);
30        let mut ctx = CpuBackend::new_context();
31        let qkv_buf = CpuBackend::from_slice(&qkv);
32        let mut q = CpuBackend::alloc(self.tokens * self.q_dim);
33        let mut k = CpuBackend::alloc(self.tokens * self.kv_dim);
34        let mut v = CpuBackend::alloc(self.tokens * self.kv_dim);
35        CpuBackend::split_qkv(
36            &mut ctx,
37            &qkv_buf,
38            &mut q,
39            &mut k,
40            &mut v,
41            self.tokens,
42            self.q_dim,
43            self.kv_dim,
44        );
45        CpuBackend::sync(&mut ctx);
46        let mut out = CpuBackend::to_vec(&q, self.tokens * self.q_dim);
47        out.extend(CpuBackend::to_vec(&k, self.tokens * self.kv_dim));
48        out.extend(CpuBackend::to_vec(&v, self.tokens * self.kv_dim));
49        out
50    }
51
52    #[cfg(all(target_os = "macos", feature = "metal"))]
53    fn run_metal(&self, seed: u64) -> Output {
54        use ferrum_kernels::backend::metal::MetalBackend;
55        use ferrum_kernels::backend::Backend;
56
57        let qkv = random_vec(self.fused(), -2.0, 2.0, seed);
58        let mut ctx = MetalBackend::new_context();
59        let qkv_buf = MetalBackend::from_slice(&qkv);
60        let mut q = MetalBackend::alloc(self.tokens * self.q_dim);
61        let mut k = MetalBackend::alloc(self.tokens * self.kv_dim);
62        let mut v = MetalBackend::alloc(self.tokens * self.kv_dim);
63        MetalBackend::split_qkv(
64            &mut ctx,
65            &qkv_buf,
66            &mut q,
67            &mut k,
68            &mut v,
69            self.tokens,
70            self.q_dim,
71            self.kv_dim,
72        );
73        MetalBackend::sync(&mut ctx);
74        let mut out = MetalBackend::to_vec(&q, self.tokens * self.q_dim);
75        out.extend(MetalBackend::to_vec(&k, self.tokens * self.kv_dim));
76        out.extend(MetalBackend::to_vec(&v, self.tokens * self.kv_dim));
77        out
78    }
79
80    #[cfg(feature = "cuda")]
81    fn run_cuda(&self, seed: u64) -> Output {
82        use ferrum_kernels::backend::cuda::CudaBackend;
83        use ferrum_kernels::backend::Backend;
84
85        let qkv = random_vec(self.fused(), -2.0, 2.0, seed);
86        let mut ctx = CudaBackend::new_context();
87        let qkv_buf = CudaBackend::from_slice(&qkv);
88        let mut q = CudaBackend::alloc(self.tokens * self.q_dim);
89        let mut k = CudaBackend::alloc(self.tokens * self.kv_dim);
90        let mut v = CudaBackend::alloc(self.tokens * self.kv_dim);
91        CudaBackend::split_qkv(
92            &mut ctx,
93            &qkv_buf,
94            &mut q,
95            &mut k,
96            &mut v,
97            self.tokens,
98            self.q_dim,
99            self.kv_dim,
100        );
101        CudaBackend::sync(&mut ctx);
102        let mut out = CudaBackend::to_vec(&q, self.tokens * self.q_dim);
103        out.extend(CudaBackend::to_vec(&k, self.tokens * self.kv_dim));
104        out.extend(CudaBackend::to_vec(&v, self.tokens * self.kv_dim));
105        out
106    }
107}