Skip to main content

ferrum_testkit/op_diff/
qk_norm_rope.rs

1//! `qk_norm_rope` op-diff harness — covers the fused
2//! rms_norm + rotary-position-embedding + head-major transpose used in
3//! all transformer attention layers (`split_qkv_norm_rope_into_paged_cache_f16`
4//! in nsys traces — top-10 kernel on M3).
5//!
6//! Layout:
7//!   input  `[tokens, heads, head_dim]` — token-major
8//!   norm_w `[head_dim]`
9//!   cos    `[max_pos, head_dim/2]`
10//!   sin    `[max_pos, head_dim/2]`
11//!   output `[heads, tokens, head_dim]` — head-major after the fused
12//!          transpose
13//!
14//! `mode = 1` exercises the actual RoPE pairs path; `mode = 0` is the
15//! transpose-only fallback (which the harness can also test for fast
16//! sanity).
17
18use super::{random_vec, OpUnderTest, Output};
19
20pub struct QkNormRopeOp {
21    pub tokens: usize,
22    pub heads: usize,
23    pub head_dim: usize,
24    pub pos_offset: usize,
25    pub eps: f32,
26    pub mode: i32,
27}
28
29impl QkNormRopeOp {
30    fn max_pos(&self) -> usize {
31        self.pos_offset + self.tokens + 16 // small safety margin
32    }
33    fn output_len(&self) -> usize {
34        self.tokens * self.heads * self.head_dim
35    }
36
37    /// Inputs derived from seed:
38    ///   x:     [-2, 2)
39    ///   norm:  [0.5, 1.5)
40    ///   cos/sin: precomputed RoPE rotation tables (theta_i = 10000^{-2i/d})
41    fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>, Vec<f32>, Vec<f32>) {
42        let half = self.head_dim / 2;
43        let x = random_vec(self.tokens * self.heads * self.head_dim, -2.0, 2.0, seed);
44        let norm = random_vec(self.head_dim, 0.5, 1.5, seed.wrapping_add(1));
45
46        let mut cos = Vec::with_capacity(self.max_pos() * half);
47        let mut sin = Vec::with_capacity(self.max_pos() * half);
48        for pos in 0..self.max_pos() {
49            for i in 0..half {
50                let theta = 10000f32.powf(-(i as f32) * 2.0 / self.head_dim as f32);
51                let angle = pos as f32 * theta;
52                cos.push(angle.cos());
53                sin.push(angle.sin());
54            }
55        }
56        (x, norm, cos, sin)
57    }
58}
59
60impl OpUnderTest for QkNormRopeOp {
61    fn name(&self) -> &str {
62        "qk_norm_rope"
63    }
64
65    fn run_cpu(&self, seed: u64) -> Output {
66        use ferrum_kernels::backend::cpu::CpuBackend;
67        use ferrum_kernels::backend::Backend;
68        let (x, w, cos, sin) = self.build_input(seed);
69        let mut ctx = CpuBackend::new_context();
70        let x_buf = CpuBackend::from_slice(&x);
71        let w_buf = CpuBackend::from_slice(&w);
72        let cos_buf = CpuBackend::from_slice(&cos);
73        let sin_buf = CpuBackend::from_slice(&sin);
74        let mut out = CpuBackend::alloc(self.output_len());
75        CpuBackend::qk_norm_rope(
76            &mut ctx,
77            &x_buf,
78            &w_buf,
79            &cos_buf,
80            &sin_buf,
81            &mut out,
82            self.tokens,
83            self.heads,
84            self.head_dim,
85            self.pos_offset,
86            self.eps,
87            self.mode,
88        );
89        CpuBackend::sync(&mut ctx);
90        CpuBackend::to_vec(&out, self.output_len())
91    }
92
93    #[cfg(all(target_os = "macos", feature = "metal"))]
94    fn run_metal(&self, seed: u64) -> Output {
95        use ferrum_kernels::backend::metal::MetalBackend;
96        use ferrum_kernels::backend::Backend;
97        let (x, w, cos, sin) = self.build_input(seed);
98        let mut ctx = MetalBackend::new_context();
99        let x_buf = MetalBackend::from_slice(&x);
100        let w_buf = MetalBackend::from_slice(&w);
101        let cos_buf = MetalBackend::from_slice(&cos);
102        let sin_buf = MetalBackend::from_slice(&sin);
103        let mut out = MetalBackend::alloc(self.output_len());
104        MetalBackend::qk_norm_rope(
105            &mut ctx,
106            &x_buf,
107            &w_buf,
108            &cos_buf,
109            &sin_buf,
110            &mut out,
111            self.tokens,
112            self.heads,
113            self.head_dim,
114            self.pos_offset,
115            self.eps,
116            self.mode,
117        );
118        MetalBackend::sync(&mut ctx);
119        MetalBackend::to_vec(&out, self.output_len())
120    }
121
122    #[cfg(feature = "cuda")]
123    fn run_cuda(&self, seed: u64) -> Output {
124        use ferrum_kernels::backend::cuda::CudaBackend;
125        use ferrum_kernels::backend::Backend;
126        let (x, w, cos, sin) = self.build_input(seed);
127        let mut ctx = CudaBackend::new_context();
128        let x_buf = CudaBackend::from_slice(&x);
129        let w_buf = CudaBackend::from_slice(&w);
130        let cos_buf = CudaBackend::from_slice(&cos);
131        let sin_buf = CudaBackend::from_slice(&sin);
132        let mut out = CudaBackend::alloc(self.output_len());
133        CudaBackend::qk_norm_rope(
134            &mut ctx,
135            &x_buf,
136            &w_buf,
137            &cos_buf,
138            &sin_buf,
139            &mut out,
140            self.tokens,
141            self.heads,
142            self.head_dim,
143            self.pos_offset,
144            self.eps,
145            self.mode,
146        );
147        CudaBackend::sync(&mut ctx);
148        CudaBackend::to_vec(&out, self.output_len())
149    }
150}