modelexpress_common/
providers.rs1use anyhow::Result;
5use std::path::PathBuf;
6
7#[async_trait::async_trait]
10pub trait ModelProviderTrait: Send + Sync {
11 async fn download_model(
13 &self,
14 model_name: &str,
15 cache_path: Option<PathBuf>,
16 ignore_weights: bool,
17 ) -> Result<PathBuf>;
18
19 async fn delete_model(&self, model_name: &str) -> Result<()>;
22
23 async fn get_model_path(&self, model_name: &str, cache_dir: PathBuf) -> Result<PathBuf>;
26
27 fn provider_name(&self) -> &'static str;
29
30 fn is_ignored(filename: &str) -> bool
34 where
35 Self: Sized,
36 {
37 const DEFAULT_IGNORED: [&str; 3] = [".gitattributes", ".gitignore", "README.md"];
38 DEFAULT_IGNORED.contains(&filename)
39 }
40
41 fn is_image(path: &std::path::Path) -> bool
45 where
46 Self: Sized,
47 {
48 path.extension().is_some_and(|ext| {
49 ext.eq_ignore_ascii_case("png")
50 || ext.eq_ignore_ascii_case("jpg")
51 || ext.eq_ignore_ascii_case("jpeg")
52 || ext.eq_ignore_ascii_case("gif")
53 || ext.eq_ignore_ascii_case("webp")
54 || ext.eq_ignore_ascii_case("svg")
55 || ext.eq_ignore_ascii_case("ico")
56 || ext.eq_ignore_ascii_case("bmp")
57 || ext.eq_ignore_ascii_case("tiff")
58 || ext.eq_ignore_ascii_case("tif")
59 })
60 }
61
62 fn is_weight_file(filename: &str) -> bool
64 where
65 Self: Sized,
66 {
67 filename.ends_with(".bin")
68 || filename.ends_with(".safetensors")
69 || filename.ends_with(".h5")
70 || filename.ends_with(".msgpack")
71 || filename.ends_with(".ckpt.index")
72 }
73}
74
75pub mod huggingface;
76
77pub use huggingface::HuggingFaceProvider;
78
79#[cfg(test)]
80mod tests {
81 use super::*;
82 use std::path::Path;
83
84 #[test]
85 fn test_is_image_function() {
86 assert!(HuggingFaceProvider::is_image(Path::new("test.png")));
87 assert!(HuggingFaceProvider::is_image(Path::new("test.PNG")));
88 assert!(HuggingFaceProvider::is_image(Path::new("test.jpg")));
89 assert!(HuggingFaceProvider::is_image(Path::new("test.JPG")));
90 assert!(HuggingFaceProvider::is_image(Path::new("test.jpeg")));
91 assert!(HuggingFaceProvider::is_image(Path::new("test.JPEG")));
92
93 assert!(!HuggingFaceProvider::is_image(Path::new("test.txt")));
94 assert!(!HuggingFaceProvider::is_image(Path::new("test.py")));
95 assert!(!HuggingFaceProvider::is_image(Path::new("test")));
96 assert!(!HuggingFaceProvider::is_image(Path::new("test.model")));
97 }
98
99 #[test]
100 fn test_ignored_files() {
101 assert!(HuggingFaceProvider::is_ignored(".gitattributes"));
102 assert!(HuggingFaceProvider::is_ignored(".gitignore"));
103 assert!(HuggingFaceProvider::is_ignored("README.md"));
104
105 assert!(!HuggingFaceProvider::is_ignored("model.bin"));
106 assert!(!HuggingFaceProvider::is_ignored("tokenizer.json"));
107 assert!(!HuggingFaceProvider::is_ignored("config.json"));
108 }
109
110 #[test]
111 fn test_is_weight_file() {
112 assert!(HuggingFaceProvider::is_weight_file("model.bin"));
113 assert!(HuggingFaceProvider::is_weight_file("model.safetensors"));
114 assert!(HuggingFaceProvider::is_weight_file("model.h5"));
115 assert!(HuggingFaceProvider::is_weight_file("model.msgpack"));
116 assert!(HuggingFaceProvider::is_weight_file("model.ckpt.index"));
117
118 assert!(!HuggingFaceProvider::is_weight_file("tokenizer.json"));
119 assert!(!HuggingFaceProvider::is_weight_file("config.json"));
120 assert!(!HuggingFaceProvider::is_weight_file("README.md"));
121 }
122}