deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! DPA-2 descriptor — repinit + g1 shape transform + repformers stack.
//!
//! Translated from `deepmd/dpmodel/descriptor/dpa2.py::DescrptDPA2`.
//!
//! ```text
//!     g1_ext_tebd = TypeEmbedding[atype_ext]              # [nf, nall, tebd_dim]
//!     g1 = repinit(env_mat, g1_ext_tebd, …)               # [nf, nloc, ng_repinit]
//!     (g1 = concat(g1, repinit_three_body) if use_three_body)
//!     g1 = g1_shape_tranform(g1)                          # → [nf, nloc, ng1_repformer]
//!     (if add_tebd_to_repinit_out: g1 += tebd_transform(centre_tebd))
//!     g1, g2, h2, rot_mat, sw = repformers(env_mat_rf, g1_ext, …)
//!     if concat_output_tebd:
//!         g1 = concat([g1, centre_tebd], -1)
//! ```
//!
//! The repinit block is the DPA-1 descriptor block; we reuse
//! [`crate::descriptor_dpa1::build_dpa1_descriptor`].

use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};
use serde::{Deserialize, Serialize};

use crate::descriptor_dpa1::{
    build_dpa1_descriptor, DPA1Config, DPA1Inputs,
};
use crate::descriptor_repformers::{
    build_repformers, RepformersConfig, RepformersInputs,
};
use crate::descriptor_se_t_tebd::TebdInputMode;
use crate::nn::{embedding_mlp, ActivationKind, MlpSpec};

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct DPA2Config {
    pub repinit: RepinitArgs,
    pub repformers: RepformersConfig,
    pub ntypes: usize,
    #[serde(default = "default_tebd_dim")]
    pub tebd_dim: usize,
    #[serde(default = "default_true")]
    pub concat_output_tebd: bool,
    #[serde(default)]
    pub add_tebd_to_repinit_out: bool,
    /// Optional second repinit block for three-body (DPA-2.5).
    #[serde(default)]
    pub repinit_three_body: Option<RepinitArgs>,
    #[serde(default = "default_activation")]
    pub activation_function: String,
}

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct RepinitArgs {
    pub rcut: f64,
    pub rcut_smth: f64,
    pub nsel: usize,
    #[serde(default = "default_neuron")]
    pub neuron: Vec<usize>,
    #[serde(default = "default_axis_neuron")]
    pub axis_neuron: usize,
    #[serde(default)]
    pub tebd_input_mode: TebdInputMode,
    #[serde(default)]
    pub type_one_side: bool,
    #[serde(default)]
    pub resnet_dt: bool,
    #[serde(default)]
    pub use_three_body: bool,
}

fn default_neuron() -> Vec<usize> {
    vec![25, 50, 100]
}
fn default_axis_neuron() -> usize {
    16
}
fn default_tebd_dim() -> usize {
    8
}
fn default_true() -> bool {
    true
}
fn default_activation() -> String {
    "tanh".into()
}

pub struct DPA2Inputs {
    pub env_mat_raw_repinit: NodeId,
    pub env_mat_raw_repformer: NodeId,
    pub atype_loc: NodeId,
    pub nei_atype: NodeId,
    pub nlist_repinit: NodeId,
    pub nlist_mask_repinit: NodeId,
    pub sw_repinit: NodeId,
    pub nlist_repformer: NodeId,
    pub nlist_mask_repformer: NodeId,
    pub sw_repformer: NodeId,
    /// `[ntypes+1, tebd_dim]` type embedding table.
    pub type_embedding: NodeId,
    /// `[nf, nall, tebd_dim]` per-extended-atom embedding (gather of
    /// `type_embedding[atype_ext]`).
    pub g1_ext_tebd: NodeId,
    /// `[nf, nall]` i32 mapping that resolves each extended atom to its
    /// owning local atom (`mapping[f, i] ∈ [0, nloc)`).  Used to build
    /// `g1_ext` from `g1` after the shape transform.
    pub mapping: NodeId,
    /// Optional three-body inputs (used when `cfg.repinit.use_three_body=true`).
    /// `[nf, nloc, three_body_nsel, 4]` env mat for the three-body block.
    pub env_mat_raw_three_body: Option<NodeId>,
    /// `[nf, nloc, three_body_nsel]` i32 nei type indices for three-body.
    pub nei_atype_three_body: Option<NodeId>,
    /// `[nf, nloc, three_body_nsel]` f32 switching weight for three-body.
    pub sw_three_body: Option<NodeId>,
}

pub struct DPA2Descriptor {
    /// `[nf, nloc, dim_out]` final per-atom embedding (`g1`).
    pub descriptor: NodeId,
    /// `[nf, nloc, ng2, 3]` equivariant rep from repformer's `rot_mat`.
    pub gr: NodeId,
    /// Debug: repinit output (pre-shape-transform).
    pub debug_repinit: Option<NodeId>,
    /// Debug: g1 after shape transform, fed to repformer (pre-gather).
    pub debug_g1_post_xform: Option<NodeId>,
    pub dim_out: usize,
}

pub fn build_dpa2_descriptor(
    g: &mut Graph,
    cfg: &DPA2Config,
    inputs: DPA2Inputs,
    nf: usize,
    nloc: usize,
    nall: usize,
) -> Result<DPA2Descriptor> {
    let activation = ActivationKind::parse(&cfg.activation_function)?;
    let repinit_cfg = repinit_to_dpa1(cfg, &cfg.repinit);

    // ── repinit (DPA-1 style block) ──
    let dpa1_inputs = DPA1Inputs {
        env_mat_raw: inputs.env_mat_raw_repinit,
        atype_loc: inputs.atype_loc,
        nei_atype: inputs.nei_atype,
        nlist_mask: inputs.nlist_mask_repinit,
        sw: inputs.sw_repinit,
        type_embedding: inputs.type_embedding,
    };
    let repinit_out = build_dpa1_descriptor(g, &repinit_cfg, dpa1_inputs, nf, nloc)?;
    let mut g1 = repinit_out.descriptor; // [nf, nloc, ng_repinit·M2 (+ tebd)]
    let dbg_repinit = Some(g1);

    // ── (optional) three-body repinit (se_t-style block) ──
    if cfg.repinit.use_three_body {
        use crate::descriptor_se_t_tebd::{
            build_se_t_tebd_descriptor, SeTTebdConfig, SeTTebdInputs,
        };
        let three_args = cfg.repinit_three_body.as_ref().ok_or_else(|| {
            anyhow::anyhow!(
                "use_three_body=True requires cfg.repinit_three_body to be set"
            )
        })?;
        let env_three = inputs.env_mat_raw_three_body.ok_or_else(|| {
            anyhow::anyhow!("use_three_body=True requires env_mat_raw_three_body")
        })?;
        let nei_three = inputs.nei_atype_three_body.ok_or_else(|| {
            anyhow::anyhow!("use_three_body=True requires nei_atype_three_body")
        })?;
        let sw_three = inputs.sw_three_body.ok_or_else(|| {
            anyhow::anyhow!("use_three_body=True requires sw_three_body")
        })?;
        let three_cfg = SeTTebdConfig {
            rcut: three_args.rcut,
            rcut_smth: three_args.rcut_smth,
            sel: vec![three_args.nsel],
            ntypes: cfg.ntypes,
            neuron: three_args.neuron.clone(),
            tebd_dim: cfg.tebd_dim,
            tebd_input_mode: three_args.tebd_input_mode,
            resnet_dt: three_args.resnet_dt,
            smooth: true,
            activation_function: cfg.activation_function.clone(),
            param_prefix: "repinit_three_body".into(),
        };
        let three_inputs = SeTTebdInputs {
            env_mat_raw: env_three,
            atype_loc: inputs.atype_loc,
            nei_atype: nei_three,
            type_embedding: inputs.type_embedding,
            sw: Some(sw_three),
            exclude_mask: None,
        };
        let three_out = build_se_t_tebd_descriptor(g, &three_cfg, three_inputs, nf, nloc)?;
        let main_dim = match g.shape(g1).dim(2) {
            rlx_ir::Dim::Static(n) => n,
            _ => return Err(anyhow::anyhow!("dpa2: g1 last dim must be static")),
        };
        let cat_shape = Shape::new(
            &[nf, nloc, main_dim + three_out.dim_out],
            DType::F32,
        );
        g1 = g.concat(vec![g1, three_out.descriptor], 2, cat_shape);
    }

    // ── g1 shape transform: linear from g1_in_dim → ng1 (repformer) ──
    // Python uses NativeLayer(in_dim, out_dim, bias=False).  If the
    // dims already match it uses `Identity()` (no params).  Mirror
    // both branches.
    let ng1 = cfg.repformers.g1_dim;
    let g1_in_dim = match g.shape(g1).dim(2) {
        rlx_ir::Dim::Static(n) => n,
        _ => return Err(anyhow::anyhow!("dpa2: g1 last dim must be static")),
    };
    let mut g1 = if g1_in_dim == ng1 {
        // Identity — no parameters.
        g1
    } else {
        let xform_w = g.param(
            "g1_shape_tranform.w",
            Shape::new(&[g1_in_dim, ng1], DType::F32),
        );
        g.mm(g1, xform_w)
    };

    if cfg.add_tebd_to_repinit_out {
        let centre = g.gather_(inputs.type_embedding, inputs.atype_loc, 0);
        let trans = MlpSpec {
            param_prefix: "tebd_transform",
            in_dim: cfg.tebd_dim,
            neuron: &[ng1],
            activation,
            resnet_dt: false,
        };
        let tebd_proj = embedding_mlp(g, &trans, centre);
        g1 = g.add(g1, tebd_proj);
    }
    let dbg_g1_xform = Some(g1);

    // ── repformer block ──
    // The repformer expects g1_ext [nf, nall, ng1].  Build it via
    // `gather(g1, mapping, axis=1)` — mapping[f, i] tells which local
    // atom owns each extended position.  Python does:
    //     mapping_ext = tile(mapping[..., None], (1, 1, ng1))
    //     g1_ext = take_along_axis(g1, mapping_ext, axis=1)
    // which is exactly an axis-1 gather.
    let g1_ext_padded = if nloc == nall {
        // Single-rank, no ghosts — g1 already has shape [nf, nall, ng1].
        g1
    } else {
        // Gather g1 [nf, nloc, ng1] using mapping [nf, nall] → [nf, nall, ng1].
        // The rlx-cpu axis-1 gather kernel treats `outer = nf` and ids
        // per-batch, so a [nf, nall] index against a [nf, nloc, ng1]
        // table yields exactly [nf, nall, ng1].
        let _ = inputs.g1_ext_tebd;
        g.gather_(g1, inputs.mapping, 1)
    };

    let rf_inputs = RepformersInputs {
        g1_ext: g1_ext_padded,
        env_mat_raw: inputs.env_mat_raw_repformer,
        nlist: inputs.nlist_repformer,
        nlist_mask: inputs.nlist_mask_repformer,
        sw: inputs.sw_repformer,
        atype_loc: inputs.atype_loc,
        mapping: if nloc == nall { None } else { Some(inputs.mapping) },
    };
    let rf_out = build_repformers(g, &cfg.repformers, rf_inputs, nf, nloc, nall)?;
    let mut out_g1 = rf_out.g1;

    if cfg.concat_output_tebd {
        let centre = g.gather_(inputs.type_embedding, inputs.atype_loc, 0); // [nf, nloc, tebd]
        let cat_shape =
            Shape::new(&[nf, nloc, ng1 + cfg.tebd_dim], DType::F32);
        out_g1 = g.concat(vec![out_g1, centre], 2, cat_shape);
    }

    let dim_out = if cfg.concat_output_tebd {
        ng1 + cfg.tebd_dim
    } else {
        ng1
    };
    Ok(DPA2Descriptor {
        descriptor: out_g1,
        gr: rf_out.rot_mat,
        debug_repinit: dbg_repinit,
        debug_g1_post_xform: dbg_g1_xform,
        dim_out,
    })
}

fn repinit_to_dpa1(top: &DPA2Config, ri: &RepinitArgs) -> DPA1Config {
    DPA1Config {
        rcut: ri.rcut,
        rcut_smth: ri.rcut_smth,
        sel: vec![ri.nsel],
        ntypes: top.ntypes,
        neuron: ri.neuron.clone(),
        axis_neuron: ri.axis_neuron,
        tebd_dim: top.tebd_dim,
        tebd_input_mode: ri.tebd_input_mode,
        resnet_dt: ri.resnet_dt,
        attn: 16,
        attn_layer: 0, // repinit uses no attention; pure se_atten-style embedding
        attn_dotr: false,
        num_heads: 1,
        normalize: false,
        scaling_factor: 1.0,
        ln_eps: 1e-5,
        smooth: true,
        type_one_side: ri.type_one_side,
        activation_function: top.activation_function.clone(),
        concat_output_tebd: false,
    }
}