use super::{random_vec, OpUnderTest, Output};
pub struct ArgmaxRowsOp {
pub m: usize,
pub n: usize,
}
impl ArgmaxRowsOp {
fn build_input(&self, seed: u64) -> (Vec<f32>, Vec<u32>) {
let mut logits = random_vec(self.m * self.n, -1.0, 1.0, seed);
let mut expected = Vec::with_capacity(self.m);
for row in 0..self.m {
let col = ((seed as usize).wrapping_add(row.wrapping_mul(7))) % self.n;
logits[row * self.n + col] = 100.0; expected.push(col as u32);
}
(logits, expected)
}
}
impl OpUnderTest for ArgmaxRowsOp {
fn name(&self) -> &str {
"argmax_rows_f16"
}
fn run_cpu(&self, seed: u64) -> Output {
use ferrum_kernels::backend::cpu::CpuBackend;
use ferrum_kernels::backend::Backend;
let (logits, _) = self.build_input(seed);
let mut ctx = CpuBackend::new_context();
let buf = CpuBackend::from_slice(&logits);
let idx = CpuBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
.expect("cpu argmax_rows_f16");
idx.into_iter().map(|i| i as f32).collect()
}
#[cfg(all(target_os = "macos", feature = "metal"))]
fn run_metal(&self, seed: u64) -> Output {
use ferrum_kernels::backend::metal::MetalBackend;
use ferrum_kernels::backend::Backend;
let (logits, _) = self.build_input(seed);
let mut ctx = MetalBackend::new_context();
let buf = MetalBackend::from_slice(&logits);
let idx = MetalBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
.expect("metal argmax_rows_f16");
idx.into_iter().map(|i| i as f32).collect()
}
#[cfg(feature = "cuda")]
fn run_cuda(&self, seed: u64) -> Output {
use ferrum_kernels::backend::cuda::CudaBackend;
use ferrum_kernels::backend::Backend;
let (logits, _) = self.build_input(seed);
let mut ctx = CudaBackend::new_context();
let buf = CudaBackend::from_slice(&logits);
let idx = CudaBackend::argmax_rows_f16(&mut ctx, &buf, self.m, self.n)
.expect("cuda argmax_rows_f16");
idx.into_iter().map(|i| i as f32).collect()
}
}