ferrum_testkit/op_diff/
fused_add_rms_norm.rs1use super::{random_vec, OpUnderTest, Output};
8
9pub struct FusedAddRmsNormOp {
10 pub tokens: usize,
11 pub dim: usize,
12 pub eps: f32,
13}
14
15impl FusedAddRmsNormOp {
16 fn elems(&self) -> usize {
17 self.tokens * self.dim
18 }
19
20 fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
21 let residual = random_vec(self.elems(), -2.0, 2.0, seed);
22 let x = random_vec(self.elems(), -2.0, 2.0, seed.wrapping_add(1));
23 let w = random_vec(self.dim, 0.5, 1.5, seed.wrapping_add(2));
24 (residual, x, w)
25 }
26}
27
28impl OpUnderTest for FusedAddRmsNormOp {
29 fn name(&self) -> &str {
30 "fused_add_rms_norm"
31 }
32
33 fn run_cpu(&self, seed: u64) -> Output {
34 use ferrum_kernels::backend::cpu::CpuBackend;
35 use ferrum_kernels::backend::Backend;
36
37 let (residual, x, w) = self.build_input(seed);
38 let mut ctx = CpuBackend::new_context();
39 let mut residual_buf = CpuBackend::from_slice(&residual);
40 let x_buf = CpuBackend::from_slice(&x);
41 let w_buf = CpuBackend::from_slice(&w);
42 let mut out = CpuBackend::alloc(self.elems());
43 CpuBackend::fused_add_rms_norm(
44 &mut ctx,
45 &mut residual_buf,
46 &x_buf,
47 &w_buf,
48 self.eps,
49 &mut out,
50 self.tokens,
51 self.dim,
52 );
53 CpuBackend::sync(&mut ctx);
54 let mut combined = CpuBackend::to_vec(&residual_buf, self.elems());
55 combined.extend(CpuBackend::to_vec(&out, self.elems()));
56 combined
57 }
58
59 #[cfg(all(target_os = "macos", feature = "metal"))]
60 fn run_metal(&self, seed: u64) -> Output {
61 use ferrum_kernels::backend::metal::MetalBackend;
62 use ferrum_kernels::backend::Backend;
63
64 let (residual, x, w) = self.build_input(seed);
65 let mut ctx = MetalBackend::new_context();
66 let mut residual_buf = MetalBackend::from_slice(&residual);
67 let x_buf = MetalBackend::from_slice(&x);
68 let w_buf = MetalBackend::from_slice(&w);
69 let mut out = MetalBackend::alloc(self.elems());
70 MetalBackend::fused_add_rms_norm(
71 &mut ctx,
72 &mut residual_buf,
73 &x_buf,
74 &w_buf,
75 self.eps,
76 &mut out,
77 self.tokens,
78 self.dim,
79 );
80 MetalBackend::sync(&mut ctx);
81 let mut combined = MetalBackend::to_vec(&residual_buf, self.elems());
82 combined.extend(MetalBackend::to_vec(&out, self.elems()));
83 combined
84 }
85
86 #[cfg(feature = "cuda")]
87 fn run_cuda(&self, seed: u64) -> Output {
88 use ferrum_kernels::backend::cuda::CudaBackend;
89 use ferrum_kernels::backend::Backend;
90
91 let (residual, x, w) = self.build_input(seed);
92 let mut ctx = CudaBackend::new_context();
93 let mut residual_buf = CudaBackend::from_slice(&residual);
94 let x_buf = CudaBackend::from_slice(&x);
95 let w_buf = CudaBackend::from_slice(&w);
96 let mut out = CudaBackend::alloc(self.elems());
97 CudaBackend::fused_add_rms_norm(
98 &mut ctx,
99 &mut residual_buf,
100 &x_buf,
101 &w_buf,
102 self.eps,
103 &mut out,
104 self.tokens,
105 self.dim,
106 );
107 CudaBackend::sync(&mut ctx);
108 let mut combined = CudaBackend::to_vec(&residual_buf, self.elems());
109 combined.extend(CudaBackend::to_vec(&out, self.elems()));
110 combined
111 }
112}