deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! DP-style model graphs for the non-energy output kinds:
//! dipole, polarizability, DOS, and generic invariant property.
//!
//! Each function composes:
//! * a `se_e2_a` descriptor (the default in upstream DPModel),
//! * the matching fitting net,
//! * a sum reduction over `nloc` to produce the global quantity.
//!
//! Mirrors `dipole_model.py`, `polar_model.py`, `dos_model.py`,
//! `property_model.py` from `deepmd.dpmodel.model`.  Forces/virials
//! are derivatives of the same graphs and are exposed elsewhere
//! (`crate::transform_output`).

use anyhow::Result;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};

use crate::config::{DPModelConfig, EnerFittingConfig, SeAConfig};
use crate::descriptor::{build_se_a_descriptor, SeADescriptor, SeAExtraInputs};
use crate::fitting::{
    build_dipole_fitting, build_invar_fitting, build_polar_fitting, DipoleFittingConfig,
    FittingInputs, PolarFittingConfig,
};

/// Inputs shared by every DP-style graph.
pub struct DPInputs {
    pub env_mat: NodeId,
    pub atype: NodeId,
}

fn build_descriptor(
    g: &mut Graph,
    descriptor_cfg: &SeAConfig,
    nf: usize,
    nloc: usize,
) -> Result<(SeADescriptor, DPInputs)> {
    let nnei = descriptor_cfg.nnei();
    let env_mat = g.input("env_mat", Shape::new(&[nf, nloc, nnei, 4], DType::F32));
    let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
    let descriptor = build_se_a_descriptor(
        g,
        descriptor_cfg,
        env_mat,
        atype,
        nf,
        nloc,
        SeAExtraInputs::default(),
    )?;
    Ok((descriptor, DPInputs { env_mat, atype }))
}

// ── dipole ────────────────────────────────────────────────────────────

pub struct DPDipoleGraph {
    /// Per-atom dipole, `[nf, nloc, 3]`.
    pub dipole: NodeId,
    /// Global dipole, `[nf, 3]` (sum over `nloc`).
    pub global_dipole: NodeId,
    pub inputs: DPInputs,
}

pub fn build_dp_dipole_graph(
    descriptor_cfg: &SeAConfig,
    fitting_cfg: &DipoleFittingConfig,
    nf: usize,
    nloc: usize,
) -> Result<(Graph, DPDipoleGraph)> {
    let mut g = Graph::new("dp_dipole");
    let (descriptor, inputs) = build_descriptor(&mut g, descriptor_cfg, nf, nloc)?;
    let gr_4d = reshape_gr_for_fitting(&mut g, descriptor.gr, nf, nloc, descriptor_cfg);
    let fit_inputs = FittingInputs {
        descriptor: descriptor.descriptor,
        atype_loc: inputs.atype,
        fparam: None,
        aparam: None,
        case_embd: None,
        exclude_mask: None,
        gr: Some(gr_4d),
    };
    let out = build_dipole_fitting(&mut g, fitting_cfg, fit_inputs, nf, nloc)?;
    let global = g.reduce(
        out.dipole,
        ReduceOp::Sum,
        vec![1],
        false,
        Shape::new(&[nf, 3], DType::F32),
    );
    Ok((
        g,
        DPDipoleGraph {
            dipole: out.dipole,
            global_dipole: global,
            inputs,
        },
    ))
}

// ── polarizability ────────────────────────────────────────────────────

pub struct DPPolarGraph {
    /// Per-atom polarizability, `[nf, nloc, 3, 3]`.
    pub polar: NodeId,
    /// Global polarizability, `[nf, 3, 3]`.
    pub global_polar: NodeId,
    pub inputs: DPInputs,
}

pub fn build_dp_polar_graph(
    descriptor_cfg: &SeAConfig,
    fitting_cfg: &PolarFittingConfig,
    nf: usize,
    nloc: usize,
) -> Result<(Graph, DPPolarGraph)> {
    let mut g = Graph::new("dp_polar");
    let (descriptor, inputs) = build_descriptor(&mut g, descriptor_cfg, nf, nloc)?;
    let gr_4d = reshape_gr_for_fitting(&mut g, descriptor.gr, nf, nloc, descriptor_cfg);
    let fit_inputs = FittingInputs {
        descriptor: descriptor.descriptor,
        atype_loc: inputs.atype,
        fparam: None,
        aparam: None,
        case_embd: None,
        exclude_mask: None,
        gr: Some(gr_4d),
    };
    let out = build_polar_fitting(&mut g, fitting_cfg, fit_inputs, nf, nloc)?;
    let global = g.reduce(
        out.polar,
        ReduceOp::Sum,
        vec![1],
        false,
        Shape::new(&[nf, 3, 3], DType::F32),
    );
    Ok((
        g,
        DPPolarGraph {
            polar: out.polar,
            global_polar: global,
            inputs,
        },
    ))
}

// ── DOS ───────────────────────────────────────────────────────────────

pub struct DPDosGraph {
    /// Per-atom DOS, `[nf, nloc, numb_dos]`.
    pub atom_dos: NodeId,
    /// Global DOS, `[nf, numb_dos]`.
    pub dos: NodeId,
    pub inputs: DPInputs,
}

pub fn build_dp_dos_graph(
    cfg: &DPModelConfig,
    nf: usize,
    nloc: usize,
) -> Result<(Graph, DPDosGraph)> {
    let mut g = Graph::new("dp_dos");
    let (descriptor, inputs) = build_descriptor(&mut g, &cfg.descriptor, nf, nloc)?;
    let fit_inputs = FittingInputs {
        descriptor: descriptor.descriptor,
        atype_loc: inputs.atype,
        fparam: None,
        aparam: None,
        case_embd: None,
        exclude_mask: None,
        gr: None,
    };
    let invar = build_invar_fitting(
        &mut g,
        &cfg.fitting_net,
        fit_inputs,
        nf,
        nloc,
        "fitting",
        "fitting.bias_atom_e",
    )?;
    let dos = g.reduce(
        invar.atom_output,
        ReduceOp::Sum,
        vec![1],
        false,
        Shape::new(&[nf, cfg.fitting_net.dim_out], DType::F32),
    );
    Ok((
        g,
        DPDosGraph {
            atom_dos: invar.atom_output,
            dos,
            inputs,
        },
    ))
}

// ── Property ──────────────────────────────────────────────────────────

pub struct DPPropertyGraph {
    /// Per-atom property, `[nf, nloc, dim_out]`.
    pub atom_property: NodeId,
    /// Global property, `[nf, dim_out]`.
    pub property: NodeId,
    pub inputs: DPInputs,
}

/// Generic invariant property model — identical topology to the
/// energy / DOS path but with a configurable output name.  The output
/// is named via the param-prefix of `bias_param_name`.
pub fn build_dp_property_graph(
    cfg: &DPModelConfig,
    bias_param_name: &str,
    nf: usize,
    nloc: usize,
) -> Result<(Graph, DPPropertyGraph)> {
    let mut g = Graph::new("dp_property");
    let (descriptor, inputs) = build_descriptor(&mut g, &cfg.descriptor, nf, nloc)?;
    let fit_inputs = FittingInputs {
        descriptor: descriptor.descriptor,
        atype_loc: inputs.atype,
        fparam: None,
        aparam: None,
        case_embd: None,
        exclude_mask: None,
        gr: None,
    };
    let invar = build_invar_fitting(
        &mut g,
        &cfg.fitting_net,
        fit_inputs,
        nf,
        nloc,
        "fitting",
        bias_param_name,
    )?;
    let global = g.reduce(
        invar.atom_output,
        ReduceOp::Sum,
        vec![1],
        false,
        Shape::new(&[nf, cfg.fitting_net.dim_out], DType::F32),
    );
    Ok((
        g,
        DPPropertyGraph {
            atom_property: invar.atom_output,
            property: global,
            inputs,
        },
    ))
}

fn reshape_gr_for_fitting(
    g: &mut Graph,
    gr: NodeId,
    nf: usize,
    nloc: usize,
    cfg: &SeAConfig,
) -> NodeId {
    // SeADescriptor::gr already has shape [nf, nloc, ng, 3].  Make it
    // explicit by passing through a reshape so downstream fitting
    // assertions on static shape always succeed.
    let ng = cfg.ng();
    let out_shape = Shape::new(&[nf, nloc, ng, 3], DType::F32);
    g.reshape(
        gr,
        vec![nf as i64, nloc as i64, ng as i64, 3],
        out_shape,
    )
}

/// Convenience overload of the energy graph that mirrors the
/// non-energy builders — pure wrapper around
/// [`crate::model::build_dp_energy_graph`] so callers using
/// `model_dp::` can find every kind of DP graph in one place.
pub fn build_dp_energy_graph(
    cfg: &DPModelConfig,
    nf: usize,
    nloc: usize,
) -> Result<(Graph, crate::model::DPEnergyGraph)> {
    crate::model::build_dp_energy_graph(cfg, nf, nloc)
}

/// Build an `EnerFittingConfig` re-using `descriptor.dim_out` as the
/// fitting input dim — handy for the dipole/polar configs that need
/// a base ener config under the hood.
pub fn ener_cfg_from(descriptor: &SeAConfig, neuron: Vec<usize>, dim_out: usize) -> EnerFittingConfig {
    EnerFittingConfig {
        ntypes: descriptor.ntypes(),
        dim_descrpt: descriptor.dim_out(),
        dim_out,
        neuron,
        resnet_dt: true,
        numb_fparam: 0,
        numb_aparam: 0,
        dim_case_embd: 0,
        activation_function: "tanh".into(),
        mixed_types: true,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::fitting::{DipoleFittingConfig, PolarFittingConfig};

    fn descrpt() -> SeAConfig {
        SeAConfig {
            rcut: 6.0,
            rcut_smth: 0.5,
            sel: vec![46, 92],
            neuron: vec![16, 32, 64],
            axis_neuron: 4,
            resnet_dt: false,
            type_one_side: true,
            activation_function: "tanh".into(),
            type_map: None,
        }
    }

    #[test]
    fn dipole_graph_builds() {
        let d = descrpt();
        let mw = d.ng();
        let cfg = DipoleFittingConfig {
            base: ener_cfg_from(&d, vec![16], 1),
            embedding_width: mw,
        };
        let (g, _) = build_dp_dipole_graph(&d, &cfg, 1, 4).expect("build");
        assert!(g.len() > 20);
    }

    #[test]
    fn polar_graph_builds() {
        let d = descrpt();
        let mw = d.ng();
        let cfg = PolarFittingConfig {
            base: ener_cfg_from(&d, vec![16], 1),
            embedding_width: mw,
            fit_diag: false,
            shift_diag: true,
        };
        let (g, _) = build_dp_polar_graph(&d, &cfg, 1, 4).expect("build");
        assert!(g.len() > 25);
    }

    #[test]
    fn dos_graph_builds() {
        let d = descrpt();
        let fit = EnerFittingConfig {
            ntypes: d.ntypes(),
            dim_descrpt: d.dim_out(),
            dim_out: 32, // numb_dos
            neuron: vec![16, 16],
            resnet_dt: true,
            numb_fparam: 0,
            numb_aparam: 0,
            dim_case_embd: 0,
            activation_function: "tanh".into(),
            mixed_types: true,
        };
        let cfg = DPModelConfig {
            descriptor: d,
            fitting_net: fit,
            type_map: None,
        };
        let (g, out) = build_dp_dos_graph(&cfg, 1, 4).expect("build");
        assert!(g.len() > 20);
        let s = g.shape(out.dos);
        assert_eq!(s.dim(1), rlx_ir::Dim::Static(32));
    }
}