deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! Spin model graph helpers.
//!
//! Translated from `deepmd/dpmodel/model/spin_model.py::SpinModel`.
//!
//! The spin model wraps a backbone DP atomic model by doubling the
//! atom count: each real atom of type `t < ntypes_real` is mirrored
//! by a virtual atom of type `t + ntypes_real` whose position is
//! `coord + spin * scale[t]`.  The backbone model then runs as
//! usual; its per-atom outputs are split back into a "real" half and
//! a "spin" half, the latter giving the magnetic-moment gradient.
//!
//! In graph form, the spin block is mostly bookkeeping: a few
//! `concat`/`gather`/`add` ops to assemble the doubled inputs from
//! `coord`, `atype`, `spin`, and a per-type virtual-scale table.

use anyhow::Result;
use rlx_ir::infer::GraphExt;
use rlx_ir::{DType, Graph, NodeId, Shape};

#[derive(Debug, Clone)]
pub struct SpinConfig {
    /// Number of *real* atom types. The backbone model sees
    /// `2*ntypes_real` total types (real + virtual mirrors).
    pub ntypes_real: usize,
}

/// Outputs of [`build_spin_inputs`].
pub struct SpinExpanded {
    /// `[nf, 2*nloc, 3]` doubled coordinates.
    pub coord_spin: NodeId,
    /// `[nf, 2*nloc]` doubled atom types.
    pub atype_spin: NodeId,
    /// `[nf, 2*nloc, 3]` coordinate correction for the virial term —
    /// `0` for real atoms, `-spin_dist` for virtual atoms.  Used by
    /// the upstream virial transform; see `make_model.py`.
    pub coord_corr: NodeId,
}

/// Compose the doubled (coord, atype, coord_corr) inputs from the
/// raw per-atom (coord, atype, spin) and a per-type
/// `virtual_scale_mask` param.
///
/// Params emitted:
///
/// ```text
///     spin.virtual_scale_mask    # [2 * ntypes_real] (typically 0 for virtual rows)
/// ```
pub fn build_spin_inputs(
    g: &mut Graph,
    cfg: &SpinConfig,
    coord: NodeId,
    atype: NodeId,
    spin: NodeId,
    nf: usize,
    nloc: usize,
) -> Result<SpinExpanded> {
    let two_nt = 2 * cfg.ntypes_real;

    // ── virtual atom types: atype + ntypes_real ──
    let ntypes_const = i32_filled(g, cfg.ntypes_real as i32, &[nf, nloc]);
    let virtual_atype = g.add(atype, ntypes_const);
    let atype_spin_shape = Shape::new(&[nf, 2 * nloc], DType::I32);
    let atype_spin = g.concat(vec![atype, virtual_atype], 1, atype_spin_shape);

    // ── per-atom virtual scale: vsm[atype] ──
    let vsm = g.param(
        "spin.virtual_scale_mask",
        Shape::new(&[two_nt], DType::F32),
    );
    let scale_per_atom = g.gather_(vsm, atype, 0); // [nf, nloc]
    let scale_3d_shape = Shape::new(&[nf, nloc, 1], DType::F32);
    let scale_3d = g.reshape(
        scale_per_atom,
        vec![nf as i64, nloc as i64, 1],
        scale_3d_shape,
    );

    // ── spin_dist = spin * scale ──
    let spin_dist = g.mul(spin, scale_3d); // [nf, nloc, 3]
    let virtual_coord = g.add(coord, spin_dist);

    let coord_spin_shape = Shape::new(&[nf, 2 * nloc, 3], DType::F32);
    let coord_spin = g.concat(vec![coord, virtual_coord], 1, coord_spin_shape);

    // ── coord_corr: zeros for real, -spin_dist for virtual ──
    let zero_real = zero_tensor(g, &[nf, nloc, 3]);
    let neg_spin_dist = g.neg(spin_dist);
    let coord_corr_shape = Shape::new(&[nf, 2 * nloc, 3], DType::F32);
    let coord_corr = g.concat(vec![zero_real, neg_spin_dist], 1, coord_corr_shape);

    Ok(SpinExpanded {
        coord_spin,
        atype_spin,
        coord_corr,
    })
}

/// Inverse of [`build_spin_inputs`] for per-atom outputs: split a
/// `[nf, 2*nloc, dim]` backbone output into a real half and a
/// magnetic half (both shaped `[nf, nloc, dim]`).
pub fn split_spin_output(
    g: &mut Graph,
    output: NodeId,
    nf: usize,
    nloc: usize,
    dim: usize,
) -> (NodeId, NodeId) {
    let real = g.narrow_(output, 1, 0, nloc);
    let virt = g.narrow_(output, 1, nloc, nloc);
    let _ = (nf, dim);
    (real, virt)
}

fn i32_filled(g: &mut Graph, value: i32, shape: &[usize]) -> NodeId {
    let n: usize = shape.iter().product();
    let bytes: Vec<u8> = (0..n).flat_map(|_| value.to_le_bytes()).collect();
    g.add_node(
        rlx_ir::op::Op::Constant { data: bytes },
        vec![],
        Shape::new(shape, DType::I32),
    )
}

fn zero_tensor(g: &mut Graph, shape: &[usize]) -> NodeId {
    let n: usize = shape.iter().product();
    g.add_node(
        rlx_ir::op::Op::Constant {
            data: vec![0u8; n * 4],
        },
        vec![],
        Shape::new(shape, DType::F32),
    )
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn spin_inputs_build() {
        let mut g = Graph::new("spin");
        let nf = 1;
        let nloc = 4;
        let cfg = SpinConfig { ntypes_real: 2 };
        let coord = g.input("coord", Shape::new(&[nf, nloc, 3], DType::F32));
        let atype = g.input("atype", Shape::new(&[nf, nloc], DType::I32));
        let spin = g.input("spin", Shape::new(&[nf, nloc, 3], DType::F32));
        let out = build_spin_inputs(&mut g, &cfg, coord, atype, spin, nf, nloc)
            .expect("build");
        let s_coord = g.shape(out.coord_spin);
        assert_eq!(s_coord.dim(1), rlx_ir::Dim::Static(2 * nloc));
        let s_atype = g.shape(out.atype_spin);
        assert_eq!(s_atype.dim(1), rlx_ir::Dim::Static(2 * nloc));
    }

    #[test]
    fn spin_output_splits() {
        let mut g = Graph::new("spin_split");
        let nf = 1;
        let nloc = 4;
        let backbone =
            g.input("out", Shape::new(&[nf, 2 * nloc, 1], DType::F32));
        let (real, virt) = split_spin_output(&mut g, backbone, nf, nloc, 1);
        assert_eq!(g.shape(real).dim(1), rlx_ir::Dim::Static(nloc));
        assert_eq!(g.shape(virt).dim(1), rlx_ir::Dim::Static(nloc));
    }
}