ferrum_testkit/op_diff/
qk_norm_rope.rs1use super::{random_vec, OpUnderTest, Output};
19
20pub struct QkNormRopeOp {
21 pub tokens: usize,
22 pub heads: usize,
23 pub head_dim: usize,
24 pub pos_offset: usize,
25 pub eps: f32,
26 pub mode: i32,
27}
28
29impl QkNormRopeOp {
30 fn max_pos(&self) -> usize {
31 self.pos_offset + self.tokens + 16 }
33 fn output_len(&self) -> usize {
34 self.tokens * self.heads * self.head_dim
35 }
36
37 fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
42 let half = self.head_dim / 2;
43 let x = random_vec(self.tokens * self.heads * self.head_dim, -2.0, 2.0, seed);
44 let norm = random_vec(self.head_dim, 0.5, 1.5, seed.wrapping_add(1));
45
46 let mut cos = Vec::with_capacity(self.max_pos() * half);
47 let mut sin = Vec::with_capacity(self.max_pos() * half);
48 for pos in 0..self.max_pos() {
49 for i in 0..half {
50 let theta = 10000f32.powf(-(i as f32) * 2.0 / self.head_dim as f32);
51 let angle = pos as f32 * theta;
52 cos.push(angle.cos());
53 sin.push(angle.sin());
54 }
55 }
56 (x, norm, cos, sin)
57 }
58}
59
60impl OpUnderTest for QkNormRopeOp {
61 fn name(&self) -> &str {
62 "qk_norm_rope"
63 }
64
65 fn run_cpu(&self, seed: u64) -> Output {
66 use ferrum_kernels::backend::cpu::CpuBackend;
67 use ferrum_kernels::backend::Backend;
68 let (x, w, cos, sin) = self.build_input(seed);
69 let mut ctx = CpuBackend::new_context();
70 let x_buf = CpuBackend::from_slice(&x);
71 let w_buf = CpuBackend::from_slice(&w);
72 let cos_buf = CpuBackend::from_slice(&cos);
73 let sin_buf = CpuBackend::from_slice(&sin);
74 let mut out = CpuBackend::alloc(self.output_len());
75 CpuBackend::qk_norm_rope(
76 &mut ctx,
77 &x_buf,
78 &w_buf,
79 &cos_buf,
80 &sin_buf,
81 &mut out,
82 self.tokens,
83 self.heads,
84 self.head_dim,
85 self.pos_offset,
86 self.eps,
87 self.mode,
88 );
89 CpuBackend::sync(&mut ctx);
90 CpuBackend::to_vec(&out, self.output_len())
91 }
92
93 #[cfg(all(target_os = "macos", feature = "metal"))]
94 fn run_metal(&self, seed: u64) -> Output {
95 use ferrum_kernels::backend::metal::MetalBackend;
96 use ferrum_kernels::backend::Backend;
97 let (x, w, cos, sin) = self.build_input(seed);
98 let mut ctx = MetalBackend::new_context();
99 let x_buf = MetalBackend::from_slice(&x);
100 let w_buf = MetalBackend::from_slice(&w);
101 let cos_buf = MetalBackend::from_slice(&cos);
102 let sin_buf = MetalBackend::from_slice(&sin);
103 let mut out = MetalBackend::alloc(self.output_len());
104 MetalBackend::qk_norm_rope(
105 &mut ctx,
106 &x_buf,
107 &w_buf,
108 &cos_buf,
109 &sin_buf,
110 &mut out,
111 self.tokens,
112 self.heads,
113 self.head_dim,
114 self.pos_offset,
115 self.eps,
116 self.mode,
117 );
118 MetalBackend::sync(&mut ctx);
119 MetalBackend::to_vec(&out, self.output_len())
120 }
121
122 #[cfg(feature = "cuda")]
123 fn run_cuda(&self, seed: u64) -> Output {
124 use ferrum_kernels::backend::cuda::CudaBackend;
125 use ferrum_kernels::backend::Backend;
126 let (x, w, cos, sin) = self.build_input(seed);
127 let mut ctx = CudaBackend::new_context();
128 let x_buf = CudaBackend::from_slice(&x);
129 let w_buf = CudaBackend::from_slice(&w);
130 let cos_buf = CudaBackend::from_slice(&cos);
131 let sin_buf = CudaBackend::from_slice(&sin);
132 let mut out = CudaBackend::alloc(self.output_len());
133 CudaBackend::qk_norm_rope(
134 &mut ctx,
135 &x_buf,
136 &w_buf,
137 &cos_buf,
138 &sin_buf,
139 &mut out,
140 self.tokens,
141 self.heads,
142 self.head_dim,
143 self.pos_offset,
144 self.eps,
145 self.mode,
146 );
147 CudaBackend::sync(&mut ctx);
148 CudaBackend::to_vec(&out, self.output_len())
149 }
150}