deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! `se_r` (radial-only DeepPot-SE) descriptor graph builder.
//!
//! Translated from `DescrptSeR.call` in
//! `deepmd/dpmodel/descriptor/se_r.py`.
//!
//! Inputs:
//!
//! * `env_mat_raw` — `[nf, nloc, nnei, 1]` radial-only env matrix from
//!   the host (`R = s(r_ji)`, no unit-vector axes).
//! * `atype_loc`   — `[nf, nloc]` i32 atom types of local atoms.
//! * `exclude_mask` — `[nf, nloc, nnei]` f32 (optional).
//!
//! Topology:
//!
//! ```text
//!     R = (R_raw - davg[atype_loc]) / dstd[atype_loc]
//!     R *= exclude_mask
//!     for each neighbor type t:
//!         G_t = N_t(R_t)                       # [nf, nloc, sel[t], ng]
//!         G_t = mean over neighbors           # [nf, nloc, ng]
//!         out += G_t * (sel[t] / nnei)
//!     out *= 1/5
//! ```
//!
//! Note `se_r` only supports `type_one_side = true` upstream — we
//! mirror that restriction.

use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};

use crate::nn::{embedding_mlp, scalar_const, ActivationKind, MlpSpec};

/// Configuration for the `se_r` descriptor.
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SeRConfig {
    pub rcut: f64,
    pub rcut_smth: f64,
    pub sel: Vec<usize>,
    #[serde(default = "default_embedding_neuron")]
    pub neuron: Vec<usize>,
    #[serde(default)]
    pub resnet_dt: bool,
    #[serde(default = "default_activation")]
    pub activation_function: String,
    #[serde(default)]
    pub type_map: Option<Vec<String>>,
}

fn default_embedding_neuron() -> Vec<usize> {
    vec![24, 48, 96]
}
fn default_activation() -> String {
    "tanh".into()
}

impl SeRConfig {
    pub fn ntypes(&self) -> usize {
        self.sel.len()
    }
    pub fn nnei(&self) -> usize {
        self.sel.iter().sum()
    }
    pub fn ng(&self) -> usize {
        *self.neuron.last().expect("se_r: empty neuron list")
    }
    pub fn dim_out(&self) -> usize {
        self.ng()
    }
    pub fn sel_cumsum(&self) -> Vec<usize> {
        let mut acc = Vec::with_capacity(self.sel.len() + 1);
        acc.push(0);
        let mut s = 0;
        for &n in &self.sel {
            s += n;
            acc.push(s);
        }
        acc
    }
}

/// Handle returned from [`build_se_r_descriptor`].
pub struct SeRDescriptor {
    /// Descriptor `D` node, shape `[nf, nloc, ng]`.
    pub descriptor: NodeId,
    /// Output dimension (= `ng`).
    pub dim_out: usize,
}

pub fn build_se_r_descriptor(
    g: &mut Graph,
    cfg: &SeRConfig,
    env_mat_raw: NodeId,
    atype_loc: NodeId,
    nf: usize,
    nloc: usize,
    exclude_mask: Option<NodeId>,
) -> Result<SeRDescriptor> {
    let activation = ActivationKind::parse(&cfg.activation_function)?;
    let ntypes = cfg.ntypes();
    let nnei = cfg.nnei();
    let ng = cfg.ng();
    let sec = cfg.sel_cumsum();

    let rr_shape = g.shape(env_mat_raw).clone();
    if rr_shape.rank() != 4 {
        bail!(
            "se_r: env-matrix input must have rank 4 [nf, nloc, nnei, 1], got rank {}",
            rr_shape.rank()
        );
    }

    let davg = g.param(
        "descriptor.davg",
        Shape::new(&[ntypes, nnei, 1], DType::F32),
    );
    let dstd = g.param(
        "descriptor.dstd",
        Shape::new(&[ntypes, nnei, 1], DType::F32),
    );
    let davg_g = g.gather_(davg, atype_loc, 0);
    let dstd_g = g.gather_(dstd, atype_loc, 0);
    let mut rr = g.sub(env_mat_raw, davg_g);
    rr = g.div(rr, dstd_g);

    if let Some(mask) = exclude_mask {
        let mask_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
        let mask_4d = g.reshape(
            mask,
            vec![nf as i64, nloc as i64, nnei as i64, 1],
            mask_shape,
        );
        rr = g.mul(rr, mask_4d);
    }

    let inv_nnei = 1.0 / nnei as f32;
    let mut acc: Option<NodeId> = None;
    for t in 0..ntypes {
        let start = sec[t];
        let len = sec[t + 1] - start;
        if len == 0 {
            continue;
        }
        let rr_t = g.narrow_(rr, 2, start, len); // [nf, nloc, len, 1]
        let prefix = format!("descriptor.embedding.{t}");
        let mlp = MlpSpec {
            param_prefix: &prefix,
            in_dim: 1,
            neuron: &cfg.neuron,
            activation,
            resnet_dt: cfg.resnet_dt,
        };
        let gg_t = embedding_mlp(g, &mlp, rr_t); // [nf, nloc, len, ng]
        // Σ over the nnei axis. Python does mean(gg)*sel[t]/nnei, but
        // sel[t] == len so the factors collapse to sum_g * (1/nnei).
        let mean_shape = Shape::new(&[nf, nloc, ng], DType::F32);
        let sum_g = g.reduce(gg_t, ReduceOp::Sum, vec![2], false, mean_shape);
        let scale = scalar_const(g, inv_nnei);
        let scaled = g.mul(sum_g, scale);
        acc = Some(match acc {
            None => scaled,
            Some(prev) => g.add(prev, scaled),
        });
    }

    let xyz_scatter =
        acc.ok_or_else(|| anyhow::anyhow!("se_r: empty selection (sum(sel) == 0)"))?;

    // res = xyz_scatter / 5 (matches the Python `res_rescale = 1/5`).
    let five_inv = scalar_const(g, 0.2);
    let res = g.mul(xyz_scatter, five_inv);

    Ok(SeRDescriptor {
        descriptor: res,
        dim_out: ng,
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn se_r_descriptor_builds() {
        let cfg = SeRConfig {
            rcut: 6.0,
            rcut_smth: 0.5,
            sel: vec![46, 92],
            neuron: vec![16, 32, 64],
            resnet_dt: false,
            activation_function: "tanh".into(),
            type_map: None,
        };
        let mut g = Graph::new("se_r_test");
        let nf = 1;
        let nloc = 8;
        let nnei = cfg.nnei();
        let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 1], DType::F32));
        let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
        let out = build_se_r_descriptor(&mut g, &cfg, env_mat, atype, nf, nloc, None)
            .expect("build");
        assert_eq!(out.dim_out, cfg.dim_out());
        assert!(g.len() > 10);
    }
}