Skip to main content

ferrum_testkit/op_diff/
fused_add_rms_norm.rs

1//! `fused_add_rms_norm` op-diff harness — see `crate::op_diff`.
2//!
3//! Fused `residual += x; out = rms_norm(residual, w)`. The output compared
4//! across backends is `[residual_after, out]` concatenated, so a divergence
5//! in either the in-place residual update or the norm is caught.
6
7use super::{random_vec, OpUnderTest, Output};
8
9pub struct FusedAddRmsNormOp {
10    pub tokens: usize,
11    pub dim: usize,
12    pub eps: f32,
13}
14
15impl FusedAddRmsNormOp {
16    fn elems(&self) -> usize {
17        self.tokens * self.dim
18    }
19
20    fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
21        let residual = random_vec(self.elems(), -2.0, 2.0, seed);
22        let x = random_vec(self.elems(), -2.0, 2.0, seed.wrapping_add(1));
23        let w = random_vec(self.dim, 0.5, 1.5, seed.wrapping_add(2));
24        (residual, x, w)
25    }
26}
27
28impl OpUnderTest for FusedAddRmsNormOp {
29    fn name(&self) -> &str {
30        "fused_add_rms_norm"
31    }
32
33    fn run_cpu(&self, seed: u64) -> Output {
34        use ferrum_kernels::backend::cpu::CpuBackend;
35        use ferrum_kernels::backend::Backend;
36
37        let (residual, x, w) = self.build_input(seed);
38        let mut ctx = CpuBackend::new_context();
39        let mut residual_buf = CpuBackend::from_slice(&residual);
40        let x_buf = CpuBackend::from_slice(&x);
41        let w_buf = CpuBackend::from_slice(&w);
42        let mut out = CpuBackend::alloc(self.elems());
43        CpuBackend::fused_add_rms_norm(
44            &mut ctx,
45            &mut residual_buf,
46            &x_buf,
47            &w_buf,
48            self.eps,
49            &mut out,
50            self.tokens,
51            self.dim,
52        );
53        CpuBackend::sync(&mut ctx);
54        let mut combined = CpuBackend::to_vec(&residual_buf, self.elems());
55        combined.extend(CpuBackend::to_vec(&out, self.elems()));
56        combined
57    }
58
59    #[cfg(all(target_os = "macos", feature = "metal"))]
60    fn run_metal(&self, seed: u64) -> Output {
61        use ferrum_kernels::backend::metal::MetalBackend;
62        use ferrum_kernels::backend::Backend;
63
64        let (residual, x, w) = self.build_input(seed);
65        let mut ctx = MetalBackend::new_context();
66        let mut residual_buf = MetalBackend::from_slice(&residual);
67        let x_buf = MetalBackend::from_slice(&x);
68        let w_buf = MetalBackend::from_slice(&w);
69        let mut out = MetalBackend::alloc(self.elems());
70        MetalBackend::fused_add_rms_norm(
71            &mut ctx,
72            &mut residual_buf,
73            &x_buf,
74            &w_buf,
75            self.eps,
76            &mut out,
77            self.tokens,
78            self.dim,
79        );
80        MetalBackend::sync(&mut ctx);
81        let mut combined = MetalBackend::to_vec(&residual_buf, self.elems());
82        combined.extend(MetalBackend::to_vec(&out, self.elems()));
83        combined
84    }
85
86    #[cfg(feature = "cuda")]
87    fn run_cuda(&self, seed: u64) -> Output {
88        use ferrum_kernels::backend::cuda::CudaBackend;
89        use ferrum_kernels::backend::Backend;
90
91        let (residual, x, w) = self.build_input(seed);
92        let mut ctx = CudaBackend::new_context();
93        let mut residual_buf = CudaBackend::from_slice(&residual);
94        let x_buf = CudaBackend::from_slice(&x);
95        let w_buf = CudaBackend::from_slice(&w);
96        let mut out = CudaBackend::alloc(self.elems());
97        CudaBackend::fused_add_rms_norm(
98            &mut ctx,
99            &mut residual_buf,
100            &x_buf,
101            &w_buf,
102            self.eps,
103            &mut out,
104            self.tokens,
105            self.dim,
106        );
107        CudaBackend::sync(&mut ctx);
108        let mut combined = CudaBackend::to_vec(&residual_buf, self.elems());
109        combined.extend(CudaBackend::to_vec(&out, self.elems()));
110        combined
111    }
112}