deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// Parity test for DescrptBlockRepformers (DPA-2 core).

use std::collections::HashMap;
use std::path::Path;

use deepmd::descriptor_repformers::{build_repformers, RepformersConfig, RepformersInputs};
use rlx_ir::{DType, Graph, Shape};
use rlx_runtime::{Device, Session};
use serde::Deserialize;

#[derive(Deserialize)]
struct Tensor { name: String, #[allow(dead_code)] shape: Vec<usize>, data: Vec<f32> }

#[derive(Deserialize)]
struct Config {
    rcut: f64, rcut_smth: f64, sel: Vec<usize>, ntypes: usize,
    nlayers: usize, g1_dim: usize, g2_dim: usize, axis_neuron: usize,
    activation_function: String, smooth: bool, use_sqrt_nnei: bool,
    nf: usize, nloc: usize, nall: usize,
}

#[derive(Deserialize)]
struct Fixture {
    config: Config, params: HashMap<String, Tensor>,
    inputs: Vec<Tensor>, outputs: Vec<Tensor>,
}

fn find_named<'a>(list: &'a [Tensor], name: &str) -> &'a Tensor {
    list.iter().find(|t| t.name == name).expect("missing tensor")
}

fn assert_allclose(actual: &[f32], expected: &[f32], atol: f32, rtol: f32, label: &str) {
    assert_eq!(actual.len(), expected.len(), "{label}: length mismatch");
    let mut max_abs = 0f32;
    for (i, (a, b)) in actual.iter().zip(expected.iter()).enumerate() {
        let abs = (a - b).abs();
        let tol = atol + rtol * b.abs();
        max_abs = max_abs.max(abs);
        assert!(abs <= tol, "{label}[{i}]: actual={a:.6e} expected={b:.6e} |Δ|={abs:.3e}");
    }
    eprintln!("{label}: max_abs={max_abs:.3e}");
}

/// Parity for the repformer block at `nlayers=0` (just env_mat
/// normalization + g2_embd) — this validates the input pipeline.
/// Full-layer parity (`nlayers ≥ 1`) is a follow-up; the current
/// `RepformerLayer` translation has approximate fills for the
/// `Atten2*` / `_update_g1_conv` paths and lands at ~1e-2.
#[test]
fn repformers_nlayers0_matches_python_reference() {
    let path = Path::new(env!("CARGO_MANIFEST_DIR"))
        .parent().unwrap().parent().unwrap()
        .join("tests/fixtures/repformers_nlayers0.json");
    let fx: Fixture = serde_json::from_str(&std::fs::read_to_string(&path).unwrap()).unwrap();

    let cfg = RepformersConfig {
        rcut: fx.config.rcut, rcut_smth: fx.config.rcut_smth,
        sel: fx.config.sel[0], ntypes: fx.config.ntypes,
        nlayers: fx.config.nlayers,
        g1_dim: fx.config.g1_dim, g2_dim: fx.config.g2_dim,
        axis_neuron: fx.config.axis_neuron,
        update_g1_has_conv: false, update_g1_has_drrd: true,
        update_g1_has_grrg: true, update_g1_has_attn: false,
        update_g2_has_g1g1: false, update_g2_has_attn: false,
        update_h2: false,
        attn1_hidden: 8, attn1_nhead: 1,
        attn2_hidden: 4, attn2_nhead: 1,
        attn2_has_gate: false,
        activation_function: fx.config.activation_function.clone(),
        update_style: "res_avg".into(),
        smooth: fx.config.smooth, use_sqrt_nnei: fx.config.use_sqrt_nnei,
        g1_out_conv: true, g1_out_mlp: true,
        ln_eps: 1e-5, epsilon: 1e-4,
    };
    let nf = fx.config.nf;
    let nloc = fx.config.nloc;
    let nall = fx.config.nall;
    let nnei = cfg.sel;

    let mut g = Graph::new("repformers_parity");
    let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
    let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
    let nlist = g.input("nlist", Shape::new(&[nf, nloc, nnei], DType::I32));
    let nlist_mask = g.input("nlist_mask", Shape::new(&[nf, nloc, nnei], DType::F32));
    let sw = g.input("sw", Shape::new(&[nf, nloc, nnei], DType::F32));
    let g1_ext = g.input("g1_ext", Shape::new(&[nf, nall, cfg.g1_dim], DType::F32));

    let out = build_repformers(
        &mut g, &cfg,
        RepformersInputs {
            g1_ext, env_mat_raw: env_mat, nlist, nlist_mask, sw, atype_loc: atype,
            mapping: None,
        },
        nf, nloc, nall,
    ).expect("build");
    g.set_outputs(vec![out.g1, out.g2, out.h2, out.rot_mat]);

    let session = Session::new(Device::Cpu);
    let mut compiled = session.compile(g);
    for (name, t) in &fx.params {
        let bytes: Vec<u8> = t.data.iter().flat_map(|v| v.to_le_bytes()).collect();
        compiled.set_param_typed(name, &bytes, DType::F32);
    }

    let em_raw = find_named(&fx.inputs, "env_mat_raw");
    let atype_ext = find_named(&fx.inputs, "atype_ext");
    let nlist_in = find_named(&fx.inputs, "nlist");
    let nlist_mask_in = find_named(&fx.inputs, "nlist_mask");
    let sw_in = find_named(&fx.inputs, "sw");
    let g1_ext_in = find_named(&fx.inputs, "atype_embd_ext");

    let atype_loc_f32: Vec<f32> = atype_ext.data.iter().take(nf * nloc).copied().collect();
    let nlist_f32: Vec<f32> = nlist_in.data.clone();

    let outputs = compiled.run(&[
        ("env_mat", em_raw.data.as_slice()),
        ("atype", atype_loc_f32.as_slice()),
        ("nlist", nlist_f32.as_slice()),
        ("nlist_mask", nlist_mask_in.data.as_slice()),
        ("sw", sw_in.data.as_slice()),
        ("g1_ext", g1_ext_in.data.as_slice()),
    ]);
    let exp_g1 = find_named(&fx.outputs, "g1");
    let exp_g2 = find_named(&fx.outputs, "g2");
    let exp_h2 = find_named(&fx.outputs, "h2");
    assert_allclose(&outputs[0], &exp_g1.data, 1e-5, 1e-4, "g1");
    assert_allclose(&outputs[1], &exp_g2.data, 1e-5, 1e-4, "g2");
    assert_allclose(&outputs[2], &exp_h2.data, 1e-5, 1e-4, "h2");
}