ferrum_testkit/op_diff/
argmax_rows.rs1use super::{random_vec, OpUnderTest, Output};
10
11pub struct ArgmaxRowsOp {
12 pub m: usize,
13 pub n: usize,
14}
15
16impl ArgmaxRowsOp {
17 fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<u32>) {
18 let mut logits = random_vec(self.m * self.n, -1.0, 1.0, seed);
19 let mut expected = Vec::with_capacity(self.m);
20 for row in 0..self.m {
21 let col = ((seed as usize).wrapping_add(row.wrapping_mul(7))) % self.n;
22 logits[row * self.n + col] = 100.0; expected.push(col as u32);
24 }
25 (logits, expected)
26 }
27}
28
29impl OpUnderTest for ArgmaxRowsOp {
30 fn name(&self) -> &str {
31 "argmax_rows_f16"
32 }
33
34 fn run_cpu(&self, seed: u64) -> Output {
35 use ferrum_kernels::backend::cpu::CpuBackend;
36 use ferrum_kernels::backend::Backend;
37
38 let (logits, _) = self.build_input(seed);
39 let mut ctx = CpuBackend::new_context();
40 let buf = CpuBackend::from_slice(&logits);
41 let idx = CpuBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
42 .expect("cpu argmax_rows_f16");
43 idx.into_iter().map(|i| i as f32).collect()
44 }
45
46 #[cfg(all(target_os = "macos", feature = "metal"))]
47 fn run_metal(&self, seed: u64) -> Output {
48 use ferrum_kernels::backend::metal::MetalBackend;
49 use ferrum_kernels::backend::Backend;
50
51 let (logits, _) = self.build_input(seed);
52 let mut ctx = MetalBackend::new_context();
53 let buf = MetalBackend::from_slice(&logits);
54 let idx = MetalBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
55 .expect("metal argmax_rows_f16");
56 idx.into_iter().map(|i| i as f32).collect()
57 }
58
59 #[cfg(feature = "cuda")]
60 fn run_cuda(&self, seed: u64) -> Output {
61 use ferrum_kernels::backend::cuda::CudaBackend;
62 use ferrum_kernels::backend::Backend;
63
64 let (logits, _) = self.build_input(seed);
65 let mut ctx = CudaBackend::new_context();
66 let buf = CudaBackend::from_slice(&logits);
67 let idx = CudaBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
68 .expect("cuda argmax_rows_f16");
69 idx.into_iter().map(|i| i as f32).collect()
70 }
71}