ferrum_testkit/op_diff/
silu_mul.rs1use super::{random_vec, OpUnderTest, Output};
10
11pub struct SiluMulOp {
12 pub tokens: usize,
13 pub intermediate: usize,
15}
16
17impl SiluMulOp {
18 fn input_len(&self) -> usize {
19 self.tokens * 2 * self.intermediate
20 }
21 fn output_len(&self) -> usize {
22 self.tokens * self.intermediate
23 }
24
25 fn build_input(&self, seed: u64) -> Vec<f32> {
26 random_vec(self.input_len(), -3.0, 3.0, seed)
27 }
28}
29
30impl OpUnderTest for SiluMulOp {
31 fn name(&self) -> &str {
32 "fused_silu_mul"
33 }
34
35 fn run_cpu(&self, seed: u64) -> Output {
36 use ferrum_kernels::backend::cpu::CpuBackend;
37 use ferrum_kernels::backend::Backend;
38
39 let gate_up = self.build_input(seed);
40 let mut ctx = CpuBackend::new_context();
41 let gu_buf = CpuBackend::from_slice(&gate_up);
42 let mut out = CpuBackend::alloc(self.output_len());
43 CpuBackend::fused_silu_mul_split(
44 &mut ctx,
45 &gu_buf,
46 &mut out,
47 self.tokens,
48 self.intermediate,
49 );
50 CpuBackend::sync(&mut ctx);
51 CpuBackend::to_vec(&out, self.output_len())
52 }
53
54 #[cfg(all(target_os = "macos", feature = "metal"))]
55 fn run_metal(&self, seed: u64) -> Output {
56 use ferrum_kernels::backend::metal::MetalBackend;
57 use ferrum_kernels::backend::Backend;
58
59 let gate_up = self.build_input(seed);
60 let mut ctx = MetalBackend::new_context();
61 let gu_buf = MetalBackend::from_slice(&gate_up);
62 let mut out = MetalBackend::alloc(self.output_len());
63 MetalBackend::fused_silu_mul_split(
64 &mut ctx,
65 &gu_buf,
66 &mut out,
67 self.tokens,
68 self.intermediate,
69 );
70 MetalBackend::sync(&mut ctx);
71 MetalBackend::to_vec(&out, self.output_len())
72 }
73
74 #[cfg(feature = "cuda")]
75 fn run_cuda(&self, seed: u64) -> Output {
76 use ferrum_kernels::backend::cuda::CudaBackend;
77 use ferrum_kernels::backend::Backend;
78
79 let gate_up = self.build_input(seed);
80 let mut ctx = CudaBackend::new_context();
81 let gu_buf = CudaBackend::from_slice(&gate_up);
82 let mut out = CudaBackend::alloc(self.output_len());
83 CudaBackend::fused_silu_mul_split(
84 &mut ctx,
85 &gu_buf,
86 &mut out,
87 self.tokens,
88 self.intermediate,
89 );
90 CudaBackend::sync(&mut ctx);
91 CudaBackend::to_vec(&out, self.output_len())
92 }
93}