use std::collections::HashMap;
use std::path::Path;
use deepmd::config::EnerFittingConfig;
use deepmd::fitting::{build_invar_fitting, FittingInputs};
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;
#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }
#[derive(Deserialize)]
struct Config {
#[allow(dead_code)] var_name: String,
ntypes: usize, dim_descrpt: usize, dim_out: usize,
neuron: Vec<usize>, resnet_dt: bool,
activation_function: String, mixed_types: bool,
nf: usize, nloc: usize,
numb_fparam: usize, numb_aparam: usize, dim_case_embd: usize,
}
#[derive(Deserialize)]
struct Fixture {
config: Config, params: HashMap<String, Tensor>,
inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}
fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
list.iter().find(|t| t.name == name).expect("missing tensor")
}
fn run_invar_fitting_fixture(fixture_name: &str) {
let path = Path::new(env!("CARGO_MANIFEST_DIR"))
.parent().unwrap().parent().unwrap()
.join("tests/fixtures").join(fixture_name);
let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
let cfg = EnerFittingConfig {
ntypes: fx.config.ntypes,
dim_descrpt: fx.config.dim_descrpt,
dim_out: fx.config.dim_out,
neuron: fx.config.neuron.clone(),
resnet_dt: fx.config.resnet_dt,
numb_fparam: fx.config.numb_fparam,
numb_aparam: fx.config.numb_aparam,
dim_case_embd: fx.config.dim_case_embd,
activation_function: fx.config.activation_function.clone(),
mixed_types: fx.config.mixed_types,
};
let nf = fx.config.nf;
let nloc = fx.config.nloc;
let dim_descrpt = cfg.dim_descrpt;
let mut g = Graph::new("invar_fitting_parity");
let descriptor = g.input("descriptor", Shape::new(&[nf, nloc, dim_descrpt], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let out = build_invar_fitting(
&mut g, &cfg,
FittingInputs {
descriptor, atype_loc: atype,
fparam: None, aparam: None, case_embd: None,
exclude_mask: None, gr: None,
},
nf, nloc,
"fitting", "fitting.bias_atom_e",
).expect("build");
g.set_outputs(vec![out.atom_output]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile(g);
for (name, t) in &fx.params {
let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
compiled.set_param_typed(name, &bytes, DType::F32);
}
let desc_in = find_named(&fx.inputs, "descriptor");
let atype_in = find_named(&fx.inputs, "atype");
let atype_f32: Vec<f32> = atype_in.data.iter().copied().collect();
let outputs = compiled.run(&[
("descriptor", desc_in.data.as_slice()),
("atype", atype_f32.as_slice()),
]);
let exp = find_named(&fx.outputs, "energy");
let mut max_abs = 0f32;
for (i, (a, b)) in outputs[0].iter().zip(exp.data.iter()).enumerate() {
let d = (a - b).abs();
if d > max_abs { max_abs = d; }
assert!(d <= 1e-5 + 1e-4 * b.abs(), "energy[{i}]: actual={a:.6e} expected={b:.6e}");
}
eprintln!("energy: max_abs={max_abs:.3e}");
}
#[test]
fn invar_fitting_matches_python_reference() {
run_invar_fitting_fixture("invar_fitting_small.json");
}
#[test]
fn invar_fitting_pertype_matches_python_reference() {
run_invar_fitting_fixture("invar_fitting_pertype.json");
}
#[test]
fn invar_fitting_dos_matches_python_reference() {
run_invar_fitting_fixture("invar_fitting_dos.json");
}