use crate::error::{Result, TrustformersError};
use reqwest;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs::{self};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, RwLock};
use std::time::{Duration, SystemTime};
use tokio::fs as async_fs;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::time::{interval, sleep};
use tokio_stream::StreamExt;
use trustformers_core::errors::TrustformersError as CoreTrustformersError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MirrorConfig {
pub storage_path: PathBuf,
pub remote_hub_url: String,
pub sync_interval: Duration,
pub max_storage_size_gb: f64,
pub compression_enabled: bool,
pub auto_cleanup: bool,
pub cleanup_threshold: f64, pub parallel_downloads: usize,
pub retry_attempts: u32,
pub retry_delay: Duration,
pub bandwidth_limit_mbps: Option<f64>,
pub priority_models: Vec<String>, }
impl Default for MirrorConfig {
fn default() -> Self {
Self {
storage_path: PathBuf::from("./hub_mirror"),
remote_hub_url: "https://hub.trustformers.ai".to_string(),
sync_interval: Duration::from_secs(3600), max_storage_size_gb: 100.0,
compression_enabled: true,
auto_cleanup: true,
cleanup_threshold: 0.8,
parallel_downloads: 4,
retry_attempts: 3,
retry_delay: Duration::from_secs(5),
bandwidth_limit_mbps: None,
priority_models: Vec::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedModel {
pub model_id: String,
pub version: String,
pub local_path: PathBuf,
pub remote_url: String,
pub cached_at: SystemTime,
pub last_accessed: SystemTime,
pub access_count: u64,
pub file_size: u64,
pub checksum: String,
pub metadata: ModelMetadata,
pub is_priority: bool,
pub download_complete: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub name: String,
pub description: Option<String>,
pub architecture: String,
pub task: String,
pub language: Option<String>,
pub license: Option<String>,
pub tags: Vec<String>,
pub performance_metrics: HashMap<String, f64>,
pub size_mb: f64,
pub dependencies: Vec<String>,
}
#[derive(Debug, Clone, Default)]
pub struct MirrorStats {
pub total_models: usize,
pub total_size_gb: f64,
pub cache_hits: u64,
pub cache_misses: u64,
pub downloads_completed: u64,
pub downloads_failed: u64,
pub last_sync: Option<SystemTime>,
pub sync_errors: u64,
pub bandwidth_saved_gb: f64,
pub average_download_speed_mbps: f64,
}
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub model_id: String,
pub version: String,
pub bytes_downloaded: u64,
pub total_bytes: u64,
pub progress_percent: f64,
pub download_speed_mbps: f64,
pub eta_seconds: u64,
pub status: DownloadStatus,
}
#[derive(Debug, Clone, PartialEq)]
pub enum DownloadStatus {
Queued,
Downloading,
Completed,
Failed(String),
Cancelled,
}
pub struct HubMirror {
config: MirrorConfig,
cache: Arc<RwLock<HashMap<String, CachedModel>>>,
stats: Arc<Mutex<MirrorStats>>,
download_queue: Arc<RwLock<HashMap<String, DownloadProgress>>>,
http_client: reqwest::Client,
sync_handle: Option<tokio::task::JoinHandle<()>>,
}
impl HubMirror {
pub fn new(config: MirrorConfig) -> Result<Self> {
fs::create_dir_all(&config.storage_path).map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to create storage directory: {}",
e
)))
})?;
let client_builder = reqwest::Client::builder()
.timeout(Duration::from_secs(300))
.user_agent("TrustformeRS-Mirror/1.0");
if let Some(bandwidth_limit) = config.bandwidth_limit_mbps {
tracing::info!("Bandwidth limit set to {} Mbps", bandwidth_limit);
}
let http_client = client_builder.build().map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to create HTTP client: {}",
e
)))
})?;
let mirror = Self {
config,
cache: Arc::new(RwLock::new(HashMap::new())),
stats: Arc::new(Mutex::new(MirrorStats::default())),
download_queue: Arc::new(RwLock::new(HashMap::new())),
http_client,
sync_handle: None,
};
Ok(mirror)
}
pub async fn initialize(&mut self) -> Result<()> {
self.load_cache().await?;
self.start_background_sync().await?;
self.cleanup_if_needed().await?;
tracing::info!(
"Hub mirror initialized with {} cached models",
self.cache.read().expect("lock should not be poisoned").len()
);
Ok(())
}
pub async fn get_model(&self, model_id: &str, version: Option<&str>) -> Result<PathBuf> {
let version = version.unwrap_or("latest");
let cache_key = format!("{}:{}", model_id, version);
{
let mut cache = self.cache.write().expect("lock should not be poisoned");
if let Some(cached_model) = cache.get_mut(&cache_key) {
if cached_model.download_complete && cached_model.local_path.exists() {
cached_model.last_accessed = SystemTime::now();
cached_model.access_count += 1;
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.cache_hits += 1;
tracing::debug!("Cache hit for {}:{}", model_id, version);
return Ok(cached_model.local_path.clone());
}
}
}
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.cache_misses += 1;
}
tracing::info!("Cache miss for {}:{}, downloading...", model_id, version);
self.download_model(model_id, version).await
}
async fn download_model(&self, model_id: &str, version: &str) -> Result<PathBuf> {
let cache_key = format!("{}:{}", model_id, version);
{
let queue = self.download_queue.read().expect("lock should not be poisoned");
if let Some(progress) = queue.get(&cache_key) {
if progress.status == DownloadStatus::Downloading {
return self.wait_for_download(&cache_key).await;
}
}
}
let download_url = format!(
"{}/models/{}/versions/{}/download",
self.config.remote_hub_url, model_id, version
);
let metadata = self.fetch_model_metadata(model_id, version).await?;
{
let mut queue = self.download_queue.write().expect("lock should not be poisoned");
queue.insert(
cache_key.clone(),
DownloadProgress {
model_id: model_id.to_string(),
version: version.to_string(),
bytes_downloaded: 0,
total_bytes: (metadata.size_mb * 1024.0 * 1024.0) as u64,
progress_percent: 0.0,
download_speed_mbps: 0.0,
eta_seconds: 0,
status: DownloadStatus::Queued,
},
);
}
let local_path = self
.config
.storage_path
.join("models")
.join(model_id)
.join(version)
.join("model.safetensors");
let parent_dir = local_path.parent().ok_or_else(|| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to get parent directory for path: {}",
local_path.display()
)))
})?;
fs::create_dir_all(parent_dir).map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to create model directory: {}",
e
)))
})?;
let download_result =
self.download_with_progress(&download_url, &local_path, &cache_key).await;
match download_result {
Ok(()) => {
let file_size = local_path
.metadata()
.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to get file metadata: {}",
e
)))
})?
.len();
let checksum = self.calculate_file_hash(&local_path).await?;
let cached_model = CachedModel {
model_id: model_id.to_string(),
version: version.to_string(),
local_path: local_path.clone(),
remote_url: download_url,
cached_at: SystemTime::now(),
last_accessed: SystemTime::now(),
access_count: 1,
file_size,
checksum,
metadata,
is_priority: self.config.priority_models.contains(&model_id.to_string()),
download_complete: true,
};
{
let mut cache = self.cache.write().expect("lock should not be poisoned");
cache.insert(cache_key.clone(), cached_model);
}
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.downloads_completed += 1;
stats.total_models =
self.cache.read().expect("lock should not be poisoned").len();
stats.total_size_gb += file_size as f64 / (1024.0 * 1024.0 * 1024.0);
}
{
let mut queue =
self.download_queue.write().expect("lock should not be poisoned");
queue.remove(&cache_key);
}
self.save_cache().await?;
tracing::info!(
"Successfully downloaded {}:{} ({} MB)",
model_id,
version,
file_size / 1024 / 1024
);
Ok(local_path)
},
Err(e) => {
{
let mut queue =
self.download_queue.write().expect("lock should not be poisoned");
if let Some(progress) = queue.get_mut(&cache_key) {
progress.status = DownloadStatus::Failed(e.to_string());
}
}
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.downloads_failed += 1;
}
Err(e)
},
}
}
async fn download_with_progress(
&self,
url: &str,
local_path: &Path,
cache_key: &str,
) -> Result<()> {
let mut attempt = 0;
while attempt < self.config.retry_attempts {
attempt += 1;
match self.attempt_download(url, local_path, cache_key).await {
Ok(()) => return Ok(()),
Err(e) => {
if attempt < self.config.retry_attempts {
tracing::warn!(
"Download attempt {} failed for {}: {}. Retrying in {:?}...",
attempt,
cache_key,
e,
self.config.retry_delay
);
sleep(self.config.retry_delay).await;
} else {
return Err(e);
}
},
}
}
Err(TrustformersError::Core(CoreTrustformersError::other(
format!(
"Download failed after {} attempts",
self.config.retry_attempts
)
.to_string(),
)))
}
async fn attempt_download(&self, url: &str, local_path: &Path, cache_key: &str) -> Result<()> {
{
let mut queue = self.download_queue.write().expect("lock should not be poisoned");
if let Some(progress) = queue.get_mut(cache_key) {
progress.status = DownloadStatus::Downloading;
}
}
let response = self.http_client.get(url).send().await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to start download: {}",
e
)))
})?;
if !response.status().is_success() {
return Err(TrustformersError::Core(CoreTrustformersError::other(
format!("Download failed with status: {}", response.status()),
)));
}
let total_size = response.content_length().unwrap_or(0);
let mut file = async_fs::File::create(local_path).await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to create file: {}",
e
)))
})?;
let mut stream = response.bytes_stream();
let mut downloaded = 0u64;
let start_time = std::time::Instant::now();
while let Some(chunk) = stream.next().await {
let chunk = chunk.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to read chunk: {}",
e
)))
})?;
file.write_all(&chunk).await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to write chunk: {}",
e
)))
})?;
downloaded += chunk.len() as u64;
let elapsed = start_time.elapsed().as_secs_f64();
let speed_mbps = if elapsed > 0.0 {
(downloaded as f64 / elapsed) / (1024.0 * 1024.0)
} else {
0.0
};
let progress_percent = if total_size > 0 {
(downloaded as f64 / total_size as f64) * 100.0
} else {
0.0
};
let eta_seconds = if speed_mbps > 0.0 && total_size > 0 {
((total_size - downloaded) as f64 / (speed_mbps * 1024.0 * 1024.0)) as u64
} else {
0
};
{
let mut queue = self.download_queue.write().expect("lock should not be poisoned");
if let Some(progress) = queue.get_mut(cache_key) {
progress.bytes_downloaded = downloaded;
progress.progress_percent = progress_percent;
progress.download_speed_mbps = speed_mbps;
progress.eta_seconds = eta_seconds;
}
}
if let Some(limit_mbps) = self.config.bandwidth_limit_mbps {
let target_delay = (chunk.len() as f64) / (limit_mbps * 1024.0 * 1024.0);
if elapsed < target_delay {
sleep(Duration::from_secs_f64(target_delay - elapsed)).await;
}
}
}
file.flush().await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to flush file: {}",
e
)))
})?;
Ok(())
}
async fn wait_for_download(&self, cache_key: &str) -> Result<PathBuf> {
let mut interval = interval(Duration::from_millis(100));
loop {
interval.tick().await;
let queue = self.download_queue.read().expect("lock should not be poisoned");
if let Some(progress) = queue.get(cache_key) {
match &progress.status {
DownloadStatus::Completed => {
drop(queue);
let cache = self.cache.read().expect("lock should not be poisoned");
if let Some(cached_model) = cache.get(cache_key) {
return Ok(cached_model.local_path.clone());
} else {
return Err(TrustformersError::Core(CoreTrustformersError::other(
"Download completed but model not in cache".to_string(),
)));
}
},
DownloadStatus::Failed(error) => {
return Err(TrustformersError::Core(CoreTrustformersError::other(
format!("Download failed: {}", error),
)));
},
DownloadStatus::Cancelled => {
return Err(TrustformersError::Core(CoreTrustformersError::other(
"Download was cancelled".to_string(),
)));
},
_ => {
continue;
},
}
} else {
let cache = self.cache.read().expect("lock should not be poisoned");
if let Some(cached_model) = cache.get(cache_key) {
return Ok(cached_model.local_path.clone());
} else {
return Err(TrustformersError::Core(CoreTrustformersError::other(
"Download not found in queue or cache".to_string(),
)));
}
}
}
}
async fn fetch_model_metadata(&self, model_id: &str, version: &str) -> Result<ModelMetadata> {
let metadata_url = format!(
"{}/models/{}/versions/{}/metadata",
self.config.remote_hub_url, model_id, version
);
let response = self.http_client.get(&metadata_url).send().await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to fetch metadata: {}",
e
)))
})?;
if !response.status().is_success() {
return Err(TrustformersError::Core(CoreTrustformersError::other(
format!("Failed to fetch metadata: {}", response.status()),
)));
}
let metadata: ModelMetadata = response.json().await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to parse metadata: {}",
e
)))
})?;
Ok(metadata)
}
async fn start_background_sync(&mut self) -> Result<()> {
let cache = self.cache.clone();
let stats = self.stats.clone();
let config = self.config.clone();
let http_client = self.http_client.clone();
let handle = tokio::spawn(async move {
let mut sync_interval = interval(config.sync_interval);
loop {
sync_interval.tick().await;
if let Err(e) = Self::sync_with_remote(&cache, &stats, &config, &http_client).await
{
tracing::error!("Background sync failed: {}", e);
let mut stats_lock = stats.lock().expect("lock should not be poisoned");
stats_lock.sync_errors += 1;
}
}
});
self.sync_handle = Some(handle);
Ok(())
}
async fn sync_with_remote(
cache: &Arc<RwLock<HashMap<String, CachedModel>>>,
stats: &Arc<Mutex<MirrorStats>>,
config: &MirrorConfig,
http_client: &reqwest::Client,
) -> Result<()> {
tracing::info!("Starting background sync with remote hub");
let models_url = format!("{}/models", config.remote_hub_url);
let response = http_client.get(&models_url).send().await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to fetch models list: {}",
e
)))
})?;
if !response.status().is_success() {
return Err(TrustformersError::Core(CoreTrustformersError::other(
format!("Failed to fetch models list: {}", response.status()),
)));
}
let remote_models: Vec<RemoteModelInfo> = response.json().await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to parse models list: {}",
e
)))
})?;
let mut updates_found = 0;
{
let cache_read = cache.read().expect("lock should not be poisoned");
for cached_model in cache_read.values() {
if let Some(remote_model) =
remote_models.iter().find(|m| m.model_id == cached_model.model_id)
{
if remote_model.latest_version != cached_model.version {
tracing::info!(
"Update available for {}: {} -> {}",
cached_model.model_id,
cached_model.version,
remote_model.latest_version
);
updates_found += 1;
}
}
}
}
{
let mut stats_lock = stats.lock().expect("lock should not be poisoned");
stats_lock.last_sync = Some(SystemTime::now());
}
tracing::info!("Sync completed. Found {} updates available", updates_found);
Ok(())
}
async fn cleanup_if_needed(&self) -> Result<()> {
let current_size = self.calculate_total_size().await?;
let max_size = self.config.max_storage_size_gb * 1024.0 * 1024.0 * 1024.0;
if current_size > max_size * self.config.cleanup_threshold {
tracing::info!(
"Cache cleanup triggered. Current size: {:.2} GB, Max: {:.2} GB",
current_size / (1024.0 * 1024.0 * 1024.0),
self.config.max_storage_size_gb
);
self.cleanup_cache().await?;
}
Ok(())
}
async fn cleanup_cache(&self) -> Result<()> {
let mut models_to_remove = Vec::new();
let max_size = self.config.max_storage_size_gb * 1024.0 * 1024.0 * 1024.0;
let target_size = max_size * 0.7;
{
let cache = self.cache.read().expect("lock should not be poisoned");
let mut cache_items: Vec<_> = cache.values().collect();
cache_items.sort_by(|a, b| {
match (a.is_priority, b.is_priority) {
(true, false) => std::cmp::Ordering::Greater,
(false, true) => std::cmp::Ordering::Less,
_ => a.last_accessed.cmp(&b.last_accessed), }
});
let mut current_size = self.calculate_total_size().await?;
for model in cache_items {
if current_size <= target_size || model.is_priority {
break;
}
models_to_remove.push((model.model_id.clone(), model.version.clone()));
current_size -= model.file_size as f64;
}
}
for (model_id, version) in models_to_remove {
self.remove_model(&model_id, &version).await?;
tracing::info!("Removed {}:{} during cleanup", model_id, version);
}
Ok(())
}
pub async fn remove_model(&self, model_id: &str, version: &str) -> Result<()> {
let cache_key = format!("{}:{}", model_id, version);
let local_path = {
let cache = self.cache.read().expect("lock should not be poisoned");
cache.get(&cache_key).map(|m| m.local_path.clone())
};
if let Some(path) = local_path {
if path.exists() {
async_fs::remove_file(&path).await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to remove file: {}",
e
)))
})?;
}
{
let mut cache = self.cache.write().expect("lock should not be poisoned");
cache.remove(&cache_key);
}
{
let mut stats = self.stats.lock().expect("lock should not be poisoned");
stats.total_models = self.cache.read().expect("lock should not be poisoned").len();
}
self.save_cache().await?;
}
Ok(())
}
pub fn get_stats(&self) -> MirrorStats {
self.stats.lock().expect("lock should not be poisoned").clone()
}
pub fn get_download_progress(&self) -> Vec<DownloadProgress> {
self.download_queue
.read()
.expect("lock should not be poisoned")
.values()
.cloned()
.collect()
}
async fn calculate_total_size(&self) -> Result<f64> {
let cache = self.cache.read().expect("lock should not be poisoned");
Ok(cache.values().map(|m| m.file_size as f64).sum())
}
async fn calculate_file_hash(&self, path: &Path) -> Result<String> {
let mut file = async_fs::File::open(path).await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to open file for hashing: {}",
e
)))
})?;
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192];
loop {
let n = file.read(&mut buffer).await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to read file for hashing: {}",
e
)))
})?;
if n == 0 {
break;
}
hasher.update(&buffer[..n]);
}
Ok(format!("{:x}", hasher.finalize()))
}
async fn load_cache(&self) -> Result<()> {
let cache_file = self.config.storage_path.join("cache.json");
if cache_file.exists() {
let content = async_fs::read_to_string(&cache_file).await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to read cache file: {}",
e
)))
})?;
let cached_models: HashMap<String, CachedModel> = serde_json::from_str(&content)
.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to parse cache file: {}",
e
)))
})?;
let mut valid_models = HashMap::new();
for (key, model) in cached_models {
if model.local_path.exists() {
valid_models.insert(key, model);
}
}
{
let mut cache = self.cache.write().expect("lock should not be poisoned");
*cache = valid_models;
}
tracing::info!(
"Loaded {} models from cache",
self.cache.read().expect("lock should not be poisoned").len()
);
}
Ok(())
}
async fn save_cache(&self) -> Result<()> {
let cache_file = self.config.storage_path.join("cache.json");
let cache = self.cache.read().expect("lock should not be poisoned");
let content = serde_json::to_string_pretty(&*cache).map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to serialize cache: {}",
e
)))
})?;
async_fs::write(&cache_file, content).await.map_err(|e| {
TrustformersError::Core(CoreTrustformersError::other(format!(
"Failed to write cache file: {}",
e
)))
})?;
Ok(())
}
pub async fn shutdown(&mut self) -> Result<()> {
if let Some(handle) = self.sync_handle.take() {
handle.abort();
}
self.save_cache().await?;
tracing::info!("Hub mirror shut down gracefully");
Ok(())
}
}
#[derive(Debug, Deserialize)]
struct RemoteModelInfo {
model_id: String,
latest_version: String,
size_mb: f64,
updated_at: String,
}
static HUB_MIRROR: std::sync::OnceLock<Arc<tokio::sync::Mutex<HubMirror>>> =
std::sync::OnceLock::new();
pub async fn init_hub_mirror(config: MirrorConfig) -> Result<()> {
let mut mirror = HubMirror::new(config)?;
mirror.initialize().await?;
HUB_MIRROR.set(Arc::new(tokio::sync::Mutex::new(mirror))).map_err(|_| {
TrustformersError::Core(CoreTrustformersError::other(
"Hub mirror already initialized".to_string(),
))
})?;
Ok(())
}
pub fn get_hub_mirror() -> Result<Arc<tokio::sync::Mutex<HubMirror>>> {
HUB_MIRROR.get().cloned().ok_or_else(|| {
TrustformersError::Core(CoreTrustformersError::other(
"Hub mirror not initialized".to_string(),
))
})
}
pub async fn get_model_from_mirror(model_id: &str, version: Option<&str>) -> Result<PathBuf> {
let mirror = get_hub_mirror()?;
let mirror_lock = mirror.lock().await;
mirror_lock.get_model(model_id, version).await
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_mirror_creation() {
let temp_dir = TempDir::new().expect("failed to create temp dir");
let config = MirrorConfig {
storage_path: temp_dir.path().to_path_buf(),
..Default::default()
};
let mirror = HubMirror::new(config).expect("operation failed in test");
assert_eq!(
mirror.cache.read().expect("lock should not be poisoned").len(),
0
);
}
#[tokio::test]
async fn test_cache_operations() {
let temp_dir = TempDir::new().expect("failed to create temp dir");
let config = MirrorConfig {
storage_path: temp_dir.path().to_path_buf(),
..Default::default()
};
let mirror = HubMirror::new(config).expect("operation failed in test");
mirror.save_cache().await.expect("async operation failed");
mirror.load_cache().await.expect("async operation failed");
}
#[test]
fn test_download_progress() {
let progress = DownloadProgress {
model_id: "test-model".to_string(),
version: "1.0".to_string(),
bytes_downloaded: 50,
total_bytes: 100,
progress_percent: 50.0,
download_speed_mbps: 10.0,
eta_seconds: 5,
status: DownloadStatus::Downloading,
};
assert_eq!(progress.progress_percent, 50.0);
assert_eq!(progress.status, DownloadStatus::Downloading);
}
#[test]
fn test_mirror_config() {
let config = MirrorConfig::default();
assert_eq!(config.max_storage_size_gb, 100.0);
assert_eq!(config.parallel_downloads, 4);
assert!(config.auto_cleanup);
}
#[test]
fn test_mirror_config_default_values() {
let config = MirrorConfig::default();
assert_eq!(config.storage_path, PathBuf::from("./hub_mirror"));
assert_eq!(config.remote_hub_url, "https://hub.trustformers.ai");
assert_eq!(config.sync_interval, Duration::from_secs(3600));
assert!(config.compression_enabled);
assert!((config.cleanup_threshold - 0.8).abs() < f64::EPSILON);
assert_eq!(config.retry_attempts, 3);
assert_eq!(config.retry_delay, Duration::from_secs(5));
assert!(config.bandwidth_limit_mbps.is_none());
assert!(config.priority_models.is_empty());
}
#[test]
fn test_mirror_stats_default() {
let stats = MirrorStats::default();
assert_eq!(stats.total_models, 0);
assert!((stats.total_size_gb - 0.0).abs() < f64::EPSILON);
assert_eq!(stats.cache_hits, 0);
assert_eq!(stats.cache_misses, 0);
assert_eq!(stats.downloads_completed, 0);
assert_eq!(stats.downloads_failed, 0);
assert!(stats.last_sync.is_none());
}
#[test]
fn test_download_status_equality() {
assert_eq!(DownloadStatus::Queued, DownloadStatus::Queued);
assert_eq!(DownloadStatus::Downloading, DownloadStatus::Downloading);
assert_eq!(DownloadStatus::Completed, DownloadStatus::Completed);
assert_ne!(DownloadStatus::Queued, DownloadStatus::Completed);
}
#[test]
fn test_download_status_failed_variant() {
let status = DownloadStatus::Failed("timeout".to_string());
assert_ne!(status, DownloadStatus::Completed);
if let DownloadStatus::Failed(msg) = &status {
assert_eq!(msg, "timeout");
}
}
#[test]
fn test_download_progress_creation() {
let progress = DownloadProgress {
model_id: "gpt2".to_string(),
version: "1.0".to_string(),
bytes_downloaded: 100_000,
total_bytes: 500_000,
progress_percent: 20.0,
download_speed_mbps: 50.0,
eta_seconds: 8,
status: DownloadStatus::Downloading,
};
assert_eq!(progress.model_id, "gpt2");
assert!((progress.progress_percent - 20.0).abs() < f64::EPSILON);
}
#[test]
fn test_model_metadata_creation() {
let metadata = ModelMetadata {
name: "BERT Base".to_string(),
description: Some("BERT base uncased model".to_string()),
architecture: "transformer".to_string(),
task: "text-classification".to_string(),
language: Some("en".to_string()),
license: Some("Apache-2.0".to_string()),
tags: vec!["nlp".to_string(), "transformer".to_string()],
performance_metrics: {
let mut m = HashMap::new();
m.insert("accuracy".to_string(), 0.9);
m
},
size_mb: 438.0,
dependencies: vec![],
};
assert_eq!(metadata.name, "BERT Base");
assert!(metadata.description.is_some());
assert_eq!(metadata.tags.len(), 2);
assert!((metadata.size_mb - 438.0).abs() < f64::EPSILON);
}
#[test]
fn test_cached_model_creation() {
let cached = CachedModel {
model_id: "test_model".to_string(),
version: "1.0".to_string(),
local_path: PathBuf::from("/tmp/test_model"),
remote_url: "https://example.com/model".to_string(),
cached_at: SystemTime::now(),
last_accessed: SystemTime::now(),
access_count: 5,
file_size: 1_000_000,
checksum: "sha256_hash".to_string(),
metadata: ModelMetadata {
name: "Test".to_string(),
description: None,
architecture: "test".to_string(),
task: "test".to_string(),
language: None,
license: None,
tags: vec![],
performance_metrics: HashMap::new(),
size_mb: 1.0,
dependencies: vec![],
},
is_priority: false,
download_complete: true,
};
assert_eq!(cached.access_count, 5);
assert!(cached.download_complete);
}
#[test]
fn test_mirror_config_custom() {
let config = MirrorConfig {
storage_path: PathBuf::from("/data/models"),
remote_hub_url: "https://custom-hub.example.com".to_string(),
sync_interval: Duration::from_secs(7200),
max_storage_size_gb: 500.0,
compression_enabled: false,
auto_cleanup: false,
cleanup_threshold: 0.95,
parallel_downloads: 8,
retry_attempts: 5,
retry_delay: Duration::from_secs(10),
bandwidth_limit_mbps: Some(100.0),
priority_models: vec!["model_a".to_string(), "model_b".to_string()],
};
assert_eq!(config.parallel_downloads, 8);
assert_eq!(config.priority_models.len(), 2);
assert!(config.bandwidth_limit_mbps.is_some());
}
#[test]
fn test_download_progress_complete() {
let progress = DownloadProgress {
model_id: "model_x".to_string(),
version: "2.0".to_string(),
bytes_downloaded: 1000,
total_bytes: 1000,
progress_percent: 100.0,
download_speed_mbps: 25.0,
eta_seconds: 0,
status: DownloadStatus::Completed,
};
assert_eq!(progress.bytes_downloaded, progress.total_bytes);
assert!((progress.progress_percent - 100.0).abs() < f64::EPSILON);
}
#[test]
fn test_download_progress_queued() {
let progress = DownloadProgress {
model_id: "model_y".to_string(),
version: "1.0".to_string(),
bytes_downloaded: 0,
total_bytes: 500_000_000,
progress_percent: 0.0,
download_speed_mbps: 0.0,
eta_seconds: 0,
status: DownloadStatus::Queued,
};
assert_eq!(progress.bytes_downloaded, 0);
assert_eq!(progress.status, DownloadStatus::Queued);
}
#[test]
fn test_model_metadata_no_optional_fields() {
let metadata = ModelMetadata {
name: "Minimal".to_string(),
description: None,
architecture: "cnn".to_string(),
task: "classification".to_string(),
language: None,
license: None,
tags: vec![],
performance_metrics: HashMap::new(),
size_mb: 10.0,
dependencies: vec![],
};
assert!(metadata.description.is_none());
assert!(metadata.language.is_none());
assert!(metadata.license.is_none());
}
#[test]
fn test_model_metadata_with_dependencies() {
let metadata = ModelMetadata {
name: "Complex".to_string(),
description: None,
architecture: "transformer".to_string(),
task: "generation".to_string(),
language: None,
license: None,
tags: vec![],
performance_metrics: HashMap::new(),
size_mb: 100.0,
dependencies: vec!["tokenizer-v2".to_string(), "vocab-en".to_string()],
};
assert_eq!(metadata.dependencies.len(), 2);
}
#[test]
fn test_mirror_stats_with_activity() {
let stats = MirrorStats {
total_models: 25,
total_size_gb: 50.5,
cache_hits: 1000,
cache_misses: 200,
downloads_completed: 30,
downloads_failed: 2,
last_sync: Some(SystemTime::now()),
sync_errors: 1,
bandwidth_saved_gb: 100.0,
average_download_speed_mbps: 50.0,
};
let hit_rate = stats.cache_hits as f64 / (stats.cache_hits + stats.cache_misses) as f64;
assert!(hit_rate > 0.8);
}
#[test]
fn test_cached_model_priority_flag() {
let cached = CachedModel {
model_id: "priority_model".to_string(),
version: "1.0".to_string(),
local_path: PathBuf::from("/tmp/priority"),
remote_url: "https://example.com".to_string(),
cached_at: SystemTime::now(),
last_accessed: SystemTime::now(),
access_count: 0,
file_size: 100,
checksum: "hash".to_string(),
metadata: ModelMetadata {
name: "Priority".to_string(),
description: None,
architecture: "test".to_string(),
task: "test".to_string(),
language: None,
license: None,
tags: vec![],
performance_metrics: HashMap::new(),
size_mb: 0.1,
dependencies: vec![],
},
is_priority: true,
download_complete: true,
};
assert!(cached.is_priority);
}
#[tokio::test]
async fn test_mirror_save_load_roundtrip() {
let temp_dir = TempDir::new().expect("failed to create temp dir");
let config = MirrorConfig {
storage_path: temp_dir.path().to_path_buf(),
..Default::default()
};
let mirror = HubMirror::new(config).expect("operation failed in test");
mirror.save_cache().await.expect("save failed");
mirror.load_cache().await.expect("load failed");
let cache = mirror.cache.read().expect("lock should not be poisoned");
assert_eq!(cache.len(), 0);
}
#[tokio::test]
async fn test_mirror_initial_stats() {
let temp_dir = TempDir::new().expect("failed to create temp dir");
let config = MirrorConfig {
storage_path: temp_dir.path().to_path_buf(),
..Default::default()
};
let mirror = HubMirror::new(config).expect("operation failed in test");
let stats = mirror.stats.lock().expect("lock should not be poisoned");
assert_eq!(stats.total_models, 0);
assert_eq!(stats.cache_hits, 0);
}
}