use crate::error::{Result, TrustformersError};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use std::path::{Path, PathBuf};
use std::time::SystemTime;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub model_id: String,
pub library_name: Option<String>,
pub pipeline_tag: Option<String>,
pub tags: Vec<String>,
pub config: HashMap<String, serde_json::Value>,
pub downloads: Option<u64>,
pub likes: Option<u64>,
pub created_at: Option<String>,
pub updated_at: Option<String>,
pub author: Option<String>,
pub description: Option<String>,
pub license: Option<String>,
pub task: Option<String>,
pub language: Vec<String>,
pub dataset: Vec<String>,
pub model_type: Option<String>,
pub architecture: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelPackMetadata {
pub pack_id: String,
pub name: String,
pub description: String,
pub version: String,
pub created_at: SystemTime,
pub created_by: String,
pub total_size: u64,
pub models: Vec<PackedModelInfo>,
pub dependencies: Vec<String>,
pub target_platforms: Vec<String>,
pub checksum: String,
pub compression_ratio: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PackedModelInfo {
pub model_id: String,
pub name: String,
pub version: String,
pub original_size: u64,
pub compressed_size: u64,
pub model_type: ModelType,
pub framework: String,
pub precision: PrecisionType,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ModelType {
TextGeneration,
TextClassification,
ImageClassification,
SpeechRecognition,
Translation,
Summarization,
QuestionAnswering,
Multimodal,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum PrecisionType {
FP32,
FP16,
INT8,
INT4,
Mixed,
}
#[derive(Debug, Clone)]
pub struct PackCreationConfig {
pub compression_level: u8, pub include_cache: bool,
pub include_examples: bool,
pub include_documentation: bool,
pub target_platforms: Vec<String>,
pub max_pack_size: Option<u64>, pub split_large_packs: bool,
}
impl Default for PackCreationConfig {
fn default() -> Self {
Self {
compression_level: 6,
include_cache: false,
include_examples: true,
include_documentation: true,
target_platforms: vec![
"linux".to_string(),
"windows".to_string(),
"macos".to_string(),
],
max_pack_size: Some(2 * 1024 * 1024 * 1024), split_large_packs: true,
}
}
}
pub struct OfflineModelPackManager {
base_path: PathBuf,
registry: HashMap<String, ModelPackMetadata>,
}
impl OfflineModelPackManager {
pub fn new(base_path: impl AsRef<Path>) -> Result<Self> {
let base_path = base_path.as_ref().to_path_buf();
std::fs::create_dir_all(&base_path)?;
let mut manager = Self {
base_path,
registry: HashMap::new(),
};
manager.load_registry()?;
Ok(manager)
}
pub async fn create_pack(
&mut self,
name: String,
description: String,
model_ids: Vec<String>,
config: PackCreationConfig,
) -> Result<String> {
let pack_id = Uuid::new_v4().to_string();
let pack_path = self.base_path.join(format!("{}.tfpack", pack_id));
let mut models = Vec::new();
let mut total_original_size = 0u64;
for model_id in &model_ids {
let model_info = self.get_model_info(model_id).await?;
let estimated_size = 1024 * 1024 * 512; total_original_size += estimated_size;
models.push(PackedModelInfo {
model_id: model_id.clone(),
name: model_info.model_id.clone(),
version: "latest".to_string(), original_size: estimated_size,
compressed_size: 0, model_type: self.infer_model_type(&model_info),
framework: model_info
.library_name
.clone()
.unwrap_or_else(|| "transformers".to_string()),
precision: PrecisionType::FP32, metadata: self.extract_metadata_from_model_info(&model_info),
});
}
let compressed_size =
self.create_compressed_archive(&model_ids, &pack_path, &config).await?;
let compression_ratio = if total_original_size > 0 {
compressed_size as f64 / total_original_size as f64
} else {
1.0
};
for model in &mut models {
model.compressed_size = (model.original_size as f64 * compression_ratio) as u64;
}
let checksum = self.calculate_file_checksum(&pack_path)?;
let metadata = ModelPackMetadata {
pack_id: pack_id.clone(),
name: name.clone(),
description,
version: "1.0.0".to_string(),
created_at: SystemTime::now(),
created_by: "trustformers".to_string(),
total_size: compressed_size,
models,
dependencies: Vec::new(), target_platforms: config.target_platforms.clone(),
checksum,
compression_ratio,
};
self.save_pack_metadata(&metadata)?;
self.registry.insert(pack_id.clone(), metadata);
Ok(pack_id)
}
pub async fn install_pack(&mut self, pack_path: impl AsRef<Path>) -> Result<String> {
let pack_path = pack_path.as_ref();
let metadata = self.load_pack_metadata(pack_path)?;
self.verify_pack_integrity(pack_path, &metadata)?;
let install_path = self.base_path.join("installed").join(&metadata.pack_id);
std::fs::create_dir_all(&install_path)?;
self.extract_pack(pack_path, &install_path).await?;
self.registry.insert(metadata.pack_id.clone(), metadata.clone());
self.save_registry()?;
Ok(metadata.pack_id)
}
pub fn list_packs(&self) -> Vec<&ModelPackMetadata> {
self.registry.values().collect()
}
pub fn get_pack_info(&self, pack_id: &str) -> Option<&ModelPackMetadata> {
self.registry.get(pack_id)
}
pub async fn remove_pack(&mut self, pack_id: &str) -> Result<()> {
if let Some(metadata) = self.registry.remove(pack_id) {
let install_path = self.base_path.join("installed").join(&metadata.pack_id);
if install_path.exists() {
tokio::fs::remove_dir_all(&install_path).await?;
}
let pack_path = self.base_path.join(format!("{}.tfpack", pack_id));
if pack_path.exists() {
tokio::fs::remove_file(&pack_path).await?;
}
self.save_registry()?;
}
Ok(())
}
pub async fn create_curated_pack(
&mut self,
pack_type: CuratedPackType,
config: PackCreationConfig,
) -> Result<String> {
let (name, description, model_ids) = match pack_type {
CuratedPackType::NLP => (
"NLP Essentials".to_string(),
"Essential models for natural language processing tasks".to_string(),
vec![
"bert-base-uncased".to_string(),
"gpt2".to_string(),
"distilbert-base-uncased".to_string(),
"roberta-base".to_string(),
],
),
CuratedPackType::Vision => (
"Computer Vision Pack".to_string(),
"Essential models for computer vision tasks".to_string(),
vec![
"vit-base-patch16-224".to_string(),
"resnet-50".to_string(),
"clip-vit-base-patch32".to_string(),
],
),
CuratedPackType::Multimodal => (
"Multimodal AI Pack".to_string(),
"Models for cross-modal understanding and generation".to_string(),
vec![
"clip-vit-base-patch32".to_string(),
"blip-image-captioning-base".to_string(),
"layoutlm-base-uncased".to_string(),
],
),
CuratedPackType::EdgeOptimized => (
"Edge Deployment Pack".to_string(),
"Optimized models for edge and mobile deployment".to_string(),
vec![
"distilbert-base-uncased".to_string(),
"mobilenet-v2".to_string(),
"efficientnet-b0".to_string(),
],
),
};
self.create_pack(name, description, model_ids, config).await
}
pub async fn update_pack(
&mut self,
pack_id: &str,
additional_models: Vec<String>,
) -> Result<String> {
let existing_metadata = self
.registry
.get(pack_id)
.ok_or_else(|| {
TrustformersError::file_not_found(format!("Pack {} not found", pack_id))
})?
.clone();
let mut all_models: Vec<String> =
existing_metadata.models.iter().map(|m| m.model_id.clone()).collect();
all_models.extend(additional_models);
let new_pack_id = self
.create_pack(
format!("{} (Updated)", existing_metadata.name),
existing_metadata.description,
all_models,
PackCreationConfig::default(),
)
.await?;
self.remove_pack(pack_id).await?;
Ok(new_pack_id)
}
async fn get_model_info(&self, model_id: &str) -> Result<ModelInfo> {
Ok(ModelInfo {
model_id: model_id.to_string(),
pipeline_tag: Some("text-generation".to_string()),
library_name: Some("transformers".to_string()),
tags: vec![],
config: HashMap::new(),
downloads: Some(1000),
likes: Some(50),
created_at: None,
updated_at: None,
author: None,
description: None,
license: None,
task: None,
language: vec![],
dataset: vec![],
model_type: None,
architecture: None,
})
}
async fn create_compressed_archive(
&self,
model_ids: &[String],
output_path: &Path,
config: &PackCreationConfig,
) -> Result<u64> {
use oxiarc_archive::tar::TarWriter;
use oxiarc_deflate::streaming::GzipStreamEncoder;
let file = File::create(output_path)?;
let encoder = GzipStreamEncoder::new(file, 6);
let mut tar_writer = TarWriter::new(encoder);
let metadata = serde_json::json!({
"version": "1.0",
"compression": format!("{:?}", config.compression_level),
"models": model_ids.len(),
"created": std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
"split_large_packs": config.split_large_packs,
"model_ids": model_ids
});
let metadata_content = serde_json::to_string_pretty(&metadata)?;
tar_writer
.add_file_with_mode("pack_metadata.json", metadata_content.as_bytes(), 0o644)
.map_err(|e| TrustformersError::invalid_input_simple(e.to_string()))?;
let mut total_size = 0u64;
for model_id in model_ids {
let model_config = serde_json::json!({
"model_id": model_id,
"type": "transformers",
"format": "safetensors",
"architecture": "auto-detected"
});
let config_content = serde_json::to_string_pretty(&model_config)?;
let model_path = format!("models/{}/config.json", model_id);
let content_len = config_content.len() as u64;
tar_writer
.add_file_with_mode(&model_path, config_content.as_bytes(), 0o644)
.map_err(|e| TrustformersError::invalid_input_simple(e.to_string()))?;
total_size += content_len;
}
let encoder = tar_writer
.into_inner()
.map_err(|e| TrustformersError::invalid_input_simple(e.to_string()))?;
encoder
.finish()
.map_err(|e| TrustformersError::invalid_input_simple(e.to_string()))?;
let final_size = output_path.metadata()?.len();
Ok(final_size)
}
fn calculate_file_checksum(&self, file_path: &Path) -> Result<String> {
let mut file = File::open(file_path)?;
let mut hasher = Sha256::new();
let mut buffer = [0; 8192];
loop {
let bytes_read = file.read(&mut buffer)?;
if bytes_read == 0 {
break;
}
hasher.update(&buffer[..bytes_read]);
}
Ok(format!("{:x}", hasher.finalize()))
}
fn save_pack_metadata(&self, metadata: &ModelPackMetadata) -> Result<()> {
let metadata_path = self.base_path.join(format!("{}.metadata.json", metadata.pack_id));
let file = File::create(metadata_path)?;
serde_json::to_writer_pretty(file, metadata)?;
Ok(())
}
fn infer_model_type(&self, model_info: &ModelInfo) -> ModelType {
match model_info.pipeline_tag.as_deref() {
Some("text-generation") => ModelType::TextGeneration,
Some("text-classification") => ModelType::TextClassification,
Some("image-classification") => ModelType::ImageClassification,
Some("automatic-speech-recognition") => ModelType::SpeechRecognition,
Some("translation") => ModelType::Translation,
Some("summarization") => ModelType::Summarization,
Some("question-answering") => ModelType::QuestionAnswering,
_ => ModelType::TextGeneration, }
}
fn load_pack_metadata(&self, pack_path: &Path) -> Result<ModelPackMetadata> {
let pack_stem = pack_path.file_stem().ok_or_else(|| {
TrustformersError::invalid_input_simple("Invalid pack file name".to_string())
})?;
let metadata_path =
pack_path.with_file_name(format!("{}.metadata.json", pack_stem.to_string_lossy()));
if metadata_path.exists() {
let file = File::open(metadata_path)?;
let metadata: ModelPackMetadata = serde_json::from_reader(file)?;
Ok(metadata)
} else {
Err(TrustformersError::invalid_input_simple(
"Pack metadata not found".to_string(),
))
}
}
fn verify_pack_integrity(&self, pack_path: &Path, metadata: &ModelPackMetadata) -> Result<()> {
let calculated_checksum = self.calculate_file_checksum(pack_path)?;
if calculated_checksum != metadata.checksum {
return Err(TrustformersError::invalid_input_simple(
"Pack checksum mismatch".to_string(),
));
}
Ok(())
}
async fn extract_pack(&self, pack_path: &Path, extract_path: &Path) -> Result<()> {
use oxiarc_archive::tar::TarStreamReader;
use oxiarc_deflate::streaming::GzipStreamDecoder;
use std::io::Read as _;
const TAR_REGULAR_FILE: u8 = b'0';
const TAR_REGULAR_FILE_ALT: u8 = 0;
const TAR_DIRECTORY: u8 = b'5';
std::fs::create_dir_all(extract_path)?;
let file = File::open(pack_path)?;
let decoder = GzipStreamDecoder::new(file);
let mut stream = TarStreamReader::new(decoder);
while let Some(mut entry) = stream
.next_entry()
.map_err(|e| TrustformersError::invalid_input_simple(e.to_string()))?
{
let entry_name = entry.header.name.clone();
let typeflag = entry.header.typeflag;
let sanitized = entry_name.trim_start_matches("./").trim_start_matches('/');
let dest = extract_path.join(sanitized);
match typeflag {
TAR_DIRECTORY => {
std::fs::create_dir_all(&dest)?;
},
TAR_REGULAR_FILE | TAR_REGULAR_FILE_ALT => {
if let Some(parent) = dest.parent() {
std::fs::create_dir_all(parent)?;
}
let mut out_file = File::create(&dest)?;
let mut buf = Vec::new();
entry
.read_to_end(&mut buf)
.map_err(|e| TrustformersError::invalid_input_simple(e.to_string()))?;
std::io::Write::write_all(&mut out_file, &buf)?;
},
_ => {},
}
}
let metadata_path = extract_path.join("pack_metadata.json");
let manifest = if metadata_path.exists() {
let metadata_content = std::fs::read_to_string(&metadata_path)?;
let mut metadata: serde_json::Value = serde_json::from_str(&metadata_content)?;
metadata["extraction_time"] = serde_json::json!(std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs());
metadata
} else {
serde_json::json!({
"extraction_time": std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
"pack_source": pack_path.display().to_string()
})
};
let manifest_path = extract_path.join("manifest.json");
std::fs::write(manifest_path, serde_json::to_string_pretty(&manifest)?)?;
Ok(())
}
fn load_registry(&mut self) -> Result<()> {
let registry_path = self.base_path.join("registry.json");
if registry_path.exists() {
let file = File::open(registry_path)?;
self.registry = serde_json::from_reader(file).unwrap_or_default();
}
Ok(())
}
fn save_registry(&self) -> Result<()> {
let registry_path = self.base_path.join("registry.json");
let file = File::create(registry_path)?;
serde_json::to_writer_pretty(file, &self.registry)?;
Ok(())
}
fn extract_metadata_from_model_info(&self, model_info: &ModelInfo) -> HashMap<String, String> {
let mut metadata = HashMap::new();
if let Some(author) = &model_info.author {
metadata.insert("author".to_string(), author.clone());
}
if let Some(description) = &model_info.description {
metadata.insert("description".to_string(), description.clone());
}
if let Some(license) = &model_info.license {
metadata.insert("license".to_string(), license.clone());
}
if let Some(created_at) = &model_info.created_at {
metadata.insert("created_at".to_string(), created_at.clone());
}
if let Some(updated_at) = &model_info.updated_at {
metadata.insert("updated_at".to_string(), updated_at.clone());
}
if let Some(downloads) = model_info.downloads {
metadata.insert("downloads".to_string(), downloads.to_string());
}
if let Some(likes) = model_info.likes {
metadata.insert("likes".to_string(), likes.to_string());
}
if let Some(task) = &model_info.task {
metadata.insert("task".to_string(), task.clone());
}
if let Some(architecture) = &model_info.architecture {
metadata.insert("architecture".to_string(), architecture.clone());
}
if let Some(model_type) = &model_info.model_type {
metadata.insert("model_type".to_string(), model_type.clone());
}
if let Some(pipeline_tag) = &model_info.pipeline_tag {
metadata.insert("pipeline_tag".to_string(), pipeline_tag.clone());
}
if !model_info.language.is_empty() {
metadata.insert("language".to_string(), model_info.language.join(", "));
}
if !model_info.dataset.is_empty() {
metadata.insert("datasets".to_string(), model_info.dataset.join(", "));
}
if !model_info.tags.is_empty() {
metadata.insert("tags".to_string(), model_info.tags.join(", "));
}
for (key, value) in &model_info.config {
match value {
serde_json::Value::String(s) => {
metadata.insert(format!("config_{}", key), s.clone());
},
serde_json::Value::Number(n) => {
metadata.insert(format!("config_{}", key), n.to_string());
},
serde_json::Value::Bool(b) => {
metadata.insert(format!("config_{}", key), b.to_string());
},
_ => {
metadata.insert(format!("config_{}", key), value.to_string());
},
}
}
metadata
}
}
#[derive(Debug, Clone)]
pub enum CuratedPackType {
NLP,
Vision,
Multimodal,
EdgeOptimized,
}
impl OfflineModelPackManager {
pub async fn create_development_pack(&mut self) -> Result<String> {
self.create_curated_pack(
CuratedPackType::NLP,
PackCreationConfig {
compression_level: 9,
include_examples: true,
include_documentation: true,
..Default::default()
},
)
.await
}
pub async fn create_production_pack(&mut self, target_platform: String) -> Result<String> {
self.create_curated_pack(
CuratedPackType::EdgeOptimized,
PackCreationConfig {
compression_level: 9,
include_cache: false,
include_examples: false,
include_documentation: false,
target_platforms: vec![target_platform],
max_pack_size: Some(1024 * 1024 * 1024), ..Default::default()
},
)
.await
}
}
pub struct HubIntegration {
pub hub_options: crate::hub::HubOptions,
}
impl HubIntegration {
pub fn new(options: Option<crate::hub::HubOptions>) -> Self {
Self {
hub_options: options.unwrap_or_default(),
}
}
pub async fn download_model_to_pack(
&self,
pack_manager: &mut OfflineModelPackManager,
model_id: &str,
pack_id: &str,
) -> Result<()> {
let _model_path = crate::hub::download_file_from_hub(
model_id,
"config.json",
Some(self.hub_options.clone()),
)
.map_err(|e| TrustformersError::io_error(format!("Hub download failed: {}", e)))?;
let model_info = self.get_hub_model_info(model_id).await?;
let additional_models = vec![model_id.to_string()];
pack_manager.update_pack(pack_id, additional_models).await?;
Ok(())
}
pub async fn create_pack_from_hub_collection(
&self,
pack_manager: &mut OfflineModelPackManager,
collection_name: &str,
model_ids: Vec<String>,
config: PackCreationConfig,
) -> Result<String> {
for model_id in &model_ids {
let _ = self.get_hub_model_info(model_id).await?;
}
pack_manager
.create_pack(
format!("Hub Collection: {}", collection_name),
format!(
"Model pack created from Hub collection: {}",
collection_name
),
model_ids,
config,
)
.await
}
async fn get_hub_model_info(&self, model_id: &str) -> Result<ModelInfo> {
match crate::hub::load_model_card_from_hub(model_id, Some(self.hub_options.clone())) {
Ok(model_card) => {
Ok(ModelInfo {
model_id: model_id.to_string(),
library_name: Some("transformers".to_string()),
pipeline_tag: model_card.pipeline_tag.clone(),
tags: model_card.tags.unwrap_or_default(),
config: model_card.extra.into_iter().collect(),
downloads: None, likes: None, created_at: None,
updated_at: None,
author: None,
description: None,
license: model_card.license,
task: model_card.pipeline_tag,
language: model_card.language.unwrap_or_default(),
dataset: model_card.datasets.unwrap_or_default(),
model_type: None,
architecture: None,
})
},
Err(_) => {
Ok(ModelInfo {
model_id: model_id.to_string(),
pipeline_tag: Some("text-generation".to_string()),
library_name: Some("transformers".to_string()),
tags: vec![],
config: HashMap::new(),
downloads: Some(1000),
likes: Some(50),
created_at: None,
updated_at: None,
author: None,
description: None,
license: None,
task: None,
language: vec![],
dataset: vec![],
model_type: None,
architecture: None,
})
},
}
}
}
impl OfflineModelPackManager {
pub fn with_hub_integration(
base_path: impl AsRef<Path>,
hub_options: Option<crate::hub::HubOptions>,
) -> Result<(Self, HubIntegration)> {
let manager = Self::new(base_path)?;
let hub_integration = HubIntegration::new(hub_options);
Ok((manager, hub_integration))
}
pub async fn create_pack_from_hub(
&mut self,
hub_integration: &HubIntegration,
name: String,
description: String,
model_ids: Vec<String>,
config: PackCreationConfig,
) -> Result<String> {
let mut enhanced_models = Vec::new();
let mut total_original_size = 0u64;
for model_id in &model_ids {
let model_info = hub_integration.get_hub_model_info(model_id).await?;
let estimated_size = 1024 * 1024 * 512; total_original_size += estimated_size;
enhanced_models.push(PackedModelInfo {
model_id: model_id.clone(),
name: model_info.model_id.clone(),
version: "latest".to_string(),
original_size: estimated_size,
compressed_size: 0, model_type: self.infer_model_type(&model_info),
framework: model_info
.library_name
.clone()
.unwrap_or_else(|| "transformers".to_string()),
precision: PrecisionType::FP32, metadata: self.extract_metadata_from_model_info(&model_info),
});
}
self.create_pack(name, description, model_ids, config).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
fn temp_dir_path() -> std::path::PathBuf {
let mut path = env::temp_dir();
let pid = std::process::id() as u64;
let suffix = pid.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
path.push(format!("trustformers_test_{}", suffix));
path
}
#[test]
fn test_model_pack_metadata_fields() {
let metadata = ModelPackMetadata {
pack_id: "test-pack-id".to_string(),
name: "Test Pack".to_string(),
description: "A test model pack".to_string(),
version: "1.0.0".to_string(),
created_at: SystemTime::now(),
created_by: "trustformers".to_string(),
total_size: 1024 * 1024,
models: vec![],
dependencies: vec![],
target_platforms: vec!["linux".to_string()],
checksum: "abc123".to_string(),
compression_ratio: 0.75,
};
assert_eq!(metadata.pack_id, "test-pack-id");
assert_eq!(metadata.name, "Test Pack");
assert!(!metadata.version.is_empty(), "version should not be empty");
assert!(
metadata.compression_ratio > 0.0,
"compression_ratio should be positive"
);
}
#[test]
fn test_model_pack_metadata_compression_ratio_bounded() {
let metadata = ModelPackMetadata {
pack_id: "id1".to_string(),
name: "Pack".to_string(),
description: "desc".to_string(),
version: "1.0.0".to_string(),
created_at: SystemTime::now(),
created_by: "test".to_string(),
total_size: 512,
models: vec![],
dependencies: vec![],
target_platforms: vec![],
checksum: "abc".to_string(),
compression_ratio: 0.65,
};
assert!(
metadata.compression_ratio > 0.0,
"compression_ratio should be positive"
);
}
#[test]
fn test_packed_model_info_construction() {
let info = PackedModelInfo {
model_id: "bert-base-uncased".to_string(),
name: "BERT Base Uncased".to_string(),
version: "latest".to_string(),
original_size: 1024 * 1024 * 440,
compressed_size: 1024 * 1024 * 320,
model_type: ModelType::TextClassification,
framework: "transformers".to_string(),
precision: PrecisionType::FP32,
metadata: HashMap::new(),
};
assert_eq!(info.model_id, "bert-base-uncased");
assert!(
info.compressed_size <= info.original_size,
"compressed_size should not exceed original_size after compression"
);
}
#[test]
fn test_packed_model_info_model_type_variants() {
let types = [
ModelType::TextGeneration,
ModelType::TextClassification,
ModelType::ImageClassification,
ModelType::SpeechRecognition,
ModelType::Translation,
ModelType::Summarization,
ModelType::QuestionAnswering,
ModelType::Multimodal,
];
assert_eq!(types.len(), 8, "should have 8 ModelType variants");
}
#[test]
fn test_pack_creation_config_default() {
let config = PackCreationConfig::default();
assert!(
config.compression_level <= 9,
"compression_level should be in [0,9]"
);
assert!(
!config.target_platforms.is_empty(),
"target_platforms should not be empty by default"
);
assert!(
config.max_pack_size.is_some(),
"default max_pack_size should be set"
);
let max_size = config.max_pack_size.expect("max_pack_size should be set");
assert!(max_size > 0, "max_pack_size should be positive");
}
#[test]
fn test_pack_creation_config_compression_level_range() {
for level in 0u8..=9 {
let config = PackCreationConfig {
compression_level: level,
..PackCreationConfig::default()
};
assert!(
config.compression_level <= 9,
"compression_level {} should be valid (0-9)",
config.compression_level
);
}
}
#[test]
fn test_offline_pack_manager_new_creates_directory() {
let path = temp_dir_path();
let _manager = OfflineModelPackManager::new(&path)
.expect("OfflineModelPackManager::new should succeed");
assert!(path.exists(), "base directory should be created");
std::fs::remove_dir_all(&path).ok();
}
#[test]
fn test_offline_pack_manager_list_packs_initially_empty() {
let path = temp_dir_path();
let manager = OfflineModelPackManager::new(&path)
.expect("OfflineModelPackManager::new should succeed");
let packs = manager.list_packs();
let _ = packs.len(); std::fs::remove_dir_all(&path).ok();
}
#[test]
fn test_offline_pack_manager_get_pack_info_missing_returns_none() {
let path = temp_dir_path();
let manager = OfflineModelPackManager::new(&path)
.expect("OfflineModelPackManager::new should succeed");
let info = manager.get_pack_info("non-existent-pack-id");
assert!(
info.is_none(),
"get_pack_info on missing pack should return None"
);
std::fs::remove_dir_all(&path).ok();
}
#[tokio::test]
async fn test_create_pack_returns_pack_id() {
let path = temp_dir_path();
let mut manager = OfflineModelPackManager::new(&path)
.expect("OfflineModelPackManager::new should succeed");
let config = PackCreationConfig::default();
let pack_id = manager
.create_pack(
"Test Pack".to_string(),
"A test pack for unit testing".to_string(),
vec!["gpt2".to_string()],
config,
)
.await
.expect("create_pack should succeed");
assert!(
!pack_id.is_empty(),
"create_pack should return non-empty pack_id"
);
std::fs::remove_dir_all(&path).ok();
}
#[tokio::test]
async fn test_create_pack_registers_in_list() {
let path = temp_dir_path();
let mut manager = OfflineModelPackManager::new(&path)
.expect("OfflineModelPackManager::new should succeed");
let config = PackCreationConfig::default();
let pack_id = manager
.create_pack(
"Listed Pack".to_string(),
"Pack that should appear in listing".to_string(),
vec!["bert-base-uncased".to_string()],
config,
)
.await
.expect("create_pack should succeed");
let packs = manager.list_packs();
let found = packs.iter().any(|p| p.pack_id == pack_id);
assert!(found, "newly created pack should appear in list_packs()");
std::fs::remove_dir_all(&path).ok();
}
#[tokio::test]
async fn test_create_pack_metadata_has_model_info() {
let path = temp_dir_path();
let mut manager = OfflineModelPackManager::new(&path)
.expect("OfflineModelPackManager::new should succeed");
let config = PackCreationConfig::default();
let pack_id = manager
.create_pack(
"Metadata Test Pack".to_string(),
"Testing metadata fields".to_string(),
vec!["gpt2".to_string(), "bert-base-uncased".to_string()],
config,
)
.await
.expect("create_pack should succeed");
let info = manager
.get_pack_info(&pack_id)
.expect("pack should be retrievable after creation");
assert_eq!(info.name, "Metadata Test Pack");
assert!(
!info.checksum.is_empty(),
"pack should have a non-empty integrity checksum"
);
assert!(info.total_size > 0, "pack should have positive total_size");
assert!(
!info.models.is_empty(),
"pack should contain model information"
);
std::fs::remove_dir_all(&path).ok();
}
#[tokio::test]
async fn test_create_pack_pack_id_is_unique() {
let path = temp_dir_path();
let mut manager = OfflineModelPackManager::new(&path)
.expect("OfflineModelPackManager::new should succeed");
let config = PackCreationConfig::default();
let id1 = manager
.create_pack(
"Pack A".to_string(),
"First pack".to_string(),
vec!["gpt2".to_string()],
config.clone(),
)
.await
.expect("first create_pack should succeed");
let id2 = manager
.create_pack(
"Pack B".to_string(),
"Second pack".to_string(),
vec!["bert-base-uncased".to_string()],
config,
)
.await
.expect("second create_pack should succeed");
assert_ne!(id1, id2, "each created pack should have a unique pack_id");
std::fs::remove_dir_all(&path).ok();
}
#[test]
fn test_precision_type_variants_serializable() {
let types = [
PrecisionType::FP32,
PrecisionType::FP16,
PrecisionType::INT8,
PrecisionType::INT4,
PrecisionType::Mixed,
];
for precision in &types {
let serialized =
serde_json::to_string(precision).expect("PrecisionType should be serializable");
assert!(
!serialized.is_empty(),
"serialized precision should not be empty"
);
}
}
#[test]
fn test_model_info_construction() {
let info = ModelInfo {
model_id: "test/model".to_string(),
library_name: Some("transformers".to_string()),
pipeline_tag: Some("text-generation".to_string()),
tags: vec!["nlp".to_string()],
config: HashMap::new(),
downloads: Some(5000),
likes: Some(200),
created_at: None,
updated_at: None,
author: Some("test-author".to_string()),
description: Some("A test model".to_string()),
license: Some("apache-2.0".to_string()),
task: Some("text-generation".to_string()),
language: vec!["en".to_string()],
dataset: vec![],
model_type: None,
architecture: None,
};
assert_eq!(info.model_id, "test/model");
assert_eq!(info.pipeline_tag.as_deref(), Some("text-generation"));
assert_eq!(info.downloads, Some(5000));
}
}