#![cfg(all(feature = "cpu", feature = "mlx", target_os = "macos"))]
use rlx_ir::{DType, Graph, NodeId, Op, Shape};
use rlx_runtime::{Device, Session};
use rlx_sparse::SparseTensor;
fn bytes_to_f64s(bytes: &[u8]) -> Vec<f64> {
bytes
.chunks_exact(8)
.map(|c| f64::from_le_bytes(c.try_into().unwrap()))
.collect()
}
fn const_i32(g: &mut Graph, xs: &[i32]) -> NodeId {
let mut bytes = Vec::with_capacity(xs.len() * 4);
for &x in xs {
bytes.extend_from_slice(&x.to_le_bytes());
}
g.add_node(
Op::Constant { data: bytes },
vec![],
Shape::new(&[xs.len()], DType::I32),
)
}
fn const_f64(g: &mut Graph, xs: &[f64]) -> NodeId {
let mut bytes = Vec::with_capacity(xs.len() * 8);
for &x in xs {
bytes.extend_from_slice(&x.to_le_bytes());
}
g.add_node(
Op::Constant { data: bytes },
vec![],
Shape::new(&[xs.len()], DType::F64),
)
}
fn build_nonsym_4() -> (Vec<f64>, Vec<i32>, Vec<i32>) {
let values = vec![5.0, -1.0, -2.0, 4.0, -1.0, -2.0, 4.0, -1.0, -2.0, 3.0];
let col_idx = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3];
let row_ptr = vec![0, 2, 5, 8, 10];
(values, col_idx, row_ptr)
}
fn transpose_csr(
values: &[f64],
col_idx: &[i32],
row_ptr: &[i32],
n: usize,
) -> (Vec<f64>, Vec<i32>, Vec<i32>) {
let nnz = values.len();
let mut t_count = vec![0i32; n];
for &c in col_idx {
t_count[c as usize] += 1;
}
let mut t_row_ptr = vec![0i32; n + 1];
for r in 0..n {
t_row_ptr[r + 1] = t_row_ptr[r] + t_count[r];
}
let mut t_col_idx = vec![0i32; nnz];
let mut t_values = vec![0f64; nnz];
let mut cursor = t_row_ptr.clone();
for r in 0..n {
for k in row_ptr[r] as usize..row_ptr[r + 1] as usize {
let c = col_idx[k] as usize;
let pos = cursor[c] as usize;
t_col_idx[pos] = r as i32;
t_values[pos] = values[k];
cursor[c] += 1;
}
}
(t_values, t_col_idx, t_row_ptr)
}
#[test]
fn lu_general_runs_on_mlx_and_matches_cpu() {
rlx_sparse::register();
let (values, col_idx, row_ptr) = build_nonsym_4();
let n = 4;
let (vt, cit, rpt) = transpose_csr(&values, &col_idx, &row_ptr, n);
let b_data = [1.0_f64, 2.5, -1.0, 3.0];
let build = || {
let mut g = Graph::new("lu_general");
let v = const_f64(&mut g, &values);
let ci = const_i32(&mut g, &col_idx);
let rp = const_i32(&mut g, &row_ptr);
let vt_n = const_f64(&mut g, &vt);
let cit_n = const_i32(&mut g, &cit);
let rpt_n = const_i32(&mut g, &rpt);
let b = const_f64(&mut g, &b_data);
let a = SparseTensor::from_csr(v, ci, rp, n, n);
let at = SparseTensor::from_csr(vt_n, cit_n, rpt_n, n, n);
let x = a.solve_general(&mut g, b, &at);
g.set_outputs(vec![x]);
g
};
let mut cpu = Session::new(Device::Cpu).compile(build());
let mut mlx = Session::new(Device::Mlx).compile(build());
let cpu_x = bytes_to_f64s(&cpu.run_typed(&[])[0].0);
let mlx_x = bytes_to_f64s(&mlx.run_typed(&[])[0].0);
for i in 0..n {
assert_eq!(
cpu_x[i], mlx_x[i],
"lu_general[{i}]: cpu={} mlx={}",
cpu_x[i], mlx_x[i]
);
}
}
#[test]
fn gmres_runs_on_mlx_and_matches_cpu() {
rlx_sparse::register();
let (values, col_idx, row_ptr) = build_nonsym_4();
let n = 4;
let (vt, cit, rpt) = transpose_csr(&values, &col_idx, &row_ptr, n);
let b_data = [1.0_f64, 2.5, -1.0, 3.0];
let build = || {
let mut g = Graph::new("gmres");
let v = const_f64(&mut g, &values);
let ci = const_i32(&mut g, &col_idx);
let rp = const_i32(&mut g, &row_ptr);
let vt_n = const_f64(&mut g, &vt);
let cit_n = const_i32(&mut g, &cit);
let rpt_n = const_i32(&mut g, &rpt);
let b = const_f64(&mut g, &b_data);
let a = SparseTensor::from_csr(v, ci, rp, n, n);
let at = SparseTensor::from_csr(vt_n, cit_n, rpt_n, n, n);
let x = a.gmres_solve(&mut g, b, 100, 1e-12, &at);
g.set_outputs(vec![x]);
g
};
let mut cpu = Session::new(Device::Cpu).compile(build());
let mut mlx = Session::new(Device::Mlx).compile(build());
let cpu_x = bytes_to_f64s(&cpu.run_typed(&[])[0].0);
let mlx_x = bytes_to_f64s(&mlx.run_typed(&[])[0].0);
for i in 0..n {
assert_eq!(
cpu_x[i], mlx_x[i],
"gmres[{i}]: cpu={} mlx={}",
cpu_x[i], mlx_x[i]
);
}
}
#[test]
fn values_grad_runs_on_mlx_and_matches_cpu() {
rlx_sparse::register();
let col_idx = vec![0, 1, 0, 1, 2, 1, 2, 3, 2, 3];
let row_ptr = vec![0, 2, 5, 8, 10];
let u = [0.5_f64, -1.5, 2.0, -0.25];
let v = [1.0_f64, 3.0, -0.5, 2.5];
let n = 4;
let nnz = col_idx.len();
let build = || {
let mut g = Graph::new("values_grad_direct");
let ci = const_i32(&mut g, &col_idx);
let rp = const_i32(&mut g, &row_ptr);
let u_n = const_f64(&mut g, &u);
let v_n = const_f64(&mut g, &v);
let out = g.custom_op(
rlx_sparse::SPARSE_VALUES_GRAD,
Vec::new(),
vec![ci, rp, u_n, v_n],
);
g.set_outputs(vec![out]);
g
};
let mut cpu = Session::new(Device::Cpu).compile(build());
let mut mlx = Session::new(Device::Mlx).compile(build());
let cpu_out = bytes_to_f64s(&cpu.run_typed(&[])[0].0);
let mlx_out = bytes_to_f64s(&mlx.run_typed(&[])[0].0);
assert_eq!(cpu_out.len(), nnz);
assert_eq!(mlx_out.len(), nnz);
for k in 0..nnz {
assert_eq!(
cpu_out[k], mlx_out[k],
"values_grad[{k}]: cpu={} mlx={}",
cpu_out[k], mlx_out[k]
);
}
let row_of: Vec<usize> = {
let mut r_of_k = vec![0usize; nnz];
for r in 0..n {
for k in row_ptr[r] as usize..row_ptr[r + 1] as usize {
r_of_k[k] = r;
}
}
r_of_k
};
for k in 0..nnz {
let want = u[row_of[k]] * v[col_idx[k] as usize];
assert!(
(cpu_out[k] - want).abs() < 1e-12,
"values_grad[{k}]: got {}, expected {}",
cpu_out[k],
want
);
}
}