ferrum_testkit/op_diff/
residual_add.rs1use super::{random_vec, OpUnderTest, Output};
6
7pub struct ResidualAddOp {
8 pub len: usize,
9}
10
11impl ResidualAddOp {
12 fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>) {
13 let residual = random_vec(self.len, -2.0, 2.0, seed);
14 let x = random_vec(self.len, -2.0, 2.0, seed.wrapping_add(1));
15 (residual, x)
16 }
17}
18
19impl OpUnderTest for ResidualAddOp {
20 fn name(&self) -> &str {
21 "residual_add"
22 }
23
24 fn run_cpu(&self, seed: u64) -> Output {
25 use ferrum_kernels::backend::cpu::CpuBackend;
26 use ferrum_kernels::backend::Backend;
27
28 let (residual, x) = self.build_input(seed);
29 let mut ctx = CpuBackend::new_context();
30 let mut residual_buf = CpuBackend::from_slice(&residual);
31 let x_buf = CpuBackend::from_slice(&x);
32 CpuBackend::add_inplace(&mut ctx, &mut residual_buf, &x_buf, self.len);
33 CpuBackend::sync(&mut ctx);
34 CpuBackend::to_vec(&residual_buf, self.len)
35 }
36
37 #[cfg(all(target_os = "macos", feature = "metal"))]
38 fn run_metal(&self, seed: u64) -> Output {
39 use ferrum_kernels::backend::metal::MetalBackend;
40 use ferrum_kernels::backend::Backend;
41
42 let (residual, x) = self.build_input(seed);
43 let mut ctx = MetalBackend::new_context();
44 let mut residual_buf = MetalBackend::from_slice(&residual);
45 let x_buf = MetalBackend::from_slice(&x);
46 MetalBackend::add_inplace(&mut ctx, &mut residual_buf, &x_buf, self.len);
47 MetalBackend::sync(&mut ctx);
48 MetalBackend::to_vec(&residual_buf, self.len)
49 }
50
51 #[cfg(feature = "cuda")]
52 fn run_cuda(&self, seed: u64) -> Output {
53 use ferrum_kernels::backend::cuda::CudaBackend;
54 use ferrum_kernels::backend::Backend;
55
56 let (residual, x) = self.build_input(seed);
57 let mut ctx = CudaBackend::new_context();
58 let mut residual_buf = CudaBackend::from_slice(&residual);
59 let x_buf = CudaBackend::from_slice(&x);
60 CudaBackend::add_inplace(&mut ctx, &mut residual_buf, &x_buf, self.len);
61 CudaBackend::sync(&mut ctx);
62 CudaBackend::to_vec(&residual_buf, self.len)
63 }
64}