use anyhow::Result;
use std::path::PathBuf;
#[async_trait::async_trait]
pub trait ModelProviderTrait: Send + Sync {
async fn download_model(
&self,
model_name: &str,
cache_path: Option<PathBuf>,
ignore_weights: bool,
) -> Result<PathBuf>;
async fn delete_model(&self, model_name: &str, cache_dir: PathBuf) -> Result<()>;
async fn get_model_path(&self, model_name: &str, cache_dir: PathBuf) -> Result<PathBuf>;
fn provider_name(&self) -> &'static str;
fn is_ignored(filename: &str) -> bool
where
Self: Sized,
{
const DEFAULT_IGNORED: [&str; 1] = ["README.md"];
let name = std::path::Path::new(filename)
.file_name()
.and_then(|s| s.to_str())
.unwrap_or(filename);
name.starts_with('.') || DEFAULT_IGNORED.contains(&name)
}
fn is_image(path: &std::path::Path) -> bool
where
Self: Sized,
{
path.extension().is_some_and(|ext| {
ext.eq_ignore_ascii_case("png")
|| ext.eq_ignore_ascii_case("jpg")
|| ext.eq_ignore_ascii_case("jpeg")
|| ext.eq_ignore_ascii_case("gif")
|| ext.eq_ignore_ascii_case("webp")
|| ext.eq_ignore_ascii_case("svg")
|| ext.eq_ignore_ascii_case("ico")
|| ext.eq_ignore_ascii_case("bmp")
|| ext.eq_ignore_ascii_case("tiff")
|| ext.eq_ignore_ascii_case("tif")
})
}
fn is_weight_file(filename: &str) -> bool
where
Self: Sized,
{
filename.ends_with(".bin")
|| filename.ends_with(".safetensors")
|| filename.ends_with(".h5")
|| filename.ends_with(".msgpack")
|| filename.ends_with(".ckpt.index")
}
}
pub mod huggingface;
pub use huggingface::HuggingFaceProvider;
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
#[test]
fn test_is_image_function() {
assert!(HuggingFaceProvider::is_image(Path::new("test.png")));
assert!(HuggingFaceProvider::is_image(Path::new("test.PNG")));
assert!(HuggingFaceProvider::is_image(Path::new("test.jpg")));
assert!(HuggingFaceProvider::is_image(Path::new("test.JPG")));
assert!(HuggingFaceProvider::is_image(Path::new("test.jpeg")));
assert!(HuggingFaceProvider::is_image(Path::new("test.JPEG")));
assert!(!HuggingFaceProvider::is_image(Path::new("test.txt")));
assert!(!HuggingFaceProvider::is_image(Path::new("test.py")));
assert!(!HuggingFaceProvider::is_image(Path::new("test")));
assert!(!HuggingFaceProvider::is_image(Path::new("test.model")));
}
#[test]
fn test_ignored_files() {
assert!(HuggingFaceProvider::is_ignored(".gitattributes"));
assert!(HuggingFaceProvider::is_ignored(".gitignore"));
assert!(HuggingFaceProvider::is_ignored(".gitkeep"));
assert!(HuggingFaceProvider::is_ignored(".hidden"));
assert!(HuggingFaceProvider::is_ignored("subdir/.gitkeep"));
assert!(HuggingFaceProvider::is_ignored("a/b/.hidden"));
assert!(HuggingFaceProvider::is_ignored("README.md"));
assert!(HuggingFaceProvider::is_ignored("subdir/README.md"));
assert!(!HuggingFaceProvider::is_ignored("model.bin"));
assert!(!HuggingFaceProvider::is_ignored("tokenizer.json"));
assert!(!HuggingFaceProvider::is_ignored("config.json"));
}
#[test]
fn test_is_weight_file() {
assert!(HuggingFaceProvider::is_weight_file("model.bin"));
assert!(HuggingFaceProvider::is_weight_file("model.safetensors"));
assert!(HuggingFaceProvider::is_weight_file("model.h5"));
assert!(HuggingFaceProvider::is_weight_file("model.msgpack"));
assert!(HuggingFaceProvider::is_weight_file("model.ckpt.index"));
assert!(!HuggingFaceProvider::is_weight_file("tokenizer.json"));
assert!(!HuggingFaceProvider::is_weight_file("config.json"));
assert!(!HuggingFaceProvider::is_weight_file("README.md"));
}
}