Skip to main content

ferrum_testkit/op_diff/
residual_add.rs

1//! `residual_add` (`add_inplace`) op-diff harness — see `crate::op_diff`.
2//!
3//! `residual += x`, elementwise. Output is the updated residual buffer.
4
5use super::{random_vec, OpUnderTest, Output};
6
7pub struct ResidualAddOp {
8    pub len: usize,
9}
10
11impl ResidualAddOp {
12    fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>) {
13        let residual = random_vec(self.len, -2.0, 2.0, seed);
14        let x = random_vec(self.len, -2.0, 2.0, seed.wrapping_add(1));
15        (residual, x)
16    }
17}
18
19impl OpUnderTest for ResidualAddOp {
20    fn name(&self) -> &str {
21        "residual_add"
22    }
23
24    fn run_cpu(&self, seed: u64) -> Output {
25        use ferrum_kernels::backend::cpu::CpuBackend;
26        use ferrum_kernels::backend::Backend;
27
28        let (residual, x) = self.build_input(seed);
29        let mut ctx = CpuBackend::new_context();
30        let mut residual_buf = CpuBackend::from_slice(&residual);
31        let x_buf = CpuBackend::from_slice(&x);
32        CpuBackend::add_inplace(&mut ctx, &mut residual_buf, &x_buf, self.len);
33        CpuBackend::sync(&mut ctx);
34        CpuBackend::to_vec(&residual_buf, self.len)
35    }
36
37    #[cfg(all(target_os = "macos", feature = "metal"))]
38    fn run_metal(&self, seed: u64) -> Output {
39        use ferrum_kernels::backend::metal::MetalBackend;
40        use ferrum_kernels::backend::Backend;
41
42        let (residual, x) = self.build_input(seed);
43        let mut ctx = MetalBackend::new_context();
44        let mut residual_buf = MetalBackend::from_slice(&residual);
45        let x_buf = MetalBackend::from_slice(&x);
46        MetalBackend::add_inplace(&mut ctx, &mut residual_buf, &x_buf, self.len);
47        MetalBackend::sync(&mut ctx);
48        MetalBackend::to_vec(&residual_buf, self.len)
49    }
50
51    #[cfg(feature = "cuda")]
52    fn run_cuda(&self, seed: u64) -> Output {
53        use ferrum_kernels::backend::cuda::CudaBackend;
54        use ferrum_kernels::backend::Backend;
55
56        let (residual, x) = self.build_input(seed);
57        let mut ctx = CudaBackend::new_context();
58        let mut residual_buf = CudaBackend::from_slice(&residual);
59        let x_buf = CudaBackend::from_slice(&x);
60        CudaBackend::add_inplace(&mut ctx, &mut residual_buf, &x_buf, self.len);
61        CudaBackend::sync(&mut ctx);
62        CudaBackend::to_vec(&residual_buf, self.len)
63    }
64}