use std::collections::HashMap;
use std::path::Path;
use deepmd::config::EnerFittingConfig;
use deepmd::fitting::{build_dipole_fitting, DipoleFittingConfig, FittingInputs};
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;
#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }
#[derive(Deserialize)]
struct Config {
ntypes: usize, dim_descrpt: usize, embedding_width: usize,
neuron: Vec<usize>, resnet_dt: bool,
activation_function: String, mixed_types: bool,
nf: usize, nloc: usize,
}
#[derive(Deserialize)]
struct Fixture {
config: Config, params: HashMap<String, Tensor>,
inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}
fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
list.iter().find(|t| t.name == name).expect("missing tensor")
}
#[test]
fn dipole_fitting_matches_python_reference() {
let path = Path::new(env!("CARGO_MANIFEST_DIR"))
.parent().unwrap().parent().unwrap()
.join("tests/fixtures/dipole_fitting_small.json");
let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();
let base = EnerFittingConfig {
ntypes: fx.config.ntypes,
dim_descrpt: fx.config.dim_descrpt,
dim_out: fx.config.embedding_width,
neuron: fx.config.neuron.clone(),
resnet_dt: fx.config.resnet_dt,
numb_fparam: 0, numb_aparam: 0, dim_case_embd: 0,
activation_function: fx.config.activation_function.clone(),
mixed_types: fx.config.mixed_types,
};
let cfg = DipoleFittingConfig { base, embedding_width: fx.config.embedding_width };
let nf = fx.config.nf;
let nloc = fx.config.nloc;
let mw = fx.config.embedding_width;
let mut g = Graph::new("dipole_fitting_parity");
let descriptor = g.input("descriptor", Shape::new(&[nf, nloc, fx.config.dim_descrpt], DType::F32));
let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
let gr = g.input("gr", Shape::new(&[nf, nloc, mw, 3], DType::F32));
let out = build_dipole_fitting(
&mut g, &cfg,
FittingInputs {
descriptor, atype_loc: atype,
fparam: None, aparam: None, case_embd: None,
exclude_mask: None, gr: Some(gr),
},
nf, nloc,
).expect("build");
g.set_outputs(vec![out.dipole]);
let session = Session::new(Device::Cpu);
let mut compiled = session.compile(g);
for (name, t) in &fx.params {
let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
compiled.set_param_typed(name, &bytes, DType::F32);
}
let desc_in = find_named(&fx.inputs, "descriptor");
let atype_in = find_named(&fx.inputs, "atype");
let gr_in = find_named(&fx.inputs, "gr");
let atype_f32: Vec<f32> = atype_in.data.iter().copied().collect();
let outputs = compiled.run(&[
("descriptor", desc_in.data.as_slice()),
("atype", atype_f32.as_slice()),
("gr", gr_in.data.as_slice()),
]);
let exp = find_named(&fx.outputs, "dipole");
let mut max_abs = 0f32;
for (a, b) in outputs[0].iter().zip(exp.data.iter()) {
max_abs = max_abs.max((a - b).abs());
}
eprintln!("dipole: max_abs={max_abs:.3e}");
assert!(max_abs <= 1e-5, "dipole drift {max_abs:.3e}");
}