use crate::MobileConfig;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use trustformers_core::error::{CoreError, Result};
use trustformers_core::TrustformersError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelManagerConfig {
pub update_server_url: String,
pub api_key: Option<String>,
pub storage_directory: PathBuf,
pub max_storage_size_mb: usize,
pub enable_auto_updates: bool,
pub update_check_interval_seconds: u64,
pub enable_differential_updates: bool,
pub require_signature_verification: bool,
pub download_timeout_seconds: u64,
pub max_concurrent_downloads: usize,
pub enable_compression: bool,
pub download_retry_attempts: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub model_id: String,
pub version: String,
pub model_type: String,
pub size_bytes: usize,
pub checksum: String,
pub signature: Option<String>,
pub download_url: String,
pub differential_url: Option<String>,
pub description: String,
pub required_config: MobileConfig,
pub compatibility: ModelCompatibility,
pub release_timestamp: u64,
pub deprecation_timestamp: Option<u64>,
pub tags: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCompatibility {
pub min_android_api: Option<u32>,
pub min_ios_version: Option<String>,
pub required_features: Vec<String>,
pub min_memory_mb: usize,
pub supported_architectures: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct DownloadProgress {
pub model_id: String,
pub total_bytes: usize,
pub downloaded_bytes: usize,
pub download_speed_bps: f64,
pub eta_seconds: f64,
pub status: DownloadStatus,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DownloadStatus {
Pending,
Downloading,
Verifying,
Installing,
Completed,
Failed(String),
Cancelled,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelUpdate {
pub current: ModelMetadata,
pub update: ModelMetadata,
pub update_type: UpdateType,
pub priority: UpdatePriority,
pub update_size_bytes: usize,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UpdateType {
Full,
Differential,
ConfigOnly,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum UpdatePriority {
Critical,
High,
Normal,
Low,
}
pub struct ModelManager {
config: ModelManagerConfig,
models: HashMap<String, ModelMetadata>,
downloads: HashMap<String, DownloadProgress>,
storage_stats: StorageStats,
last_update_check: Option<SystemTime>,
http_client: Option<Box<dyn HttpClient>>,
}
#[derive(Debug, Clone)]
pub struct StorageStats {
total_size_bytes: usize,
used_size_bytes: usize,
model_count: usize,
last_cleanup_time: SystemTime,
}
pub trait HttpClient: Send + Sync {
fn download_file(&self, url: &str, destination: &Path) -> Result<()>;
fn download_with_progress(
&self,
url: &str,
destination: &Path,
progress_callback: Box<dyn Fn(usize, usize) + Send + Sync>,
) -> Result<()>;
fn get_metadata(&self, url: &str) -> Result<String>;
}
impl ModelManager {
pub fn new(config: ModelManagerConfig) -> Result<Self> {
config.validate()?;
std::fs::create_dir_all(&config.storage_directory)?;
let storage_stats = StorageStats {
total_size_bytes: config.max_storage_size_mb * 1024 * 1024,
used_size_bytes: 0,
model_count: 0,
last_cleanup_time: SystemTime::now(),
};
let mut manager = Self {
config,
models: HashMap::new(),
downloads: HashMap::new(),
storage_stats,
last_update_check: None,
http_client: None,
};
manager.load_models_from_storage()?;
#[cfg(not(test))]
{
manager.http_client = Some(Box::new(DefaultHttpClient::new(
manager.config.download_timeout_seconds,
)?));
}
Ok(manager)
}
pub fn set_http_client(&mut self, client: Box<dyn HttpClient>) {
self.http_client = Some(client);
}
pub async fn check_for_updates(&mut self) -> Result<Vec<ModelUpdate>> {
if !self.should_check_for_updates() {
return Ok(Vec::new());
}
tracing::info!(
"Checking for model updates from {}",
self.config.update_server_url
);
let updates_url = format!("{}/updates", self.config.update_server_url);
let client = self.http_client.as_ref().ok_or_else(|| {
TrustformersError::runtime_error("HTTP client not initialized".into())
})?;
let response = client.get_metadata(&updates_url)?;
let available_models: Vec<ModelMetadata> = serde_json::from_str(&response)?;
let mut updates = Vec::new();
for available_model in available_models {
if let Some(current_model) = self.models.get(&available_model.model_id) {
if self.is_update_available(current_model, &available_model)? {
let update_type = self.determine_update_type(current_model, &available_model);
let priority = self.determine_update_priority(&available_model);
let update_size =
self.calculate_update_size(current_model, &available_model, &update_type);
updates.push(ModelUpdate {
current: current_model.clone(),
update: available_model,
update_type,
priority,
update_size_bytes: update_size,
});
}
} else {
let update = ModelUpdate {
current: ModelMetadata::default_for_id(&available_model.model_id),
update: available_model.clone(),
update_type: UpdateType::Full,
priority: UpdatePriority::Normal,
update_size_bytes: available_model.size_bytes,
};
updates.push(update);
}
}
self.last_update_check = Some(SystemTime::now());
tracing::info!("Found {} available updates", updates.len());
Ok(updates)
}
pub async fn download_model(
&mut self,
model_id: &str,
progress_callback: Option<Box<dyn Fn(DownloadProgress) + Send + Sync>>,
) -> Result<()> {
let model_metadata = self.get_model_metadata_from_server(model_id).await?;
self.check_compatibility(&model_metadata)?;
self.ensure_storage_space(model_metadata.size_bytes)?;
let mut progress = DownloadProgress {
model_id: model_id.to_string(),
total_bytes: model_metadata.size_bytes,
downloaded_bytes: 0,
download_speed_bps: 0.0,
eta_seconds: 0.0,
status: DownloadStatus::Pending,
};
self.downloads.insert(model_id.to_string(), progress.clone());
if let Some(ref callback) = progress_callback {
callback(progress.clone());
}
progress.status = DownloadStatus::Downloading;
self.downloads.insert(model_id.to_string(), progress.clone());
let temp_path = self.get_temp_download_path(model_id);
let final_path = self.get_model_path(model_id);
let client = self.http_client.as_ref().ok_or_else(|| {
TrustformersError::runtime_error("HTTP client not initialized".into())
})?;
let download_progress_callback = {
let model_id = model_id.to_string();
move |downloaded: usize, total: usize| {
tracing::debug!(
"Download progress: {}/{} bytes for model {}",
downloaded,
total,
model_id
);
}
};
client.download_with_progress(
&model_metadata.download_url,
&temp_path,
Box::new(download_progress_callback),
)?;
progress.status = DownloadStatus::Verifying;
self.downloads.insert(model_id.to_string(), progress.clone());
if let Some(ref callback) = progress_callback {
callback(progress.clone());
}
self.verify_model(&temp_path, &model_metadata)?;
progress.status = DownloadStatus::Installing;
self.downloads.insert(model_id.to_string(), progress.clone());
if let Some(ref callback) = progress_callback {
callback(progress.clone());
}
self.install_model(&temp_path, &final_path, &model_metadata)?;
progress.status = DownloadStatus::Completed;
self.downloads.insert(model_id.to_string(), progress.clone());
if let Some(ref callback) = progress_callback {
callback(progress.clone());
}
self.models.insert(model_id.to_string(), model_metadata);
self.save_models_metadata()?;
let _ = std::fs::remove_file(&temp_path);
self.downloads.remove(model_id);
tracing::info!("Successfully downloaded and installed model: {}", model_id);
Ok(())
}
pub async fn apply_differential_update(&mut self, update: &ModelUpdate) -> Result<()> {
if update.update_type != UpdateType::Differential {
return Err(
TrustformersError::runtime_error("Not a differential update".into()).into(),
);
}
let current_path = self.get_model_path(&update.current.model_id);
let patch_path = self.get_temp_download_path(&format!("{}_patch", update.current.model_id));
let updated_path =
self.get_temp_download_path(&format!("{}_updated", update.current.model_id));
if let Some(diff_url) = &update.update.differential_url {
let client = self.http_client.as_ref().ok_or_else(|| {
TrustformersError::runtime_error("HTTP client not initialized".into())
})?;
client.download_file(diff_url, &patch_path)?;
} else {
return Err(
TrustformersError::runtime_error("No differential URL provided".into()).into(),
);
}
self.apply_binary_patch(¤t_path, &patch_path, &updated_path)?;
self.verify_model(&updated_path, &update.update)?;
let final_path = self.get_model_path(&update.current.model_id);
std::fs::rename(&updated_path, &final_path)?;
self.models.insert(update.current.model_id.clone(), update.update.clone());
self.save_models_metadata()?;
let _ = std::fs::remove_file(&patch_path);
tracing::info!(
"Successfully applied differential update for model: {}",
update.current.model_id
);
Ok(())
}
pub fn remove_model(&mut self, model_id: &str) -> Result<()> {
let model_path = self.get_model_path(model_id);
if model_path.exists() {
std::fs::remove_file(&model_path)?;
}
if let Some(model) = self.models.remove(model_id) {
self.storage_stats.used_size_bytes =
self.storage_stats.used_size_bytes.saturating_sub(model.size_bytes);
self.storage_stats.model_count = self.storage_stats.model_count.saturating_sub(1);
}
self.save_models_metadata()?;
tracing::info!("Removed model: {}", model_id);
Ok(())
}
pub fn list_models(&self) -> Vec<&ModelMetadata> {
self.models.values().collect()
}
pub fn get_model(&self, model_id: &str) -> Option<&ModelMetadata> {
self.models.get(model_id)
}
pub fn get_model_path(&self, model_id: &str) -> PathBuf {
self.config.storage_directory.join(format!("{}.trustformers", model_id))
}
pub fn get_storage_stats(&self) -> &StorageStats {
&self.storage_stats
}
pub fn cleanup_storage(&mut self) -> Result<()> {
let mut models_by_age: Vec<_> = self
.models
.values()
.filter(|model| {
if let Some(deprecation_time) = model.deprecation_timestamp {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("SystemTime should be after UNIX_EPOCH")
.as_secs();
now > deprecation_time
} else {
false
}
})
.map(|model| (model.model_id.clone(), model.release_timestamp))
.collect();
models_by_age.sort_by_key(|(_, timestamp)| *timestamp);
while self.storage_stats.used_size_bytes > self.storage_stats.total_size_bytes * 90 / 100 {
if let Some((model_id, _)) = models_by_age.pop() {
self.remove_model(&model_id)?;
} else {
break;
}
}
self.storage_stats.last_cleanup_time = SystemTime::now();
Ok(())
}
pub fn cancel_download(&mut self, model_id: &str) -> Result<()> {
if let Some(mut progress) = self.downloads.remove(model_id) {
progress.status = DownloadStatus::Cancelled;
let temp_path = self.get_temp_download_path(model_id);
let _ = std::fs::remove_file(&temp_path);
tracing::info!("Cancelled download for model: {}", model_id);
}
Ok(())
}
fn should_check_for_updates(&self) -> bool {
if !self.config.enable_auto_updates {
return false;
}
if let Some(last_check) = self.last_update_check {
let elapsed = SystemTime::now().duration_since(last_check).unwrap_or(Duration::ZERO);
elapsed.as_secs() >= self.config.update_check_interval_seconds
} else {
true
}
}
fn is_update_available(
&self,
current: &ModelMetadata,
available: &ModelMetadata,
) -> Result<bool> {
Ok(available.version != current.version
&& available.release_timestamp > current.release_timestamp)
}
fn determine_update_type(
&self,
current: &ModelMetadata,
available: &ModelMetadata,
) -> UpdateType {
if available.differential_url.is_some() && self.config.enable_differential_updates {
UpdateType::Differential
} else {
UpdateType::Full
}
}
fn determine_update_priority(&self, model: &ModelMetadata) -> UpdatePriority {
if model.tags.contains(&"critical".to_string())
|| model.tags.contains(&"security".to_string())
{
UpdatePriority::Critical
} else if model.tags.contains(&"performance".to_string()) {
UpdatePriority::High
} else {
UpdatePriority::Normal
}
}
fn calculate_update_size(
&self,
current: &ModelMetadata,
available: &ModelMetadata,
update_type: &UpdateType,
) -> usize {
match update_type {
UpdateType::Full => available.size_bytes,
UpdateType::Differential => available.size_bytes / 10, UpdateType::ConfigOnly => 1024, }
}
async fn get_model_metadata_from_server(&self, model_id: &str) -> Result<ModelMetadata> {
let metadata_url = format!(
"{}/models/{}/metadata",
self.config.update_server_url, model_id
);
let client = self.http_client.as_ref().ok_or_else(|| {
TrustformersError::runtime_error("HTTP client not initialized".into())
})?;
let response = client.get_metadata(&metadata_url)?;
let metadata: ModelMetadata = serde_json::from_str(&response)?;
Ok(metadata)
}
fn check_compatibility(&self, model: &ModelMetadata) -> Result<()> {
#[cfg(target_os = "android")]
{
if let Some(min_api) = model.compatibility.min_android_api {
if min_api > 21 {
return Err(TrustformersError::runtime_error(
"Android version incompatible".into(),
)
.into());
}
}
}
#[cfg(target_os = "ios")]
{
if let Some(ref min_ios) = model.compatibility.min_ios_version {
tracing::debug!("Checking iOS compatibility: {}", min_ios);
}
}
let available_memory = self.get_available_memory_mb();
if model.compatibility.min_memory_mb > available_memory {
return Err(TrustformersError::runtime_error(format!(
"Insufficient memory: {} MB required, {} MB available",
model.compatibility.min_memory_mb, available_memory
))
.into());
}
Ok(())
}
fn ensure_storage_space(&mut self, required_bytes: usize) -> Result<()> {
let available_space =
self.storage_stats.total_size_bytes - self.storage_stats.used_size_bytes;
if available_space < required_bytes {
self.cleanup_storage()?;
let available_after_cleanup =
self.storage_stats.total_size_bytes - self.storage_stats.used_size_bytes;
if available_after_cleanup < required_bytes {
return Err(
TrustformersError::runtime_error("Insufficient storage space".into()).into(),
);
}
}
Ok(())
}
fn verify_model(&self, path: &Path, metadata: &ModelMetadata) -> Result<()> {
let file_size = std::fs::metadata(path)?.len() as usize;
if file_size != metadata.size_bytes {
return Err(TrustformersError::runtime_error("Model size mismatch".into()).into());
}
let file_checksum = self.calculate_checksum(path)?;
if file_checksum != metadata.checksum {
return Err(TrustformersError::runtime_error("Model checksum mismatch".into()).into());
}
if self.config.require_signature_verification {
if let Some(ref signature) = metadata.signature {
self.verify_signature(path, signature)?;
} else {
return Err(TrustformersError::runtime_error(
"Model signature required but not provided".into(),
)
.into());
}
}
Ok(())
}
fn install_model(
&mut self,
temp_path: &Path,
final_path: &Path,
metadata: &ModelMetadata,
) -> Result<()> {
std::fs::rename(temp_path, final_path)?;
if self.config.enable_compression {
self.compress_model(final_path)?;
}
self.storage_stats.used_size_bytes += metadata.size_bytes;
self.storage_stats.model_count += 1;
Ok(())
}
fn load_models_from_storage(&mut self) -> Result<()> {
let metadata_path = self.config.storage_directory.join("models_metadata.json");
if metadata_path.exists() {
let metadata_content = std::fs::read_to_string(&metadata_path)?;
let models: HashMap<String, ModelMetadata> = serde_json::from_str(&metadata_content)?;
for (model_id, metadata) in models {
let model_path = self.get_model_path(&model_id);
if model_path.exists() {
self.storage_stats.used_size_bytes += metadata.size_bytes;
self.storage_stats.model_count += 1;
self.models.insert(model_id, metadata);
}
}
}
Ok(())
}
fn save_models_metadata(&self) -> Result<()> {
let metadata_path = self.config.storage_directory.join("models_metadata.json");
let metadata_content = serde_json::to_string_pretty(&self.models)?;
std::fs::write(&metadata_path, metadata_content)?;
Ok(())
}
fn get_temp_download_path(&self, model_id: &str) -> PathBuf {
self.config.storage_directory.join(format!("{}.tmp", model_id))
}
fn calculate_checksum(&self, path: &Path) -> Result<String> {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let content = std::fs::read(path)?;
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
Ok(format!("{:x}", hasher.finish()))
}
fn verify_signature(&self, _path: &Path, _signature: &str) -> Result<()> {
Ok(())
}
fn apply_binary_patch(&self, _original: &Path, _patch: &Path, _output: &Path) -> Result<()> {
Ok(())
}
fn compress_model(&self, _path: &Path) -> Result<()> {
Ok(())
}
fn get_available_memory_mb(&self) -> usize {
2048 }
}
impl Default for ModelManagerConfig {
fn default() -> Self {
Self {
update_server_url: "https://models.trustformers.ai".to_string(),
api_key: None,
storage_directory: PathBuf::from("./models"),
max_storage_size_mb: 1024,
enable_auto_updates: true,
update_check_interval_seconds: 3600, enable_differential_updates: true,
require_signature_verification: true,
download_timeout_seconds: 300, max_concurrent_downloads: 2,
enable_compression: true,
download_retry_attempts: 3,
}
}
}
impl ModelManagerConfig {
pub fn validate(&self) -> Result<()> {
if self.update_server_url.is_empty() {
return Err(TrustformersError::config_error(
"Update server URL cannot be empty",
"validate_model_config",
)
.into());
}
if self.max_storage_size_mb == 0 {
return Err(
TrustformersError::config_error("Storage size must be > 0", "validate").into(),
);
}
if self.download_timeout_seconds < 30 {
return Err(
TrustformersError::config_error("Download timeout too short", "validate").into(),
);
}
if self.max_concurrent_downloads == 0 || self.max_concurrent_downloads > 10 {
return Err(TrustformersError::config_error(
"Invalid concurrent download count",
"validate",
)
.into());
}
Ok(())
}
}
impl ModelMetadata {
fn default_for_id(model_id: &str) -> Self {
Self {
model_id: model_id.to_string(),
version: "0.0.0".to_string(),
model_type: "unknown".to_string(),
size_bytes: 0,
checksum: String::new(),
signature: None,
download_url: String::new(),
differential_url: None,
description: String::new(),
required_config: MobileConfig::default(),
compatibility: ModelCompatibility::default(),
release_timestamp: 0,
deprecation_timestamp: None,
tags: Vec::new(),
}
}
}
impl Default for ModelCompatibility {
fn default() -> Self {
Self {
min_android_api: Some(21),
min_ios_version: Some("12.0".to_string()),
required_features: Vec::new(),
min_memory_mb: 512,
supported_architectures: vec!["arm64".to_string(), "x86_64".to_string()],
}
}
}
#[cfg(not(test))]
struct DefaultHttpClient {
timeout: Duration,
}
#[cfg(not(test))]
impl DefaultHttpClient {
fn new(timeout_seconds: u64) -> Result<Self> {
Ok(Self {
timeout: Duration::from_secs(timeout_seconds),
})
}
}
#[cfg(not(test))]
impl HttpClient for DefaultHttpClient {
fn download_file(&self, _url: &str, _destination: &Path) -> Result<()> {
Ok(())
}
fn download_with_progress(
&self,
_url: &str,
_destination: &Path,
_progress_callback: Box<dyn Fn(usize, usize) + Send + Sync>,
) -> Result<()> {
Ok(())
}
fn get_metadata(&self, _url: &str) -> Result<String> {
Ok("{}".to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
struct MockHttpClient {
responses: Arc<Mutex<HashMap<String, String>>>,
}
impl MockHttpClient {
fn new() -> Self {
Self {
responses: Arc::new(Mutex::new(HashMap::new())),
}
}
fn set_response(&self, url: &str, response: &str) {
self.responses
.lock()
.expect("lock should not be poisoned")
.insert(url.to_string(), response.to_string());
}
}
impl HttpClient for MockHttpClient {
fn download_file(&self, _url: &str, _destination: &Path) -> Result<()> {
Ok(())
}
fn download_with_progress(
&self,
_url: &str,
_destination: &Path,
progress_callback: Box<dyn Fn(usize, usize) + Send + Sync>,
) -> Result<()> {
progress_callback(100, 100);
Ok(())
}
fn get_metadata(&self, url: &str) -> Result<String> {
let responses = self.responses.lock().expect("lock should not be poisoned");
responses
.get(url)
.cloned()
.ok_or_else(|| {
TrustformersError::runtime_error("URL not found in mock responses".into())
})
.map_err(|e| e.into())
}
}
#[test]
fn test_model_manager_config_validation() {
let config = ModelManagerConfig::default();
assert!(config.validate().is_ok());
let mut invalid_config = config.clone();
invalid_config.update_server_url = String::new();
assert!(invalid_config.validate().is_err());
invalid_config.update_server_url = "https://example.com".to_string();
invalid_config.max_storage_size_mb = 0;
assert!(invalid_config.validate().is_err());
}
#[test]
fn test_model_metadata_creation() {
let metadata = ModelMetadata::default_for_id("test_model");
assert_eq!(metadata.model_id, "test_model");
assert_eq!(metadata.version, "0.0.0");
}
#[tokio::test]
async fn test_model_manager_creation() {
let config = ModelManagerConfig::default();
let result = ModelManager::new(config);
assert!(result.is_ok());
}
#[test]
fn test_model_manager_config_defaults() {
let config = ModelManagerConfig::default();
assert!(!config.update_server_url.is_empty());
assert!(config.api_key.is_none());
assert!(config.max_storage_size_mb > 0);
assert!(config.enable_auto_updates);
assert!(config.enable_differential_updates);
assert!(config.require_signature_verification);
assert!(config.download_timeout_seconds > 0);
assert!(config.max_concurrent_downloads > 0);
assert!(config.download_retry_attempts > 0);
}
#[test]
fn test_model_manager_config_invalid_empty_url() {
let mut config = ModelManagerConfig::default();
config.update_server_url = String::new();
assert!(config.validate().is_err());
}
#[test]
fn test_model_manager_config_invalid_storage() {
let mut config = ModelManagerConfig::default();
config.max_storage_size_mb = 0;
assert!(config.validate().is_err());
}
#[test]
fn test_model_metadata_default_for_id() {
let metadata = ModelMetadata::default_for_id("my_model");
assert_eq!(metadata.model_id, "my_model");
assert_eq!(metadata.version, "0.0.0");
assert!(metadata.signature.is_none());
}
#[test]
fn test_model_compatibility_creation() {
let compat = ModelCompatibility {
min_android_api: Some(30),
min_ios_version: Some("16.0".to_string()),
required_features: vec!["gpu".to_string()],
min_memory_mb: 256,
supported_architectures: vec!["arm64".to_string(), "x86_64".to_string()],
};
assert_eq!(compat.min_android_api, Some(30));
assert_eq!(compat.supported_architectures.len(), 2);
}
#[test]
fn test_download_progress_creation() {
let progress = DownloadProgress {
model_id: "test_model".to_string(),
total_bytes: 1000000,
downloaded_bytes: 500000,
download_speed_bps: 100000.0,
eta_seconds: 5.0,
status: DownloadStatus::Downloading,
};
assert_eq!(progress.downloaded_bytes * 2, progress.total_bytes);
assert!(progress.download_speed_bps > 0.0);
}
#[test]
fn test_model_metadata_with_all_fields() {
let metadata = ModelMetadata {
model_id: "bert_tiny".to_string(),
version: "1.2.0".to_string(),
model_type: "transformer".to_string(),
size_bytes: 50_000_000,
checksum: "abc123def456".to_string(),
signature: Some("sig_data".to_string()),
download_url: "https://models.example.com/bert_tiny.bin".to_string(),
differential_url: Some("https://models.example.com/bert_tiny.diff".to_string()),
description: "Tiny BERT model".to_string(),
required_config: MobileConfig::default(),
compatibility: ModelCompatibility {
min_android_api: None,
min_ios_version: None,
required_features: vec![],
min_memory_mb: 64,
supported_architectures: vec![],
},
release_timestamp: 1700000000,
deprecation_timestamp: None,
tags: vec!["nlp".to_string(), "bert".to_string()],
};
assert_eq!(metadata.model_id, "bert_tiny");
assert!(metadata.signature.is_some());
assert!(metadata.differential_url.is_some());
assert!(metadata.deprecation_timestamp.is_none());
assert_eq!(metadata.tags.len(), 2);
}
#[test]
fn test_model_manager_config_with_api_key() {
let mut config = ModelManagerConfig::default();
config.api_key = Some("test_key_123".to_string());
assert!(config.api_key.is_some());
assert!(config.validate().is_ok());
}
#[test]
fn test_model_manager_config_compression_disabled() {
let mut config = ModelManagerConfig::default();
config.enable_compression = false;
assert!(!config.enable_compression);
assert!(config.validate().is_ok());
}
#[test]
fn test_model_manager_config_auto_updates_disabled() {
let mut config = ModelManagerConfig::default();
config.enable_auto_updates = false;
assert!(!config.enable_auto_updates);
assert!(config.validate().is_ok());
}
#[test]
fn test_model_compatibility_minimal() {
let compat = ModelCompatibility {
min_android_api: None,
min_ios_version: None,
required_features: vec![],
min_memory_mb: 0,
supported_architectures: vec![],
};
assert!(compat.required_features.is_empty());
}
#[test]
fn test_mock_http_client() {
let client = MockHttpClient::new();
client.set_response("https://example.com/api", "{\"status\": \"ok\"}");
let result = client.get_metadata("https://example.com/api");
assert!(result.is_ok());
}
#[test]
fn test_mock_http_client_missing_url() {
let client = MockHttpClient::new();
let result = client.get_metadata("https://nonexistent.com");
assert!(result.is_err());
}
#[test]
fn test_mock_http_client_download() {
let client = MockHttpClient::new();
let tmp_path = std::env::temp_dir().join("test_download");
let result = client.download_file("https://example.com/model", &tmp_path);
assert!(result.is_ok());
}
#[test]
fn test_mock_http_client_download_with_progress() {
let client = MockHttpClient::new();
let tmp_path = std::env::temp_dir().join("test_download_progress");
let result = client.download_with_progress(
"https://example.com/model",
&tmp_path,
Box::new(|downloaded, total| {
let _ = (downloaded, total);
}),
);
assert!(result.is_ok());
}
}