Skip to main content

offline_intelligence/model_runtime/
format_detector.rs

1//! Format Detector
2//!
3//! Automatically detects model format from file extension
4
5use super::runtime_trait::ModelFormat;
6use std::path::Path;
7use tracing::info;
8
9pub struct FormatDetector;
10
11impl FormatDetector {
12    /// Detect model format from file extension
13    pub fn detect_from_path(path: &Path) -> Option<ModelFormat> {
14        let extension = path.extension()?.to_str()?.to_lowercase();
15        
16        let format = if ModelFormat::GGUF.extensions().contains(&extension.as_str()) {
17            Some(ModelFormat::GGUF)
18        } else if ModelFormat::GGML.extensions().contains(&extension.as_str()) {
19            // Need to disambiguate .bin files (could be GGML or other)
20            if extension == "ggml" {
21                Some(ModelFormat::GGML)
22            } else if extension == "bin" {
23                // Check filename for hints
24                if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
25                    if filename.contains("ggml") {
26                        Some(ModelFormat::GGML)
27                    } else {
28                        None // Ambiguous .bin file
29                    }
30                } else {
31                    None
32                }
33            } else {
34                None
35            }
36        } else if ModelFormat::ONNX.extensions().contains(&extension.as_str()) {
37            Some(ModelFormat::ONNX)
38        } else if ModelFormat::TensorRT.extensions().contains(&extension.as_str()) {
39            Some(ModelFormat::TensorRT)
40        } else if ModelFormat::Safetensors.extensions().contains(&extension.as_str()) {
41            Some(ModelFormat::Safetensors)
42        } else if ModelFormat::CoreML.extensions().contains(&extension.as_str()) {
43            Some(ModelFormat::CoreML)
44        } else {
45            None
46        };
47
48        if let Some(fmt) = format {
49            info!("Detected model format: {} for file: {}", fmt.name(), path.display());
50        }
51
52        format
53    }
54
55    /// List all supported extensions
56    pub fn supported_extensions() -> Vec<String> {
57        let mut exts = Vec::new();
58        for format in &[
59            ModelFormat::GGUF,
60            ModelFormat::GGML,
61            ModelFormat::ONNX,
62            ModelFormat::TensorRT,
63            ModelFormat::Safetensors,
64            ModelFormat::CoreML,
65        ] {
66            for ext in format.extensions() {
67                exts.push(ext.to_string());
68            }
69        }
70        exts
71    }
72}
73
74#[cfg(test)]
75mod tests {
76    use super::*;
77    use std::path::PathBuf;
78
79    #[test]
80    fn test_gguf_detection() {
81        let path = PathBuf::from("model.gguf");
82        assert_eq!(FormatDetector::detect_from_path(&path), Some(ModelFormat::GGUF));
83    }
84
85    #[test]
86    fn test_onnx_detection() {
87        let path = PathBuf::from("model.onnx");
88        assert_eq!(FormatDetector::detect_from_path(&path), Some(ModelFormat::ONNX));
89    }
90
91    #[test]
92    fn test_tensorrt_detection() {
93        let path = PathBuf::from("model.trt");
94        assert_eq!(FormatDetector::detect_from_path(&path), Some(ModelFormat::TensorRT));
95    }
96
97    #[test]
98    fn test_safetensors_detection() {
99        let path = PathBuf::from("model.safetensors");
100        assert_eq!(FormatDetector::detect_from_path(&path), Some(ModelFormat::Safetensors));
101    }
102}