use std::collections::BTreeMap;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
pub enum MetadataValue {
Bool(bool),
Int(i64),
UInt(u64),
Float(f64),
String(String),
Array(Vec<MetadataValue>),
}
impl MetadataValue {
pub fn as_symbol_value(&self) -> Option<u64> {
match self {
MetadataValue::UInt(v) => Some(*v),
MetadataValue::Int(v) if *v >= 0 => Some(*v as u64),
MetadataValue::Float(f) if f.is_finite() && *f >= 0.0 && f.fract() == 0.0 => {
Some(*f as u64)
}
_ => None,
}
}
pub fn display_short(&self) -> String {
match self {
MetadataValue::Bool(b) => b.to_string(),
MetadataValue::Int(i) => i.to_string(),
MetadataValue::UInt(u) => u.to_string(),
MetadataValue::Float(f) => format!("{f}"),
MetadataValue::String(s) => {
if s.len() > 80 {
format!("\"{}…\" ({} chars)", &s[..77], s.len())
} else {
format!("\"{s}\"")
}
}
MetadataValue::Array(v) => format!("[{} elements]", v.len()),
}
}
}
pub type MetadataBundle = BTreeMap<String, MetadataValue>;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
pub enum TensorDtype {
F64,
F32,
F16,
BF16,
I64,
I32,
I16,
I8,
U8,
Bool,
Quantized(String),
Other(String),
}
impl TensorDtype {
pub fn label(&self) -> String {
match self {
TensorDtype::F64 => "F64".into(),
TensorDtype::F32 => "F32".into(),
TensorDtype::F16 => "F16".into(),
TensorDtype::BF16 => "BF16".into(),
TensorDtype::I64 => "I64".into(),
TensorDtype::I32 => "I32".into(),
TensorDtype::I16 => "I16".into(),
TensorDtype::I8 => "I8".into(),
TensorDtype::U8 => "U8".into(),
TensorDtype::Bool => "Bool".into(),
TensorDtype::Quantized(q) => q.to_uppercase(),
TensorDtype::Other(s) => s.clone(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct TensorRecord {
pub name: String,
pub shape: Vec<usize>,
pub dtype: TensorDtype,
}
pub type TensorInventory = Vec<TensorRecord>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn metadata_value_as_symbol_value() {
assert_eq!(MetadataValue::UInt(4096).as_symbol_value(), Some(4096));
assert_eq!(MetadataValue::Int(32).as_symbol_value(), Some(32));
assert_eq!(MetadataValue::Int(-1).as_symbol_value(), None);
assert_eq!(MetadataValue::Float(8.0).as_symbol_value(), Some(8));
assert_eq!(MetadataValue::Float(8.5).as_symbol_value(), None);
assert_eq!(MetadataValue::String("x".into()).as_symbol_value(), None);
assert_eq!(MetadataValue::Bool(true).as_symbol_value(), None);
}
#[test]
fn metadata_value_display_short() {
assert_eq!(MetadataValue::UInt(42).display_short(), "42");
assert_eq!(MetadataValue::String("hi".into()).display_short(), "\"hi\"");
let long: String = "x".repeat(120);
let rendered = MetadataValue::String(long.clone()).display_short();
assert!(rendered.ends_with("(120 chars)"), "got: {rendered}");
}
#[test]
fn tensor_dtype_label_covers_known_formats() {
assert_eq!(TensorDtype::F16.label(), "F16");
assert_eq!(TensorDtype::Quantized("q4_k".into()).label(), "Q4_K");
assert_eq!(TensorDtype::Other("custom".into()).label(), "custom");
}
#[test]
fn metadata_value_round_trips_through_json() {
let cases = [
MetadataValue::Bool(true),
MetadataValue::Int(-7),
MetadataValue::UInt(42),
MetadataValue::Float(3.14),
MetadataValue::String("hello".into()),
MetadataValue::Array(vec![
MetadataValue::UInt(1),
MetadataValue::UInt(2),
]),
];
for mv in cases {
let json = serde_json::to_string(&mv).unwrap();
let back: MetadataValue = serde_json::from_str(&json).unwrap();
assert_eq!(mv, back);
}
}
#[test]
fn tensor_record_round_trips_through_json() {
let rec = TensorRecord {
name: "blk.0.attn_q.weight".into(),
shape: vec![4096, 4096],
dtype: TensorDtype::F16,
};
let json = serde_json::to_string(&rec).unwrap();
let back: TensorRecord = serde_json::from_str(&json).unwrap();
assert_eq!(rec, back);
}
}