deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// Parity test for DipoleFitting (mixed_types=True).

use std::collections::HashMap;
use std::path::Path;

use deepmd::config::EnerFittingConfig;
use deepmd::fitting::{build_dipole_fitting, DipoleFittingConfig, FittingInputs};
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;

#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }

#[derive(Deserialize)]
struct Config {
    ntypes: usize, dim_descrpt: usize, embedding_width: usize,
    neuron: Vec<usize>, resnet_dt: bool,
    activation_function: String, mixed_types: bool,
    nf: usize, nloc: usize,
}

#[derive(Deserialize)]
struct Fixture {
    config: Config, params: HashMap<String, Tensor>,
    inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}

fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
    list.iter().find(|t| t.name == name).expect("missing tensor")
}

#[test]
fn dipole_fitting_matches_python_reference() {
    let path = Path::new(env!("CARGO_MANIFEST_DIR"))
        .parent().unwrap().parent().unwrap()
        .join("tests/fixtures/dipole_fitting_small.json");
    let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();

    let base = EnerFittingConfig {
        ntypes: fx.config.ntypes,
        dim_descrpt: fx.config.dim_descrpt,
        dim_out: fx.config.embedding_width,
        neuron: fx.config.neuron.clone(),
        resnet_dt: fx.config.resnet_dt,
        numb_fparam: 0, numb_aparam: 0, dim_case_embd: 0,
        activation_function: fx.config.activation_function.clone(),
        mixed_types: fx.config.mixed_types,
    };
    let cfg = DipoleFittingConfig { base, embedding_width: fx.config.embedding_width };

    let nf = fx.config.nf;
    let nloc = fx.config.nloc;
    let mw = fx.config.embedding_width;

    let mut g = Graph::new("dipole_fitting_parity");
    let descriptor = g.input("descriptor", Shape::new(&[nf, nloc, fx.config.dim_descrpt], DType::F32));
    let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
    let gr = g.input("gr", Shape::new(&[nf, nloc, mw, 3], DType::F32));

    let out = build_dipole_fitting(
        &mut g, &cfg,
        FittingInputs {
            descriptor, atype_loc: atype,
            fparam: None, aparam: None, case_embd: None,
            exclude_mask: None, gr: Some(gr),
        },
        nf, nloc,
    ).expect("build");
    g.set_outputs(vec![out.dipole]);

    let session = Session::new(Device::Cpu);
    let mut compiled = session.compile(g);
    for (name, t) in &fx.params {
        let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
        compiled.set_param_typed(name, &bytes, DType::F32);
    }

    let desc_in = find_named(&fx.inputs, "descriptor");
    let atype_in = find_named(&fx.inputs, "atype");
    let gr_in = find_named(&fx.inputs, "gr");
    let atype_f32: Vec<f32> = atype_in.data.iter().copied().collect();

    let outputs = compiled.run(&[
        ("descriptor", desc_in.data.as_slice()),
        ("atype", atype_f32.as_slice()),
        ("gr", gr_in.data.as_slice()),
    ]);
    let exp = find_named(&fx.outputs, "dipole");
    let mut max_abs = 0f32;
    for (a, b) in outputs[0].iter().zip(exp.data.iter()) {
        max_abs = max_abs.max((a - b).abs());
    }
    eprintln!("dipole: max_abs={max_abs:.3e}");
    assert!(max_abs <= 1e-5, "dipole drift {max_abs:.3e}");
}