ferrum_testkit/op_diff/
rms_norm.rs1use super::{random_vec, OpUnderTest, Output};
4
5pub struct RmsNormOp {
13 pub tokens: usize,
14 pub dim: usize,
15 pub eps: f32,
16}
17
18impl RmsNormOp {
19 fn output_len(&self) -> usize {
20 self.tokens * self.dim
21 }
22
23 fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>) {
25 let x = random_vec(self.tokens * self.dim, -2.0, 2.0, seed);
26 let w = random_vec(self.dim, 0.5, 1.5, seed.wrapping_add(1));
27 (x, w)
28 }
29}
30
31impl OpUnderTest for RmsNormOp {
32 fn name(&self) -> &str {
33 "rms_norm"
34 }
35
36 fn run_cpu(&self, seed: u64) -> Output {
37 use ferrum_kernels::backend::cpu::CpuBackend;
38 use ferrum_kernels::backend::Backend;
39
40 let (x, w) = self.build_input(seed);
41 let mut ctx = CpuBackend::new_context();
42 let x_buf = CpuBackend::from_slice(&x);
43 let w_buf = CpuBackend::from_slice(&w);
44 let mut out = CpuBackend::alloc(self.output_len());
45 CpuBackend::rms_norm(
46 &mut ctx,
47 &x_buf,
48 &w_buf,
49 self.eps,
50 &mut out,
51 self.tokens,
52 self.dim,
53 );
54 CpuBackend::sync(&mut ctx);
55 CpuBackend::to_vec(&out, self.output_len())
56 }
57
58 #[cfg(all(target_os = "macos", feature = "metal"))]
59 fn run_metal(&self, seed: u64) -> Output {
60 use ferrum_kernels::backend::metal::MetalBackend;
61 use ferrum_kernels::backend::Backend;
62
63 let (x, w) = self.build_input(seed);
64 let mut ctx = MetalBackend::new_context();
65 let x_buf = MetalBackend::from_slice(&x);
66 let w_buf = MetalBackend::from_slice(&w);
67 let mut out = MetalBackend::alloc(self.output_len());
68 MetalBackend::rms_norm(
69 &mut ctx,
70 &x_buf,
71 &w_buf,
72 self.eps,
73 &mut out,
74 self.tokens,
75 self.dim,
76 );
77 MetalBackend::sync(&mut ctx);
78 MetalBackend::to_vec(&out, self.output_len())
79 }
80
81 #[cfg(feature = "cuda")]
82 fn run_cuda(&self, seed: u64) -> Output {
83 use ferrum_kernels::backend::cuda::CudaBackend;
84 use ferrum_kernels::backend::Backend;
85
86 let (x, w) = self.build_input(seed);
87 let mut ctx = CudaBackend::new_context();
88 let x_buf = CudaBackend::from_slice(&x);
89 let w_buf = CudaBackend::from_slice(&w);
90 let mut out = CudaBackend::alloc(self.output_len());
91 CudaBackend::rms_norm(
92 &mut ctx,
93 &x_buf,
94 &w_buf,
95 self.eps,
96 &mut out,
97 self.tokens,
98 self.dim,
99 );
100 CudaBackend::sync(&mut ctx);
101 CudaBackend::to_vec(&out, self.output_len())
102 }
103}