deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! Hybrid descriptor — concatenates outputs of multiple descriptors.
//!
//! Translated from `DescrptHybrid.call` in
//! `deepmd/dpmodel/descriptor/hybrid.py`.  The original Python class
//! also takes care of slicing one global `nlist` into per-descriptor
//! sub-nlists; that splitting belongs in host preprocessing (because
//! the sliced indices depend on neighbor counts), so in graph-form
//! "hybrid" is just an `axis=-1` concat of the descriptor outputs.
//!
//! Callers build each sub-descriptor with its own builder, hand the
//! resulting `NodeId`s to [`concat_descriptor_outputs`], and then feed
//! the combined descriptor into a fitting net.

use anyhow::{anyhow, Result};
use rlx_ir::{DType, Graph, NodeId, Shape};

/// Concat a set of descriptor output nodes along the last axis.
///
/// Each input must have shape `[nf, nloc, d_i]` (the standard DeePMD
/// descriptor output layout).  Returns the concat node and its total
/// last-axis dimension.
pub fn concat_descriptor_outputs(
    g: &mut Graph,
    outputs: &[NodeId],
    nf: usize,
    nloc: usize,
) -> Result<(NodeId, usize)> {
    if outputs.is_empty() {
        return Err(anyhow!("hybrid: at least one sub-descriptor is required"));
    }
    let mut total = 0usize;
    for &n in outputs {
        let s = g.shape(n);
        if s.rank() != 3 {
            return Err(anyhow!(
                "hybrid: each sub-descriptor must have rank 3 [nf, nloc, d]"
            ));
        }
        let d = match s.dim(2) {
            rlx_ir::Dim::Static(n) => n,
            _ => return Err(anyhow!("hybrid: sub-descriptor last dim must be static")),
        };
        total += d;
    }
    let out_shape = Shape::new(&[nf, nloc, total], DType::F32);
    let node = g.concat(outputs.to_vec(), 2, out_shape);
    Ok((node, total))
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn hybrid_concat_dims_match() {
        let mut g = Graph::new("hybrid");
        let nf = 1;
        let nloc = 4;
        let a = g.input("a", Shape::new(&[nf, nloc, 16], DType::F32));
        let b = g.input("b", Shape::new(&[nf, nloc, 24], DType::F32));
        let (out, total) = concat_descriptor_outputs(&mut g, &[a, b], nf, nloc).unwrap();
        assert_eq!(total, 40);
        let s = g.shape(out);
        assert_eq!(s.rank(), 3);
        assert_eq!(s.dim(2), rlx_ir::Dim::Static(40));
    }
}