Skip to main content

ferrum_testkit/op_diff/
silu_mul.rs

1//! `fused_silu_mul_split` op-diff harness.
2//!
3//! Input layout (matches the kernel API):
4//!   - `gate_up`: tokens × (2 * intermediate)
5//!   - For each token row: `[gate ‖ up]` concatenated
6//! Output:
7//!   - `out`: tokens × intermediate, where `out[i,j] = silu(gate[i,j]) * up[i,j]`
8
9use super::{random_vec, OpUnderTest, Output};
10
11pub struct SiluMulOp {
12    pub tokens: usize,
13    /// One side; the gate_up buffer is `tokens × (2*intermediate)`.
14    pub intermediate: usize,
15}
16
17impl SiluMulOp {
18    fn input_len(&self) -> usize {
19        self.tokens * 2 * self.intermediate
20    }
21    fn output_len(&self) -> usize {
22        self.tokens * self.intermediate
23    }
24
25    fn build_input(&self, seed: u64) -> Vec<f32> {
26        random_vec(self.input_len(), -3.0, 3.0, seed)
27    }
28}
29
30impl OpUnderTest for SiluMulOp {
31    fn name(&self) -> &str {
32        "fused_silu_mul"
33    }
34
35    fn run_cpu(&self, seed: u64) -> Output {
36        use ferrum_kernels::backend::cpu::CpuBackend;
37        use ferrum_kernels::backend::Backend;
38
39        let gate_up = self.build_input(seed);
40        let mut ctx = CpuBackend::new_context();
41        let gu_buf = CpuBackend::from_slice(&gate_up);
42        let mut out = CpuBackend::alloc(self.output_len());
43        CpuBackend::fused_silu_mul_split(
44            &mut ctx,
45            &gu_buf,
46            &mut out,
47            self.tokens,
48            self.intermediate,
49        );
50        CpuBackend::sync(&mut ctx);
51        CpuBackend::to_vec(&out, self.output_len())
52    }
53
54    #[cfg(all(target_os = "macos", feature = "metal"))]
55    fn run_metal(&self, seed: u64) -> Output {
56        use ferrum_kernels::backend::metal::MetalBackend;
57        use ferrum_kernels::backend::Backend;
58
59        let gate_up = self.build_input(seed);
60        let mut ctx = MetalBackend::new_context();
61        let gu_buf = MetalBackend::from_slice(&gate_up);
62        let mut out = MetalBackend::alloc(self.output_len());
63        MetalBackend::fused_silu_mul_split(
64            &mut ctx,
65            &gu_buf,
66            &mut out,
67            self.tokens,
68            self.intermediate,
69        );
70        MetalBackend::sync(&mut ctx);
71        MetalBackend::to_vec(&out, self.output_len())
72    }
73
74    #[cfg(feature = "cuda")]
75    fn run_cuda(&self, seed: u64) -> Output {
76        use ferrum_kernels::backend::cuda::CudaBackend;
77        use ferrum_kernels::backend::Backend;
78
79        let gate_up = self.build_input(seed);
80        let mut ctx = CudaBackend::new_context();
81        let gu_buf = CudaBackend::from_slice(&gate_up);
82        let mut out = CudaBackend::alloc(self.output_len());
83        CudaBackend::fused_silu_mul_split(
84            &mut ctx,
85            &gu_buf,
86            &mut out,
87            self.tokens,
88            self.intermediate,
89        );
90        CudaBackend::sync(&mut ctx);
91        CudaBackend::to_vec(&out, self.output_len())
92    }
93}