deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! Per-atom → global output transforms and gradient builders.
//!
//! Translated from `deepmd/dpmodel/model/transform_output.py` and the
//! force/virial paths threaded through
//! `deepmd/dpmodel/model/make_model.py`.
//!
//! Three concerns live here:
//!
//! * [`reduce_atomic_sum`] / [`reduce_atomic_mean`] — sum or mean of a
//!   per-atom tensor along the `nloc` axis (matches the
//!   `reducible` + `intensive` flags on `OutputVariableDef`).
//! * [`build_force_grad_graph`] — intended to use [`rlx-autodiff`](https://docs.rs/rlx-autodiff)
//!   autodiff to emit a graph that produces `dE / d(env_mat)`.  The caller
//!   chain rule `dE/dr = dE/d(env_mat) · d(env_mat)/dr` on the host.
//! * [`build_virial_via_position`] — `Ξ = -Σ_i F_i ⊗ r_i`; expects
//!   the host to pass `forces` and `positions` graph inputs of
//!   matching shape.

use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::op::ReduceOp;
use rlx_ir::{DType, Graph, NodeId, Shape};

/// Σ over the `nloc` axis of an `[nf, nloc, *]` tensor.
pub fn reduce_atomic_sum(g: &mut Graph, atomic: NodeId, atom_axis: usize) -> NodeId {
    let in_shape = g.shape(atomic).clone();
    let mut out_dims: Vec<rlx_ir::Dim> = in_shape.dims().to_vec();
    out_dims.remove(atom_axis);
    let out_shape = Shape::from_dims(&out_dims, in_shape.dtype());
    g.reduce(atomic, ReduceOp::Sum, vec![atom_axis], false, out_shape)
}

/// Mean over the `nloc` axis (the `intensive` path in DeePMD).
pub fn reduce_atomic_mean(g: &mut Graph, atomic: NodeId, atom_axis: usize) -> NodeId {
    let in_shape = g.shape(atomic).clone();
    let mut out_dims: Vec<rlx_ir::Dim> = in_shape.dims().to_vec();
    out_dims.remove(atom_axis);
    let out_shape = Shape::from_dims(&out_dims, in_shape.dtype());
    g.reduce(atomic, ReduceOp::Mean, vec![atom_axis], false, out_shape)
}

/// Builds a backward graph that computes the gradient of a scalar
/// (singleton-reduced) target with respect to a set of forward
/// inputs/params.  Currently a placeholder — wire up
/// [`rlx-autodiff`](https://docs.rs/rlx-autodiff) once force/virial
/// training paths are enabled in this crate:
///
/// ```rust,ignore
/// let mut prepared = forward.clone();
/// prepared.set_outputs(vec![target]);
/// rlx_autodiff::grad(&prepared, wrt)
/// ```
pub fn build_force_grad_graph(forward: &Graph, target: NodeId, wrt: &[NodeId]) -> Graph {
    let _ = (target, wrt);
    forward.clone()
}

/// Compute the virial tensor `Ξ_{αβ} = -Σ_i F_iα · r_iβ`.
///
/// `forces` and `positions` must both have shape `[nf, nloc, 3]`.
/// Returns a `[nf, 3, 3]` node.
pub fn build_virial_via_position(
    g: &mut Graph,
    forces: NodeId,
    positions: NodeId,
    nf: usize,
    nloc: usize,
) -> NodeId {
    // F[..., :, None] · r[..., None, :] → [nf, nloc, 3, 3], then Σ_i
    let f_4d_shape = Shape::new(&[nf, nloc, 3, 1], DType::F32);
    let f_4d = g.reshape(forces, vec![nf as i64, nloc as i64, 3, 1], f_4d_shape);
    let r_4d_shape = Shape::new(&[nf, nloc, 1, 3], DType::F32);
    let r_4d = g.reshape(positions, vec![nf as i64, nloc as i64, 1, 3], r_4d_shape);
    let outer = g.mul(f_4d, r_4d); // [nf, nloc, 3, 3]
    let sum_shape = Shape::new(&[nf, 3, 3], DType::F32);
    let sum = g.reduce(outer, ReduceOp::Sum, vec![1], false, sum_shape);
    g.neg(sum)
}

/// Helper: build a scalar total-energy node by summing `[nf, nloc, 1]`
/// down to `[1]` — the typical `target` for autodiff in MD.
pub fn build_total_energy(g: &mut Graph, atomic_energy: NodeId, nf: usize) -> Result<NodeId> {
    // atomic_energy: [nf, nloc, 1] → reduce nloc → [nf, 1] → reduce nf → [1]
    let nloc_reduce = reduce_atomic_sum(g, atomic_energy, 1); // [nf, 1]
    let _ = nf;
    let scalar_shape = Shape::new(&[1], DType::F32);
    Ok(g.reduce(nloc_reduce, ReduceOp::Sum, vec![0], false, scalar_shape))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn reductions_match_axis() {
        let mut g = Graph::new("reduce");
        let x = g.input("x", Shape::new(&[2, 4, 3], DType::F32));
        let s = reduce_atomic_sum(&mut g, x, 1);
        assert_eq!(g.shape(s).rank(), 2);
        assert_eq!(g.shape(s).dim(1), rlx_ir::Dim::Static(3));
    }

    #[test]
    fn virial_shape_matches() {
        let mut g = Graph::new("virial");
        let nf = 1;
        let nloc = 4;
        let f = g.input("f", Shape::new(&[nf, nloc, 3], DType::F32));
        let r = g.input("r", Shape::new(&[nf, nloc, 3], DType::F32));
        let v = build_virial_via_position(&mut g, f, r, nf, nloc);
        let s = g.shape(v);
        assert_eq!(s.rank(), 3);
        assert_eq!(s.dim(1), rlx_ir::Dim::Static(3));
        assert_eq!(s.dim(2), rlx_ir::Dim::Static(3));
    }
}