deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! Host-side environment matrix construction.
//!
//! Direct Rust translation of `_make_env_mat` and the smooth /
//! exponential switching weights in
//! `deepmd/dpmodel/utils/env_mat.py`.
//!
//! Given extended coordinates, a neighbor list, and switching
//! parameters, produces the geometric environment matrix `R` that
//! every DeePMD descriptor graph consumes as its input.
//!
//! The output layout matches what `crate::descriptor::*` graphs
//! expect when their `env_mat_raw` input is fed.  The descriptor
//! graphs handle the `davg`/`dstd` normalization themselves.

#[derive(Debug, Clone)]
pub struct EnvMatOutput {
    /// `[nf * nloc * nnei * (4 or 1)]` flat f32 buffer.
    pub env_mat: Vec<f32>,
    /// `[nf * nloc * nnei * 3]` masked relative coordinates.
    pub diff: Vec<f32>,
    /// `[nf * nloc * nnei]` switch weight.
    pub sw: Vec<f32>,
    /// `4` for full, `1` for radial-only.
    pub last_dim: usize,
}

/// Smooth switching weight (`compute_smooth_weight` in Python):
///
/// ```text
///     u = clamp((r - rmin) / (rmax - rmin), 0, 1)
///     w = u³ (-6u² + 15u - 10) + 1
/// ```
pub fn smooth_weight(r: f32, rmin: f32, rmax: f32) -> f32 {
    assert!(rmin < rmax, "rmin must be < rmax");
    let r = r.clamp(rmin, rmax);
    let u = (r - rmin) / (rmax - rmin);
    let u2 = u * u;
    u2 * u * (-6.0 * u2 + 15.0 * u - 10.0) + 1.0
}

/// Exponential switching weight (`compute_exp_sw` in Python):
///
/// ```text
///     w = exp(-exp(20/rmin · (r - rmin)))     (with r clamped to [0, rmax])
/// ```
pub fn exp_sw(r: f32, rmin: f32, rmax: f32) -> f32 {
    assert!(rmin < rmax, "rmin must be < rmax");
    let r = r.clamp(0.0, rmax);
    let c = 20.0;
    let a = c / rmin;
    let b = rmin;
    (-((a * (r - b)).exp())).exp()
}

#[derive(Debug, Clone, Copy)]
pub struct EnvMatParams {
    pub rcut: f32,
    pub rcut_smth: f32,
    pub protection: f32,
    pub use_exp_switch: bool,
    pub radial_only: bool,
}

/// Build the environment matrix on the host.
///
/// Arguments:
///
/// * `coord`  — `nf * nall * 3` flat f32 buffer (or `nf * (nall*3)`,
///   accepted equivalently).
/// * `nlist`  — `nf * nloc * nnei` i32 buffer, `-1` for empty slots.
/// * `nf`, `nall`, `nloc`, `nnei` — explicit shapes.
/// * `params` — geometric switch parameters.
pub fn make_env_mat(
    coord: &[f32],
    nlist: &[i32],
    nf: usize,
    nall: usize,
    nloc: usize,
    nnei: usize,
    params: EnvMatParams,
) -> EnvMatOutput {
    assert_eq!(coord.len(), nf * nall * 3);
    assert_eq!(nlist.len(), nf * nloc * nnei);

    let last_dim = if params.radial_only { 1 } else { 4 };
    let mut env_mat = vec![0f32; nf * nloc * nnei * last_dim];
    let mut diff = vec![0f32; nf * nloc * nnei * 3];
    let mut sw = vec![0f32; nf * nloc * nnei];

    for f in 0..nf {
        let coord_base = f * nall * 3;
        let nlist_base = f * nloc * nnei;
        for i in 0..nloc {
            let xi = coord[coord_base + i * 3];
            let yi = coord[coord_base + i * 3 + 1];
            let zi = coord[coord_base + i * 3 + 2];
            for j in 0..nnei {
                let raw_j = nlist[nlist_base + i * nnei + j];
                let valid = raw_j >= 0;
                let jj = raw_j.max(0) as usize;
                let xj = coord[coord_base + jj * 3];
                let yj = coord[coord_base + jj * 3 + 1];
                let zj = coord[coord_base + jj * 3 + 2];
                let dx = xj - xi;
                let dy = yj - yi;
                let dz = zj - zi;
                let r2 = dx * dx + dy * dy + dz * dz;
                let r = if valid && r2 > 0.0 { r2.sqrt() } else { 1.0 };

                let weight_raw = if params.use_exp_switch {
                    exp_sw(r, params.rcut_smth, params.rcut)
                } else {
                    smooth_weight(r, params.rcut_smth, params.rcut)
                };
                let weight = if valid { weight_raw } else { 0.0 };

                let denom = r + params.protection;
                let t0 = 1.0 / denom;
                let inv_denom2 = 1.0 / (denom * denom);

                let diff_base = (nlist_base + i * nnei + j) * 3;
                if valid {
                    diff[diff_base] = dx;
                    diff[diff_base + 1] = dy;
                    diff[diff_base + 2] = dz;
                }
                sw[nlist_base + i * nnei + j] = weight;

                let em_base = (nlist_base + i * nnei + j) * last_dim;
                env_mat[em_base] = t0 * weight;
                if !params.radial_only {
                    env_mat[em_base + 1] = dx * inv_denom2 * weight;
                    env_mat[em_base + 2] = dy * inv_denom2 * weight;
                    env_mat[em_base + 3] = dz * inv_denom2 * weight;
                }
            }
        }
    }

    EnvMatOutput {
        env_mat,
        diff,
        sw,
        last_dim,
    }
}

/// Subtract `davg[atype]` and divide by `dstd[atype]` in-place, the way
/// `EnvMat.call` does after the geometric part.
///
/// * `env_mat` is mutated in-place.
/// * `atype_loc` is `nf * nloc` i32 (local-atom types).
/// * `davg`, `dstd` are `[ntypes, nnei, last_dim]` flat buffers.
pub fn apply_stats(
    env_mat: &mut [f32],
    atype_loc: &[i32],
    davg: &[f32],
    dstd: &[f32],
    nf: usize,
    nloc: usize,
    nnei: usize,
    last_dim: usize,
    ntypes: usize,
) {
    assert_eq!(env_mat.len(), nf * nloc * nnei * last_dim);
    assert_eq!(atype_loc.len(), nf * nloc);
    assert_eq!(davg.len(), ntypes * nnei * last_dim);
    assert_eq!(dstd.len(), ntypes * nnei * last_dim);
    for f in 0..nf {
        for i in 0..nloc {
            let t = atype_loc[f * nloc + i].max(0) as usize;
            assert!(t < ntypes, "atype out of range");
            for j in 0..nnei {
                let em_base = ((f * nloc + i) * nnei + j) * last_dim;
                let stat_base = (t * nnei + j) * last_dim;
                for k in 0..last_dim {
                    let v = env_mat[em_base + k];
                    env_mat[em_base + k] = (v - davg[stat_base + k]) / dstd[stat_base + k];
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn env_mat_shapes_match() {
        let nf = 1;
        let nall = 4;
        let nloc = 2;
        let nnei = 3;
        let coord = vec![
            0.0, 0.0, 0.0, // atom 0
            1.0, 0.0, 0.0, // atom 1
            0.0, 1.0, 0.0, // atom 2
            0.0, 0.0, 1.0, // atom 3
        ];
        let nlist = vec![1i32, 2, 3, 0, 2, -1];
        let params = EnvMatParams {
            rcut: 2.0,
            rcut_smth: 0.5,
            protection: 0.0,
            use_exp_switch: false,
            radial_only: false,
        };
        let out = make_env_mat(&coord, &nlist, nf, nall, nloc, nnei, params);
        assert_eq!(out.env_mat.len(), nf * nloc * nnei * 4);
        assert_eq!(out.diff.len(), nf * nloc * nnei * 3);
        assert_eq!(out.sw.len(), nf * nloc * nnei);

        // Slot with nlist=-1 must have all zeros.
        let masked_idx = (0 * nloc + 1) * nnei + 2; // (i=1, j=2): -1
        assert_eq!(out.sw[masked_idx], 0.0);
        for k in 0..4 {
            assert_eq!(out.env_mat[masked_idx * 4 + k], 0.0);
        }
    }

    #[test]
    fn smooth_weight_clamps() {
        let rmin = 1.0;
        let rmax = 6.0;
        assert_eq!(smooth_weight(0.5, rmin, rmax), 1.0); // r < rmin → u=0 → 1
        assert_eq!(smooth_weight(10.0, rmin, rmax), 0.0); // r > rmax → u=1 → 0
    }
}