deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! `se_t` (DeepPot-SE T, three-body) descriptor graph builder.
//!
//! Translated from `DescrptSeT.call` in
//! `deepmd/dpmodel/descriptor/se_t.py`.
//!
//! Topology (per unordered pair of neighbor types `(ti, tj)` with
//! `ti ≤ tj`):
//!
//! ```text
//!     R    = (R_raw - davg[atype]) / dstd[atype]                # [nf, nloc, nnei, 4]
//!     R   *= exclude_mask
//!     R_i  = R[:, :, sec_ti:sec_ti+nt_i, 1:4]                   # xyz only
//!     R_j  = R[:, :, sec_tj:sec_tj+nt_j, 1:4]
//!     env  = R_i · R_jᵀ                                         # [nf, nloc, nt_i, nt_j]
//!     gg   = N_{ti,tj}(env.unsqueeze(-1))                       # [nf, nloc, nt_i, nt_j, ng]
//!     res += (env.unsqueeze(-1) * gg).sum(dims=[2,3]) / (nt_i·nt_j)
//! ```

use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};

use crate::nn::{embedding_mlp, scalar_const, ActivationKind, MlpSpec};

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SeTConfig {
    pub rcut: f64,
    pub rcut_smth: f64,
    pub sel: Vec<usize>,
    #[serde(default = "default_neuron")]
    pub neuron: Vec<usize>,
    #[serde(default)]
    pub resnet_dt: bool,
    #[serde(default = "default_activation")]
    pub activation_function: String,
    #[serde(default)]
    pub type_map: Option<Vec<String>>,
}

fn default_neuron() -> Vec<usize> {
    vec![24, 48, 96]
}
fn default_activation() -> String {
    "tanh".into()
}

impl SeTConfig {
    pub fn ntypes(&self) -> usize {
        self.sel.len()
    }
    pub fn nnei(&self) -> usize {
        self.sel.iter().sum()
    }
    pub fn ng(&self) -> usize {
        *self.neuron.last().expect("se_t: empty neuron list")
    }
    pub fn dim_out(&self) -> usize {
        self.ng()
    }
    pub fn sel_cumsum(&self) -> Vec<usize> {
        let mut acc = Vec::with_capacity(self.sel.len() + 1);
        acc.push(0);
        let mut s = 0;
        for &n in &self.sel {
            s += n;
            acc.push(s);
        }
        acc
    }
}

pub struct SeTDescriptor {
    /// Descriptor `D` node, shape `[nf, nloc, ng]`.
    pub descriptor: NodeId,
    pub dim_out: usize,
}

pub fn build_se_t_descriptor(
    g: &mut Graph,
    cfg: &SeTConfig,
    env_mat_raw: NodeId,
    atype_loc: NodeId,
    nf: usize,
    nloc: usize,
    exclude_mask: Option<NodeId>,
) -> Result<SeTDescriptor> {
    let activation = ActivationKind::parse(&cfg.activation_function)?;
    let ntypes = cfg.ntypes();
    let nnei = cfg.nnei();
    let ng = cfg.ng();
    let sec = cfg.sel_cumsum();

    let rr_shape = g.shape(env_mat_raw).clone();
    if rr_shape.rank() != 4 {
        bail!(
            "se_t: env-matrix input must have rank 4 [nf, nloc, nnei, 4], got rank {}",
            rr_shape.rank()
        );
    }

    let davg = g.param(
        "descriptor.davg",
        Shape::new(&[ntypes, nnei, 4], DType::F32),
    );
    let dstd = g.param(
        "descriptor.dstd",
        Shape::new(&[ntypes, nnei, 4], DType::F32),
    );
    let davg_g = g.gather_(davg, atype_loc, 0);
    let dstd_g = g.gather_(dstd, atype_loc, 0);
    let mut rr = g.sub(env_mat_raw, davg_g);
    rr = g.div(rr, dstd_g);

    if let Some(mask) = exclude_mask {
        let mask_shape = Shape::new(&[nf, nloc, nnei, 1], DType::F32);
        let mask_4d = g.reshape(
            mask,
            vec![nf as i64, nloc as i64, nnei as i64, 1],
            mask_shape,
        );
        rr = g.mul(rr, mask_4d);
    }

    // Drop the radial column; keep the xyz unit-vector columns.
    let rr_xyz = g.narrow_(rr, 3, 1, 3); // [nf, nloc, nnei, 3]

    let mut acc: Option<NodeId> = None;
    for ti in 0..ntypes {
        for tj in ti..ntypes {
            let nt_i = sec[ti + 1] - sec[ti];
            let nt_j = sec[tj + 1] - sec[tj];
            if nt_i == 0 || nt_j == 0 {
                continue;
            }
            // [nf, nloc, nt_i, 3]
            let rr_i = g.narrow_(rr_xyz, 2, sec[ti], nt_i);
            // [nf, nloc, nt_j, 3]
            let rr_j = g.narrow_(rr_xyz, 2, sec[tj], nt_j);
            // env = rr_i · rr_jᵀ : [nf, nloc, nt_i, nt_j]
            let rr_j_t = g.transpose_(rr_j, vec![0, 1, 3, 2]);
            let env_ij = g.mm(rr_i, rr_j_t);

            // env_ij[..., None] → [nf, nloc, nt_i, nt_j, 1]
            let env_5d_shape = Shape::new(&[nf, nloc, nt_i, nt_j, 1], DType::F32);
            let env_5d = g.reshape(
                env_ij,
                vec![nf as i64, nloc as i64, nt_i as i64, nt_j as i64, 1],
                env_5d_shape,
            );

            // Embedding MLP over the trailing "1" axis → [nf, nloc, nt_i, nt_j, ng]
            let prefix = format!("descriptor.embedding.{ti}.{tj}");
            let mlp = MlpSpec {
                param_prefix: &prefix,
                in_dim: 1,
                neuron: &cfg.neuron,
                activation,
                resnet_dt: cfg.resnet_dt,
            };
            let gg = embedding_mlp(g, &mlp, env_5d);

            // weighted_sum = sum_{i,j}(env_ij[..., None] * gg) over axes 2,3
            // → [nf, nloc, ng]
            let weighted = g.mul(env_5d, gg);
            let sum_shape = Shape::new(&[nf, nloc, ng], DType::F32);
            let summed = g.reduce(weighted, ReduceOp::Sum, vec![2, 3], false, sum_shape);

            let scale = scalar_const(g, 1.0 / (nt_i as f32 * nt_j as f32));
            let scaled = g.mul(summed, scale);

            acc = Some(match acc {
                None => scaled,
                Some(prev) => g.add(prev, scaled),
            });
        }
    }

    let res = acc.ok_or_else(|| anyhow::anyhow!("se_t: empty selection"))?;
    Ok(SeTDescriptor {
        descriptor: res,
        dim_out: ng,
    })
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn se_t_descriptor_builds() {
        let cfg = SeTConfig {
            rcut: 6.0,
            rcut_smth: 0.5,
            sel: vec![20, 40],
            neuron: vec![8, 16, 32],
            resnet_dt: false,
            activation_function: "tanh".into(),
            type_map: None,
        };
        let mut g = Graph::new("se_t_test");
        let nf = 1;
        let nloc = 4;
        let nnei = cfg.nnei();
        let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
        let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
        let out = build_se_t_descriptor(&mut g, &cfg, env_mat, atype, nf, nloc, None)
            .expect("build");
        assert_eq!(out.dim_out, cfg.dim_out());
        assert!(g.len() > 10);
    }
}