deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! `se_e2_a` (DeepPot-SE) descriptor graph builder.
//!
//! Translated from `DescrptSeA.call` in
//! `deepmd/dpmodel/descriptor/se_e2_a.py`.
//!
//! Inputs to the graph:
//!
//! * `env_mat_raw`  — `[nf, nloc, nnei, 4]`, the geometric env matrix
//!   produced by the host (`R = (s, s·r̂)`).
//! * `atype_loc`    — `[nf, nloc]` i32, atom types of the local atoms.
//! * `exclude_mask` — `[nf, nloc, nnei]` f32 (optional). `1` for kept
//!   pairs, `0` for excluded. Caller multiplies the geometric mask in.
//!
//! Graph topology (per the Python `DescrptSeA.call`):
//!
//! ```text
//!     R = (env_mat_raw - davg[atype_loc]) / dstd[atype_loc]
//!     R *= exclude_mask
//!     # type_one_side=true:
//!     for each neighbor type t:
//!         G_t = N_t(R_t[..., 0:1])
//!         g  += G_tᵀ · R_t
//!     # type_one_side=false: per-centre + per-neighbor (ntypes² nets)
//!     for ti in 0..ntypes:
//!         contrib_ti = Σ_j G_{ti,j}ᵀ R_{j_slice}
//!         g += (atype_loc == ti).cast · contrib_ti
//!     g  /= nnei
//!     g_< = g[:axis_neuron, :]
//!     D   = g · g_<ᵀ      # [nf, nloc, ng, M₂]
//! ```

use anyhow::{anyhow, bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};

use crate::config::SeAConfig;
use crate::nn::{embedding_mlp, scalar_const, ActivationKind, MlpSpec};

/// Handle returned from [`build_se_a_descriptor`].
pub struct SeADescriptor {
    /// Descriptor `D` node, shape `[nf, nloc, ng·M₂]`.
    pub descriptor: NodeId,
    /// Rotationally-equivariant single-particle representation,
    /// shape `[nf, nloc, ng, 3]`.
    pub gr: NodeId,
    /// Output dimension of the descriptor (`ng·M₂`).
    pub dim_out: usize,
}

/// Optional inputs that some graphs need.
#[derive(Default, Clone, Copy)]
pub struct SeAExtraInputs {
    /// Optional `[nf, nloc, nnei]` f32 mask multiplied into `R` before
    /// the embedding net. When `None`, no exclusion is applied.
    pub exclude_mask: Option<NodeId>,
}

/// Build the `se_e2_a` descriptor sub-graph.
///
/// Param names emitted on `g`:
///
/// ```text
///     descriptor.davg              # [ntypes, nnei, 4]
///     descriptor.dstd              # [ntypes, nnei, 4]
///     descriptor.embedding.{t}.layer{l}.{w,b,idt?}        (type_one_side=true)
///     descriptor.embedding.{ti}.{tj}.layer{l}.{w,b,idt?}  (type_one_side=false)
/// ```
pub fn build_se_a_descriptor(
    g: &mut Graph,
    cfg: &SeAConfig,
    env_mat_raw: NodeId,
    atype_loc: NodeId,
    nf: usize,
    nloc: usize,
    extra: SeAExtraInputs,
) -> Result<SeADescriptor> {
    let activation = ActivationKind::parse(&cfg.activation_function)?;
    let ntypes = cfg.ntypes();
    let ng = cfg.ng();
    let m2 = cfg.axis_neuron;
    let nnei = cfg.nnei();
    let sec = cfg.sel_cumsum();

    let rr_shape = g.shape(env_mat_raw).clone();
    if rr_shape.rank() != 4 {
        bail!(
            "se_e2_a: env-matrix input must have rank 4 [nf, nloc, nnei, 4], got rank {}",
            rr_shape.rank()
        );
    }

    // ── Normalization: R = (R_raw - davg[atype]) / dstd[atype] ──
    let davg = g.param(
        "descriptor.davg",
        Shape::new(&[ntypes, nnei, 4], DType::F32),
    );
    let dstd = g.param(
        "descriptor.dstd",
        Shape::new(&[ntypes, nnei, 4], DType::F32),
    );
    let davg_g = g.gather_(davg, atype_loc, 0);
    let dstd_g = g.gather_(dstd, atype_loc, 0);
    let mut rr = g.sub(env_mat_raw, davg_g);
    rr = g.div(rr, dstd_g);

    if let Some(mask) = extra.exclude_mask {
        // mask is [nf, nloc, nnei]; broadcast over the trailing 4-dim.
        let mask_shape = g.shape(mask).clone();
        let mut dims: Vec<rlx_ir::Dim> = mask_shape.dims().to_vec();
        dims.push(rlx_ir::Dim::Static(1));
        let mask_4d_shape = Shape::from_dims(&dims, DType::F32);
        let mask_4d = g.reshape(
            mask,
            vec![nf as i64, nloc as i64, nnei as i64, 1],
            mask_4d_shape,
        );
        rr = g.mul(rr, mask_4d);
    }

    let descriptor_out = if cfg.type_one_side {
        build_descriptor_type_one_side(g, cfg, activation, rr, &sec, nf, nloc, nnei, ng, m2)?
    } else {
        build_descriptor_type_two_side(
            g, cfg, activation, rr, atype_loc, &sec, nf, nloc, nnei, ng, m2,
        )?
    };

    Ok(descriptor_out)
}

fn build_descriptor_type_one_side(
    g: &mut Graph,
    cfg: &SeAConfig,
    activation: ActivationKind,
    rr: NodeId,
    sec: &[usize],
    nf: usize,
    nloc: usize,
    nnei: usize,
    ng: usize,
    m2: usize,
) -> Result<SeADescriptor> {
    let ntypes = cfg.ntypes();
    let mut acc: Option<NodeId> = None;

    for t in 0..ntypes {
        let start = sec[t];
        let len = sec[t + 1] - start;
        if len == 0 {
            continue;
        }

        let rr_t = g.narrow_(rr, 2, start, len); // [nf, nloc, len, 4]
        let ss_t = g.narrow_(rr_t, 3, 0, 1); //     [nf, nloc, len, 1]
        let prefix = format!("descriptor.embedding.{t}");
        let mlp = MlpSpec {
            param_prefix: &prefix,
            in_dim: 1,
            neuron: &cfg.neuron,
            activation,
            resnet_dt: cfg.resnet_dt,
        };
        let gg_t = embedding_mlp(g, &mlp, ss_t); //   [nf, nloc, len, ng]

        let gg_t_t = g.transpose_(gg_t, vec![0, 1, 3, 2]); // [nf, nloc, ng, len]
        let contrib = g.mm(gg_t_t, rr_t); //                 [nf, nloc, ng, 4]

        acc = Some(match acc {
            None => contrib,
            Some(prev) => g.add(prev, contrib),
        });
    }
    let gr_unscaled =
        acc.ok_or_else(|| anyhow!("se_e2_a: empty selection (sum(sel) == 0)"))?;
    Ok(finalize_descriptor(g, gr_unscaled, nf, nloc, nnei, ng, m2))
}

fn build_descriptor_type_two_side(
    g: &mut Graph,
    cfg: &SeAConfig,
    activation: ActivationKind,
    rr: NodeId,
    atype_loc: NodeId,
    sec: &[usize],
    nf: usize,
    nloc: usize,
    nnei: usize,
    ng: usize,
    m2: usize,
) -> Result<SeADescriptor> {
    // For each centre type ti, mask_ti = (atype_loc == ti).cast f32,
    // shape [nf, nloc, 1, 1] (broadcast over ng × 4).  Contribution
    // from centre-type ti is Σ_j G_{ti,j}ᵀ · R_{j_slice}, masked.
    let ntypes = cfg.ntypes();
    let mut acc: Option<NodeId> = None;

    for ti in 0..ntypes {
        let mut sum_j: Option<NodeId> = None;
        for tj in 0..ntypes {
            let start = sec[tj];
            let len = sec[tj + 1] - start;
            if len == 0 {
                continue;
            }
            let rr_j = g.narrow_(rr, 2, start, len); // [nf, nloc, len, 4]
            let ss_j = g.narrow_(rr_j, 3, 0, 1); //     [nf, nloc, len, 1]
            let prefix = format!("descriptor.embedding.{ti}.{tj}");
            let mlp = MlpSpec {
                param_prefix: &prefix,
                in_dim: 1,
                neuron: &cfg.neuron,
                activation,
                resnet_dt: cfg.resnet_dt,
            };
            let gg_ij = embedding_mlp(g, &mlp, ss_j);
            let gg_ij_t = g.transpose_(gg_ij, vec![0, 1, 3, 2]);
            let contrib = g.mm(gg_ij_t, rr_j); //         [nf, nloc, ng, 4]
            sum_j = Some(match sum_j {
                None => contrib,
                Some(prev) => g.add(prev, contrib),
            });
        }
        let Some(contrib_ti) = sum_j else { continue };

        // Mask: cast (atype_loc == ti) to f32, reshape to [nf, nloc, 1, 1].
        let ti_const = type_index_const(g, ti as i32, nf, nloc);
        let eq = g.binary(
            rlx_ir::op::BinaryOp::Sub,
            atype_loc,
            ti_const,
            g.shape(atype_loc).clone(),
        );
        let eq_f32 = g.cast(eq, DType::F32);
        // (atype - ti) == 0 ⇒ mask = 1 - sign(|x|).  Easier: use Where.
        // Instead, derive mask = 1 - min(|eq_f32|, 1).
        let abs = g.activation(
            rlx_ir::op::Activation::Abs,
            eq_f32,
            g.shape(eq_f32).clone(),
        );
        let one = scalar_const(g, 1.0);
        let clipped = g.binary(
            rlx_ir::op::BinaryOp::Min,
            abs,
            one,
            g.shape(abs).clone(),
        );
        let mask = g.sub(one, clipped); // [nf, nloc], f32, 1 where atype==ti

        // Reshape mask to [nf, nloc, 1, 1] for broadcast over [ng, 4].
        let mask_4d_shape = Shape::new(&[nf, nloc, 1, 1], DType::F32);
        let mask_4d = g.reshape(mask, vec![nf as i64, nloc as i64, 1, 1], mask_4d_shape);
        let masked = g.mul(contrib_ti, mask_4d);

        acc = Some(match acc {
            None => masked,
            Some(prev) => g.add(prev, masked),
        });
    }
    let gr_unscaled =
        acc.ok_or_else(|| anyhow!("se_e2_a: empty selection (sum(sel) == 0)"))?;
    Ok(finalize_descriptor(g, gr_unscaled, nf, nloc, nnei, ng, m2))
}

fn finalize_descriptor(
    g: &mut Graph,
    gr_unscaled: NodeId,
    nf: usize,
    nloc: usize,
    nnei: usize,
    ng: usize,
    m2: usize,
) -> SeADescriptor {
    let inv_nnei = scalar_const(g, 1.0 / nnei as f32);
    let gr = g.mul(gr_unscaled, inv_nnei);

    let gr_lt = g.narrow_(gr, 2, 0, m2);
    let gr_lt_t = g.transpose_(gr_lt, vec![0, 1, 3, 2]);
    let d = g.mm(gr, gr_lt_t);

    let dim_out = ng * m2;
    let d_flat_shape = Shape::new(&[nf, nloc, dim_out], DType::F32);
    let descriptor = g.reshape(
        d,
        vec![nf as i64, nloc as i64, dim_out as i64],
        d_flat_shape,
    );

    let gr_vec = g.narrow_(gr, 3, 1, 3);

    SeADescriptor {
        descriptor,
        gr: gr_vec,
        dim_out,
    }
}

/// A constant `[nf, nloc]` i32 tensor filled with `value` — needed to
/// compare against `atype_loc` element-wise.
fn type_index_const(g: &mut Graph, value: i32, nf: usize, nloc: usize) -> NodeId {
    let total = nf * nloc;
    let bytes: Vec<u8> = (0..total).flat_map(|_| value.to_le_bytes()).collect();
    g.add_node(
        rlx_ir::op::Op::Constant { data: bytes },
        vec![],
        Shape::new(&[nf, nloc], DType::I32),
    )
}