use ferrum_types::{FerrumError, Result};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub enum WeightFormat {
Safetensors { dir: PathBuf },
Gguf { path: PathBuf },
}
impl WeightFormat {
pub fn detect(path: &Path) -> Result<Self> {
if path.is_file()
&& path
.extension()
.map(|e| e.eq_ignore_ascii_case("gguf"))
.unwrap_or(false)
{
return Ok(Self::Gguf {
path: path.to_owned(),
});
}
if path.is_dir() && path.join("config.json").is_file() {
return Ok(Self::Safetensors {
dir: path.to_owned(),
});
}
Err(FerrumError::model(format!(
"Unrecognized weight format at {}: expected a `.gguf` file or \
a HuggingFace safetensors directory containing `config.json`.",
path.display()
)))
}
pub fn path(&self) -> &Path {
match self {
Self::Safetensors { dir } => dir,
Self::Gguf { path } => path,
}
}
pub fn label(&self) -> &'static str {
match self {
Self::Safetensors { .. } => "safetensors",
Self::Gguf { .. } => "gguf",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn detect_gguf_by_extension() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("Qwen3-0.6B-Q4_K_M.gguf");
fs::write(&path, b"GGUF\0\0\0\0").unwrap();
let fmt = WeightFormat::detect(&path).unwrap();
assert!(matches!(fmt, WeightFormat::Gguf { .. }));
assert_eq!(fmt.label(), "gguf");
}
#[test]
fn detect_safetensors_dir() {
let dir = tempfile::tempdir().unwrap();
fs::write(dir.path().join("config.json"), b"{}").unwrap();
let fmt = WeightFormat::detect(dir.path()).unwrap();
assert!(matches!(fmt, WeightFormat::Safetensors { .. }));
assert_eq!(fmt.label(), "safetensors");
}
#[test]
fn detect_unknown_returns_error() {
let dir = tempfile::tempdir().unwrap();
let err = WeightFormat::detect(dir.path()).unwrap_err();
let msg = format!("{}", err);
assert!(msg.contains("Unrecognized weight format"), "{msg}");
}
}