deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! `se_atten_v2` descriptor — DPA-1 with `tebd_input_mode="strip"` and
//! smooth type-embedding mixing forced on.
//!
//! Translated from
//! `deepmd/dpmodel/descriptor/se_atten_v2.py::DescrptSeAttenV2`.
//! That class is implemented in Python as a thin subclass of
//! `DescrptDPA1` that pins those two flags; we mirror that exactly via
//! [`DPA1Config`] presets.

use anyhow::Result;
use rlx_ir::Graph;

use crate::descriptor_dpa1::{
    build_dpa1_descriptor, DPA1Config, DPA1Descriptor, DPA1Inputs,
};
use crate::descriptor_se_t_tebd::TebdInputMode;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct SeAttenV2Config {
    pub rcut: f64,
    pub rcut_smth: f64,
    pub sel: Vec<usize>,
    pub ntypes: usize,
    #[serde(default = "default_neuron")]
    pub neuron: Vec<usize>,
    #[serde(default = "default_axis_neuron")]
    pub axis_neuron: usize,
    #[serde(default = "default_tebd_dim")]
    pub tebd_dim: usize,
    #[serde(default)]
    pub resnet_dt: bool,
    #[serde(default = "default_attn")]
    pub attn: usize,
    #[serde(default = "default_attn_layer")]
    pub attn_layer: usize,
    #[serde(default = "default_true")]
    pub attn_dotr: bool,
    #[serde(default = "default_num_heads")]
    pub num_heads: usize,
    #[serde(default = "default_true")]
    pub normalize: bool,
    #[serde(default = "default_scaling_factor")]
    pub scaling_factor: f32,
    #[serde(default = "default_ln_eps")]
    pub ln_eps: f32,
    #[serde(default)]
    pub type_one_side: bool,
    #[serde(default = "default_activation")]
    pub activation_function: String,
    #[serde(default = "default_true")]
    pub concat_output_tebd: bool,
}

fn default_neuron() -> Vec<usize> {
    vec![25, 50, 100]
}
fn default_axis_neuron() -> usize {
    8
}
fn default_tebd_dim() -> usize {
    8
}
fn default_attn() -> usize {
    128
}
fn default_attn_layer() -> usize {
    2
}
fn default_num_heads() -> usize {
    1
}
fn default_ln_eps() -> f32 {
    1e-5
}
fn default_true() -> bool {
    true
}
fn default_scaling_factor() -> f32 {
    1.0
}
fn default_activation() -> String {
    "tanh".into()
}

impl SeAttenV2Config {
    pub fn to_dpa1(&self) -> DPA1Config {
        DPA1Config {
            rcut: self.rcut,
            rcut_smth: self.rcut_smth,
            sel: self.sel.clone(),
            ntypes: self.ntypes,
            neuron: self.neuron.clone(),
            axis_neuron: self.axis_neuron,
            tebd_dim: self.tebd_dim,
            tebd_input_mode: TebdInputMode::Strip,
            resnet_dt: self.resnet_dt,
            attn: self.attn,
            attn_layer: self.attn_layer,
            attn_dotr: self.attn_dotr,
            num_heads: self.num_heads,
            normalize: self.normalize,
            scaling_factor: self.scaling_factor,
            ln_eps: self.ln_eps,
            smooth: true,
            type_one_side: self.type_one_side,
            activation_function: self.activation_function.clone(),
            concat_output_tebd: self.concat_output_tebd,
        }
    }
}

pub fn build_se_atten_v2_descriptor(
    g: &mut Graph,
    cfg: &SeAttenV2Config,
    inputs: DPA1Inputs,
    nf: usize,
    nloc: usize,
) -> Result<DPA1Descriptor> {
    let dpa1 = cfg.to_dpa1();
    build_dpa1_descriptor(g, &dpa1, inputs, nf, nloc)
}