use std::collections::HashMap;
use std::path::Path;
use deepmd::descriptor_repflows::{build_repflows, RepflowsConfig, RepflowsInputs};
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;
#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }
#[derive(Deserialize)]
struct Config {
rcut: f64, rcut_smth: f64,
a_rcut: f64, a_rcut_smth: f64,
nsel: usize, a_sel: usize, ntypes: usize,
nlayers: usize, n_dim: usize, e_dim: usize, a_dim: usize,
activation_function: String, smooth: bool,
nf: usize, nloc: usize, nall: usize, #[allow(dead_code)] nnei: usize,
}
#[derive(Deserialize)]
struct Fixture {
config: Config, params: HashMap<String, Tensor>,
inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}
fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
list.iter().find(|t| t.name == name).expect("missing tensor")
}
fn assert_allclose(actual: &[f32], expected: &[f32], atol: f32, rtol: f32, label: &str) {
assert_eq!(actual.len(), expected.len(), "{label}: length mismatch");
let mut max_abs = 0f32;
for (i, (a, b)) in actual.iter().zip(expected.iter()).enumerate() {
let abs = (a - b).abs();
let tol = atol + rtol * b.abs();
max_abs = max_abs.max(abs);
assert!(abs <= tol, "{label}[{i}]: actual={a:.6e} expected={b:.6e} |Δ|={abs:.3e}");
}
eprintln!("{label}: max_abs={max_abs:.3e}");
}
#[test]
fn repflows_nlayers0_matches_python_reference() {
let path = Path::new(env!("CARGO_MANIFEST_DIR"))
.parent().unwrap().parent().unwrap()
.join("tests/fixtures/repflows_nlayers0.json");
let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
let cfg = RepflowsConfig {
rcut: fx.config.rcut, rcut_smth: fx.config.rcut_smth,
nsel: fx.config.nsel,
a_rcut: fx.config.a_rcut, a_rcut_smth: fx.config.a_rcut_smth,
a_sel: fx.config.a_sel, ntypes: fx.config.ntypes,
nlayers: fx.config.nlayers,
n_dim: fx.config.n_dim, e_dim: fx.config.e_dim, a_dim: fx.config.a_dim,
activation_function: fx.config.activation_function.clone(),
smooth: fx.config.smooth, ln_eps: 1e-5,
axis_neuron: 2, update_angle: true,
};
let nf = fx.config.nf;
let nloc = fx.config.nloc;
let nall = fx.config.nall;
let nnei = cfg.nsel;
let a_sel = cfg.a_sel;
let mut g = Graph::new("repflows_parity");
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let nlist = g.input("nlist", Shape::new(&[nf, nloc, nnei], DType::I32));
let nlist_mask = g.input("nlist_mask", Shape::new(&[nf, nloc, nnei], DType::F32));
let sw = g.input("sw", Shape::new(&[nf, nloc, nnei], DType::F32));
let angle_input = g.input("angle_input", Shape::new(&[nf, nloc, a_sel, a_sel, 1], DType::F32));
let angle_mask = g.input("angle_mask", Shape::new(&[nf, nloc, a_sel, a_sel], DType::F32));
let node_ext = g.input("node_ext", Shape::new(&[nf, nall, cfg.n_dim], DType::F32));
let out = build_repflows(
&mut g, &cfg,
RepflowsInputs {
node_ebd_ext: node_ext, env_mat_raw: env_mat,
nlist, nlist_mask, sw,
angle_input, angle_mask,
a_sw: None,
atype_loc: atype,
},
nf, nloc, nall,
).expect("build");
g.set_outputs(vec![out.node_ebd, out.edge_ebd, out.h2]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile(g);
for (name, t) in &fx.params {
let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
compiled.set_param_typed(name, &bytes, DType::F32);
}
let em_raw = find_named(&fx.inputs, "env_mat_raw");
let atype_ext = find_named(&fx.inputs, "atype_ext");
let nlist_in = find_named(&fx.inputs, "nlist");
let nlist_mask_in = find_named(&fx.inputs, "nlist_mask");
let sw_in = find_named(&fx.inputs, "sw");
let node_ext_in = find_named(&fx.inputs, "node_ebd_ext");
let atype_loc_f32: Vec<f32> = atype_ext.data.iter().take(nf * nloc).copied().collect();
let zero_ai = vec![0f32; nf * nloc * a_sel * a_sel];
let zero_am = vec![0f32; nf * nloc * a_sel * a_sel];
let outputs = compiled.run(&[
("env_mat", em_raw.data.as_slice()),
("atype", atype_loc_f32.as_slice()),
("nlist", nlist_in.data.as_slice()),
("nlist_mask", nlist_mask_in.data.as_slice()),
("sw", sw_in.data.as_slice()),
("angle_input", zero_ai.as_slice()),
("angle_mask", zero_am.as_slice()),
("node_ext", node_ext_in.data.as_slice()),
]);
let exp_node = find_named(&fx.outputs, "node_ebd");
let exp_edge = find_named(&fx.outputs, "edge_ebd");
let exp_h2 = find_named(&fx.outputs, "h2");
assert_allclose(&outputs[0], &exp_node.data, 1e-5, 1e-4, "node_ebd");
assert_allclose(&outputs[1], &exp_edge.data, 1e-5, 1e-4, "edge_ebd");
assert_allclose(&outputs[2], &exp_h2.data, 1e-5, 1e-4, "h2");
}