use std::collections::HashMap;
use std::path::Path;
use deepmd::descriptor_dpa1::{build_dpa1_descriptor, DPA1Config, DPA1Inputs};
use deepmd::descriptor_se_t_tebd::TebdInputMode;
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;
#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }
#[derive(Deserialize)]
struct Config {
rcut_repinit: f64, rcut_smth_repinit: f64,
nsel: usize, ntypes: usize, tebd_dim: usize,
repinit_neuron: Vec<usize>, repinit_axis_neuron: usize,
nf: usize, nloc: usize, nall: usize,
}
#[derive(Deserialize)]
struct Fixture {
config: Config, params: HashMap<String, Tensor>,
inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}
fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
list.iter().find(|t| t.name == name).expect("missing")
}
#[test]
fn dpa2_repinit_matches_python() {
let path = Path::new(env!("CARGO_MANIFEST_DIR"))
.parent().unwrap().parent().unwrap()
.join("tests/fixtures/dpa2_small.json");
let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
let nf = fx.config.nf;
let nloc = fx.config.nloc;
let nsel = fx.config.nsel;
let _ = fx.config.nall;
let cfg = DPA1Config {
rcut: fx.config.rcut_repinit, rcut_smth: fx.config.rcut_smth_repinit,
sel: vec![nsel], ntypes: fx.config.ntypes,
neuron: fx.config.repinit_neuron.clone(),
axis_neuron: fx.config.repinit_axis_neuron,
tebd_dim: fx.config.tebd_dim,
tebd_input_mode: TebdInputMode::Concat,
resnet_dt: false,
attn: 16, attn_layer: 0, attn_dotr: false,
num_heads: 1, normalize: false,
scaling_factor: 1.0, ln_eps: 1e-5,
smooth: true, type_one_side: true,
activation_function: "tanh".into(),
concat_output_tebd: false,
};
let mut g = Graph::new("repinit_debug");
let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nsel, 4], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let nei_atype = g.input("nei_atype", Shape::new(&[nf, nloc, nsel], DType::I32));
let nlist_mask = g.input("nlist_mask", Shape::new(&[nf, nloc, nsel], DType::F32));
let sw = g.input("sw", Shape::new(&[nf, nloc, nsel], DType::F32));
let t_table = g.input("type_embed", Shape::new(&[cfg.ntypes + 1, cfg.tebd_dim], DType::F32));
let out = build_dpa1_descriptor(
&mut g, &cfg,
DPA1Inputs {
env_mat_raw: env_mat, atype_loc: atype, nei_atype,
nlist_mask, sw, type_embedding: t_table,
},
nf, nloc,
).expect("build");
use rlx_ir::infer::GraphExt;
let davg = g.param("descriptor.davg", Shape::new(&[cfg.ntypes, nsel, 4], DType::F32));
let dstd = g.param("descriptor.dstd", Shape::new(&[cfg.ntypes, nsel, 4], DType::F32));
let davg_g = g.gather_(davg, atype, 0);
let dstd_g = g.gather_(dstd, atype, 0);
let mut rr = g.sub(env_mat, davg_g);
rr = g.div(rr, dstd_g);
g.set_outputs(vec![out.descriptor, rr]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile(g);
for (name, t) in &fx.params {
if name.starts_with("repformers.") || name == "g1_shape_tranform.w" {
continue;
}
let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
compiled.set_param_typed(name, &bytes, DType::F32);
}
let em = find_named(&fx.inputs, "env_mat_raw_repinit");
let atype_ext = find_named(&fx.inputs, "atype_ext");
let nei = find_named(&fx.inputs, "nei_atype");
let nm = find_named(&fx.inputs, "nlist_mask_repinit");
let sw_in = find_named(&fx.inputs, "sw_repinit");
let tab = find_named(&fx.inputs, "type_embedding");
let atype_loc_f32: Vec<f32> = atype_ext.data.iter().take(nf * nloc).copied().collect();
let outputs = compiled.run(&[
("env_mat", em.data.as_slice()),
("atype", atype_loc_f32.as_slice()),
("nei_atype", nei.data.as_slice()),
("nlist_mask", nm.data.as_slice()),
("sw", sw_in.data.as_slice()),
("type_embed", tab.data.as_slice()),
]);
let expected_env = find_named(&fx.outputs, "repinit_normed_env");
let mut env_max_abs = 0f32;
for (a, e) in outputs[1].iter().zip(expected_env.data.iter()) {
env_max_abs = env_max_abs.max((a - e).abs());
}
eprintln!("normed_env: max_abs={env_max_abs:.3e}");
let expected = find_named(&fx.outputs, "repinit_g1");
let mut max_abs = 0f32;
for (i, (a, e)) in outputs[0].iter().zip(expected.data.iter()).enumerate() {
let d = (a - e).abs();
if d > max_abs { max_abs = d; }
if d > 1e-4 && i < 8 {
eprintln!(" [{i}]: actual={a:.6e} expected={e:.6e} |Δ|={d:.3e}");
}
}
eprintln!("repinit_g1: max_abs={max_abs:.3e}");
}