deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// Parity test for DescrptBlockRepflows at nlayers=1 (DPA-3 RepFlowLayer).
// Canonical config: update_style="res_avg", update_angle=True,
// a_compress_rate=0, n_multi_edge_message=1, smooth_edge_update=True.

use std::collections::HashMap;
use std::path::Path;

use deepmd::descriptor_repflows::{build_repflows, RepflowsConfig, RepflowsInputs};
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;

#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }

#[derive(Deserialize)]
struct Config {
    rcut: f64, rcut_smth: f64,
    a_rcut: f64, a_rcut_smth: f64,
    nsel: usize, a_sel: usize, ntypes: usize,
    nlayers: usize, n_dim: usize, e_dim: usize, a_dim: usize,
    axis_neuron: usize,
    update_angle: bool,
    #[allow(dead_code)] update_style: String,
    activation_function: String, smooth: bool,
    nf: usize, nloc: usize, nall: usize, #[allow(dead_code)] nnei: usize,
}

#[derive(Deserialize)]
struct Fixture {
    config: Config, params: HashMap<String, Tensor>,
    inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}

fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
    list.iter().find(|t| t.name == name).expect("missing tensor")
}

fn report_drift(actual: &[f32], expected: &[f32], label: &str) -> f32 {
    assert_eq!(actual.len(), expected.len(), "{label}: len {} vs {}", actual.len(), expected.len());
    let mut max_abs = 0f32;
    let mut max_at = 0usize;
    for (i, (a, b)) in actual.iter().zip(expected.iter()).enumerate() {
        let d = (a - b).abs();
        if d > max_abs { max_abs = d; max_at = i; }
    }
    eprintln!("{label}: max_abs={:.3e} at [{}] (actual={:.6e} expected={:.6e})",
              max_abs, max_at, actual[max_at], expected[max_at]);
    max_abs
}

fn run_for_fixture(fixture_name: &str, tol: f32) {
    let path = Path::new(env!("CARGO_MANIFEST_DIR"))
        .parent().unwrap().parent().unwrap()
        .join("tests/fixtures").join(fixture_name);
    let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();

    let cfg = RepflowsConfig {
        rcut: fx.config.rcut, rcut_smth: fx.config.rcut_smth,
        nsel: fx.config.nsel,
        a_rcut: fx.config.a_rcut, a_rcut_smth: fx.config.a_rcut_smth,
        a_sel: fx.config.a_sel, ntypes: fx.config.ntypes,
        nlayers: fx.config.nlayers,
        n_dim: fx.config.n_dim, e_dim: fx.config.e_dim, a_dim: fx.config.a_dim,
        activation_function: fx.config.activation_function.clone(),
        smooth: fx.config.smooth, ln_eps: 1e-5,
        axis_neuron: fx.config.axis_neuron,
        update_angle: fx.config.update_angle,
    };
    let nf = fx.config.nf;
    let nloc = fx.config.nloc;
    let nall = fx.config.nall;
    let nnei = cfg.nsel;
    let a_sel = cfg.a_sel;

    let mut g = Graph::new("repflows_n1");
    let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
    let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
    let nlist = g.input("nlist", Shape::new(&[nf, nloc, nnei], DType::I32));
    let nlist_mask = g.input("nlist_mask", Shape::new(&[nf, nloc, nnei], DType::F32));
    let sw = g.input("sw", Shape::new(&[nf, nloc, nnei], DType::F32));
    let angle_input = g.input("angle_input", Shape::new(&[nf, nloc, a_sel, a_sel, 1], DType::F32));
    let angle_mask = g.input("angle_mask", Shape::new(&[nf, nloc, a_sel, a_sel], DType::F32));
    let a_sw = g.input("a_sw", Shape::new(&[nf, nloc, a_sel], DType::F32));
    let node_ext = g.input("node_ext", Shape::new(&[nf, nall, cfg.n_dim], DType::F32));

    let out = build_repflows(
        &mut g, &cfg,
        RepflowsInputs {
            node_ebd_ext: node_ext, env_mat_raw: env_mat,
            nlist, nlist_mask, sw,
            angle_input, angle_mask,
            a_sw: Some(a_sw),
            atype_loc: atype,
        },
        nf, nloc, nall,
    ).expect("build");
    g.set_outputs(vec![out.node_ebd, out.edge_ebd, out.h2]);

    let session = Session::new(Device::Cpu);
    let mut compiled = session.compile(g);
    for (name, t) in &fx.params {
        let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
        compiled.set_param_typed(name, &bytes, DType::F32);
    }

    let em_raw = find_named(&fx.inputs, "env_mat_raw");
    let atype_ext = find_named(&fx.inputs, "atype_ext");
    let nlist_in = find_named(&fx.inputs, "nlist");
    let nlist_mask_in = find_named(&fx.inputs, "nlist_mask");
    let sw_in = find_named(&fx.inputs, "sw");
    let node_ext_in = find_named(&fx.inputs, "node_ebd_ext");
    let atype_loc_f32: Vec<f32> = atype_ext.data.iter().take(nf * nloc).copied().collect();

    let angle_input_in = find_named(&fx.inputs, "angle_input");
    let angle_mask_in = find_named(&fx.inputs, "angle_mask");
    let a_sw_in = find_named(&fx.inputs, "a_sw");

    let outputs = compiled.run(&[
        ("env_mat", em_raw.data.as_slice()),
        ("atype", atype_loc_f32.as_slice()),
        ("nlist", nlist_in.data.as_slice()),
        ("nlist_mask", nlist_mask_in.data.as_slice()),
        ("sw", sw_in.data.as_slice()),
        ("angle_input", angle_input_in.data.as_slice()),
        ("angle_mask", angle_mask_in.data.as_slice()),
        ("a_sw", a_sw_in.data.as_slice()),
        ("node_ext", node_ext_in.data.as_slice()),
    ]);
    let exp_node = find_named(&fx.outputs, "node_ebd");
    let exp_edge = find_named(&fx.outputs, "edge_ebd");
    let exp_h2 = find_named(&fx.outputs, "h2");
    let d_n = report_drift(&outputs[0], &exp_node.data, "node_ebd");
    let d_e = report_drift(&outputs[1], &exp_edge.data, "edge_ebd");
    let d_h = report_drift(&outputs[2], &exp_h2.data, "h2");
    assert!(d_n <= tol, "node drift {d_n:.3e} > {tol:.0e}");
    assert!(d_e <= tol, "edge drift {d_e:.3e} > {tol:.0e}");
    assert!(d_h <= tol, "h2 drift {d_h:.3e} > {tol:.0e}");
}

#[test]
fn repflows_nlayers1_matches_python_reference() {
    run_for_fixture("repflows_nlayers1.json", 1e-5);
}

#[test]
fn repflows_nlayers2_matches_python_reference() {
    run_for_fixture("repflows_nlayers2.json", 1e-5);
}

#[test]
fn repflows_nlayers1_nf2_matches_python_reference() {
    run_for_fixture("repflows_nlayers1_nf2.json", 1e-5);
}

#[test]
fn repflows_nlayers2_nf2_matches_python_reference() {
    run_for_fixture("repflows_nlayers2_nf2.json", 1e-5);
}