use std::collections::HashMap;
use std::path::Path;
use deepmd::descriptor_repformers::{build_repformers, RepformersConfig, RepformersInputs};
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;
#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }
#[derive(Deserialize)]
struct Config {
rcut: f64, rcut_smth: f64, sel: Vec<usize>, ntypes: usize,
nlayers: usize, g1_dim: usize, g2_dim: usize, axis_neuron: usize,
activation_function: String, smooth: bool, use_sqrt_nnei: bool,
nf: usize, nloc: usize, nall: usize,
}
#[derive(Deserialize)]
struct Fixture {
config: Config, params: HashMap<String, Tensor>,
inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}
fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
list.iter().find(|t| t.name == name).expect("missing tensor")
}
fn assert_allclose(actual: &[f32], expected: &[f32], atol: f32, rtol: f32, label: &str) {
assert_eq!(actual.len(), expected.len(), "{label}: length mismatch");
let mut max_abs = 0f32;
for (i, (a, b)) in actual.iter().zip(expected.iter()).enumerate() {
let abs = (a - b).abs();
let tol = atol + rtol * b.abs();
max_abs = max_abs.max(abs);
assert!(abs <= tol, "{label}[{i}]: actual={a:.6e} expected={b:.6e} |Δ|={abs:.3e}");
}
eprintln!("{label}: max_abs={max_abs:.3e}");
}
#[test]
fn repformers_nlayers0_matches_python_reference() {
let path = Path::new(env!("CARGO_MANIFEST_DIR"))
.parent().unwrap().parent().unwrap()
.join("tests/fixtures/repformers_nlayers0.json");
let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
let cfg = RepformersConfig {
rcut: fx.config.rcut, rcut_smth: fx.config.rcut_smth,
sel: fx.config.sel[0], ntypes: fx.config.ntypes,
nlayers: fx.config.nlayers,
g1_dim: fx.config.g1_dim, g2_dim: fx.config.g2_dim,
axis_neuron: fx.config.axis_neuron,
update_g1_has_conv: false, update_g1_has_drrd: true,
update_g1_has_grrg: true, update_g1_has_attn: false,
update_g2_has_g1g1: false, update_g2_has_attn: false,
update_h2: false,
attn1_hidden: 8, attn1_nhead: 1,
attn2_hidden: 4, attn2_nhead: 1,
attn2_has_gate: false,
activation_function: fx.config.activation_function.clone(),
update_style: "res_avg".into(),
smooth: fx.config.smooth, use_sqrt_nnei: fx.config.use_sqrt_nnei,
g1_out_conv: true, g1_out_mlp: true,
ln_eps: 1e-5, epsilon: 1e-4,
};
let nf = fx.config.nf;
let nloc = fx.config.nloc;
let nall = fx.config.nall;
let nnei = cfg.sel;
let mut g = Graph::new("repformers_parity");
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let nlist = g.input("nlist", Shape::new(&[nf, nloc, nnei], DType::I32));
let nlist_mask = g.input("nlist_mask", Shape::new(&[nf, nloc, nnei], DType::F32));
let sw = g.input("sw", Shape::new(&[nf, nloc, nnei], DType::F32));
let g1_ext = g.input("g1_ext", Shape::new(&[nf, nall, cfg.g1_dim], DType::F32));
let out = build_repformers(
&mut g, &cfg,
RepformersInputs {
g1_ext, env_mat_raw: env_mat, nlist, nlist_mask, sw, atype_loc: atype,
mapping: None,
},
nf, nloc, nall,
).expect("build");
g.set_outputs(vec![out.g1, out.g2, out.h2, out.rot_mat]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile(g);
for (name, t) in &fx.params {
let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
compiled.set_param_typed(name, &bytes, DType::F32);
}
let em_raw = find_named(&fx.inputs, "env_mat_raw");
let atype_ext = find_named(&fx.inputs, "atype_ext");
let nlist_in = find_named(&fx.inputs, "nlist");
let nlist_mask_in = find_named(&fx.inputs, "nlist_mask");
let sw_in = find_named(&fx.inputs, "sw");
let g1_ext_in = find_named(&fx.inputs, "atype_embd_ext");
let atype_loc_f32: Vec<f32> = atype_ext.data.iter().take(nf * nloc).copied().collect();
let nlist_f32: Vec<f32> = nlist_in.data.clone();
let outputs = compiled.run(&[
("env_mat", em_raw.data.as_slice()),
("atype", atype_loc_f32.as_slice()),
("nlist", nlist_f32.as_slice()),
("nlist_mask", nlist_mask_in.data.as_slice()),
("sw", sw_in.data.as_slice()),
("g1_ext", g1_ext_in.data.as_slice()),
]);
let exp_g1 = find_named(&fx.outputs, "g1");
let exp_g2 = find_named(&fx.outputs, "g2");
let exp_h2 = find_named(&fx.outputs, "h2");
assert_allclose(&outputs[0], &exp_g1.data, 1e-5, 1e-4, "g1");
assert_allclose(&outputs[1], &exp_g2.data, 1e-5, 1e-4, "g2");
assert_allclose(&outputs[2], &exp_h2.data, 1e-5, 1e-4, "h2");
}