use std::fs;
use std::path::{Path, PathBuf};
use crate::diagnostics::types::{
MetadataBundle, MetadataValue, TensorDtype, TensorInventory, TensorRecord,
};
use crate::diagnostics::DiagnosticError;
pub trait ModelSource {
fn declared_architecture(&self) -> Option<String>;
fn metadata(&self) -> MetadataBundle;
fn tensors(&self) -> TensorInventory;
fn format_label(&self) -> String;
}
pub struct GgufSource {
file: crate::gguf::GgufFile,
}
impl GgufSource {
pub fn open(path: impl AsRef<Path>) -> Result<Self, DiagnosticError> {
let path_ref = path.as_ref();
let file = crate::gguf::GgufFile::open(path_ref).map_err(|e| {
DiagnosticError::ReadModel(format!("open {}: {e}", path_ref.display()))
})?;
Ok(Self { file })
}
}
impl ModelSource for GgufSource {
fn declared_architecture(&self) -> Option<String> {
self.file
.data
.get_string("general.architecture")
.map(|s| s.to_string())
}
fn metadata(&self) -> MetadataBundle {
let mut out = MetadataBundle::new();
for (k, v) in &self.file.data.metadata {
out.insert(k.clone(), from_gguf_value(v));
}
out
}
fn tensors(&self) -> TensorInventory {
self.file
.data
.tensors
.iter()
.map(|t| TensorRecord {
name: t.name.clone(),
shape: t.dims.iter().map(|d| *d as usize).collect(),
dtype: from_ggml_type(t.dtype),
})
.collect()
}
fn format_label(&self) -> String {
format!("GGUF v{}", self.file.data.header.version)
}
}
fn from_gguf_value(v: &crate::gguf::MetadataValue) -> MetadataValue {
use crate::gguf::MetadataValue as G;
match v {
G::Uint8(x) => MetadataValue::UInt(*x as u64),
G::Int8(x) => MetadataValue::Int(*x as i64),
G::Uint16(x) => MetadataValue::UInt(*x as u64),
G::Int16(x) => MetadataValue::Int(*x as i64),
G::Uint32(x) => MetadataValue::UInt(*x as u64),
G::Int32(x) => MetadataValue::Int(*x as i64),
G::Uint64(x) => MetadataValue::UInt(*x),
G::Int64(x) => MetadataValue::Int(*x),
G::Float32(x) => MetadataValue::Float(*x as f64),
G::Float64(x) => MetadataValue::Float(*x),
G::Bool(x) => MetadataValue::Bool(*x),
G::String(s) => MetadataValue::String(s.clone()),
G::Array(a) => MetadataValue::Array(
a.values.iter().map(from_gguf_value).collect(),
),
}
}
fn from_ggml_type(t: crate::gguf::GgmlType) -> TensorDtype {
use crate::gguf::GgmlType as G;
match t {
G::F32 => TensorDtype::F32,
G::F16 => TensorDtype::F16,
other => TensorDtype::Quantized(format!("{other:?}").to_lowercase()),
}
}
pub struct SafetensorsSource {
root: PathBuf,
declared_arch: Option<String>,
metadata: MetadataBundle,
tensors: TensorInventory,
label: String,
}
impl SafetensorsSource {
pub fn open(path: impl AsRef<Path>) -> Result<Self, DiagnosticError> {
let path = path.as_ref();
if path.is_dir() {
let sharded = crate::safetensors::ShardedSafeTensors::open(path)
.map_err(|e| {
DiagnosticError::ReadModel(format!("open dir {}: {e}", path.display()))
})?;
let shard_count = count_shards(path);
let (declared_arch, metadata) = load_config_json(path);
let tensors = enumerate_sharded(&sharded);
Ok(Self {
root: path.to_path_buf(),
declared_arch,
metadata,
tensors,
label: format!("SafeTensors ({} shards)", shard_count),
})
} else {
let single = crate::safetensors::SafeTensorsFile::open(path).map_err(|e| {
DiagnosticError::ReadModel(format!("open {}: {e}", path.display()))
})?;
let (declared_arch, metadata) = path
.parent()
.map(load_config_json)
.unwrap_or_default();
let tensors = enumerate_single(&single);
Ok(Self {
root: path.to_path_buf(),
declared_arch,
metadata,
tensors,
label: "SafeTensors".into(),
})
}
}
pub fn root(&self) -> &Path {
&self.root
}
}
impl ModelSource for SafetensorsSource {
fn declared_architecture(&self) -> Option<String> {
self.declared_arch.clone()
}
fn metadata(&self) -> MetadataBundle {
self.metadata.clone()
}
fn tensors(&self) -> TensorInventory {
self.tensors.clone()
}
fn format_label(&self) -> String {
self.label.clone()
}
}
fn count_shards(dir: &Path) -> usize {
fs::read_dir(dir)
.map(|entries| {
entries
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.map(|ext| ext == "safetensors")
.unwrap_or(false)
})
.count()
})
.unwrap_or(0)
}
fn enumerate_single(file: &crate::safetensors::SafeTensorsFile) -> TensorInventory {
file.tensor_names()
.map(|n| n.to_string())
.collect::<Vec<_>>()
.into_iter()
.filter_map(|name| {
file.tensor_info(&name).map(|info| TensorRecord {
name: name.clone(),
shape: info.shape.clone(),
dtype: from_st_dtype(info.dtype),
})
})
.collect()
}
fn enumerate_sharded(sharded: &crate::safetensors::ShardedSafeTensors) -> TensorInventory {
let mut names = sharded.tensor_names();
names.sort();
names
.into_iter()
.filter_map(|name| {
sharded.tensor_info(&name).map(|info| TensorRecord {
name: name.clone(),
shape: info.shape.clone(),
dtype: from_st_dtype(info.dtype),
})
})
.collect()
}
fn from_st_dtype(d: crate::safetensors::SafeTensorsDtype) -> TensorDtype {
use crate::safetensors::SafeTensorsDtype as S;
match d {
S::F64 => TensorDtype::F64,
S::F32 => TensorDtype::F32,
S::F16 => TensorDtype::F16,
S::BF16 => TensorDtype::BF16,
S::I64 => TensorDtype::I64,
S::I32 => TensorDtype::I32,
S::I16 => TensorDtype::I16,
S::I8 => TensorDtype::I8,
S::U8 => TensorDtype::U8,
S::Bool => TensorDtype::Bool,
}
}
fn load_config_json(dir: &Path) -> (Option<String>, MetadataBundle) {
let path = dir.join("config.json");
let Ok(raw) = fs::read_to_string(&path) else {
return (None, MetadataBundle::new());
};
let Ok(value) = serde_json::from_str::<serde_json::Value>(&raw) else {
return (None, MetadataBundle::new());
};
let declared_arch = value
.get("model_type")
.and_then(|v| v.as_str())
.map(|s| s.to_string());
let mut out = MetadataBundle::new();
flatten_json(&value, String::new(), &mut out);
(declared_arch, out)
}
fn flatten_json(value: &serde_json::Value, prefix: String, out: &mut MetadataBundle) {
match value {
serde_json::Value::Object(map) => {
for (k, v) in map {
let key = if prefix.is_empty() {
k.clone()
} else {
format!("{prefix}.{k}")
};
flatten_json(v, key, out);
}
}
serde_json::Value::Array(arr) => {
let values = arr.iter().filter_map(convert_scalar_json).collect::<Vec<_>>();
if !values.is_empty() {
out.insert(prefix, MetadataValue::Array(values));
}
}
other => {
if let Some(scalar) = convert_scalar_json(other) {
out.insert(prefix, scalar);
}
}
}
}
fn convert_scalar_json(value: &serde_json::Value) -> Option<MetadataValue> {
match value {
serde_json::Value::Null => None,
serde_json::Value::Bool(b) => Some(MetadataValue::Bool(*b)),
serde_json::Value::Number(n) => {
if let Some(u) = n.as_u64() {
Some(MetadataValue::UInt(u))
} else if let Some(i) = n.as_i64() {
Some(MetadataValue::Int(i))
} else {
n.as_f64().map(MetadataValue::Float)
}
}
serde_json::Value::String(s) => Some(MetadataValue::String(s.clone())),
serde_json::Value::Array(_) | serde_json::Value::Object(_) => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn flattens_nested_json_into_dotted_keys() {
let value: serde_json::Value = serde_json::from_str(
r#"{
"model_type": "qwen35moe",
"qwen35moe": {
"embedding_length": 4096,
"attention": { "head_count": 32 }
},
"tokenizer": { "tokens": ["a", "b"] }
}"#,
)
.unwrap();
let mut bundle = MetadataBundle::new();
flatten_json(&value, String::new(), &mut bundle);
assert_eq!(
bundle.get("model_type"),
Some(&MetadataValue::String("qwen35moe".into()))
);
assert_eq!(
bundle.get("qwen35moe.embedding_length"),
Some(&MetadataValue::UInt(4096))
);
assert_eq!(
bundle.get("qwen35moe.attention.head_count"),
Some(&MetadataValue::UInt(32))
);
match bundle.get("tokenizer.tokens").unwrap() {
MetadataValue::Array(v) => {
assert_eq!(v.len(), 2);
}
other => panic!("expected array, got {other:?}"),
}
}
#[test]
fn missing_config_json_yields_empty_bundle() {
let tmp = tempfile::tempdir().unwrap();
let (arch, meta) = load_config_json(tmp.path());
assert!(arch.is_none());
assert!(meta.is_empty());
}
#[test]
fn malformed_config_json_yields_empty_bundle() {
let tmp = tempfile::tempdir().unwrap();
fs::write(tmp.path().join("config.json"), "not json").unwrap();
let (arch, meta) = load_config_json(tmp.path());
assert!(arch.is_none());
assert!(meta.is_empty());
}
#[test]
fn config_json_extracts_model_type() {
let tmp = tempfile::tempdir().unwrap();
fs::write(
tmp.path().join("config.json"),
r#"{"model_type":"llama","hidden_size":4096}"#,
)
.unwrap();
let (arch, meta) = load_config_json(tmp.path());
assert_eq!(arch.as_deref(), Some("llama"));
assert_eq!(meta.get("hidden_size"), Some(&MetadataValue::UInt(4096)));
}
}