deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! DPA-3 descriptor — type embedding + repflows stack.
//!
//! Translated from `deepmd/dpmodel/descriptor/dpa3.py::DescrptDPA3`.
//! Optionally folds in charge+spin per-frame embeddings when
//! `add_chg_spin_ebd=true`.

use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};

use crate::descriptor_repflows::{
    build_repflows, RepflowsConfig, RepflowsInputs,
};

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DPA3Config {
    pub repflows: RepflowsConfig,
    pub ntypes: usize,
    #[serde(default = "default_tebd_dim")]
    pub tebd_dim: usize,
    #[serde(default = "default_true")]
    pub concat_output_tebd: bool,
    #[serde(default)]
    pub use_loc_mapping: bool,
    #[serde(default)]
    pub add_chg_spin_ebd: bool,
    /// Activation used in the chg/spin mix MLP (Python's outer
    /// `activation_function` arg; defaults to "silu" to match DPA-3 defaults).
    #[serde(default = "default_chg_spin_activation")]
    pub activation_function: String,
}

fn default_chg_spin_activation() -> String {
    "silu".into()
}

fn default_tebd_dim() -> usize {
    8
}
fn default_true() -> bool {
    true
}

pub struct DPA3Inputs {
    pub env_mat_raw: NodeId,
    pub atype_loc: NodeId,
    pub nlist: NodeId,
    pub nlist_mask: NodeId,
    pub sw: NodeId,
    /// `[nf, nloc, a_sel, a_sel, 1]` precomputed cosines (host).
    pub angle_input: NodeId,
    /// `[nf, nloc, a_sel, a_sel]` mask for valid angle pairs.
    pub angle_mask: NodeId,
    /// `[ntypes+1, tebd_dim]` type embedding table.
    pub type_embedding: NodeId,
    /// `[nf, nall, n_dim]` extended node embedding (host-produced).
    pub node_ebd_ext: NodeId,
    /// Optional `[nf, 2]` charge/spin per-frame (used when
    /// `add_chg_spin_ebd=true`).
    pub charge_spin: Option<NodeId>,
    /// `[201, tebd_dim]` charge embedding table (required when
    /// `add_chg_spin_ebd=true`).
    pub chg_table: Option<NodeId>,
    /// `[101, tebd_dim]` spin embedding table (required when
    /// `add_chg_spin_ebd=true`).
    pub spin_table: Option<NodeId>,
}

pub struct DPA3Descriptor {
    /// `[nf, nloc, n_dim (+ tebd_dim)]` final node embedding.
    pub descriptor: NodeId,
    /// `[nf, nloc, e_dim, 3]` equivariant rot matrix.
    pub gr: NodeId,
    pub dim_out: usize,
}

pub fn build_dpa3_descriptor(
    g: &mut Graph,
    cfg: &DPA3Config,
    inputs: DPA3Inputs,
    nf: usize,
    nloc: usize,
    nall: usize,
) -> Result<DPA3Descriptor> {
    use crate::nn::{scalar_const, ActivationKind, apply_activation};

    // ── add_chg_spin_ebd ──
    // Python (dpa3.py:723-741):
    //   chg_ebd  = chg_table[(charge + 100).astype(int64)]      # [nf, tebd]
    //   spin_ebd = spin_table[spin.astype(int64)]
    //   cs_cat   = concat([chg_ebd, spin_ebd], axis=-1)         # [nf, 2*tebd]
    //   sys_cs_embd = act(mix_cs_mlp(cs_cat))                   # [nf, tebd]
    //   node_ebd_ext = node_ebd_ext + sys_cs_embd[..., None, :] # broadcast over atoms
    let node_ebd_ext_final = if cfg.add_chg_spin_ebd {
        let activation = ActivationKind::parse(&cfg.activation_function)?;
        let cs = inputs.charge_spin.ok_or_else(|| {
            anyhow::anyhow!("add_chg_spin_ebd=True requires charge_spin input")
        })?;
        let chg_table = inputs.chg_table.ok_or_else(|| {
            anyhow::anyhow!("add_chg_spin_ebd=True requires chg_table input")
        })?;
        let spin_table = inputs.spin_table.ok_or_else(|| {
            anyhow::anyhow!("add_chg_spin_ebd=True requires spin_table input")
        })?;
        // charge_idx = charge_spin[:, 0] + 100  (i32 via cast in rlx-cpu f32 arena)
        let charge_f = g.narrow_(cs, 1, 0, 1);   // [nf, 1]
        let charge_f_2d = g.reshape(
            charge_f, vec![nf as i64, 1],
            Shape::new(&[nf, 1], DType::F32),
        );
        let one_hundred = scalar_const(g, 100.0);
        let charge_shifted = g.add(charge_f_2d, one_hundred);
        let charge_idx = g.reshape(
            charge_shifted, vec![nf as i64],
            Shape::new(&[nf], DType::F32),
        );
        let spin_f = g.narrow_(cs, 1, 1, 1); // [nf, 1]
        let spin_idx = g.reshape(
            spin_f, vec![nf as i64],
            Shape::new(&[nf], DType::F32),
        );
        // chg_ebd = chg_table[charge_idx]  [nf, tebd]
        let chg_ebd = g.gather_(chg_table, charge_idx, 0);
        let spin_ebd = g.gather_(spin_table, spin_idx, 0);
        let cs_cat = g.concat(
            vec![chg_ebd, spin_ebd], 1,
            Shape::new(&[nf, 2 * cfg.tebd_dim], DType::F32),
        );
        let mix_w = g.param(
            "mix_cs_mlp.w",
            Shape::new(&[2 * cfg.tebd_dim, cfg.tebd_dim], DType::F32),
        );
        let mix_b = g.param(
            "mix_cs_mlp.b",
            Shape::new(&[cfg.tebd_dim], DType::F32),
        );
        let pre = g.mm(cs_cat, mix_w);
        let pre = g.add(pre, mix_b);
        let sys_cs = apply_activation(g, activation, pre); // [nf, tebd]
        // Broadcast over the atom axis: [nf, 1, tebd] → add to [nf, nall_or_nloc, tebd].
        let node_shape = g.shape(inputs.node_ebd_ext).clone();
        let atom_dim = match node_shape.dim(1) {
            rlx_ir::Dim::Static(d) => d,
            _ => return Err(anyhow::anyhow!("dpa3: node_ebd_ext atom dim must be static")),
        };
        let sys_cs_3d = g.reshape(
            sys_cs, vec![nf as i64, 1, cfg.tebd_dim as i64],
            Shape::new(&[nf, 1, cfg.tebd_dim], DType::F32),
        );
        // Tile [nf, 1, tebd] → [nf, atom_dim, tebd] via concat (cheap for small N).
        let mut tiles = Vec::with_capacity(atom_dim);
        for _ in 0..atom_dim {
            tiles.push(sys_cs_3d);
        }
        let sys_cs_broadcast = g.concat(
            tiles, 1,
            Shape::new(&[nf, atom_dim, cfg.tebd_dim], DType::F32),
        );
        g.add(inputs.node_ebd_ext, sys_cs_broadcast)
    } else {
        inputs.node_ebd_ext
    };

    let rf_inputs = RepflowsInputs {
        node_ebd_ext: node_ebd_ext_final,
        env_mat_raw: inputs.env_mat_raw,
        nlist: inputs.nlist,
        nlist_mask: inputs.nlist_mask,
        sw: inputs.sw,
        angle_input: inputs.angle_input,
        angle_mask: inputs.angle_mask,
        a_sw: None,
        atype_loc: inputs.atype_loc,
    };
    let rf_out = build_repflows(g, &cfg.repflows, rf_inputs, nf, nloc, nall)?;
    let mut node_ebd = rf_out.node_ebd;
    let n_dim = cfg.repflows.n_dim;

    if cfg.concat_output_tebd {
        let centre = g.gather_(inputs.type_embedding, inputs.atype_loc, 0); // [nf, nloc, tebd]
        let cat_shape = Shape::new(&[nf, nloc, n_dim + cfg.tebd_dim], DType::F32);
        node_ebd = g.concat(vec![node_ebd, centre], 2, cat_shape);
    }

    let _ = inputs.charge_spin;
    let _ = cfg.add_chg_spin_ebd;
    let _ = cfg.use_loc_mapping;

    let dim_out = if cfg.concat_output_tebd {
        n_dim + cfg.tebd_dim
    } else {
        n_dim
    };
    Ok(DPA3Descriptor {
        descriptor: node_ebd,
        gr: rf_out.rot_mat,
        dim_out,
    })
}