deepmd 0.1.0

DeePMD-kit deep potential models as RLX IR graph builders
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.

//! Weight loading for serialized DeePMD-kit models.
//!
//! DeePMD-kit emits two native serialization formats from
//! `deepmd.dpmodel.utils.serialization.save_dp_model`:
//!
//! 1. **`.yaml` / `.yml`** — pure JSON-/YAML-compatible: every numpy
//!    array is encoded as
//!    `{@class: "np.ndarray", dtype: "float64", value: [...]}`.
//! 2. **`.dp` / `.hlo`** — HDF5 file with a JSON header in
//!    `f.attrs["json"]`; arrays are stored as separate HDF5 datasets
//!    named `variable_NNNN`, referenced by name from the JSON.
//!
//! This module implements format (1) — the YAML/JSON path — since it's
//! pure-Rust and matches the lossless serialization used by upstream
//! tests.  Loading the `.dp` HDF5 format additionally needs the `hdf5`
//! crate and is left for a follow-up.
//!
//! After loading you get a [`WeightTable`] keyed by the *DeePMD param
//! path* (e.g. `descriptor.embeddings.networks.0.layers.0.matrix`).
//! The [`map_to_graph_keys`] helper rewrites those into the param
//! names this crate's graph builders emit (e.g.
//! `descriptor.embedding.0.layer0.w`), so the result can be fed
//! straight into an RLX runtime's parameter map.

use anyhow::{anyhow, bail, Result};
use std::collections::BTreeMap;
use std::path::Path;

/// Owned f32 tensor.
#[derive(Debug, Clone)]
pub struct WeightTensor {
    pub data: Vec<f32>,
    pub shape: Vec<usize>,
}

/// Flat map: deepmd param path → tensor.
#[derive(Debug, Default, Clone)]
pub struct WeightTable {
    pub tensors: BTreeMap<String, WeightTensor>,
    /// Top-level JSON dict (without `@variables` payloads) — useful
    /// for inspecting the config bundled inside the serialized model.
    pub config: serde_json::Value,
}

impl WeightTable {
    pub fn get(&self, key: &str) -> Option<&WeightTensor> {
        self.tensors.get(key)
    }

    pub fn insert(&mut self, key: impl Into<String>, t: WeightTensor) {
        self.tensors.insert(key.into(), t);
    }

    pub fn keys(&self) -> impl Iterator<Item = &str> + '_ {
        self.tensors.keys().map(|s| s.as_str())
    }
}

/// Load a `.yaml` / `.yml` / `.json` DeePMD model file.
pub fn load_model_file<P: AsRef<Path>>(path: P) -> Result<WeightTable> {
    let text = std::fs::read_to_string(path.as_ref())?;
    let ext = path
        .as_ref()
        .extension()
        .and_then(|s| s.to_str())
        .unwrap_or("");
    let value: serde_json::Value = match ext {
        "json" => serde_json::from_str(&text)?,
        "yaml" | "yml" => yaml_to_json(&text)?,
        other => bail!("unsupported model file extension: {other}"),
    };
    parse_model_value(value)
}

fn yaml_to_json(text: &str) -> Result<serde_json::Value> {
    // We don't pull in serde_yaml as a dep — most DeePMD models in
    // testing are .json or trivially convertible.  If the file looks
    // like YAML, error out with a hint.
    serde_json::from_str(text).map_err(|_| {
        anyhow!("YAML decoding not available — convert to .json first or add the serde_yaml feature")
    })
}

pub fn parse_model_value(value: serde_json::Value) -> Result<WeightTable> {
    let mut table = WeightTable::default();
    table.config = value.clone();
    walk_collect(&value, String::new(), &mut table)?;
    Ok(table)
}

fn walk_collect(
    value: &serde_json::Value,
    prefix: String,
    table: &mut WeightTable,
) -> Result<()> {
    use serde_json::Value as V;
    match value {
        V::Object(map) => {
            // Inline np.ndarray payload encoding: {@class: "np.ndarray", dtype, value}
            if let Some(V::String(cls)) = map.get("@class") {
                if cls == "np.ndarray" {
                    let dtype = map
                        .get("dtype")
                        .and_then(|v| v.as_str())
                        .unwrap_or("float32");
                    let value = map
                        .get("value")
                        .ok_or_else(|| anyhow!("np.ndarray without `value` at {prefix}"))?;
                    let t = decode_ndarray(dtype, value)?;
                    table.tensors.insert(prefix, t);
                    return Ok(());
                }
            }
            for (k, v) in map {
                let new_prefix = if prefix.is_empty() {
                    k.clone()
                } else {
                    format!("{prefix}.{k}")
                };
                walk_collect(v, new_prefix, table)?;
            }
        }
        V::Array(arr) => {
            for (i, v) in arr.iter().enumerate() {
                let new_prefix = format!("{prefix}.{i}");
                walk_collect(v, new_prefix, table)?;
            }
        }
        _ => {}
    }
    Ok(())
}

fn decode_ndarray(dtype: &str, value: &serde_json::Value) -> Result<WeightTensor> {
    let (shape, flat) = flatten_nested(value)?;
    let mut data = Vec::with_capacity(flat.len());
    let to_f32 = |v: &serde_json::Value| -> Result<f32> {
        v.as_f64()
            .map(|f| f as f32)
            .or_else(|| v.as_i64().map(|i| i as f32))
            .ok_or_else(|| anyhow!("ndarray element is not a number: {v}"))
    };
    for v in &flat {
        data.push(to_f32(v)?);
    }
    let _ = dtype; // f32 storage regardless; upcasts at execute time
    Ok(WeightTensor { data, shape })
}

fn flatten_nested(value: &serde_json::Value) -> Result<(Vec<usize>, Vec<serde_json::Value>)> {
    use serde_json::Value as V;
    fn recurse(
        v: &V,
        shape: &mut Vec<usize>,
        depth: usize,
        flat: &mut Vec<V>,
    ) -> Result<()> {
        match v {
            V::Array(arr) => {
                if shape.len() == depth {
                    shape.push(arr.len());
                } else if shape[depth] != arr.len() {
                    bail!("ragged ndarray at depth {depth}");
                }
                for el in arr {
                    recurse(el, shape, depth + 1, flat)?;
                }
            }
            other => flat.push(other.clone()),
        }
        Ok(())
    }
    let mut shape = Vec::new();
    let mut flat = Vec::new();
    recurse(value, &mut shape, 0, &mut flat)?;
    Ok((shape, flat))
}

/// Translation table — rewrite DeePMD's serialized param paths into
/// the `crate::descriptor::*` / `crate::fitting::*` param-name
/// conventions.  Both `se_e2_a` (`type_one_side = true`) and the
/// shared-net energy fitting are covered; extending this map is the
/// usual integration step when a new descriptor / fitting variant
/// becomes available.
///
/// Returned map: graph-param-name → deepmd-key (so callers can `.remove`
/// the deepmd key from `WeightTable` and bind by graph name).
pub fn map_to_graph_keys(table: &WeightTable) -> BTreeMap<String, String> {
    let mut out = BTreeMap::new();
    for key in table.tensors.keys() {
        if let Some(mapped) = translate_se_e2_a(key) {
            out.insert(mapped, key.clone());
            continue;
        }
        if let Some(mapped) = translate_ener_fitting(key) {
            out.insert(mapped, key.clone());
            continue;
        }
    }
    out
}

fn translate_se_e2_a(key: &str) -> Option<String> {
    if let Some(rest) = key.strip_prefix("descriptor.@variables.") {
        match rest {
            "davg" => return Some("descriptor.davg".into()),
            "dstd" => return Some("descriptor.dstd".into()),
            _ => {}
        }
    }
    // descriptor.embeddings.networks.{t}.layers.{l}.@variables.{w|b|idt}
    let parts: Vec<&str> = key.split('.').collect();
    if parts.len() == 8
        && parts[0] == "descriptor"
        && parts[1] == "embeddings"
        && parts[2] == "networks"
        && parts[4] == "layers"
        && parts[6] == "@variables"
    {
        let t = parts[3];
        let l = parts[5];
        let var = match parts[7] {
            "matrix" | "w" => "w",
            "bias" | "b" => "b",
            "idt" => "idt",
            _ => return None,
        };
        return Some(format!("descriptor.embedding.{t}.layer{l}.{var}"));
    }
    None
}

fn translate_ener_fitting(key: &str) -> Option<String> {
    if let Some(rest) = key.strip_prefix("fitting_net.@variables.") {
        match rest {
            "bias_atom_e" => return Some("fitting.bias_atom_e".into()),
            "fparam_avg" => return Some("fitting.fparam_avg".into()),
            "fparam_inv_std" => return Some("fitting.fparam_inv_std".into()),
            "aparam_avg" => return Some("fitting.aparam_avg".into()),
            "aparam_inv_std" => return Some("fitting.aparam_inv_std".into()),
            _ => {}
        }
    }
    let parts: Vec<&str> = key.split('.').collect();
    // fitting_net.nets.networks.0.layers.{l}.@variables.{w|b|idt}
    if parts.len() == 8
        && parts[0] == "fitting_net"
        && parts[1] == "nets"
        && parts[2] == "networks"
        && parts[3] == "0"
        && parts[4] == "layers"
        && parts[6] == "@variables"
    {
        let l_idx: usize = parts[5].parse().ok()?;
        // The last "layer" in the deepmd serialization is the final
        // dense; everything before is a hidden layer.
        let var = match parts[7] {
            "matrix" | "w" => "w",
            "bias" | "b" => "b",
            "idt" => "idt",
            _ => return None,
        };
        // Caller can disambiguate hidden vs final by inspecting the
        // total layer count and remapping the highest index to `final`.
        return Some(format!("fitting.hidden.layer{l_idx}.{var}"));
    }
    None
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn ndarray_roundtrip() {
        let json = serde_json::json!({
            "descriptor": {
                "@variables": {
                    "davg": {
                        "@class": "np.ndarray",
                        "dtype": "float32",
                        "value": [[1.0, 2.0], [3.0, 4.0]]
                    }
                }
            }
        });
        let table = parse_model_value(json).unwrap();
        let t = table.get("descriptor.@variables.davg").unwrap();
        assert_eq!(t.shape, vec![2, 2]);
        assert_eq!(t.data, vec![1.0, 2.0, 3.0, 4.0]);
    }

    #[test]
    fn key_translation_se_e2_a() {
        let mut table = WeightTable::default();
        table.insert(
            "descriptor.embeddings.networks.0.layers.1.@variables.matrix",
            WeightTensor { data: vec![0.0; 4], shape: vec![2, 2] },
        );
        table.insert(
            "descriptor.@variables.davg",
            WeightTensor { data: vec![0.0; 2], shape: vec![1, 2] },
        );
        let m = map_to_graph_keys(&table);
        assert_eq!(m.get("descriptor.embedding.0.layer1.w").map(|s| s.as_str()),
            Some("descriptor.embeddings.networks.0.layers.1.@variables.matrix"));
        assert_eq!(m.get("descriptor.davg").map(|s| s.as_str()),
            Some("descriptor.@variables.davg"));
    }
}