pub mod cache;
pub mod downloader;
pub mod registry;
pub mod verification;
pub use cache::{CacheConfig, ModelCache};
pub use downloader::{DownloadError, DownloadProgress, ModelDownloader};
pub use registry::{ModelInfo, ModelRegistry, ModelSource};
pub use verification::{Checksum, ModelVerifier};
use crate::error::RusTorchResult;
use crate::model_import::ImportedModel;
use std::path::PathBuf;
pub struct ModelHub {
downloader: ModelDownloader,
registry: ModelRegistry,
cache: ModelCache,
verifier: ModelVerifier,
}
impl ModelHub {
pub fn new() -> RusTorchResult<Self> {
let cache_dir = dirs::cache_dir()
.unwrap_or_else(|| std::env::temp_dir())
.join("rustorch")
.join("models");
let downloader = ModelDownloader::new();
let registry = ModelRegistry::default();
let cache = ModelCache::new(cache_dir)?;
let verifier = ModelVerifier::new();
Ok(Self {
downloader,
registry,
cache,
verifier,
})
}
pub fn with_cache_dir<P: Into<PathBuf>>(cache_dir: P) -> RusTorchResult<Self> {
let downloader = ModelDownloader::new();
let registry = ModelRegistry::default();
let cache = ModelCache::new(cache_dir)?;
let verifier = ModelVerifier::new();
Ok(Self {
downloader,
registry,
cache,
verifier,
})
}
pub async fn load_model(&mut self, model_name: &str) -> RusTorchResult<ImportedModel> {
if let Some(cached_path) = self.cache.get_model_path(model_name) {
if cached_path.exists() {
println!("Loading cached model: {}", model_name);
return crate::model_import::pytorch::import_pytorch_model(&cached_path);
}
}
let model_info = self.registry.get_model_info(model_name)?;
println!("Downloading model: {} from {}", model_name, model_info.url);
let download_path = self.cache.get_download_path(model_name);
self.downloader
.download_with_progress(&model_info.url, &download_path, |progress| {
println!("Download progress: {:.1}%", progress.percentage());
})
.await?;
if let Some(expected_hash) = &model_info.checksum {
println!("Verifying model integrity...");
self.verifier.verify_file(&download_path, expected_hash)?;
println!("Model verification successful");
}
let cached_path = self.cache.cache_model(model_name, &download_path)?;
println!("Importing model...");
crate::model_import::pytorch::import_pytorch_model(&cached_path)
}
pub fn list_models(&self) -> Vec<&str> {
self.registry.list_models()
}
pub fn get_model_info(&self, model_name: &str) -> RusTorchResult<&ModelInfo> {
self.registry.get_model_info(model_name)
}
pub fn clear_cache(&mut self) -> RusTorchResult<()> {
self.cache.clear()
}
pub fn cache_stats(&self) -> (usize, u64) {
self.cache.stats()
}
}
impl Default for ModelHub {
fn default() -> Self {
Self::new().expect("Failed to create default ModelHub")
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn test_model_hub_creation() {
let hub = ModelHub::new();
assert!(hub.is_ok());
}
#[test]
fn test_model_hub_with_custom_cache() {
let temp_dir = TempDir::new().unwrap();
let hub = ModelHub::with_cache_dir(temp_dir.path());
assert!(hub.is_ok());
}
#[test]
fn test_list_models() {
let hub = ModelHub::new().unwrap();
let models = hub.list_models();
assert!(!models.is_empty());
assert!(models.contains(&"resnet18"));
assert!(models.contains(&"resnet50"));
}
#[tokio::test]
async fn test_model_loading_mock() {
let temp_dir = TempDir::new().unwrap();
let hub = ModelHub::with_cache_dir(temp_dir.path()).unwrap();
assert!(hub.list_models().contains(&"resnet18"));
}
}