deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! Shared neural-network building blocks: activations, dense layers,
//! embedding/fitting MLPs.  Translated from
//! `deepmd/dpmodel/utils/network.py`.

use anyhow::{bail, Result};
use rlx_ir::infer::GraphExt;
use rlx_ir::op::{Activation, BinaryOp, Op};
use rlx_ir::{DType, Graph, NodeId, Shape};

/// Activation kinds recognized in DeePMD config (matches
/// `get_activation_fn` in `utils/network.py`).
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ActivationKind {
    Tanh,
    Relu,
    Relu6,
    /// Both `gelu` and `gelu_tf` map here — Python uses the tanh
    /// approximation in both cases.
    Gelu,
    Sigmoid,
    Softplus,
    Silu,
    Linear,
}

impl ActivationKind {
    pub fn parse(s: &str) -> Result<Self> {
        let s = s.to_ascii_lowercase();
        Ok(match s.as_str() {
            "tanh" => Self::Tanh,
            "relu" => Self::Relu,
            "relu6" => Self::Relu6,
            "gelu" | "gelu_tf" => Self::Gelu,
            "sigmoid" => Self::Sigmoid,
            "softplus" => Self::Softplus,
            "silu" => Self::Silu,
            "none" | "linear" => Self::Linear,
            _ => bail!("unsupported activation: {s}"),
        })
    }
}

/// Apply an activation function to a graph node.
pub fn apply_activation(g: &mut Graph, kind: ActivationKind, x: NodeId) -> NodeId {
    match kind {
        ActivationKind::Tanh => g.tanh(x),
        ActivationKind::Relu => g.relu(x),
        ActivationKind::Gelu => g.gelu_approx(x),
        ActivationKind::Sigmoid => {
            // σ(x) = 1 / (1 + exp(-x)). The IR has no direct sigmoid
            // activation, so synthesize it.
            let neg = g.neg(x);
            let exp = g.exp(neg);
            let one = scalar_const(g, 1.0);
            let denom = g.add(exp, one);
            let one2 = scalar_const(g, 1.0);
            g.div(one2, denom)
        }
        ActivationKind::Softplus => {
            // softplus(x) = log(1 + exp(x))
            let exp = g.exp(x);
            let one = scalar_const(g, 1.0);
            let sum = g.add(exp, one);
            g.activation(Activation::Log, sum, g.shape(sum).clone())
        }
        ActivationKind::Silu => g.silu(x),
        ActivationKind::Relu6 => {
            // min(max(x, 0), 6)
            let zero = scalar_const(g, 0.0);
            let six = scalar_const(g, 6.0);
            let s = g.shape(x).clone();
            let pos = g.binary(BinaryOp::Max, x, zero, s.clone());
            g.binary(BinaryOp::Min, pos, six, s)
        }
        ActivationKind::Linear => x,
    }
}

/// Allocate a rank-1, length-1 `f32` constant tensor.
pub fn scalar_const(g: &mut Graph, v: f32) -> NodeId {
    let bytes = v.to_le_bytes().to_vec();
    g.add_node(
        Op::Constant { data: bytes },
        vec![],
        Shape::new(&[1], DType::F32),
    )
}

/// Allocate an f32 constant tensor of arbitrary shape.
pub fn tensor_const(g: &mut Graph, data: &[f32], shape: &[usize]) -> NodeId {
    let bytes: Vec<u8> = data.iter().flat_map(|v| v.to_le_bytes()).collect();
    g.add_node(
        Op::Constant { data: bytes },
        vec![],
        Shape::new(shape, DType::F32),
    )
}

/// One dense layer with optional bias / activation / `resnet_dt` / skip
/// connection — the building block of every DeePMD MLP
/// (`NativeLayer.call` in `utils/network.py`).
///
/// ```text
///     y = act(x · W + b)
///     if resnet_dt: y *= idt
///     if resnet and Wo == Wi:        y += x
///     elif resnet and Wo == 2*Wi:    y += concat([x, x], -1)
/// ```
pub struct DenseLayerSpec<'a> {
    pub param_prefix: &'a str,
    pub in_dim: usize,
    pub out_dim: usize,
    pub activation: ActivationKind,
    pub use_bias: bool,
    pub resnet_dt: bool,
    pub resnet: bool,
}

pub fn dense_layer(g: &mut Graph, spec: &DenseLayerSpec, x: NodeId) -> NodeId {
    let w = g.param(
        format!("{}.w", spec.param_prefix),
        Shape::new(&[spec.in_dim, spec.out_dim], DType::F32),
    );
    let mm = g.mm(x, w);

    let pre_act = if spec.use_bias {
        let b = g.param(
            format!("{}.b", spec.param_prefix),
            Shape::new(&[spec.out_dim], DType::F32),
        );
        g.add(mm, b)
    } else {
        mm
    };

    let mut y = apply_activation(g, spec.activation, pre_act);

    if spec.resnet_dt {
        let idt = g.param(
            format!("{}.idt", spec.param_prefix),
            Shape::new(&[spec.out_dim], DType::F32),
        );
        y = g.mul(y, idt);
    }

    if spec.resnet {
        if spec.out_dim == spec.in_dim {
            y = g.add(y, x);
        } else if spec.out_dim == 2 * spec.in_dim {
            let x_shape = g.shape(x).clone();
            let last_axis = x_shape.rank() - 1;
            let mut new_dims: Vec<rlx_ir::Dim> = x_shape.dims().to_vec();
            new_dims[last_axis] = match new_dims[last_axis] {
                rlx_ir::Dim::Static(n) => rlx_ir::Dim::Static(n * 2),
                other => other,
            };
            let doubled_shape = Shape::from_dims(&new_dims, x_shape.dtype());
            let xx = g.concat(vec![x, x], last_axis, doubled_shape);
            y = g.add(y, xx);
        }
    }
    y
}

/// Stack of `DenseLayerSpec` — the canonical DeePMD embedding-net
/// (`EmbeddingNet`) and fitting-net trunk.  Each layer has `resnet=true`.
pub struct MlpSpec<'a> {
    pub param_prefix: &'a str,
    pub in_dim: usize,
    pub neuron: &'a [usize],
    pub activation: ActivationKind,
    pub resnet_dt: bool,
}

pub fn embedding_mlp(g: &mut Graph, spec: &MlpSpec, x_in: NodeId) -> NodeId {
    let mut x = x_in;
    let mut in_dim = spec.in_dim;
    for (l, &out_dim) in spec.neuron.iter().enumerate() {
        let layer = format!("{}.layer{l}", spec.param_prefix);
        let layer_spec = DenseLayerSpec {
            param_prefix: &layer,
            in_dim,
            out_dim,
            activation: spec.activation,
            use_bias: true,
            resnet_dt: spec.resnet_dt,
            resnet: true,
        };
        x = dense_layer(g, &layer_spec, x);
        in_dim = out_dim;
    }
    x
}

/// Fitting-net trunk: hidden layers (`resnet=true`) followed by a final
/// linear projection (`activation=Linear`, no resnet).
pub fn fitting_mlp(
    g: &mut Graph,
    spec: &MlpSpec,
    x_in: NodeId,
    final_dim: usize,
    final_prefix: &str,
) -> NodeId {
    let hidden = embedding_mlp(g, spec, x_in);
    let in_dim = *spec.neuron.last().unwrap_or(&spec.in_dim);
    let final_spec = DenseLayerSpec {
        param_prefix: final_prefix,
        in_dim,
        out_dim: final_dim,
        activation: ActivationKind::Linear,
        use_bias: true,
        resnet_dt: false,
        resnet: false,
    };
    dense_layer(g, &final_spec, hidden)
}