Skip to main content

ferrum_testkit/op_diff/
rms_norm.rs

1//! `rms_norm` op-diff harness — see `crate::op_diff` for the framework.
2
3use super::{random_vec, OpUnderTest, Output};
4
5/// One concrete rms_norm invocation. Inputs:
6///   - `x`: tokens × dim activation
7///   - `w`: dim weight (per-channel scale)
8///   - `eps`: the usual RMSNorm epsilon
9///
10/// Output: tokens × dim, same dtype as input on every backend's compute
11/// dtype (typically fp16 on Metal/CUDA, fp32 on CPU).
12pub struct RmsNormOp {
13    pub tokens: usize,
14    pub dim: usize,
15    pub eps: f32,
16}
17
18impl RmsNormOp {
19    fn output_len(&self) -> usize {
20        self.tokens * self.dim
21    }
22
23    /// Inputs are derived from seed so per-backend runs see identical x/w.
24    fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>) {
25        let x = random_vec(self.tokens * self.dim, -2.0, 2.0, seed);
26        let w = random_vec(self.dim, 0.5, 1.5, seed.wrapping_add(1));
27        (x, w)
28    }
29}
30
31impl OpUnderTest for RmsNormOp {
32    fn name(&self) -> &str {
33        "rms_norm"
34    }
35
36    fn run_cpu(&self, seed: u64) -> Output {
37        use ferrum_kernels::backend::cpu::CpuBackend;
38        use ferrum_kernels::backend::Backend;
39
40        let (x, w) = self.build_input(seed);
41        let mut ctx = CpuBackend::new_context();
42        let x_buf = CpuBackend::from_slice(&x);
43        let w_buf = CpuBackend::from_slice(&w);
44        let mut out = CpuBackend::alloc(self.output_len());
45        CpuBackend::rms_norm(
46            &mut ctx,
47            &x_buf,
48            &w_buf,
49            self.eps,
50            &mut out,
51            self.tokens,
52            self.dim,
53        );
54        CpuBackend::sync(&mut ctx);
55        CpuBackend::to_vec(&out, self.output_len())
56    }
57
58    #[cfg(all(target_os = "macos", feature = "metal"))]
59    fn run_metal(&self, seed: u64) -> Output {
60        use ferrum_kernels::backend::metal::MetalBackend;
61        use ferrum_kernels::backend::Backend;
62
63        let (x, w) = self.build_input(seed);
64        let mut ctx = MetalBackend::new_context();
65        let x_buf = MetalBackend::from_slice(&x);
66        let w_buf = MetalBackend::from_slice(&w);
67        let mut out = MetalBackend::alloc(self.output_len());
68        MetalBackend::rms_norm(
69            &mut ctx,
70            &x_buf,
71            &w_buf,
72            self.eps,
73            &mut out,
74            self.tokens,
75            self.dim,
76        );
77        MetalBackend::sync(&mut ctx);
78        MetalBackend::to_vec(&out, self.output_len())
79    }
80
81    #[cfg(feature = "cuda")]
82    fn run_cuda(&self, seed: u64) -> Output {
83        use ferrum_kernels::backend::cuda::CudaBackend;
84        use ferrum_kernels::backend::Backend;
85
86        let (x, w) = self.build_input(seed);
87        let mut ctx = CudaBackend::new_context();
88        let x_buf = CudaBackend::from_slice(&x);
89        let w_buf = CudaBackend::from_slice(&w);
90        let mut out = CudaBackend::alloc(self.output_len());
91        CudaBackend::rms_norm(
92            &mut ctx,
93            &x_buf,
94            &w_buf,
95            self.eps,
96            &mut out,
97            self.tokens,
98            self.dim,
99        );
100        CudaBackend::sync(&mut ctx);
101        CudaBackend::to_vec(&out, self.output_len())
102    }
103}