#![cfg(feature = "cpu")]
use rlx_ir::{DType, Graph, Shape};
use rlx_opt::vmap::vmap;
use rlx_runtime::{Device, Session};
fn f64s_to_bytes(xs: &[f64]) -> Vec<u8> {
let mut o = Vec::with_capacity(xs.len() * 8);
for x in xs {
o.extend_from_slice(&x.to_le_bytes());
}
o
}
fn bytes_to_f64s(b: &[u8]) -> Vec<f64> {
b.chunks_exact(8)
.map(|c| f64::from_le_bytes(c.try_into().unwrap()))
.collect()
}
#[test]
fn vmap_diag_extract_batches_correctly() {
rlx_linalg::register();
let n = 3;
let batch = 2;
let mut g = Graph::new("vm_de");
let a = g.input("a", Shape::new(&[n, n], DType::F64));
let d = rlx_linalg::diag_extract(&mut g, a);
g.set_outputs(vec![d]);
let vg = vmap(&g, &["a"], batch);
let mut c = Session::new(Device::Cpu).compile(vg);
let a_data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 0.0, 0.0, 0.0, 20.0, 0.0, 0.0, 0.0, 30.0,
];
let outs = c.run_typed(&[("a", &f64s_to_bytes(&a_data), DType::F64)]);
let d_got = bytes_to_f64s(&outs[0].0);
let want = [1.0, 5.0, 9.0, 10.0, 20.0, 30.0];
assert_eq!(d_got.len(), batch * n);
for i in 0..(batch * n) {
assert!(
(d_got[i] - want[i]).abs() < 1e-12,
"vmap[diag][{i}]={} want {}",
d_got[i],
want[i]
);
}
}
#[test]
fn vmap_diag_set_batches_correctly() {
rlx_linalg::register();
let n = 3;
let batch = 2;
let mut g = Graph::new("vm_ds");
let v = g.input("v", Shape::new(&[n], DType::F64));
let m = rlx_linalg::diag_set(&mut g, v);
g.set_outputs(vec![m]);
let vg = vmap(&g, &["v"], batch);
let mut c = Session::new(Device::Cpu).compile(vg);
let v_data = vec![
2.0, 3.0, 5.0, 7.0, 11.0, 13.0,
];
let outs = c.run_typed(&[("v", &f64s_to_bytes(&v_data), DType::F64)]);
let m_got = bytes_to_f64s(&outs[0].0);
assert_eq!(m_got.len(), batch * n * n);
let want = vec![
2.0, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 5.0, 7.0, 0.0, 0.0, 0.0, 11.0, 0.0, 0.0, 0.0, 13.0,
];
for i in 0..(batch * n * n) {
assert!(
(m_got[i] - want[i]).abs() < 1e-12,
"vmap[ds][{i}]={} want {}",
m_got[i],
want[i]
);
}
}
#[test]
fn vmap_trace_via_composition() {
rlx_linalg::register();
let n = 3;
let batch = 2;
let mut g = Graph::new("vm_tr");
let a = g.input("a", Shape::new(&[n, n], DType::F64));
let t = rlx_linalg::trace(&mut g, a);
g.set_outputs(vec![t]);
let vg = vmap(&g, &["a"], batch);
let mut c = Session::new(Device::Cpu).compile(vg);
let a_data = vec![
1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0,
1.0, 1.0, 1.0, 1.0, 4.0, 1.0, 1.0, 1.0, 6.0,
];
let outs = c.run_typed(&[("a", &f64s_to_bytes(&a_data), DType::F64)]);
let t_got = bytes_to_f64s(&outs[0].0);
assert!(
(t_got[0] - 15.0).abs() < 1e-12,
"vmap[trace][0]={}",
t_got[0]
);
assert!(
(t_got[1] - 11.0).abs() < 1e-12,
"vmap[trace][1]={}",
t_got[1]
);
}
#[test]
fn vmap_unrelated_op_without_rule_panics_clearly() {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
rlx_linalg::register();
let n = 3;
let mut g = Graph::new("vm_logdet_panic");
let a = g.input("a", Shape::new(&[n, n], DType::F64));
let l = rlx_linalg::logdet(&mut g, a);
g.set_outputs(vec![l]);
let _ = vmap(&g, &["a"], 2);
}));
assert!(result.is_err(), "expected vmap to panic on logdet");
}