llama-rs 0.16.0

A high-performance Rust implementation of llama.cpp - LLM inference engine with full GGUF support
Documentation
//! `ModelSource` trait and concrete impls for GGUF and SafeTensors.

use std::fs;
use std::path::{Path, PathBuf};

use crate::diagnostics::types::{
    MetadataBundle, MetadataValue, TensorDtype, TensorInventory, TensorRecord,
};
use crate::diagnostics::DiagnosticError;

pub trait ModelSource {
    /// Architecture identifier declared by the model. For GGUF this is
    /// `general.architecture`; for SafeTensors it's `model_type` from
    /// `config.json`. `None` means the file didn't declare one.
    fn declared_architecture(&self) -> Option<String>;

    /// Metadata as a format-agnostic flat map.
    fn metadata(&self) -> MetadataBundle;

    /// Every tensor declared by the file, in declaration order.
    fn tensors(&self) -> TensorInventory;

    /// Short label used in reports (e.g. `"GGUF v3"` or
    /// `"SafeTensors (4 shards)"`).
    fn format_label(&self) -> String;
}

// =============================================================================
// GGUF
// =============================================================================

pub struct GgufSource {
    file: crate::gguf::GgufFile,
}

impl GgufSource {
    pub fn open(path: impl AsRef<Path>) -> Result<Self, DiagnosticError> {
        let path_ref = path.as_ref();
        let file = crate::gguf::GgufFile::open(path_ref).map_err(|e| {
            DiagnosticError::ReadModel(format!("open {}: {e}", path_ref.display()))
        })?;
        Ok(Self { file })
    }
}

impl ModelSource for GgufSource {
    fn declared_architecture(&self) -> Option<String> {
        self.file
            .data
            .get_string("general.architecture")
            .map(|s| s.to_string())
    }

    fn metadata(&self) -> MetadataBundle {
        let mut out = MetadataBundle::new();
        for (k, v) in &self.file.data.metadata {
            out.insert(k.clone(), from_gguf_value(v));
        }
        out
    }

    fn tensors(&self) -> TensorInventory {
        self.file
            .data
            .tensors
            .iter()
            .map(|t| TensorRecord {
                name: t.name.clone(),
                shape: t.dims.iter().map(|d| *d as usize).collect(),
                dtype: from_ggml_type(t.dtype),
            })
            .collect()
    }

    fn format_label(&self) -> String {
        format!("GGUF v{}", self.file.data.header.version)
    }
}

fn from_gguf_value(v: &crate::gguf::MetadataValue) -> MetadataValue {
    use crate::gguf::MetadataValue as G;
    match v {
        G::Uint8(x) => MetadataValue::UInt(*x as u64),
        G::Int8(x) => MetadataValue::Int(*x as i64),
        G::Uint16(x) => MetadataValue::UInt(*x as u64),
        G::Int16(x) => MetadataValue::Int(*x as i64),
        G::Uint32(x) => MetadataValue::UInt(*x as u64),
        G::Int32(x) => MetadataValue::Int(*x as i64),
        G::Uint64(x) => MetadataValue::UInt(*x),
        G::Int64(x) => MetadataValue::Int(*x),
        G::Float32(x) => MetadataValue::Float(*x as f64),
        G::Float64(x) => MetadataValue::Float(*x),
        G::Bool(x) => MetadataValue::Bool(*x),
        G::String(s) => MetadataValue::String(s.clone()),
        G::Array(a) => MetadataValue::Array(
            a.values.iter().map(from_gguf_value).collect(),
        ),
    }
}

fn from_ggml_type(t: crate::gguf::GgmlType) -> TensorDtype {
    use crate::gguf::GgmlType as G;
    match t {
        G::F32 => TensorDtype::F32,
        G::F16 => TensorDtype::F16,
        other => TensorDtype::Quantized(format!("{other:?}").to_lowercase()),
    }
}

// =============================================================================
// SafeTensors
// =============================================================================

/// Source for a SafeTensors model. Handles both single-file (`.safetensors`)
/// and sharded directory layouts.
pub struct SafetensorsSource {
    /// The root path (file or directory).
    root: PathBuf,

    /// Declared architecture (from `config.json`'s `model_type`).
    declared_arch: Option<String>,

    /// Flattened metadata from `config.json`.
    metadata: MetadataBundle,

    /// Enumerated tensors.
    tensors: TensorInventory,

    /// Short format label.
    label: String,
}

impl SafetensorsSource {
    /// Open a single `.safetensors` file or a directory containing
    /// shards + `config.json`.
    pub fn open(path: impl AsRef<Path>) -> Result<Self, DiagnosticError> {
        let path = path.as_ref();

        if path.is_dir() {
            let sharded = crate::safetensors::ShardedSafeTensors::open(path)
                .map_err(|e| {
                    DiagnosticError::ReadModel(format!("open dir {}: {e}", path.display()))
                })?;
            let shard_count = count_shards(path);
            let (declared_arch, metadata) = load_config_json(path);
            let tensors = enumerate_sharded(&sharded);
            Ok(Self {
                root: path.to_path_buf(),
                declared_arch,
                metadata,
                tensors,
                label: format!("SafeTensors ({} shards)", shard_count),
            })
        } else {
            let single = crate::safetensors::SafeTensorsFile::open(path).map_err(|e| {
                DiagnosticError::ReadModel(format!("open {}: {e}", path.display()))
            })?;
            let (declared_arch, metadata) = path
                .parent()
                .map(load_config_json)
                .unwrap_or_default();
            let tensors = enumerate_single(&single);
            Ok(Self {
                root: path.to_path_buf(),
                declared_arch,
                metadata,
                tensors,
                label: "SafeTensors".into(),
            })
        }
    }

    /// The path given at construction time.
    pub fn root(&self) -> &Path {
        &self.root
    }
}

impl ModelSource for SafetensorsSource {
    fn declared_architecture(&self) -> Option<String> {
        self.declared_arch.clone()
    }

    fn metadata(&self) -> MetadataBundle {
        self.metadata.clone()
    }

    fn tensors(&self) -> TensorInventory {
        self.tensors.clone()
    }

    fn format_label(&self) -> String {
        self.label.clone()
    }
}

fn count_shards(dir: &Path) -> usize {
    fs::read_dir(dir)
        .map(|entries| {
            entries
                .filter_map(|e| e.ok())
                .filter(|e| {
                    e.path()
                        .extension()
                        .map(|ext| ext == "safetensors")
                        .unwrap_or(false)
                })
                .count()
        })
        .unwrap_or(0)
}

fn enumerate_single(file: &crate::safetensors::SafeTensorsFile) -> TensorInventory {
    file.tensor_names()
        .map(|n| n.to_string())
        .collect::<Vec<_>>()
        .into_iter()
        .filter_map(|name| {
            file.tensor_info(&name).map(|info| TensorRecord {
                name: name.clone(),
                shape: info.shape.clone(),
                dtype: from_st_dtype(info.dtype),
            })
        })
        .collect()
}

fn enumerate_sharded(sharded: &crate::safetensors::ShardedSafeTensors) -> TensorInventory {
    let mut names = sharded.tensor_names();
    names.sort();
    names
        .into_iter()
        .filter_map(|name| {
            sharded.tensor_info(&name).map(|info| TensorRecord {
                name: name.clone(),
                shape: info.shape.clone(),
                dtype: from_st_dtype(info.dtype),
            })
        })
        .collect()
}

fn from_st_dtype(d: crate::safetensors::SafeTensorsDtype) -> TensorDtype {
    use crate::safetensors::SafeTensorsDtype as S;
    match d {
        S::F64 => TensorDtype::F64,
        S::F32 => TensorDtype::F32,
        S::F16 => TensorDtype::F16,
        S::BF16 => TensorDtype::BF16,
        S::I64 => TensorDtype::I64,
        S::I32 => TensorDtype::I32,
        S::I16 => TensorDtype::I16,
        S::I8 => TensorDtype::I8,
        S::U8 => TensorDtype::U8,
        S::Bool => TensorDtype::Bool,
    }
}

/// Locate `config.json` next to the SafeTensors file (or within the
/// sharded directory) and flatten it into a `MetadataBundle`. Returns
/// `(declared_arch, metadata)`; both are `None`/empty on failure —
/// failures are soft so a malformed config doesn't block diagnostics.
fn load_config_json(dir: &Path) -> (Option<String>, MetadataBundle) {
    let path = dir.join("config.json");
    let Ok(raw) = fs::read_to_string(&path) else {
        return (None, MetadataBundle::new());
    };
    let Ok(value) = serde_json::from_str::<serde_json::Value>(&raw) else {
        return (None, MetadataBundle::new());
    };

    let declared_arch = value
        .get("model_type")
        .and_then(|v| v.as_str())
        .map(|s| s.to_string());

    let mut out = MetadataBundle::new();
    flatten_json(&value, String::new(), &mut out);
    (declared_arch, out)
}

fn flatten_json(value: &serde_json::Value, prefix: String, out: &mut MetadataBundle) {
    match value {
        serde_json::Value::Object(map) => {
            for (k, v) in map {
                let key = if prefix.is_empty() {
                    k.clone()
                } else {
                    format!("{prefix}.{k}")
                };
                flatten_json(v, key, out);
            }
        }
        serde_json::Value::Array(arr) => {
            let values = arr.iter().filter_map(convert_scalar_json).collect::<Vec<_>>();
            if !values.is_empty() {
                out.insert(prefix, MetadataValue::Array(values));
            }
        }
        other => {
            if let Some(scalar) = convert_scalar_json(other) {
                out.insert(prefix, scalar);
            }
        }
    }
}

fn convert_scalar_json(value: &serde_json::Value) -> Option<MetadataValue> {
    match value {
        serde_json::Value::Null => None,
        serde_json::Value::Bool(b) => Some(MetadataValue::Bool(*b)),
        serde_json::Value::Number(n) => {
            if let Some(u) = n.as_u64() {
                Some(MetadataValue::UInt(u))
            } else if let Some(i) = n.as_i64() {
                Some(MetadataValue::Int(i))
            } else {
                n.as_f64().map(MetadataValue::Float)
            }
        }
        serde_json::Value::String(s) => Some(MetadataValue::String(s.clone())),
        serde_json::Value::Array(_) | serde_json::Value::Object(_) => None,
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn flattens_nested_json_into_dotted_keys() {
        let value: serde_json::Value = serde_json::from_str(
            r#"{
                "model_type": "qwen35moe",
                "qwen35moe": {
                    "embedding_length": 4096,
                    "attention": { "head_count": 32 }
                },
                "tokenizer": { "tokens": ["a", "b"] }
            }"#,
        )
        .unwrap();

        let mut bundle = MetadataBundle::new();
        flatten_json(&value, String::new(), &mut bundle);
        assert_eq!(
            bundle.get("model_type"),
            Some(&MetadataValue::String("qwen35moe".into()))
        );
        assert_eq!(
            bundle.get("qwen35moe.embedding_length"),
            Some(&MetadataValue::UInt(4096))
        );
        assert_eq!(
            bundle.get("qwen35moe.attention.head_count"),
            Some(&MetadataValue::UInt(32))
        );
        match bundle.get("tokenizer.tokens").unwrap() {
            MetadataValue::Array(v) => {
                assert_eq!(v.len(), 2);
            }
            other => panic!("expected array, got {other:?}"),
        }
    }

    #[test]
    fn missing_config_json_yields_empty_bundle() {
        let tmp = tempfile::tempdir().unwrap();
        let (arch, meta) = load_config_json(tmp.path());
        assert!(arch.is_none());
        assert!(meta.is_empty());
    }

    #[test]
    fn malformed_config_json_yields_empty_bundle() {
        let tmp = tempfile::tempdir().unwrap();
        fs::write(tmp.path().join("config.json"), "not json").unwrap();
        let (arch, meta) = load_config_json(tmp.path());
        assert!(arch.is_none());
        assert!(meta.is_empty());
    }

    #[test]
    fn config_json_extracts_model_type() {
        let tmp = tempfile::tempdir().unwrap();
        fs::write(
            tmp.path().join("config.json"),
            r#"{"model_type":"llama","hidden_size":4096}"#,
        )
        .unwrap();
        let (arch, meta) = load_config_json(tmp.path());
        assert_eq!(arch.as_deref(), Some("llama"));
        assert_eq!(meta.get("hidden_size"), Some(&MetadataValue::UInt(4096)));
    }
}