ferrum_testkit/op_diff/
gemm.rs1use super::{random_vec, OpUnderTest, Output};
8use ferrum_kernels::backend::Backend;
9
10pub struct GemmOp {
13 pub m: usize,
14 pub n: usize,
15 pub k: usize,
16}
17
18impl GemmOp {
19 fn input_a_len(&self) -> usize {
20 self.m * self.k
21 }
22 fn input_b_len(&self) -> usize {
23 self.n * self.k
24 }
25 fn output_len(&self) -> usize {
26 self.m * self.n
27 }
28
29 fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<f32>) {
30 let a = random_vec(self.input_a_len(), -1.0, 1.0, seed);
31 let b = random_vec(self.input_b_len(), -1.0, 1.0, seed.wrapping_add(1));
32 (a, b)
33 }
34}
35
36impl OpUnderTest for GemmOp {
37 fn name(&self) -> &str {
38 "gemm"
39 }
40
41 fn run_cpu(&self, seed: u64) -> Output {
42 use ferrum_kernels::backend::cpu::CpuBackend;
43 let (a, b) = self.build_input(seed);
44 let mut ctx = CpuBackend::new_context();
45 let a_buf = CpuBackend::from_slice(&a);
46 let b_buf = CpuBackend::from_slice(&b);
47 let mut out = CpuBackend::alloc(self.output_len());
48 CpuBackend::gemm(&mut ctx, &a_buf, &b_buf, &mut out, self.m, self.n, self.k);
49 CpuBackend::sync(&mut ctx);
50 CpuBackend::to_vec(&out, self.output_len())
51 }
52
53 #[cfg(all(target_os = "macos", feature = "metal"))]
54 fn run_metal(&self, seed: u64) -> Output {
55 use ferrum_kernels::backend::metal::MetalBackend;
56 let (a, b) = self.build_input(seed);
57 let mut ctx = MetalBackend::new_context();
58 let a_buf = MetalBackend::from_slice(&a);
59 let b_buf = MetalBackend::from_slice(&b);
60 let mut out = MetalBackend::alloc(self.output_len());
61 MetalBackend::gemm(&mut ctx, &a_buf, &b_buf, &mut out, self.m, self.n, self.k);
62 MetalBackend::sync(&mut ctx);
63 MetalBackend::to_vec(&out, self.output_len())
64 }
65
66 #[cfg(feature = "cuda")]
67 fn run_cuda(&self, seed: u64) -> Output {
68 use ferrum_kernels::backend::cuda::CudaBackend;
69 let (a, b) = self.build_input(seed);
70 let mut ctx = CudaBackend::new_context();
71 let a_buf = CudaBackend::from_slice(&a);
72 let b_buf = CudaBackend::from_slice(&b);
73 let mut out = CudaBackend::alloc(self.output_len());
74 CudaBackend::gemm(&mut ctx, &a_buf, &b_buf, &mut out, self.m, self.n, self.k);
75 CudaBackend::sync(&mut ctx);
76 CudaBackend::to_vec(&out, self.output_len())
77 }
78}