use metaltile::{bench_kernel, kernel};
use metaltile_core::ir::KernelMode;
use crate::{
bench_types::DType,
spec::{BenchDispatch, BenchSpec},
};
#[bench_kernel(
op="sort",
subop="sort",
class=Sort,
b=1024,
n=1024,
tpg=256,
tol=0.0,
mlx="c_block_sort_{tn}_{tn}_bn256_tn4",
metal_file="sort.metal",
)]
#[kernel]
pub fn mt_sort<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let block_id = program_id::<0>();
let t = tid;
threadgroup_alloc("shared", 1024, T);
let base = block_id * n;
threadgroup_store("shared", t * 4u32, load(inp[base + t * 4u32]));
threadgroup_store("shared", t * 4u32 + 1u32, load(inp[base + t * 4u32 + 1u32]));
threadgroup_store("shared", t * 4u32 + 2u32, load(inp[base + t * 4u32 + 2u32]));
threadgroup_store("shared", t * 4u32 + 3u32, load(inp[base + t * 4u32 + 3u32]));
threadgroup_barrier();
for _k in range(1u32, 11u32, 1u32) {
for _jb in range(0u32, _k, 1u32) {
let flip = _k - _jb - 1u32;
if flip >= 7u32 {
threadgroup_barrier();
}
for _e in range(0u32, 4u32, 1u32) {
let gi = t * 4u32 + _e;
let partner = gi ^ (1u32 << flip);
if gi < partner {
let a = threadgroup_load("shared", gi);
let b = threadgroup_load("shared", partner);
let dir = (gi >> _k) & 1u32;
let want_swap = select(dir == 0u32, a > b, a < b);
threadgroup_store("shared", gi, select(want_swap, b, a));
threadgroup_store("shared", partner, select(want_swap, a, b));
}
}
}
}
threadgroup_barrier();
store(out[base + t * 4u32], threadgroup_load("shared", t * 4u32));
store(out[base + t * 4u32 + 1u32], threadgroup_load("shared", t * 4u32 + 1u32));
store(out[base + t * 4u32 + 2u32], threadgroup_load("shared", t * 4u32 + 2u32));
store(out[base + t * 4u32 + 3u32], threadgroup_load("shared", t * 4u32 + 3u32));
}
#[kernel]
pub fn mt_merge<T>(
inp: Tensor<T>,
out: Tensor<T>,
#[constexpr] n: u32,
#[constexpr] run: u32,
#[constexpr] log_steps: u32,
) {
let gi = program_id::<0>();
if gi < n {
let merged = run + run;
let pair = gi / merged;
let o = gi - pair * merged;
let a_start_raw = pair * merged;
let a_start = select(a_start_raw < n, a_start_raw, n);
let a_end_raw = a_start + run;
let a_end = select(a_end_raw < n, a_end_raw, n);
let b_start = a_end;
let b_end_raw = b_start + run;
let b_end = select(b_end_raw < n, b_end_raw, n);
let a_len = a_end - a_start;
let b_len = b_end - b_start;
let lo0 = select(o > b_len, o - b_len, 0u32);
let hi0 = select(o < a_len, o, a_len);
let mut lo = lo0;
let mut hi = hi0;
for _s in range(0u32, log_steps, 1u32) {
let active = lo < hi;
let mid = (lo + hi + 1u32) / 2u32;
let a_idx = select(mid > 0u32, mid - 1u32, 0u32);
let b_idx = o - mid;
let a_in_range = active & (a_start + a_idx < n);
let a_load = load(inp[a_start + a_idx]).cast::<f32>();
let a_probe = select(a_in_range, a_load, infinity());
let b_in_range = active & (b_idx < b_len);
let b_load = load(inp[b_start + b_idx]).cast::<f32>();
let b_probe = select(b_in_range, b_load, infinity());
let take_more_a = a_probe <= b_probe;
lo = select(active, select(take_more_a, mid, lo), lo);
hi = select(active, select(take_more_a, hi, mid - 1u32), hi);
}
let i = lo;
let j = o - i;
let a_real = i < a_len;
let b_real = j < b_len;
let a_safe = select(a_start + i < n, a_start + i, 0u32);
let b_safe = select(b_start + j < n, b_start + j, 0u32);
let a_val = select(a_real, load(inp[a_safe]).cast::<f32>(), infinity());
let b_val = select(b_real, load(inp[b_safe]).cast::<f32>(), infinity());
let pick_a = a_val <= b_val;
let chosen = select(pick_a, a_val, b_val);
store(out[gi], chosen.cast::<T>());
}
}
inventory::submit! {
BenchSpec {
op: "sort",
subop: "merge",
kernel_name: "mt_merge",
kernel_ir: mt_merge::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 0.0,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Grid3D),
}
}
#[kernel]
pub fn mt_sort_segmented<T>(inp: Tensor<T>, out: Tensor<T>, #[constexpr] n: u32) {
let row = tgid_x;
let t = tid;
threadgroup_alloc("shared", 1024, T);
let row_base = row * n;
let i0 = t * 4u32;
let i1 = i0 + 1u32;
let i2 = i0 + 2u32;
let i3 = i0 + 3u32;
let inf_f = infinity();
let v0 = select(i0 < n, load(inp[row_base + i0]).cast::<f32>(), inf_f);
let v1 = select(i1 < n, load(inp[row_base + i1]).cast::<f32>(), inf_f);
let v2 = select(i2 < n, load(inp[row_base + i2]).cast::<f32>(), inf_f);
let v3 = select(i3 < n, load(inp[row_base + i3]).cast::<f32>(), inf_f);
threadgroup_store("shared", i0, v0.cast::<T>());
threadgroup_store("shared", i1, v1.cast::<T>());
threadgroup_store("shared", i2, v2.cast::<T>());
threadgroup_store("shared", i3, v3.cast::<T>());
threadgroup_barrier();
for _k in range(1u32, 11u32, 1u32) {
for _jb in range(0u32, _k, 1u32) {
let flip = _k - _jb - 1u32;
if flip >= 7u32 {
threadgroup_barrier();
}
for _e in range(0u32, 4u32, 1u32) {
let gi = t * 4u32 + _e;
let partner = gi ^ (1u32 << flip);
if gi < partner {
let a = threadgroup_load("shared", gi);
let b = threadgroup_load("shared", partner);
let dir = (gi >> _k) & 1u32;
let a_f = a.cast::<f32>();
let b_f = b.cast::<f32>();
let want_swap = select(dir == 0u32, a_f > b_f, a_f < b_f);
threadgroup_store("shared", gi, select(want_swap, b, a));
threadgroup_store("shared", partner, select(want_swap, a, b));
}
}
}
}
threadgroup_barrier();
if i0 < n {
store(out[row_base + i0], threadgroup_load("shared", i0));
}
if i1 < n {
store(out[row_base + i1], threadgroup_load("shared", i1));
}
if i2 < n {
store(out[row_base + i2], threadgroup_load("shared", i2));
}
if i3 < n {
store(out[row_base + i3], threadgroup_load("shared", i3));
}
}
inventory::submit! {
BenchSpec {
op: "sort",
subop: "sort_segmented",
kernel_name: "mt_sort_segmented",
kernel_ir: mt_sort_segmented::kernel_ir_for,
dtypes: &[DType::F32, DType::F16, DType::BF16],
tol: 0.0,
mlx_src: None,
mlx_pattern: None,
shapes: &[],
dispatch: BenchDispatch::Generic,
kernel_mode: Some(KernelMode::Reduction),
}
}