use anyhow::{anyhow, bail, Result};
use std::collections::BTreeMap;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct WeightTensor {
pub data: Vec<f32>,
pub shape: Vec<usize>,
}
#[derive(Debug, Default, Clone)]
pub struct WeightTable {
pub tensors: BTreeMap<String, WeightTensor>,
pub config: serde_json::Value,
}
impl WeightTable {
pub fn get(&self, key: &str) -> Option<&WeightTensor> {
self.tensors.get(key)
}
pub fn insert(&mut self, key: impl Into<String>, t: WeightTensor) {
self.tensors.insert(key.into(), t);
}
pub fn keys(&self) -> impl Iterator<Item = &str> + '_ {
self.tensors.keys().map(|s| s.as_str())
}
}
pub fn load_model_file<P: AsRef<Path>>(path: P) -> Result<WeightTable> {
let text = std::fs::read_to_string(path.as_ref())?;
let ext = path
.as_ref()
.extension()
.and_then(|s| s.to_str())
.unwrap_or("");
let value: serde_json::Value = match ext {
"json" => serde_json::from_str(&text)?,
"yaml" | "yml" => yaml_to_json(&text)?,
other => bail!("unsupported model file extension: {other}"),
};
parse_model_value(value)
}
fn yaml_to_json(text: &str) -> Result<serde_json::Value> {
serde_json::from_str(text).map_err(|_| {
anyhow!("YAML decoding not available — convert to .json first or add the serde_yaml feature")
})
}
pub fn parse_model_value(value: serde_json::Value) -> Result<WeightTable> {
let mut table = WeightTable::default();
table.config = value.clone();
walk_collect(&value, String::new(), &mut table)?;
Ok(table)
}
fn walk_collect(
value: &serde_json::Value,
prefix: String,
table: &mut WeightTable,
) -> Result<()> {
use serde_json::Value as V;
match value {
V::Object(map) => {
if let Some(V::String(cls)) = map.get("@class") {
if cls == "np.ndarray" {
let dtype = map
.get("dtype")
.and_then(|v| v.as_str())
.unwrap_or("float32");
let value = map
.get("value")
.ok_or_else(|| anyhow!("np.ndarray without `value` at {prefix}"))?;
let t = decode_ndarray(dtype, value)?;
table.tensors.insert(prefix, t);
return Ok(());
}
}
for (k, v) in map {
let new_prefix = if prefix.is_empty() {
k.clone()
} else {
format!("{prefix}.{k}")
};
walk_collect(v, new_prefix, table)?;
}
}
V::Array(arr) => {
for (i, v) in arr.iter().enumerate() {
let new_prefix = format!("{prefix}.{i}");
walk_collect(v, new_prefix, table)?;
}
}
_ => {}
}
Ok(())
}
fn decode_ndarray(dtype: &str, value: &serde_json::Value) -> Result<WeightTensor> {
let (shape, flat) = flatten_nested(value)?;
let mut data = Vec::with_capacity(flat.len());
let to_f32 = |v: &serde_json::Value| -> Result<f32> {
v.as_f64()
.map(|f| f as f32)
.or_else(|| v.as_i64().map(|i| i as f32))
.ok_or_else(|| anyhow!("ndarray element is not a number: {v}"))
};
for v in &flat {
data.push(to_f32(v)?);
}
let _ = dtype; Ok(WeightTensor { data, shape })
}
fn flatten_nested(value: &serde_json::Value) -> Result<(Vec<usize>, Vec<serde_json::Value>)> {
use serde_json::Value as V;
fn recurse(
v: &V,
shape: &mut Vec<usize>,
depth: usize,
flat: &mut Vec<V>,
) -> Result<()> {
match v {
V::Array(arr) => {
if shape.len() == depth {
shape.push(arr.len());
} else if shape[depth] != arr.len() {
bail!("ragged ndarray at depth {depth}");
}
for el in arr {
recurse(el, shape, depth + 1, flat)?;
}
}
other => flat.push(other.clone()),
}
Ok(())
}
let mut shape = Vec::new();
let mut flat = Vec::new();
recurse(value, &mut shape, 0, &mut flat)?;
Ok((shape, flat))
}
pub fn map_to_graph_keys(table: &WeightTable) -> BTreeMap<String, String> {
let mut out = BTreeMap::new();
for key in table.tensors.keys() {
if let Some(mapped) = translate_se_e2_a(key) {
out.insert(mapped, key.clone());
continue;
}
if let Some(mapped) = translate_ener_fitting(key) {
out.insert(mapped, key.clone());
continue;
}
}
out
}
fn translate_se_e2_a(key: &str) -> Option<String> {
if let Some(rest) = key.strip_prefix("descriptor.@variables.") {
match rest {
"davg" => return Some("descriptor.davg".into()),
"dstd" => return Some("descriptor.dstd".into()),
_ => {}
}
}
let parts: Vec<&str> = key.split('.').collect();
if parts.len() == 8
&& parts[0] == "descriptor"
&& parts[1] == "embeddings"
&& parts[2] == "networks"
&& parts[4] == "layers"
&& parts[6] == "@variables"
{
let t = parts[3];
let l = parts[5];
let var = match parts[7] {
"matrix" | "w" => "w",
"bias" | "b" => "b",
"idt" => "idt",
_ => return None,
};
return Some(format!("descriptor.embedding.{t}.layer{l}.{var}"));
}
None
}
fn translate_ener_fitting(key: &str) -> Option<String> {
if let Some(rest) = key.strip_prefix("fitting_net.@variables.") {
match rest {
"bias_atom_e" => return Some("fitting.bias_atom_e".into()),
"fparam_avg" => return Some("fitting.fparam_avg".into()),
"fparam_inv_std" => return Some("fitting.fparam_inv_std".into()),
"aparam_avg" => return Some("fitting.aparam_avg".into()),
"aparam_inv_std" => return Some("fitting.aparam_inv_std".into()),
_ => {}
}
}
let parts: Vec<&str> = key.split('.').collect();
if parts.len() == 8
&& parts[0] == "fitting_net"
&& parts[1] == "nets"
&& parts[2] == "networks"
&& parts[3] == "0"
&& parts[4] == "layers"
&& parts[6] == "@variables"
{
let l_idx: usize = parts[5].parse().ok()?;
let var = match parts[7] {
"matrix" | "w" => "w",
"bias" | "b" => "b",
"idt" => "idt",
_ => return None,
};
return Some(format!("fitting.hidden.layer{l_idx}.{var}"));
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ndarray_roundtrip() {
let json = serde_json::json!({
"descriptor": {
"@variables": {
"davg": {
"@class": "np.ndarray",
"dtype": "float32",
"value": [[1.0, 2.0], [3.0, 4.0]]
}
}
}
});
let table = parse_model_value(json).unwrap();
let t = table.get("descriptor.@variables.davg").unwrap();
assert_eq!(t.shape, vec![2, 2]);
assert_eq!(t.data, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn key_translation_se_e2_a() {
let mut table = WeightTable::default();
table.insert(
"descriptor.embeddings.networks.0.layers.1.@variables.matrix",
WeightTensor { data: vec![0.0; 4], shape: vec![2, 2] },
);
table.insert(
"descriptor.@variables.davg",
WeightTensor { data: vec![0.0; 2], shape: vec![1, 2] },
);
let m = map_to_graph_keys(&table);
assert_eq!(m.get("descriptor.embedding.0.layer1.w").map(|s| s.as_str()),
Some("descriptor.embeddings.networks.0.layers.1.@variables.matrix"));
assert_eq!(m.get("descriptor.davg").map(|s| s.as_str()),
Some("descriptor.@variables.davg"));
}
}