#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
#![cfg(target_vendor = "apple")]
use mlx_native::ops::argsort;
use mlx_native::{DType, KernelRegistry, MlxDevice};
fn pseudo_random_f32(seed: u64, n: usize) -> Vec<f32> {
let mut state = seed;
(0..n)
.map(|_| {
state = state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
((state >> 33) as f32) / (u32::MAX as f32) - 0.5
})
.collect()
}
fn setup() -> (MlxDevice, KernelRegistry) {
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
argsort::register(&mut registry);
(device, registry)
}
fn cpu_argsort_desc(data: &[f32]) -> Vec<u32> {
let mut indices: Vec<u32> = (0..data.len() as u32).collect();
indices.sort_unstable_by(|&a, &b| {
data[b as usize]
.partial_cmp(&data[a as usize])
.unwrap_or(std::cmp::Ordering::Equal)
});
indices
}
#[test]
fn test_argsort_random_128() {
let (device, mut registry) = setup();
let row_len: u32 = 128;
let batch_size: u32 = 1;
let data = pseudo_random_f32(42, row_len as usize);
let expected = cpu_argsort_desc(&data);
let byte_len = row_len as usize * 4;
let mut input_buf = device
.alloc_buffer(byte_len, DType::F32, vec![row_len as usize])
.expect("alloc input");
input_buf
.as_mut_slice::<f32>()
.expect("write input")
.copy_from_slice(&data);
let output_buf = device
.alloc_buffer(byte_len, DType::U32, vec![row_len as usize])
.expect("alloc output");
let mut encoder = device.command_encoder().expect("encoder");
argsort::dispatch_argsort_desc_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&input_buf,
&output_buf,
batch_size,
row_len,
)
.expect("dispatch_argsort_desc_f32");
encoder.commit_and_wait().expect("commit_and_wait");
let actual = output_buf.as_slice::<u32>().expect("read output");
for i in 0..row_len as usize - 1 {
assert!(
data[actual[i] as usize] >= data[actual[i + 1] as usize],
"argsort descending violated at position {}: val[{}]={} < val[{}]={}",
i,
actual[i],
data[actual[i] as usize],
actual[i + 1],
data[actual[i + 1] as usize]
);
}
for i in 0..row_len as usize {
assert_eq!(
actual[i], expected[i],
"argsort index mismatch at position {}: GPU={}, CPU={}",
i, actual[i], expected[i]
);
}
}
#[test]
fn test_argsort_batch_4x8() {
let (device, mut registry) = setup();
let row_len: u32 = 8;
let batch_size: u32 = 4;
let total = batch_size as usize * row_len as usize;
let data = pseudo_random_f32(123, total);
let byte_len = total * 4;
let mut input_buf = device
.alloc_buffer(byte_len, DType::F32, vec![batch_size as usize, row_len as usize])
.expect("alloc input");
input_buf
.as_mut_slice::<f32>()
.expect("write input")
.copy_from_slice(&data);
let output_buf = device
.alloc_buffer(byte_len, DType::U32, vec![batch_size as usize, row_len as usize])
.expect("alloc output");
let mut encoder = device.command_encoder().expect("encoder");
argsort::dispatch_argsort_desc_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&input_buf,
&output_buf,
batch_size,
row_len,
)
.expect("dispatch_argsort_desc_f32");
encoder.commit_and_wait().expect("commit_and_wait");
let actual = output_buf.as_slice::<u32>().expect("read output");
for row in 0..batch_size as usize {
let row_data = &data[row * row_len as usize..(row + 1) * row_len as usize];
let row_indices = &actual[row * row_len as usize..(row + 1) * row_len as usize];
let _expected = cpu_argsort_desc(row_data);
for i in 0..row_len as usize - 1 {
assert!(
row_data[row_indices[i] as usize] >= row_data[row_indices[i + 1] as usize],
"batch row {}: descending order violated at {}: val[{}]={} < val[{}]={}",
row,
i,
row_indices[i],
row_data[row_indices[i] as usize],
row_indices[i + 1],
row_data[row_indices[i + 1] as usize]
);
}
let mut seen = vec![false; row_len as usize];
for &idx in row_indices {
assert!(
(idx as usize) < row_len as usize,
"batch row {}: index {} out of range",
row, idx
);
seen[idx as usize] = true;
}
assert!(
seen.iter().all(|&s| s),
"batch row {}: indices are not a valid permutation",
row
);
}
}
#[test]
fn test_argsort_already_sorted() {
let (device, mut registry) = setup();
let row_len: u32 = 8;
let batch_size: u32 = 1;
let data: Vec<f32> = (0..row_len).rev().map(|x| x as f32).collect();
let byte_len = row_len as usize * 4;
let mut input_buf = device
.alloc_buffer(byte_len, DType::F32, vec![row_len as usize])
.expect("alloc input");
input_buf
.as_mut_slice::<f32>()
.expect("write input")
.copy_from_slice(&data);
let output_buf = device
.alloc_buffer(byte_len, DType::U32, vec![row_len as usize])
.expect("alloc output");
let mut encoder = device.command_encoder().expect("encoder");
argsort::dispatch_argsort_desc_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&input_buf,
&output_buf,
batch_size,
row_len,
)
.expect("dispatch_argsort_desc_f32");
encoder.commit_and_wait().expect("commit_and_wait");
let actual = output_buf.as_slice::<u32>().expect("read output");
for i in 0..row_len as usize {
assert_eq!(
actual[i], i as u32,
"already-sorted: expected identity permutation at {}, got {}",
i, actual[i]
);
}
}