ferrum_testkit/op_diff/
split_qkv.rs1use super::{random_vec, OpUnderTest, Output};
7
8pub struct SplitQkvOp {
9 pub tokens: usize,
10 pub q_dim: usize,
11 pub kv_dim: usize,
12}
13
14impl SplitQkvOp {
15 fn fused(&self) -> usize {
16 self.tokens * (self.q_dim + 2 * self.kv_dim)
17 }
18}
19
20impl OpUnderTest for SplitQkvOp {
21 fn name(&self) -> &str {
22 "split_qkv"
23 }
24
25 fn run_cpu(&self, seed: u64) -> Output {
26 use ferrum_kernels::backend::cpu::CpuBackend;
27 use ferrum_kernels::backend::Backend;
28
29 let qkv = random_vec(self.fused(), -2.0, 2.0, seed);
30 let mut ctx = CpuBackend::new_context();
31 let qkv_buf = CpuBackend::from_slice(&qkv);
32 let mut q = CpuBackend::alloc(self.tokens * self.q_dim);
33 let mut k = CpuBackend::alloc(self.tokens * self.kv_dim);
34 let mut v = CpuBackend::alloc(self.tokens * self.kv_dim);
35 CpuBackend::split_qkv(
36 &mut ctx,
37 &qkv_buf,
38 &mut q,
39 &mut k,
40 &mut v,
41 self.tokens,
42 self.q_dim,
43 self.kv_dim,
44 );
45 CpuBackend::sync(&mut ctx);
46 let mut out = CpuBackend::to_vec(&q, self.tokens * self.q_dim);
47 out.extend(CpuBackend::to_vec(&k, self.tokens * self.kv_dim));
48 out.extend(CpuBackend::to_vec(&v, self.tokens * self.kv_dim));
49 out
50 }
51
52 #[cfg(all(target_os = "macos", feature = "metal"))]
53 fn run_metal(&self, seed: u64) -> Output {
54 use ferrum_kernels::backend::metal::MetalBackend;
55 use ferrum_kernels::backend::Backend;
56
57 let qkv = random_vec(self.fused(), -2.0, 2.0, seed);
58 let mut ctx = MetalBackend::new_context();
59 let qkv_buf = MetalBackend::from_slice(&qkv);
60 let mut q = MetalBackend::alloc(self.tokens * self.q_dim);
61 let mut k = MetalBackend::alloc(self.tokens * self.kv_dim);
62 let mut v = MetalBackend::alloc(self.tokens * self.kv_dim);
63 MetalBackend::split_qkv(
64 &mut ctx,
65 &qkv_buf,
66 &mut q,
67 &mut k,
68 &mut v,
69 self.tokens,
70 self.q_dim,
71 self.kv_dim,
72 );
73 MetalBackend::sync(&mut ctx);
74 let mut out = MetalBackend::to_vec(&q, self.tokens * self.q_dim);
75 out.extend(MetalBackend::to_vec(&k, self.tokens * self.kv_dim));
76 out.extend(MetalBackend::to_vec(&v, self.tokens * self.kv_dim));
77 out
78 }
79
80 #[cfg(feature = "cuda")]
81 fn run_cuda(&self, seed: u64) -> Output {
82 use ferrum_kernels::backend::cuda::CudaBackend;
83 use ferrum_kernels::backend::Backend;
84
85 let qkv = random_vec(self.fused(), -2.0, 2.0, seed);
86 let mut ctx = CudaBackend::new_context();
87 let qkv_buf = CudaBackend::from_slice(&qkv);
88 let mut q = CudaBackend::alloc(self.tokens * self.q_dim);
89 let mut k = CudaBackend::alloc(self.tokens * self.kv_dim);
90 let mut v = CudaBackend::alloc(self.tokens * self.kv_dim);
91 CudaBackend::split_qkv(
92 &mut ctx,
93 &qkv_buf,
94 &mut q,
95 &mut k,
96 &mut v,
97 self.tokens,
98 self.q_dim,
99 self.kv_dim,
100 );
101 CudaBackend::sync(&mut ctx);
102 let mut out = CudaBackend::to_vec(&q, self.tokens * self.q_dim);
103 out.extend(CudaBackend::to_vec(&k, self.tokens * self.kv_dim));
104 out.extend(CudaBackend::to_vec(&v, self.tokens * self.kv_dim));
105 out
106 }
107}