deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! ZBL composite model graph builders.
//!
//! Translated from
//! `deepmd/dpmodel/atomic_model/pairtab_atomic_model.py` and
//! `linear_atomic_model.DPZBLLinearEnergyAtomicModel`.
//!
//! The full DP+ZBL atomic model is a smooth interpolation between two
//! atomic energies:
//!
//! ```text
//!     E^i = (1 - c(σ_i))·E^i_DP + c(σ_i)·E^i_ZBL
//! ```
//!
//! where `c(σ)` is a quintic smooth weight on the softmin distance `σ`
//! (the typical nearest-neighbor proxy) and `E^i_ZBL` is the pairwise
//! tabulated short-range repulsion summed over neighbors.
//!
//! Two graph builders here:
//!
//! * [`build_pair_tab_energy`] — pair-tabulated per-atom energy.
//! * [`build_zbl_weight`] — quintic smooth interpolation weight from
//!   the softmin distance.
//! * [`combine_dp_zbl`] — applies the weight to combine two pre-built
//!   atomic-energy nodes.

use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::op::{BinaryOp, ReduceOp};
use rlx_ir::{DType, Graph, NodeId, Shape};

use crate::nn::scalar_const;

/// Configuration for [`build_pair_tab_energy`].
#[derive(Debug, Clone)]
pub struct PairTabConfig {
    pub ntypes: usize,
    /// Number of buckets in the cubic-spline table.
    pub nspline: usize,
    /// Table bucket lower edge (`tab_info[0]` in Python).
    pub rmin: f32,
    /// Table bucket width (`tab_info[1]` in Python).
    pub hh: f32,
    /// Cutoff radius — beyond this the energy is forced to zero.
    pub rcut: f32,
}

/// Inputs for [`build_pair_tab_energy`].  All four shape arrays are
/// produced on the host.
pub struct PairTabInputs {
    /// `[nf, nloc, nnei]` pairwise distances (already computed from
    /// coords by host).  Masked slots may carry any value — they're
    /// zeroed via `nlist_mask`.
    pub pairwise_rr: NodeId,
    /// `[nf, nloc]` i32 — local-atom types.
    pub atype_loc: NodeId,
    /// `[nf, nloc, nnei]` i32 — neighbor types.  Masked slots should
    /// have `0` or any valid type; `nlist_mask` does the masking.
    pub nei_atype: NodeId,
    /// `[nf, nloc, nnei]` f32 — `1` for real neighbors, `0` for padding.
    pub nlist_mask: NodeId,
}

/// Build the pair-tabulated per-atom energy graph (returns
/// `[nf, nloc, 1]`).
///
/// Params emitted:
///
/// ```text
///     pairtab.coef   # [ntypes, ntypes, nspline, 4]  cubic spline coefficients
/// ```
///
/// The polynomial evaluated per bucket is
/// `a0 + a1·u + a2·u² + a3·u³` with `u = (r - rmin)/hh - idx` clamped
/// to `[0, 1)`.  Energies are halved (the `½` in
/// `E^i = ½ Σ E_{ij}` to avoid double counting), and zeroed out for
/// `r ≥ rcut` and for masked neighbors.
pub fn build_pair_tab_energy(
    g: &mut Graph,
    cfg: &PairTabConfig,
    inputs: PairTabInputs,
    nf: usize,
    nloc: usize,
    nnei: usize,
) -> Result<NodeId> {
    let ntypes = cfg.ntypes;
    let nspline = cfg.nspline;

    // Compute bucket index `idx = floor(uu)`, `uu = (r - rmin)/hh - idx`.
    let rmin = scalar_const(g, cfg.rmin);
    let hi = scalar_const(g, 1.0 / cfg.hh);
    let nspline_minus_one = scalar_const(g, (nspline - 1) as f32);

    let shifted = g.sub(inputs.pairwise_rr, rmin);
    let uu_total = g.mul(shifted, hi); // [nf, nloc, nnei]
    // Clamp uu_total to [0, nspline - 1].
    let zero = scalar_const(g, 0.0);
    let uu_clamped =
        g.binary(BinaryOp::Max, uu_total, zero, g.shape(uu_total).clone());
    let uu_clamped = g.binary(
        BinaryOp::Min,
        uu_clamped,
        nspline_minus_one,
        g.shape(uu_clamped).clone(),
    );
    // idx_f = floor(uu_clamped) → uu_clamped - (uu_clamped - floor(...))
    // The IR may not expose floor directly; cast to i32 to truncate
    // (since uu_clamped ≥ 0 truncation == floor).
    let idx_i32 = g.cast(uu_clamped, DType::I32);
    let idx_f = g.cast(idx_i32, DType::F32);
    let uu = g.sub(uu_clamped, idx_f); // fractional part ∈ [0, 1)

    // Build the linear lookup index for the table.
    // table shape: [ntypes, ntypes, nspline, 4]
    // flat index = ((i_type * ntypes) + j_type) * nspline + idx
    let ntypes_c = scalar_i32(g, ntypes as i32);
    let nspline_c = scalar_i32(g, nspline as i32);
    let i_type_4d = g.reshape(
        inputs.atype_loc,
        vec![nf as i64, nloc as i64, 1],
        Shape::new(&[nf, nloc, 1], DType::I32),
    );
    // Broadcast i_type_4d to [nf, nloc, nnei] via zero add.
    let zero_i32_n = i32_zero_tensor(g, &[nf, nloc, nnei]);
    let i_type_3d = g.binary(
        BinaryOp::Add,
        i_type_4d,
        zero_i32_n,
        Shape::new(&[nf, nloc, nnei], DType::I32),
    );
    let i_scaled = g.binary(
        BinaryOp::Mul,
        i_type_3d,
        ntypes_c,
        g.shape(i_type_3d).clone(),
    );
    let pair_id = g.binary(
        BinaryOp::Add,
        i_scaled,
        inputs.nei_atype,
        g.shape(i_scaled).clone(),
    );
    let pair_id_scaled = g.binary(
        BinaryOp::Mul,
        pair_id,
        nspline_c,
        g.shape(pair_id).clone(),
    );
    let flat_idx = g.binary(
        BinaryOp::Add,
        pair_id_scaled,
        idx_i32,
        g.shape(pair_id_scaled).clone(),
    );

    // Reshape table to [ntypes² · nspline, 4] for a single gather.
    let coef = g.param(
        "pairtab.coef",
        Shape::new(&[ntypes, ntypes, nspline, 4], DType::F32),
    );
    let coef_flat_shape = Shape::new(&[ntypes * ntypes * nspline, 4], DType::F32);
    let coef_flat = g.reshape(
        coef,
        vec![(ntypes * ntypes * nspline) as i64, 4],
        coef_flat_shape,
    );

    // Flatten flat_idx to 1D, gather, reshape back.
    let total = nf * nloc * nnei;
    let flat_idx_1d = g.reshape(
        flat_idx,
        vec![total as i64],
        Shape::new(&[total], DType::I32),
    );
    let gathered = g.gather_(coef_flat, flat_idx_1d, 0); // [total, 4]
    let coef_per_pair = g.reshape(
        gathered,
        vec![nf as i64, nloc as i64, nnei as i64, 4],
        Shape::new(&[nf, nloc, nnei, 4], DType::F32),
    );

    // Evaluate cubic: a0 + a1·u + a2·u² + a3·u³.
    let a0 = g.narrow_(coef_per_pair, 3, 0, 1);
    let a1 = g.narrow_(coef_per_pair, 3, 1, 1);
    let a2 = g.narrow_(coef_per_pair, 3, 2, 1);
    let a3 = g.narrow_(coef_per_pair, 3, 3, 1);
    // Squeeze last dim → [nf, nloc, nnei]
    let squeeze = |g: &mut Graph, x: NodeId| {
        g.reshape(
            x,
            vec![nf as i64, nloc as i64, nnei as i64],
            Shape::new(&[nf, nloc, nnei], DType::F32),
        )
    };
    let a0 = squeeze(g, a0);
    let a1 = squeeze(g, a1);
    let a2 = squeeze(g, a2);
    let a3 = squeeze(g, a3);
    let u2 = g.mul(uu, uu);
    let u3 = g.mul(u2, uu);
    let t1 = g.mul(a1, uu);
    let t2 = g.mul(a2, u2);
    let t3 = g.mul(a3, u3);
    let mut ener = g.add(a0, t1);
    ener = g.add(ener, t2);
    ener = g.add(ener, t3);

    // Zero out beyond rcut.
    let rcut_c = scalar_const(g, cfg.rcut);
    let _ = rcut_c;
    let inside_mask = step_lt(g, inputs.pairwise_rr, cfg.rcut, nf, nloc, nnei);
    let ener = g.mul(ener, inside_mask);

    // Zero out masked neighbors.
    let ener = g.mul(ener, inputs.nlist_mask);

    // Σ over neighbors, halve, reshape to [nf, nloc, 1].
    let sum_shape = Shape::new(&[nf, nloc], DType::F32);
    let summed = g.reduce(ener, ReduceOp::Sum, vec![2], false, sum_shape);
    let half = scalar_const(g, 0.5);
    let halved = g.mul(summed, half);
    let out_shape = Shape::new(&[nf, nloc, 1], DType::F32);
    let out = g.reshape(halved, vec![nf as i64, nloc as i64, 1], out_shape);
    Ok(out)
}

/// Smooth interpolation weight: quintic `1 - 10u³ + 15u⁴ - 6u⁵` over
/// `u = clamp((σ - sw_rmin)/(sw_rmax - sw_rmin), 0, 1)` — taken
/// straight from `DPZBLLinearEnergyAtomicModel._compute_weight`
/// (we follow the convention where `coef=1` means pure-ZBL).
///
/// The σ input is `[nf, nloc]` — typically the host-computed softmin
/// over neighbor distances.  Returns `[nf, nloc, 1]` f32 ready to
/// broadcast against an atomic-energy node.
pub fn build_zbl_weight(
    g: &mut Graph,
    sigma: NodeId,
    sw_rmin: f32,
    sw_rmax: f32,
    nf: usize,
    nloc: usize,
) -> NodeId {
    let rmin = scalar_const(g, sw_rmin);
    let inv_rng = scalar_const(g, 1.0 / (sw_rmax - sw_rmin));
    let one = scalar_const(g, 1.0);
    let zero = scalar_const(g, 0.0);

    let shifted = g.sub(sigma, rmin);
    let u = g.mul(shifted, inv_rng);
    let u = g.binary(BinaryOp::Max, u, zero, g.shape(u).clone());
    let u = g.binary(BinaryOp::Min, u, one, g.shape(u).clone());
    let u2 = g.mul(u, u);
    let u3 = g.mul(u2, u);
    let u4 = g.mul(u3, u);
    let u5 = g.mul(u4, u);
    let c10 = scalar_const(g, 10.0);
    let c15 = scalar_const(g, 15.0);
    let c6 = scalar_const(g, 6.0);
    let t1 = g.mul(c10, u3);
    let t2 = g.mul(c15, u4);
    let t3 = g.mul(c6, u5);
    // smooth = 1 - 10u³ + 15u⁴ - 6u⁵   when σ ∈ [rmin, rmax]
    // For σ < rmin u==0 → smooth=1; σ ≥ rmax u==1 → smooth=0.  This
    // matches the upstream branch table exactly via the clamps above.
    let smooth = g.sub(one, t1);
    let smooth = g.add(smooth, t2);
    let smooth = g.sub(smooth, t3);
    let out_shape = Shape::new(&[nf, nloc, 1], DType::F32);
    g.reshape(smooth, vec![nf as i64, nloc as i64, 1], out_shape)
}

/// `(1 - c)·E_dp + c·E_zbl`.
pub fn combine_dp_zbl(g: &mut Graph, e_dp: NodeId, e_zbl: NodeId, c: NodeId) -> NodeId {
    let one = scalar_const(g, 1.0);
    let one_minus_c = g.sub(one, c);
    let a = g.mul(one_minus_c, e_dp);
    let b = g.mul(c, e_zbl);
    g.add(a, b)
}

fn scalar_i32(g: &mut Graph, v: i32) -> NodeId {
    let bytes = v.to_le_bytes().to_vec();
    g.add_node(
        rlx_ir::op::Op::Constant { data: bytes },
        vec![],
        Shape::new(&[1], DType::I32),
    )
}

fn i32_zero_tensor(g: &mut Graph, shape: &[usize]) -> NodeId {
    let n: usize = shape.iter().product();
    let bytes = vec![0u8; n * 4];
    g.add_node(
        rlx_ir::op::Op::Constant { data: bytes },
        vec![],
        Shape::new(shape, DType::I32),
    )
}

fn step_lt(
    g: &mut Graph,
    x: NodeId,
    threshold: f32,
    nf: usize,
    nloc: usize,
    nnei: usize,
) -> NodeId {
    // step = (x < threshold) → cast to f32 via comparison op.
    // The IR's `Compare(Lt)` returns Bool; cast Bool to f32 to use as a mask.
    let _ = (nf, nloc, nnei);
    let thr = scalar_const(g, threshold);
    let diff = g.sub(x, thr);
    // Mask = 1 if x < thr, else 0. Trick: sign-of-(-diff) → use max(0, -diff)/|diff|
    // We don't have sign; use 1 - clamp(max(diff, 0)/large, 0, 1).
    // Simpler: use the fact that the cubic table coefficients return
    // valid values past nspline-1 → instead, compose the same mask the
    // env_mat does:  step(thr - x) = 0 if thr - x < 0 else 1.
    let neg_diff = g.neg(diff);
    let zero_c = scalar_const(g, 0.0);
    let neg_diff_shape = g.shape(neg_diff).clone();
    let mask_unclamped = g.binary(BinaryOp::Max, neg_diff, zero_c, neg_diff_shape);
    // map (≥0 → 1, ==0 → 0) via min(mask*BIG, 1).
    let big = scalar_const(g, 1e30);
    let scaled = g.mul(mask_unclamped, big);
    let one_c = scalar_const(g, 1.0);
    let scaled_shape = g.shape(scaled).clone();
    g.binary(BinaryOp::Min, scaled, one_c, scaled_shape)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn zbl_weight_builds() {
        let mut g = Graph::new("zbl_weight");
        let sigma = g.input("sigma", Shape::new(&[1, 4], DType::F32));
        let w = build_zbl_weight(&mut g, sigma, 0.5, 1.5, 1, 4);
        let s = g.shape(w);
        assert_eq!(s.dim(0), rlx_ir::Dim::Static(1));
        assert_eq!(s.dim(2), rlx_ir::Dim::Static(1));
    }

    #[test]
    fn pair_tab_energy_builds() {
        let mut g = Graph::new("pair_tab");
        let nf = 1;
        let nloc = 2;
        let nnei = 8;
        let cfg = PairTabConfig {
            ntypes: 2,
            nspline: 16,
            rmin: 0.0,
            hh: 0.5,
            rcut: 6.0,
        };
        let inputs = PairTabInputs {
            pairwise_rr: g.input("rr", Shape::new(&[nf, nloc, nnei], DType::F32)),
            atype_loc: g.input("atype", Shape::new(&[nf, nloc], DType::I32)),
            nei_atype: g.input("nei_atype", Shape::new(&[nf, nloc, nnei], DType::I32)),
            nlist_mask: g.input("nlist_mask", Shape::new(&[nf, nloc, nnei], DType::F32)),
        };
        let e = build_pair_tab_energy(&mut g, &cfg, inputs, nf, nloc, nnei).expect("build");
        let s = g.shape(e);
        assert_eq!(s.dim(2), rlx_ir::Dim::Static(1));
    }
}