deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! Type embedding network — `deepmd/dpmodel/utils/type_embed.py`.
//!
//! Computes a per-type feature table by running a small MLP over the
//! identity matrix:
//!
//! ```text
//!     T = N(I_{ntypes})                    # [ntypes, tebd_dim]
//!     if padding:  T = concat([T, 0], 0)   # [ntypes+1, tebd_dim]
//! ```
//!
//! Downstream descriptors then gather `T[atype]` to produce per-atom
//! embeddings of shape `[nf, nloc, tebd_dim]` (or `[nf, nall, tebd_dim]`
//! for the extended atom set when needed by attention-style flows).

use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::op::Op;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};

use crate::nn::{embedding_mlp, ActivationKind, MlpSpec};

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct TypeEmbedConfig {
    pub ntypes: usize,
    #[serde(default = "default_neuron")]
    pub neuron: Vec<usize>,
    #[serde(default)]
    pub resnet_dt: bool,
    #[serde(default = "default_activation")]
    pub activation_function: String,
    /// Concat a zero "ghost" row at index `ntypes` (used for padding /
    /// virtual atoms).
    #[serde(default)]
    pub padding: bool,
    /// Whether the embedding MLP includes a bias term. Upstream
    /// `use_tebd_bias`. Defaults to `false` to match Python's default.
    #[serde(default)]
    pub use_tebd_bias: bool,
}

fn default_neuron() -> Vec<usize> {
    vec![8]
}
fn default_activation() -> String {
    "tanh".into()
}

impl TypeEmbedConfig {
    pub fn tebd_dim(&self) -> usize {
        *self
            .neuron
            .last()
            .expect("type_embed: empty neuron list")
    }
    pub fn table_rows(&self) -> usize {
        self.ntypes + if self.padding { 1 } else { 0 }
    }
}

/// Build the type-embedding sub-graph and return the `[table_rows,
/// tebd_dim]` node.  Params emitted on `g`:
///
/// ```text
///     type_embed.layer{l}.{w,b?,idt?}
/// ```
pub fn build_type_embedding(
    g: &mut Graph,
    cfg: &TypeEmbedConfig,
    param_prefix: &str,
) -> Result<NodeId> {
    let activation = ActivationKind::parse(&cfg.activation_function)?;
    let ntypes = cfg.ntypes;
    let tebd = cfg.tebd_dim();

    // Identity matrix as a constant input of shape [ntypes, ntypes].
    let mut id_data = vec![0f32; ntypes * ntypes];
    for i in 0..ntypes {
        id_data[i * ntypes + i] = 1.0;
    }
    let id_bytes: Vec<u8> = id_data.iter().flat_map(|v| v.to_le_bytes()).collect();
    let id = g.add_node(
        Op::Constant { data: id_bytes },
        vec![],
        Shape::new(&[ntypes, ntypes], DType::F32),
    );

    // The Python EmbeddingNet uses an optional bias.  Our MLP helper
    // always emits a bias param; when `use_tebd_bias=false`, we just
    // emit it filled with zero from the weight loader (the graph stays
    // identical, the param sits at zero).  This keeps the param-name
    // layout uniform across configurations and lines up with how the
    // Python serialize() round-trip handles the missing-bias case.
    let prefix = format!("{param_prefix}.embedding");
    let mlp = MlpSpec {
        param_prefix: &prefix,
        in_dim: ntypes,
        neuron: &cfg.neuron,
        activation,
        resnet_dt: cfg.resnet_dt,
    };
    let mut t = embedding_mlp(g, &mlp, id); // [ntypes, tebd]

    if cfg.padding {
        // Append a zero row of shape [1, tebd].
        let zero_bytes: Vec<u8> = vec![0u8; 4 * tebd];
        let zero = g.add_node(
            Op::Constant { data: zero_bytes },
            vec![],
            Shape::new(&[1, tebd], DType::F32),
        );
        let out_shape = Shape::new(&[ntypes + 1, tebd], DType::F32);
        t = g.concat(vec![t, zero], 0, out_shape);
    }
    Ok(t)
}

/// Gather per-atom type embeddings: `T[atype]` → `[nf, nloc, tebd_dim]`.
///
/// `atype` must be an i32 (or i64) node of shape `[nf, nloc]`.
pub fn gather_per_atom_embed(g: &mut Graph, table: NodeId, atype: NodeId) -> NodeId {
    g.gather_(table, atype, 0)
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn type_embed_builds() {
        let cfg = TypeEmbedConfig {
            ntypes: 3,
            neuron: vec![16],
            resnet_dt: false,
            activation_function: "tanh".into(),
            padding: true,
            use_tebd_bias: false,
        };
        let mut g = Graph::new("type_embed");
        let table = build_type_embedding(&mut g, &cfg, "type_embed").expect("build");
        let atype = g.input("atype", Shape::new(&[1, 4], DType::I32));
        let per_atom = gather_per_atom_embed(&mut g, table, atype);
        assert_eq!(g.shape(per_atom).rank(), 3);
        assert!(g.len() > 5);
    }
}