#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
use std::fs;
use std::path::PathBuf;
use mlx_native::ops::chunk_gated_delta_rule_tri_solve_invert::{
build_chunk_tri_solve_invert_params, dispatch_chunk_tri_solve_invert,
ChunkTriSolveInvertParams, FIXED_BT,
};
use mlx_native::{DType, KernelRegistry, MlxBuffer, MlxDevice};
const B: u32 = 4;
const T: u32 = 64;
const H: u32 = 1;
const BT: u32 = 64;
fn fixture_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures")
}
fn read_bytes(name: &str) -> Vec<u8> {
let path = fixture_dir().join(name);
fs::read(&path).unwrap_or_else(|e| {
panic!(
"failed to read fixture {} — did you run \
`python3 tests/fixtures/chunk_tri_solve_invert_reference.py`? ({})",
path.display(),
e
)
})
}
fn read_f32(name: &str) -> Vec<f32> {
let bytes = read_bytes(name);
assert!(bytes.len() % 4 == 0, "f32 byte length not multiple of 4");
let mut out = Vec::with_capacity(bytes.len() / 4);
for chunk in bytes.chunks_exact(4) {
out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
out
}
fn upload_f32(device: &MlxDevice, data: &[f32]) -> MlxBuffer {
let mut buf = device
.alloc_buffer(data.len() * 4, DType::F32, vec![data.len()])
.expect("alloc f32");
buf.as_mut_slice::<f32>()
.expect("mut")
.copy_from_slice(data);
buf
}
fn run_kernel(a_strict: &[f32]) -> Vec<f32> {
assert_eq!(BT, FIXED_BT, "test must use kernel's FIXED_BT");
let device = MlxDevice::new().expect("MlxDevice::new");
let mut registry = KernelRegistry::new();
let elems = (B * T * H * BT) as usize;
assert_eq!(a_strict.len(), elems, "input length mismatch");
let a_strict_buf = upload_f32(&device, a_strict);
let a_inv_buf = upload_f32(&device, &vec![0.0f32; elems]);
let p = ChunkTriSolveInvertParams {
b: B,
t: T,
h: H,
bt: BT,
};
let params_buf = build_chunk_tri_solve_invert_params(&device, p).expect("params");
let mut enc = device.command_encoder().expect("enc");
dispatch_chunk_tri_solve_invert(
&mut enc,
&mut registry,
device.metal_device(),
&a_strict_buf,
&a_inv_buf,
¶ms_buf,
p,
)
.expect("dispatch");
enc.commit_and_wait().expect("commit");
a_inv_buf.as_slice::<f32>().expect("read A_inv").to_vec()
}
#[test]
fn test_chunk_tri_solve_invert_random_matrices() {
let a_strict = read_f32("chunk_tri_solve_invert_input_a_strict.bin");
let a_inv_ref = read_f32("chunk_tri_solve_invert_a_inv_ref.bin");
let elems = (B * T * H * BT) as usize;
assert_eq!(a_strict.len(), elems, "A_strict length");
assert_eq!(a_inv_ref.len(), elems, "A_inv ref length");
let got = run_kernel(&a_strict);
assert_eq!(got.len(), a_inv_ref.len(), "A_inv length mismatch");
let atol: f32 = 1e-5;
let mut max_err = 0.0f32;
let mut max_err_pos = 0usize;
for (i, (&g, &r)) in got.iter().zip(a_inv_ref.iter()).enumerate() {
let err = (g - r).abs();
if err > max_err {
max_err = err;
max_err_pos = i;
}
assert!(g.is_finite(), "A_inv[{}] is non-finite: {}", i, g);
}
if max_err > atol {
panic!(
"chunk_tri_solve_invert: max_err {:.3e} > atol {:.0e} at idx {} \
(gpu={} ref={})",
max_err, atol, max_err_pos, got[max_err_pos], a_inv_ref[max_err_pos]
);
}
eprintln!(
"chunk_tri_solve_invert random OK max_err={:.3e} (atol={:.0e}, B={}, BT={})",
max_err, atol, B, BT
);
}
#[test]
fn test_chunk_tri_solve_invert_zero_input_yields_identity() {
let elems = (B * T * H * BT) as usize;
let a_strict = vec![0.0f32; elems];
let got = run_kernel(&a_strict);
for b in 0..B as usize {
for i in 0..BT as usize {
for j in 0..BT as usize {
let idx = (b * T as usize + i) * H as usize * BT as usize + j;
let expected = if i == j { 1.0f32 } else { 0.0f32 };
let v = got[idx];
assert!(
(v - expected).abs() < 1e-6,
"zero-input: A_inv[b={}, i={}, j={}] = {} != {}",
b,
i,
j,
v,
expected
);
}
}
}
eprintln!("chunk_tri_solve_invert zero-input OK A_inv == I");
}
#[test]
fn test_chunk_tri_solve_invert_near_singular_no_nan() {
let elems = (B * T * H * BT) as usize;
let mut a_strict = vec![0.0f32; elems];
for b in 0..B as usize {
for i in 0..BT as usize {
for j in 0..i {
let idx = (b * T as usize + i) * H as usize * BT as usize + j;
a_strict[idx] = 0.5f32 * (((i + j) % 2) as f32 * 2.0 - 1.0); }
}
}
let got = run_kernel(&a_strict);
let mut max_abs = 0.0f32;
for (i, &v) in got.iter().enumerate() {
assert!(v.is_finite(), "near-singular: A_inv[{}] = {} (non-finite)", i, v);
max_abs = max_abs.max(v.abs());
}
eprintln!(
"chunk_tri_solve_invert near-singular OK no NaN/Inf, max|A_inv| = {:.3e}",
max_abs
);
}