deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! Parity-testing utilities — compare Rust graph output against
//! Python-reference golden fixtures.
//!
//! The `deepmd` graphs are *graph builders* — they don't execute on
//! their own.  Element-wise parity therefore requires:
//!
//! 1. Compile the graph with one of the RLX runtime backends (e.g.
//!    `rlx-runtime` + `rlx-cpu`).
//! 2. Bind parameter buffers loaded from a DeePMD serialized model
//!    (see [`crate::weights`]).
//! 3. Bind input buffers — typically `env_mat`, `atype`, `nlist`,
//!    `sw`, `nlist_mask` produced by [`crate::env_mat::make_env_mat`].
//! 4. Run, and compare element-wise to a Python-generated reference
//!    (`deepmd.dpmodel`-side run on identical inputs).
//!
//! This module centralizes the tolerance / shape-check / loader code
//! so individual model-level parity tests stay terse.

use anyhow::{anyhow, bail, Result};
use serde::Deserialize;
use std::path::Path;

/// Golden fixture: a single Python-reference output recorded for
/// a given config + input combination.
#[derive(Debug, Clone, Deserialize)]
pub struct GoldenOutput {
    /// Output name (e.g. `"atom_energy"`, `"energy"`, `"force"`).
    pub name: String,
    /// Row-major flat tensor.
    pub data: Vec<f32>,
    pub shape: Vec<usize>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct GoldenFixture {
    /// Path inside the serialized DeePMD model file used.
    pub model_path: String,
    /// Per-input arrays the Python reference saw.
    pub inputs: Vec<GoldenInput>,
    pub outputs: Vec<GoldenOutput>,
}

#[derive(Debug, Clone, Deserialize)]
pub struct GoldenInput {
    pub name: String,
    pub data: Vec<f32>,
    pub shape: Vec<usize>,
}

impl GoldenFixture {
    pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
        let text = std::fs::read_to_string(path)?;
        Ok(serde_json::from_str(&text)?)
    }
}

/// Element-wise absolute / relative comparison against a golden tensor.
///
/// Returns the maximum absolute error and the first index that fails
/// the combined `|a - b| ≤ atol + rtol·|b|` tolerance, mirroring
/// `numpy.testing.assert_allclose`.
pub fn assert_allclose(actual: &[f32], expected: &[f32], atol: f32, rtol: f32) -> Result<f32> {
    if actual.len() != expected.len() {
        bail!(
            "length mismatch: actual={} expected={}",
            actual.len(),
            expected.len()
        );
    }
    let mut max_abs = 0f32;
    for (i, (a, b)) in actual.iter().zip(expected.iter()).enumerate() {
        let abs = (a - b).abs();
        let tol = atol + rtol * b.abs();
        max_abs = max_abs.max(abs);
        if abs > tol || a.is_nan() != b.is_nan() {
            return Err(anyhow!(
                "parity failed at index {i}: actual={a} expected={b} \
                 |Δ|={abs} > tol={tol}"
            ));
        }
    }
    Ok(max_abs)
}

/// Sanity check: the shape of an actual tensor matches an expected
/// shape vector.
pub fn assert_shape(actual: &[usize], expected: &[usize]) -> Result<()> {
    if actual != expected {
        bail!("shape mismatch: actual={actual:?} expected={expected:?}");
    }
    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn allclose_passes_on_identical() {
        let a = vec![1.0f32, 2.0, 3.0];
        let b = vec![1.0f32, 2.0, 3.0];
        let err = assert_allclose(&a, &b, 1e-6, 1e-6).unwrap();
        assert!(err < 1e-6);
    }

    #[test]
    fn allclose_passes_within_tol() {
        let a = vec![1.0f32, 2.0, 3.0];
        let b = vec![1.0 + 1e-7, 2.0 - 1e-7, 3.0 + 1e-7];
        assert!(assert_allclose(&a, &b, 1e-5, 0.0).is_ok());
    }

    #[test]
    fn allclose_fails_above_tol() {
        let a = vec![1.0f32, 2.0];
        let b = vec![1.0, 3.0];
        assert!(assert_allclose(&a, &b, 1e-3, 1e-3).is_err());
    }

    #[test]
    fn shape_check() {
        assert!(assert_shape(&[2, 3], &[2, 3]).is_ok());
        assert!(assert_shape(&[2, 3], &[3, 2]).is_err());
    }
}