offline_intelligence/model_runtime/
format_detector.rs1use super::runtime_trait::ModelFormat;
6use std::path::Path;
7use tracing::info;
8
9pub struct FormatDetector;
10
11impl FormatDetector {
12 pub fn detect_from_path(path: &Path) -> Option<ModelFormat> {
14 let extension = path.extension()?.to_str()?.to_lowercase();
15
16 let format = if ModelFormat::GGUF.extensions().contains(&extension.as_str()) {
17 Some(ModelFormat::GGUF)
18 } else if ModelFormat::GGML.extensions().contains(&extension.as_str()) {
19 if extension == "ggml" {
21 Some(ModelFormat::GGML)
22 } else if extension == "bin" {
23 if let Some(filename) = path.file_name().and_then(|n| n.to_str()) {
25 if filename.contains("ggml") {
26 Some(ModelFormat::GGML)
27 } else {
28 None }
30 } else {
31 None
32 }
33 } else {
34 None
35 }
36 } else if ModelFormat::ONNX.extensions().contains(&extension.as_str()) {
37 Some(ModelFormat::ONNX)
38 } else if ModelFormat::TensorRT.extensions().contains(&extension.as_str()) {
39 Some(ModelFormat::TensorRT)
40 } else if ModelFormat::Safetensors.extensions().contains(&extension.as_str()) {
41 Some(ModelFormat::Safetensors)
42 } else if ModelFormat::CoreML.extensions().contains(&extension.as_str()) {
43 Some(ModelFormat::CoreML)
44 } else {
45 None
46 };
47
48 if let Some(fmt) = format {
49 info!("Detected model format: {} for file: {}", fmt.name(), path.display());
50 }
51
52 format
53 }
54
55 pub fn supported_extensions() -> Vec<String> {
57 let mut exts = Vec::new();
58 for format in &[
59 ModelFormat::GGUF,
60 ModelFormat::GGML,
61 ModelFormat::ONNX,
62 ModelFormat::TensorRT,
63 ModelFormat::Safetensors,
64 ModelFormat::CoreML,
65 ] {
66 for ext in format.extensions() {
67 exts.push(ext.to_string());
68 }
69 }
70 exts
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77 use std::path::PathBuf;
78
79 #[test]
80 fn test_gguf_detection() {
81 let path = PathBuf::from("model.gguf");
82 assert_eq!(FormatDetector::detect_from_path(&path), Some(ModelFormat::GGUF));
83 }
84
85 #[test]
86 fn test_onnx_detection() {
87 let path = PathBuf::from("model.onnx");
88 assert_eq!(FormatDetector::detect_from_path(&path), Some(ModelFormat::ONNX));
89 }
90
91 #[test]
92 fn test_tensorrt_detection() {
93 let path = PathBuf::from("model.trt");
94 assert_eq!(FormatDetector::detect_from_path(&path), Some(ModelFormat::TensorRT));
95 }
96
97 #[test]
98 fn test_safetensors_detection() {
99 let path = PathBuf::from("model.safetensors");
100 assert_eq!(FormatDetector::detect_from_path(&path), Some(ModelFormat::Safetensors));
101 }
102}