deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// Parity test for InvarFitting (energy fitting net).

use std::collections::HashMap;
use std::path::Path;

use deepmd::config::EnerFittingConfig;
use deepmd::fitting::{build_invar_fitting, FittingInputs};
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;

#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }

#[derive(Deserialize)]
struct Config {
    #[allow(dead_code)] var_name: String,
    ntypes: usize, dim_descrpt: usize, dim_out: usize,
    neuron: Vec<usize>, resnet_dt: bool,
    activation_function: String, mixed_types: bool,
    nf: usize, nloc: usize,
    numb_fparam: usize, numb_aparam: usize, dim_case_embd: usize,
}

#[derive(Deserialize)]
struct Fixture {
    config: Config, params: HashMap<String, Tensor>,
    inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}

fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
    list.iter().find(|t| t.name == name).expect("missing tensor")
}

fn run_invar_fitting_fixture(fixture_name: &str) {
    let path = Path::new(env!("CARGO_MANIFEST_DIR"))
        .parent().unwrap().parent().unwrap()
        .join("tests/fixtures").join(fixture_name);
    let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();

    let cfg = EnerFittingConfig {
        ntypes: fx.config.ntypes,
        dim_descrpt: fx.config.dim_descrpt,
        dim_out: fx.config.dim_out,
        neuron: fx.config.neuron.clone(),
        resnet_dt: fx.config.resnet_dt,
        numb_fparam: fx.config.numb_fparam,
        numb_aparam: fx.config.numb_aparam,
        dim_case_embd: fx.config.dim_case_embd,
        activation_function: fx.config.activation_function.clone(),
        mixed_types: fx.config.mixed_types,
    };

    let nf = fx.config.nf;
    let nloc = fx.config.nloc;
    let dim_descrpt = cfg.dim_descrpt;

    let mut g = Graph::new("invar_fitting_parity");
    let descriptor = g.input("descriptor", Shape::new(&[nf, nloc, dim_descrpt], DType::F32));
    let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));

    let out = build_invar_fitting(
        &mut g, &cfg,
        FittingInputs {
            descriptor, atype_loc: atype,
            fparam: None, aparam: None, case_embd: None,
            exclude_mask: None, gr: None,
        },
        nf, nloc,
        "fitting", "fitting.bias_atom_e",
    ).expect("build");
    g.set_outputs(vec![out.atom_output]);

    let session = Session::new(Device::Cpu);
    let mut compiled = session.compile(g);
    for (name, t) in &fx.params {
        let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
        compiled.set_param_typed(name, &bytes, DType::F32);
    }

    let desc_in = find_named(&fx.inputs, "descriptor");
    let atype_in = find_named(&fx.inputs, "atype");
    let atype_f32: Vec<f32> = atype_in.data.iter().copied().collect();

    let outputs = compiled.run(&[
        ("descriptor", desc_in.data.as_slice()),
        ("atype", atype_f32.as_slice()),
    ]);
    let exp = find_named(&fx.outputs, "energy");
    let mut max_abs = 0f32;
    for (i, (a, b)) in outputs[0].iter().zip(exp.data.iter()).enumerate() {
        let d = (a - b).abs();
        if d > max_abs { max_abs = d; }
        assert!(d <= 1e-5 + 1e-4 * b.abs(), "energy[{i}]: actual={a:.6e} expected={b:.6e}");
    }
    eprintln!("energy: max_abs={max_abs:.3e}");
}

#[test]
fn invar_fitting_matches_python_reference() {
    run_invar_fitting_fixture("invar_fitting_small.json");
}

#[test]
fn invar_fitting_pertype_matches_python_reference() {
    run_invar_fitting_fixture("invar_fitting_pertype.json");
}

#[test]
fn invar_fitting_dos_matches_python_reference() {
    run_invar_fitting_fixture("invar_fitting_dos.json");
}