mlx-native 0.9.0

Pure-Rust Metal GPU compute library for MLX-compatible inference on Apple Silicon
Documentation
//! Wave 5b.1 iter 4.5 (T1) — direct unit test for chunk_tri_solve_invert.
//!
//! Codex iter-4 audit (missed-test): the chunk_tri_solve_invert kernel had only
//! end-to-end coverage via test_chunk_gated_delta_rule_fwd.rs. A sign-flip
//! ((I - A)^-1 vs (I + A)^-1) or a stride bug could be masked by downstream
//! stages. This file asserts the kernel's semantics directly:
//!
//!     A_inv == numpy.linalg.inv(I + A_strict)
//!
//! at f32 atol=1e-3 (tight; both sides are f32, the kernel walks O(BT^2)
//! FMAs per column so round-off is bounded by ~BT * eps * max(|L|) ~= 1e-5
//! at our magnitudes).
//!
//! Reference fixture is generated by:
//! ```sh
//! python3 tests/fixtures/chunk_tri_solve_invert_reference.py
//! ```
//!
//! Test cases:
//!   1. `test_chunk_tri_solve_invert_random_matrices` — 4 random strict-lower
//!      [BT=64, BT=64] matrices vs numpy.linalg.inv(I+A) at 1e-3 atol.
//!   2. `test_chunk_tri_solve_invert_zero_input` — A_strict = 0 must yield
//!      A_inv = I (degenerate case).
//!   3. `test_chunk_tri_solve_invert_near_singular` — A_strict scaled toward
//!      ill-conditioning; the kernel must still produce no NaN/Inf
//!      (forward-substitution doesn't divide, so this is well-defined).

#![allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]

use std::fs;
use std::path::PathBuf;

use mlx_native::ops::chunk_gated_delta_rule_tri_solve_invert::{
    build_chunk_tri_solve_invert_params, dispatch_chunk_tri_solve_invert,
    ChunkTriSolveInvertParams, FIXED_BT,
};
use mlx_native::{DType, KernelRegistry, MlxBuffer, MlxDevice};

const B: u32 = 4;
const T: u32 = 64;
const H: u32 = 1;
const BT: u32 = 64;

fn fixture_dir() -> PathBuf {
    PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures")
}

fn read_bytes(name: &str) -> Vec<u8> {
    let path = fixture_dir().join(name);
    fs::read(&path).unwrap_or_else(|e| {
        panic!(
            "failed to read fixture {} — did you run \
             `python3 tests/fixtures/chunk_tri_solve_invert_reference.py`? ({})",
            path.display(),
            e
        )
    })
}

fn read_f32(name: &str) -> Vec<f32> {
    let bytes = read_bytes(name);
    assert!(bytes.len() % 4 == 0, "f32 byte length not multiple of 4");
    let mut out = Vec::with_capacity(bytes.len() / 4);
    for chunk in bytes.chunks_exact(4) {
        out.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
    }
    out
}

fn upload_f32(device: &MlxDevice, data: &[f32]) -> MlxBuffer {
    let mut buf = device
        .alloc_buffer(data.len() * 4, DType::F32, vec![data.len()])
        .expect("alloc f32");
    buf.as_mut_slice::<f32>()
        .expect("mut")
        .copy_from_slice(data);
    buf
}

/// Run the kernel on the supplied A_strict input, return the GPU output.
fn run_kernel(a_strict: &[f32]) -> Vec<f32> {
    assert_eq!(BT, FIXED_BT, "test must use kernel's FIXED_BT");
    let device = MlxDevice::new().expect("MlxDevice::new");
    let mut registry = KernelRegistry::new();

    let elems = (B * T * H * BT) as usize;
    assert_eq!(a_strict.len(), elems, "input length mismatch");

    let a_strict_buf = upload_f32(&device, a_strict);
    let a_inv_buf = upload_f32(&device, &vec![0.0f32; elems]);

    let p = ChunkTriSolveInvertParams {
        b: B,
        t: T,
        h: H,
        bt: BT,
    };
    let params_buf = build_chunk_tri_solve_invert_params(&device, p).expect("params");

    let mut enc = device.command_encoder().expect("enc");
    dispatch_chunk_tri_solve_invert(
        &mut enc,
        &mut registry,
        device.metal_device(),
        &a_strict_buf,
        &a_inv_buf,
        &params_buf,
        p,
    )
    .expect("dispatch");
    enc.commit_and_wait().expect("commit");

    a_inv_buf.as_slice::<f32>().expect("read A_inv").to_vec()
}

#[test]
fn test_chunk_tri_solve_invert_random_matrices() {
    // Load the deterministic seed-0xC0FFEE inputs + numpy.linalg.inv reference.
    let a_strict = read_f32("chunk_tri_solve_invert_input_a_strict.bin");
    let a_inv_ref = read_f32("chunk_tri_solve_invert_a_inv_ref.bin");

    let elems = (B * T * H * BT) as usize;
    assert_eq!(a_strict.len(), elems, "A_strict length");
    assert_eq!(a_inv_ref.len(), elems, "A_inv ref length");

    let got = run_kernel(&a_strict);
    assert_eq!(got.len(), a_inv_ref.len(), "A_inv length mismatch");

    // Tight bar: f32 vs f32, well-conditioned inputs, observed max_err
    // ~1.19e-7 (~1 ULP). 1e-5 leaves 100x safety margin while still catching
    // any real numerical regression. Codex iter-4.5 audit recommended-revision.
    let atol: f32 = 1e-5;
    let mut max_err = 0.0f32;
    let mut max_err_pos = 0usize;
    for (i, (&g, &r)) in got.iter().zip(a_inv_ref.iter()).enumerate() {
        let err = (g - r).abs();
        if err > max_err {
            max_err = err;
            max_err_pos = i;
        }
        assert!(g.is_finite(), "A_inv[{}] is non-finite: {}", i, g);
    }
    if max_err > atol {
        panic!(
            "chunk_tri_solve_invert: max_err {:.3e} > atol {:.0e} at idx {} \
             (gpu={} ref={})",
            max_err, atol, max_err_pos, got[max_err_pos], a_inv_ref[max_err_pos]
        );
    }
    eprintln!(
        "chunk_tri_solve_invert random OK   max_err={:.3e} (atol={:.0e}, B={}, BT={})",
        max_err, atol, B, BT
    );
}

#[test]
fn test_chunk_tri_solve_invert_zero_input_yields_identity() {
    // A_strict = 0 -> (I + 0)^-1 = I. Per-block: A_inv block must be the
    // [BT, BT] identity matrix.
    let elems = (B * T * H * BT) as usize;
    let a_strict = vec![0.0f32; elems];

    let got = run_kernel(&a_strict);

    // For each block, the [T=BT, BT] = [64, 64] tile (with H=1) must be I.
    for b in 0..B as usize {
        for i in 0..BT as usize {
            for j in 0..BT as usize {
                // Layout: ((b*T + i) * H + 0) * BT + j  with T=BT, H=1
                let idx = (b * T as usize + i) * H as usize * BT as usize + j;
                let expected = if i == j { 1.0f32 } else { 0.0f32 };
                let v = got[idx];
                assert!(
                    (v - expected).abs() < 1e-6,
                    "zero-input: A_inv[b={}, i={}, j={}] = {} != {}",
                    b,
                    i,
                    j,
                    v,
                    expected
                );
            }
        }
    }
    eprintln!("chunk_tri_solve_invert zero-input OK   A_inv == I");
}

#[test]
fn test_chunk_tri_solve_invert_near_singular_no_nan() {
    // Construct a strict-lower with magnitudes pushed up to where I+A is
    // close to singular (large negative eigenvalues from large lower-triangle
    // entries can drive (I+L)^-1 to large magnitudes). The forward-substitution
    // algorithm doesn't DIVIDE — it walks rows accumulating products — so the
    // result should always be finite. This catches a NaN-producing algorithm
    // change (e.g., accidental introduction of division).
    let elems = (B * T * H * BT) as usize;
    let mut a_strict = vec![0.0f32; elems];
    // Magnitude 0.5 along the strict-lower triangle → I+A still well-defined
    // but the (I+L)^-1 entries grow geometrically with row index.
    for b in 0..B as usize {
        for i in 0..BT as usize {
            for j in 0..i {
                let idx = (b * T as usize + i) * H as usize * BT as usize + j;
                a_strict[idx] = 0.5f32 * (((i + j) % 2) as f32 * 2.0 - 1.0); // ±0.5
            }
        }
    }

    let got = run_kernel(&a_strict);

    let mut max_abs = 0.0f32;
    for (i, &v) in got.iter().enumerate() {
        assert!(v.is_finite(), "near-singular: A_inv[{}] = {} (non-finite)", i, v);
        max_abs = max_abs.max(v.abs());
    }
    eprintln!(
        "chunk_tri_solve_invert near-singular OK   no NaN/Inf, max|A_inv| = {:.3e}",
        max_abs
    );
}