Skip to main content

ferrum_testkit/op_diff/
argmax_rows.rs

1//! `argmax_rows_f16` op-diff harness — see `crate::op_diff`.
2//!
3//! Per-row argmax over an `[m, n]` logits buffer. Metal stores logits as f16,
4//! so to keep the argmax unambiguous across the f32 reference and the f16
5//! kernel, each row gets a well-separated spike at a deterministic column —
6//! both backends must select it regardless of f16 rounding. The compared
7//! output is the m winning indices (as f32).
8
9use super::{random_vec, OpUnderTest, Output};
10
11pub struct ArgmaxRowsOp {
12    pub m: usize,
13    pub n: usize,
14}
15
16impl ArgmaxRowsOp {
17    fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<u32>) {
18        let mut logits = random_vec(self.m * self.n, -1.0, 1.0, seed);
19        let mut expected = Vec::with_capacity(self.m);
20        for row in 0..self.m {
21            let col = ((seed as usize).wrapping_add(row.wrapping_mul(7))) % self.n;
22            logits[row * self.n + col] = 100.0; // unambiguous spike
23            expected.push(col as u32);
24        }
25        (logits, expected)
26    }
27}
28
29impl OpUnderTest for ArgmaxRowsOp {
30    fn name(&self) -> &str {
31        "argmax_rows_f16"
32    }
33
34    fn run_cpu(&self, seed: u64) -> Output {
35        use ferrum_kernels::backend::cpu::CpuBackend;
36        use ferrum_kernels::backend::Backend;
37
38        let (logits, _) = self.build_input(seed);
39        let mut ctx = CpuBackend::new_context();
40        let buf = CpuBackend::from_slice(&logits);
41        let idx = CpuBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
42            .expect("cpu argmax_rows_f16");
43        idx.into_iter().map(|i| i as f32).collect()
44    }
45
46    #[cfg(all(target_os = "macos", feature = "metal"))]
47    fn run_metal(&self, seed: u64) -> Output {
48        use ferrum_kernels::backend::metal::MetalBackend;
49        use ferrum_kernels::backend::Backend;
50
51        let (logits, _) = self.build_input(seed);
52        let mut ctx = MetalBackend::new_context();
53        let buf = MetalBackend::from_slice(&logits);
54        let idx = MetalBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
55            .expect("metal argmax_rows_f16");
56        idx.into_iter().map(|i| i as f32).collect()
57    }
58
59    #[cfg(feature = "cuda")]
60    fn run_cuda(&self, seed: u64) -> Output {
61        use ferrum_kernels::backend::cuda::CudaBackend;
62        use ferrum_kernels::backend::Backend;
63
64        let (logits, _) = self.build_input(seed);
65        let mut ctx = CudaBackend::new_context();
66        let buf = CudaBackend::from_slice(&logits);
67        let idx = CudaBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
68            .expect("cuda argmax_rows_f16");
69        idx.into_iter().map(|i| i as f32).collect()
70    }
71}