#![cfg(rlx_mlx_host)]
use rlx_ir::{DType, Graph, Shape};
use rlx_mlx::{MlxExecutable, MlxMode};
fn close(a: &[f32], b: &[f32], tol: f32) -> bool {
a.len() == b.len() && a.iter().zip(b).all(|(x, y)| (x - y).abs() <= tol)
}
#[test]
fn dense_solve_3x3_matches_reference() {
let mut g = Graph::new("dense_solve_3x3");
let a = g.input("a", Shape::new(&[3, 3], DType::F32));
let b = g.input("b", Shape::new(&[3], DType::F32));
let x = g.dense_solve(a, b, Shape::new(&[3], DType::F32));
g.set_outputs(vec![x]);
let mut exe = MlxExecutable::compile_with_mode(g, MlxMode::Lazy);
let a_data: Vec<f32> = vec![4.0, -1.0, 0.0, -1.0, 4.0, -1.0, 0.0, -1.0, 4.0];
let b_data: Vec<f32> = vec![2.0, 6.0, 2.0];
let outs = exe.run(&[("a", &a_data), ("b", &b_data)]);
let expected = [1.0_f32, 2.0, 1.0];
assert_eq!(outs.len(), 1);
assert!(
close(&outs[0], &expected, 1e-5),
"got {:?}, expected {:?}",
outs[0],
expected,
);
}
#[test]
fn batched_dense_solve_two_systems() {
let mut g = Graph::new("batched_dense_solve_2x2");
let a = g.input("a", Shape::new(&[2, 2, 2], DType::F32));
let b = g.input("b", Shape::new(&[2, 2], DType::F32));
let x = g.batched_dense_solve(a, b, Shape::new(&[2, 2], DType::F32));
g.set_outputs(vec![x]);
let mut exe = MlxExecutable::compile_with_mode(g, MlxMode::Lazy);
let a_data: Vec<f32> = vec![
2.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0, 1.0, ];
let b_data: Vec<f32> = vec![
4.0, 6.0, 3.0, 1.0, ];
let outs = exe.run(&[("a", &a_data), ("b", &b_data)]);
let expected = [2.0, 3.0, 2.0, 1.0];
assert!(
close(&outs[0], &expected, 1e-5),
"got {:?}, expected {:?}",
outs[0],
expected,
);
}