Skip to main content

ferrum_models/
weight_format.rs

1//! Dim 3 polymorphism point — weight-format detection for the executor
2//! factory.
3//!
4//! Sibling of `source::ModelFormat`. The difference:
5//!
6//! - [`source::ModelFormat`](crate::source::ModelFormat) is a "what kind
7//!   of files did we just download" hint used for cache classification
8//!   and progress bars. It carries no path.
9//! - [`WeightFormat`] is a "loader recipe" — it carries the resolved
10//!   path AND tells the executor factory which `WeightLoader<B>` to
11//!   instantiate. New formats (AWQ, EXL2, HQQ, ...) plug in by adding a
12//!   variant + a `WeightLoader<B>` impl in `ferrum-quantization`, with
13//!   no special-casing in `LlmExecutorFactory`.
14//!
15//! Replaces the `is_gguf_path` short-circuit in
16//! `ferrum-engine::registry::CandleExecutorFactory` with a real
17//! polymorphism point matching the 5-dim architecture (see
18//! `docs/architecture-refactor-status.md`).
19
20use ferrum_types::{FerrumError, Result};
21use std::path::{Path, PathBuf};
22
23/// Resolved weight format + path. Produced by [`WeightFormat::detect`]
24/// from a user-supplied path (HF cache snapshot, local dir, or a
25/// `.gguf` file).
26#[derive(Debug, Clone)]
27pub enum WeightFormat {
28    /// HuggingFace safetensors directory: `config.json` + one or more
29    /// `.safetensors` shards. The on-disk weights may be plain FP16/BF16
30    /// **or** GPTQ-Int4 (`<name>.qweight` tensors); this is decided
31    /// per-tensor by `NativeSafetensorsLoader::load_linear`.
32    Safetensors { dir: PathBuf },
33
34    /// GGUF single-file format (Llama-family / Qwen3-MoE quantized).
35    /// Loaded by `ferrum_quantization::gguf::GgufLoader`.
36    Gguf { path: PathBuf },
37    // Future: Awq { dir }, Exl2 { dir }, Hqq { dir } …
38}
39
40impl WeightFormat {
41    /// Detect the weight format from a user-supplied path.
42    ///
43    /// - If `path` is a file ending in `.gguf` (case-insensitive)
44    ///   → [`WeightFormat::Gguf`].
45    /// - If `path` is a directory containing `config.json`
46    ///   → [`WeightFormat::Safetensors`].
47    /// - Anything else returns a model error.
48    pub fn detect(path: &Path) -> Result<Self> {
49        if path.is_file()
50            && path
51                .extension()
52                .map(|e| e.eq_ignore_ascii_case("gguf"))
53                .unwrap_or(false)
54        {
55            return Ok(Self::Gguf {
56                path: path.to_owned(),
57            });
58        }
59        if path.is_dir() && path.join("config.json").is_file() {
60            return Ok(Self::Safetensors {
61                dir: path.to_owned(),
62            });
63        }
64        Err(FerrumError::model(format!(
65            "Unrecognized weight format at {}: expected a `.gguf` file or \
66             a HuggingFace safetensors directory containing `config.json`.",
67            path.display()
68        )))
69    }
70
71    /// The on-disk path this format resolved to.
72    pub fn path(&self) -> &Path {
73        match self {
74            Self::Safetensors { dir } => dir,
75            Self::Gguf { path } => path,
76        }
77    }
78
79    /// Short label for logs / telemetry.
80    pub fn label(&self) -> &'static str {
81        match self {
82            Self::Safetensors { .. } => "safetensors",
83            Self::Gguf { .. } => "gguf",
84        }
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use super::*;
91    use std::fs;
92
93    #[test]
94    fn detect_gguf_by_extension() {
95        let dir = tempfile::tempdir().unwrap();
96        let path = dir.path().join("Qwen3-0.6B-Q4_K_M.gguf");
97        fs::write(&path, b"GGUF\0\0\0\0").unwrap();
98        let fmt = WeightFormat::detect(&path).unwrap();
99        assert!(matches!(fmt, WeightFormat::Gguf { .. }));
100        assert_eq!(fmt.label(), "gguf");
101    }
102
103    #[test]
104    fn detect_safetensors_dir() {
105        let dir = tempfile::tempdir().unwrap();
106        fs::write(dir.path().join("config.json"), b"{}").unwrap();
107        let fmt = WeightFormat::detect(dir.path()).unwrap();
108        assert!(matches!(fmt, WeightFormat::Safetensors { .. }));
109        assert_eq!(fmt.label(), "safetensors");
110    }
111
112    #[test]
113    fn detect_unknown_returns_error() {
114        let dir = tempfile::tempdir().unwrap();
115        // Empty dir, no config.json, not a .gguf file.
116        let err = WeightFormat::detect(dir.path()).unwrap_err();
117        let msg = format!("{}", err);
118        assert!(msg.contains("Unrecognized weight format"), "{msg}");
119    }
120}