Skip to main content

modelexpress_common/
providers.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use anyhow::Result;
5use std::path::PathBuf;
6
7/// Trait for model providers
8/// This trait provides the framework for supporting multiple model providers.
9#[async_trait::async_trait]
10pub trait ModelProviderTrait: Send + Sync {
11    /// Download a model and return the path where it was downloaded
12    async fn download_model(
13        &self,
14        model_name: &str,
15        cache_path: Option<PathBuf>,
16        ignore_weights: bool,
17    ) -> Result<PathBuf>;
18
19    /// Delete a model from the provider's cache
20    /// Returns Ok(()) if the model was successfully deleted or didn't exist
21    async fn delete_model(&self, model_name: &str) -> Result<()>;
22
23    /// Get the full path to the latest model snapshot if it exists
24    /// Returns the path if found, or an error if not found
25    async fn get_model_path(&self, model_name: &str, cache_dir: PathBuf) -> Result<PathBuf>;
26
27    /// Get the provider name for logging
28    fn provider_name(&self) -> &'static str;
29
30    /// Check if a file should be ignored during download
31    /// This allows each provider to specify which files to skip
32    /// Default implementation ignores common repository metadata files
33    fn is_ignored(filename: &str) -> bool
34    where
35        Self: Sized,
36    {
37        const DEFAULT_IGNORED: [&str; 3] = [".gitattributes", ".gitignore", "README.md"];
38        DEFAULT_IGNORED.contains(&filename)
39    }
40
41    /// Check if a file is an image file that should be ignored
42    /// This allows each provider to customize image file detection
43    /// Default implementation recognizes common image file extensions
44    fn is_image(path: &std::path::Path) -> bool
45    where
46        Self: Sized,
47    {
48        path.extension().is_some_and(|ext| {
49            ext.eq_ignore_ascii_case("png")
50                || ext.eq_ignore_ascii_case("jpg")
51                || ext.eq_ignore_ascii_case("jpeg")
52                || ext.eq_ignore_ascii_case("gif")
53                || ext.eq_ignore_ascii_case("webp")
54                || ext.eq_ignore_ascii_case("svg")
55                || ext.eq_ignore_ascii_case("ico")
56                || ext.eq_ignore_ascii_case("bmp")
57                || ext.eq_ignore_ascii_case("tiff")
58                || ext.eq_ignore_ascii_case("tif")
59        })
60    }
61
62    /// Checks if a file is a model weight file
63    fn is_weight_file(filename: &str) -> bool
64    where
65        Self: Sized,
66    {
67        filename.ends_with(".bin")
68            || filename.ends_with(".safetensors")
69            || filename.ends_with(".h5")
70            || filename.ends_with(".msgpack")
71            || filename.ends_with(".ckpt.index")
72    }
73}
74
75pub mod huggingface;
76
77pub use huggingface::HuggingFaceProvider;
78
79#[cfg(test)]
80mod tests {
81    use super::*;
82    use std::path::Path;
83
84    #[test]
85    fn test_is_image_function() {
86        assert!(HuggingFaceProvider::is_image(Path::new("test.png")));
87        assert!(HuggingFaceProvider::is_image(Path::new("test.PNG")));
88        assert!(HuggingFaceProvider::is_image(Path::new("test.jpg")));
89        assert!(HuggingFaceProvider::is_image(Path::new("test.JPG")));
90        assert!(HuggingFaceProvider::is_image(Path::new("test.jpeg")));
91        assert!(HuggingFaceProvider::is_image(Path::new("test.JPEG")));
92
93        assert!(!HuggingFaceProvider::is_image(Path::new("test.txt")));
94        assert!(!HuggingFaceProvider::is_image(Path::new("test.py")));
95        assert!(!HuggingFaceProvider::is_image(Path::new("test")));
96        assert!(!HuggingFaceProvider::is_image(Path::new("test.model")));
97    }
98
99    #[test]
100    fn test_ignored_files() {
101        assert!(HuggingFaceProvider::is_ignored(".gitattributes"));
102        assert!(HuggingFaceProvider::is_ignored(".gitignore"));
103        assert!(HuggingFaceProvider::is_ignored("README.md"));
104
105        assert!(!HuggingFaceProvider::is_ignored("model.bin"));
106        assert!(!HuggingFaceProvider::is_ignored("tokenizer.json"));
107        assert!(!HuggingFaceProvider::is_ignored("config.json"));
108    }
109
110    #[test]
111    fn test_is_weight_file() {
112        assert!(HuggingFaceProvider::is_weight_file("model.bin"));
113        assert!(HuggingFaceProvider::is_weight_file("model.safetensors"));
114        assert!(HuggingFaceProvider::is_weight_file("model.h5"));
115        assert!(HuggingFaceProvider::is_weight_file("model.msgpack"));
116        assert!(HuggingFaceProvider::is_weight_file("model.ckpt.index"));
117
118        assert!(!HuggingFaceProvider::is_weight_file("tokenizer.json"));
119        assert!(!HuggingFaceProvider::is_weight_file("config.json"));
120        assert!(!HuggingFaceProvider::is_weight_file("README.md"));
121    }
122}