#![cfg(target_os = "macos")]
mod common;
use std::collections::BTreeMap;
use common::{Dt, gpu_lock, pack_bytes, unpack_bytes};
use metaltile_core::ir::KernelMode;
use metaltile_runtime::Context;
use metaltile_std::mlx::sort::{mt_merge, mt_sort};
fn run_sort(inp: &[f32], dt: Dt, n_blocks: usize) -> Vec<f32> {
const N: usize = 1024;
assert_eq!(inp.len(), n_blocks * N, "input must be exactly n_blocks * 1024 elements");
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("inp".into(), pack_bytes(inp, dt));
buffers.insert("out".into(), pack_bytes(&vec![0.0f32; inp.len()], dt));
buffers.insert("n".into(), (N as u32).to_le_bytes().to_vec());
let ctx = Context::new().expect("Context::new on macOS");
let mut kernel = mt_sort::kernel_ir_for(dt.to_dtype());
kernel.mode = KernelMode::Reduction;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [n_blocks, 1, 1], [256, 1, 1])
.expect("sort dispatch");
let mut out = unpack_bytes(result.outputs.get("out").expect("out"), dt);
out.truncate(n_blocks * N);
out
}
fn cpu_sort_block(block: &[f32]) -> Vec<f32> {
let mut v: Vec<f32> = block.to_vec();
v.sort_unstable_by(f32::total_cmp);
v
}
#[test]
fn sort_single_block_matches_cpu_f32() {
let _g = gpu_lock();
const N: usize = 1024;
let inp: Vec<f32> = (0..N).rev().map(|i| i as f32 * 0.1).collect();
let expected = cpu_sort_block(&inp);
let actual = run_sort(&inp, Dt::F32, 1);
assert!(actual.iter().any(|&v| v != 0.0), "sort output all zeros — empty kernel body?");
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
assert!((e - a).abs() < 1e-6, "sort mismatch at [{i}]: expected {e:.4}, got {a:.4}");
}
}
#[test]
fn sort_single_block_random_f32() {
let _g = gpu_lock();
const N: usize = 1024;
let inp: Vec<f32> = (0..N).map(|i| ((i * 37 + 13) % 100) as f32 * 0.1 - 5.0).collect();
let expected = cpu_sort_block(&inp);
let actual = run_sort(&inp, Dt::F32, 1);
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
assert!((e - a).abs() < 1e-6, "sort random mismatch at [{i}]: expected {e:.4}, got {a:.4}");
}
}
#[test]
fn sort_two_independent_blocks_f32() {
let _g = gpu_lock();
const N: usize = 1024;
let block0: Vec<f32> = (0..N).rev().map(|i| i as f32).collect();
let block1: Vec<f32> = (0..N).map(|i| ((i * 53 + 7) % 1000) as f32 * 0.01).collect();
let inp: Vec<f32> = block0.iter().chain(block1.iter()).copied().collect();
let expected0 = cpu_sort_block(&block0);
let expected1 = cpu_sort_block(&block1);
let actual = run_sort(&inp, Dt::F32, 2);
let (actual0, actual1) = actual.split_at(N);
for (i, (e, a)) in expected0.iter().zip(actual0.iter()).enumerate() {
assert!((e - a).abs() < 1e-6, "sort block0 mismatch at [{i}]");
}
for (i, (e, a)) in expected1.iter().zip(actual1.iter()).enumerate() {
assert!((e - a).abs() < 1e-6, "sort block1 mismatch at [{i}]");
}
}
#[test]
fn sort_single_block_f16() {
let _g = gpu_lock();
const N: usize = 1024;
let inp: Vec<f32> = (0..N).map(|i| Dt::F16.round(((N - 1 - i) as f32) * 0.1)).collect();
let expected = cpu_sort_block(&inp);
let actual = run_sort(&inp, Dt::F16, 1);
for (i, (e, a)) in expected.iter().zip(actual.iter()).enumerate() {
assert!((e - a).abs() < 1e-3, "sort f16 mismatch at [{i}]: expected {e:.4}, got {a:.4}");
}
}
#[test]
fn sort_output_is_non_decreasing_f32() {
let _g = gpu_lock();
const N: usize = 1024;
let inp: Vec<f32> = (0..N).map(|i| ((i * 97 + 31) % 200) as f32 - 100.0).collect();
let actual = run_sort(&inp, Dt::F32, 1);
for window in actual.windows(2) {
assert!(window[0] <= window[1], "sort output not non-decreasing at {:?}", window);
}
}
const BLOCK: usize = 1024;
fn run_merge_pass(ctx: &Context, inp: &[f32], dt: Dt, n: usize, run: usize) -> Vec<f32> {
const LOG_STEPS: u32 = 20;
assert!((1u64 << LOG_STEPS) >= (2 * run) as u64, "log_steps too small for run={run}");
let mut buffers: BTreeMap<String, Vec<u8>> = BTreeMap::new();
buffers.insert("inp".into(), pack_bytes(inp, dt));
buffers.insert("out".into(), pack_bytes(&vec![0.0f32; inp.len()], dt));
buffers.insert("n".into(), (n as u32).to_le_bytes().to_vec());
buffers.insert("run".into(), (run as u32).to_le_bytes().to_vec());
buffers.insert("log_steps".into(), LOG_STEPS.to_le_bytes().to_vec());
let mut kernel = mt_merge::kernel_ir_for(dt.to_dtype());
kernel.mode = KernelMode::Grid3D;
const TPG: usize = 256;
let result = ctx
.dispatch_with_grid(&kernel, &buffers, &BTreeMap::new(), [n.div_ceil(TPG), 1, 1], [
TPG, 1, 1,
])
.expect("merge dispatch");
let mut out = unpack_bytes(result.outputs.get("out").expect("out"), dt);
out.truncate(inp.len());
out
}
fn run_multiblock_sort(inp: &[f32], dt: Dt, n: usize) -> Vec<f32> {
assert_eq!(inp.len() % BLOCK, 0, "input must be a whole number of 1024-element blocks");
let n_blocks = inp.len() / BLOCK;
let ctx = Context::new().expect("Context::new on macOS");
let mut data = run_sort(inp, dt, n_blocks);
let mut run = BLOCK;
while run < inp.len() {
data = run_merge_pass(&ctx, &data, dt, n, run);
run *= 2;
}
data.truncate(n);
data
}
fn cpu_sort_full(v: &[f32]) -> Vec<f32> {
let mut s = v.to_vec();
s.sort_unstable_by(f32::total_cmp);
s
}
fn reverse_input(n_blocks: usize) -> Vec<f32> {
let total = n_blocks * BLOCK;
(0..total).rev().map(|i| i as f32).collect()
}
#[test]
fn sort_two_blocks_merge_matches_cpu_f32() {
let _g = gpu_lock();
let inp = reverse_input(2);
let expected = cpu_sort_full(&inp);
let actual = run_multiblock_sort(&inp, Dt::F32, inp.len());
assert!(actual.iter().any(|&v| v != 0.0), "merge output all zeros — empty kernel body?");
for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
assert!((e - a).abs() < 1e-6, "2-block merge mismatch at [{i}]: expected {e}, got {a}");
}
}
#[test]
fn sort_four_blocks_merge_matches_cpu_f32() {
let _g = gpu_lock();
let inp = reverse_input(4);
let expected = cpu_sort_full(&inp);
let actual = run_multiblock_sort(&inp, Dt::F32, inp.len());
for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
assert!((e - a).abs() < 1e-6, "4-block merge mismatch at [{i}]: expected {e}, got {a}");
}
}
#[test]
fn sort_eight_blocks_merge_matches_cpu_f32() {
let _g = gpu_lock();
let inp: Vec<f32> = (0..8 * BLOCK)
.map(|i| ((i * 2_654_435_761usize) % 1_000_003) as f32 * 0.001 - 500.0)
.collect();
let expected = cpu_sort_full(&inp);
let actual = run_multiblock_sort(&inp, Dt::F32, inp.len());
for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
assert!((e - a).abs() < 1e-6, "8-block merge mismatch at [{i}]: expected {e}, got {a}");
}
}
#[test]
fn sort_multiblock_output_is_non_decreasing_f32() {
let _g = gpu_lock();
let inp = reverse_input(4);
let actual = run_multiblock_sort(&inp, Dt::F32, inp.len());
for w in actual.windows(2) {
assert!(w[0] <= w[1], "multi-block sort output not non-decreasing: {:?}", w);
}
}
#[test]
fn sort_eight_blocks_merge_f16() {
let _g = gpu_lock();
let inp: Vec<f32> =
(0..8 * BLOCK).map(|i| Dt::F16.round(((8 * BLOCK - 1 - i) as f32) * 0.25)).collect();
let expected = cpu_sort_full(&inp);
let actual = run_multiblock_sort(&inp, Dt::F16, inp.len());
for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
assert!((e - a).abs() < 1e-2, "8-block f16 merge mismatch at [{i}]: expected {e}, got {a}");
}
}
#[test]
fn sort_merge_stable_equal_keys_f32() {
let _g = gpu_lock();
let ctx = Context::new().expect("Context::new on macOS");
const RUN: usize = 1024;
let run_a: Vec<f32> = (0..RUN).map(|i| (i / 64) as f32).collect(); let run_b: Vec<f32> = (0..RUN).map(|i| (i / 64) as f32).collect(); let inp: Vec<f32> = run_a.iter().chain(&run_b).copied().collect();
let merged = run_merge_pass(&ctx, &inp, Dt::F32, inp.len(), RUN);
for w in merged.windows(2) {
assert!(w[0] <= w[1], "stable merge not non-decreasing: {:?}", w);
}
let mut got = merged.clone();
let mut want = inp.clone();
got.sort_unstable_by(f32::total_cmp);
want.sort_unstable_by(f32::total_cmp);
assert_eq!(got, want, "merge changed the multiset — element dropped/duplicated");
}
#[test]
fn sort_non_power_of_two_n_matches_cpu_f32() {
let _g = gpu_lock();
const N: usize = 2500;
let n_blocks = N.div_ceil(BLOCK); let total = n_blocks * BLOCK; let pad = 1.0e30_f32;
let real: Vec<f32> = (0..N).map(|i| ((i * 7919 + 17) % 10_000) as f32 * 0.1 - 500.0).collect();
let mut inp = real.clone();
inp.resize(total, pad);
let expected = cpu_sort_full(&real);
let actual = run_multiblock_sort(&inp, Dt::F32, N);
assert_eq!(actual.len(), N, "result truncated to logical n");
for (i, (e, a)) in expected.iter().zip(&actual).enumerate() {
assert!((e - a).abs() < 1e-3, "non-pow2 sort mismatch at [{i}]: expected {e}, got {a}");
}
}