use metaltile::{bench_kernel, kernel};
#[bench_kernel(
op="arg_reduce",
subop="argmax",
class=ArgReduce,
n=1048576,
check_n=4096,
tpg=256,
tol=0.5,
mlx="argmax_{tn}",
metal_file="arg_reduce.metal",
)]
#[kernel]
pub fn mt_argmax<T>(inp: Tensor<T>, out: Tensor<u32>, #[constexpr] n: u32) {
let lid = tid;
let mut best_val = neg_infinity();
let mut best_idx = lid - lid;
threadgroup_alloc("tg_vals", 256);
threadgroup_alloc("tg_idxs", 256);
let n_iters = (n + lsize - 1u32) / lsize;
for _r in range(0u32, n_iters, 1u32) {
let pos = _r * lsize + lid;
if pos < n {
let v = load(inp[pos]).cast::<f32>();
let better = v > best_val;
if better {
best_val = v;
best_idx = pos;
}
}
}
threadgroup_store("tg_vals", lid, best_val);
threadgroup_store("tg_idxs", lid, best_idx);
threadgroup_barrier();
for _stage in range(0u32, 7u32, 1u32) {
let stride = 128u32 >> _stage;
if lid < stride {
let ov = threadgroup_load("tg_vals", lid + stride);
let oi = threadgroup_load("tg_idxs", lid + stride);
let tv = threadgroup_load("tg_vals", lid);
let ti = threadgroup_load("tg_idxs", lid);
let bet = (ov > tv) | ((ov == tv) & (oi < ti));
threadgroup_store("tg_vals", lid, select(bet, ov, tv));
threadgroup_store("tg_idxs", lid, select(bet, oi, ti));
}
threadgroup_barrier();
}
if lid == 0u32 {
let ov = threadgroup_load("tg_vals", 1u32);
let oi = threadgroup_load("tg_idxs", 1u32);
let tv = threadgroup_load("tg_vals", 0u32);
let ti = threadgroup_load("tg_idxs", 0u32);
let bet = (ov > tv) | ((ov == tv) & (oi < ti));
let final_idx = select(bet, oi, ti);
store(out[0], final_idx);
}
}
#[bench_kernel(
op="arg_reduce",
subop="argmin",
class=ArgReduce,
n=1048576,
check_n=4096,
tpg=256,
tol=0.5,
mlx="argmin_{tn}",
metal_file="arg_reduce.metal",
)]
#[kernel]
pub fn mt_argmin<T>(inp: Tensor<T>, out: Tensor<u32>, #[constexpr] n: u32) {
let lid = tid;
let mut best_val = infinity();
let mut best_idx = lid - lid;
threadgroup_alloc("tg_vals", 256);
threadgroup_alloc("tg_idxs", 256);
let n_iters = (n + lsize - 1u32) / lsize;
for _r in range(0u32, n_iters, 1u32) {
let pos = _r * lsize + lid;
if pos < n {
let v = load(inp[pos]).cast::<f32>();
let better = v < best_val;
if better {
best_val = v;
best_idx = pos;
}
}
}
threadgroup_store("tg_vals", lid, best_val);
threadgroup_store("tg_idxs", lid, best_idx);
threadgroup_barrier();
for _stage in range(0u32, 7u32, 1u32) {
let stride = 128u32 >> _stage;
if lid < stride {
let ov = threadgroup_load("tg_vals", lid + stride);
let oi = threadgroup_load("tg_idxs", lid + stride);
let tv = threadgroup_load("tg_vals", lid);
let ti = threadgroup_load("tg_idxs", lid);
let bet = (ov < tv) | ((ov == tv) & (oi < ti));
threadgroup_store("tg_vals", lid, select(bet, ov, tv));
threadgroup_store("tg_idxs", lid, select(bet, oi, ti));
}
threadgroup_barrier();
}
if lid == 0u32 {
let ov = threadgroup_load("tg_vals", 1u32);
let oi = threadgroup_load("tg_idxs", 1u32);
let tv = threadgroup_load("tg_vals", 0u32);
let ti = threadgroup_load("tg_idxs", 0u32);
let bet = (ov < tv) | ((ov == tv) & (oi < ti));
let final_idx = select(bet, oi, ti);
store(out[0], final_idx);
}
}