#![allow(clippy::all)]
use rstest::*;
use std::time::Duration;
use trustformers::config_management::ConfigurationManager;
use trustformers::error::TrustformersError;
use trustformers::hub::{get_cache_dir, is_cached, HubOptions};
use trustformers::prelude::*;
#[rstest]
fn test_cache_directory_creation() {
let cache_dir = get_cache_dir();
assert!(
cache_dir.is_ok(),
"Should be able to determine cache directory"
);
let cache_path = cache_dir.expect("operation failed in test");
assert!(cache_path.is_absolute(), "Cache path should be absolute");
}
#[rstest]
fn test_hub_options_default() {
let options = HubOptions::default();
assert_eq!(options.revision, Some("main".to_string()));
assert!(options.parallel_downloads);
assert_eq!(options.max_concurrent_downloads, 4);
assert!(options.enable_resumable_downloads);
assert!(options.enable_delta_compression);
assert!(options.smart_caching);
}
#[rstest]
fn test_model_cache_check() {
let result = is_cached("nonexistent-model-123456", None);
assert!(result.is_ok());
assert!(
!result.expect("operation failed in test"),
"Nonexistent model should not be cached"
);
}
#[rstest]
fn test_config_manager_validation() {
let config_manager = ConfigurationManager::new();
let valid_config = serde_json::json!({
"model_type": "bert",
"hidden_size": 768,
"num_attention_heads": 12,
"num_hidden_layers": 12
});
let result = config_manager.validate_config("model", &valid_config);
assert!(result.is_valid, "Valid config should pass validation");
}
#[rstest]
fn test_error_system_functionality() {
use trustformers::error::TrustformersError;
let error = TrustformersError::model("Test model error", "test-model");
assert!(error.to_string().contains("Test model error"));
}
#[rstest]
#[tokio::test]
async fn test_profiler_basic_functionality() {
use trustformers::profiler::Profiler;
let _profiler = Profiler::new().expect("operation failed in test");
assert!(true); }
#[rstest]
fn test_pipeline_error_handling() {
use trustformers::pipeline::text_classification::TextClassificationPipeline;
let result: Result<TextClassificationPipeline> = Err(TrustformersError::model(
"Model not found for testing",
"nonexistent-model-xyz",
));
assert!(result.is_err(), "Should error for nonexistent model");
if let Err(error) = result {
match error {
TrustformersError::Model { .. } => {
},
TrustformersError::Hub { .. } => {
},
TrustformersError::Core(_) => {
},
_ => panic!("Unexpected error type: {:?}", error),
}
}
}
#[rstest]
fn test_processor_auto_detection() {
use trustformers::processor::Modality;
let modality = Modality::Text;
assert!(matches!(modality, Modality::Text));
}
#[rstest]
fn test_auto_classes_functionality() {
use trustformers::auto::metrics::AutoMetric;
use trustformers::auto_classes::{AutoDataCollator, AutoFeatureExtractor};
let config = serde_json::json!({"model_type": "vit"});
let extractor = AutoFeatureExtractor::for_task("image-classification", &config);
assert!(extractor.is_ok());
let collator = AutoDataCollator::for_task("text-classification", &serde_json::json!({}));
assert!(collator.is_ok());
let metric = AutoMetric::for_task("text-classification");
assert!(metric.is_ok());
}
#[rstest]
fn test_streaming_pipeline_creation() {
use trustformers::pipeline::streaming::StreamConfig;
let _config = StreamConfig {
buffer_size: 4096,
max_concurrent: 4,
backpressure_threshold: 0.8,
timeout_ms: 30000,
enable_partial_results: true,
enable_transformations: true,
batch_size: Some(1024),
flush_interval_ms: 100,
};
let pipeline: Result<()> = Ok(());
assert!(pipeline.is_ok());
}
#[rstest]
fn test_config_migration_system() {
use trustformers::config_management::ConfigurationManager;
let _manager = ConfigurationManager::new();
assert!(true);
}
#[rstest]
fn test_hub_download_configuration() {
use trustformers::hub::{create_download_config_for_scenario, DownloadScenario};
let dev_config = create_download_config_for_scenario(DownloadScenario::FastDevelopment);
assert!(dev_config.parallel_downloads);
assert_eq!(dev_config.max_concurrent, 8);
assert!(!dev_config.verify_checksums);
let prod_config = create_download_config_for_scenario(DownloadScenario::Production);
assert!(prod_config.parallel_downloads);
assert_eq!(prod_config.max_concurrent, 4);
assert!(prod_config.verify_checksums);
let bandwidth_config = create_download_config_for_scenario(DownloadScenario::BandwidthLimited);
assert!(!bandwidth_config.parallel_downloads); assert_eq!(bandwidth_config.max_concurrent, 1);
}
#[rstest]
fn test_dynamic_batching_configuration() {
use trustformers::pipeline::dynamic_batching::DynamicBatchingConfig;
let config = DynamicBatchingConfig {
initial_batch_size: 8,
min_batch_size: 1,
max_batch_size: 32,
target_latency_ms: 100,
max_wait_time_ms: 50,
throughput_threshold: 10.0,
performance_window_size: 10,
adjustment_factor: 1.2,
};
assert!(config.min_batch_size <= config.max_batch_size);
assert!(
config.initial_batch_size >= config.min_batch_size
&& config.initial_batch_size <= config.max_batch_size
);
}
#[rstest]
fn test_caching_system_integration() {
use trustformers::pipeline::advanced_caching::{AdvancedCache, CacheConfig};
let config = CacheConfig {
max_entries: 1000,
max_memory_bytes: 512 * 1024 * 1024, ttl_seconds: 3600,
cleanup_interval_seconds: 300,
lru_eviction_threshold: 0.8,
smart_eviction_threshold: 0.9,
enable_hit_rate_tracking: true,
enable_memory_pressure_monitoring: true,
enable_access_pattern_analysis: true,
};
let _cache: AdvancedCache<Vec<u8>> = AdvancedCache::new(config);
}
mod performance_tests {
use super::*;
use std::time::Instant;
#[rstest]
fn benchmark_config_validation() {
let config_manager = ConfigurationManager::new();
let config = serde_json::json!({
"model_type": "bert",
"hidden_size": 768,
"num_attention_heads": 12,
"num_hidden_layers": 12,
"vocab_size": 30522
});
let start = Instant::now();
for _ in 0..1000 {
let _ = config_manager.validate_config("model", &config);
}
let duration = start.elapsed();
assert!(
duration < Duration::from_millis(100),
"Config validation taking too long: {:?}",
duration
);
}
#[rstest]
fn benchmark_error_creation() {
let start = Instant::now();
for i in 0..10000 {
let _error = TrustformersError::pipeline(format!("Test error {}", i), "benchmark");
}
let duration = start.elapsed();
assert!(
duration < Duration::from_millis(50),
"Error creation taking too long: {:?}",
duration
);
}
}
mod workflow_tests {
use super::*;
#[rstest]
fn test_complete_text_classification_workflow() {
use trustformers::pipeline::text_classification::TextClassificationConfig;
let config = TextClassificationConfig {
max_length: 512,
labels: vec!["NEGATIVE".to_string(), "POSITIVE".to_string()],
return_all_scores: false,
};
assert!(config.max_length > 0);
assert_eq!(config.labels.len(), 2);
}
#[rstest]
#[cfg(feature = "vision")]
fn test_multimodal_pipeline_configuration() {
#[allow(unused_imports)]
use trustformers::pipeline::multimodal::{
AttentionConfig, FusionStrategy, MultiModalConfig, MultiModalPipeline,
};
let config = MultiModalConfig {
max_text_length: 512,
max_image_size: (224, 224),
max_audio_duration: 30.0,
fusion_strategy: FusionStrategy::GatedFusion,
normalize_inputs: true,
attention_config: AttentionConfig::default(),
cross_modal_attention: true,
temperature: 1.0,
top_k: Some(50),
top_p: Some(0.9),
};
assert!(config.max_text_length > 0);
assert!(config.max_image_size.0 > 0 && config.max_image_size.1 > 0);
}
#[rstest]
fn test_conversational_pipeline_workflow() {
use trustformers::pipeline::conversational::{
ConversationalConfig, MemoryConfig, SummarizationConfig, SummarizationStrategy,
};
let memory_config = ConversationalConfig {
max_history_turns: 100,
max_context_tokens: 4096,
enable_summarization: true,
memory_config: MemoryConfig {
enabled: true,
decay_rate: 0.95,
compression_threshold: 0.3,
max_memories: 50000,
persist_important_memories: true,
},
summarization_config: SummarizationConfig {
enabled: true,
trigger_threshold: 800,
target_length: 512,
strategy: SummarizationStrategy::Hybrid,
preserve_recent_turns: 5,
},
..Default::default()
};
assert!(memory_config.max_history_turns > 0);
assert!(
memory_config.memory_config.decay_rate > 0.0
&& memory_config.memory_config.decay_rate <= 1.0
);
assert!(
memory_config.memory_config.compression_threshold >= 0.0
&& memory_config.memory_config.compression_threshold <= 1.0
);
}
}
#[cfg(feature = "async")]
mod async_tests {
use super::*;
use tokio::time::timeout;
#[rstest]
#[tokio::test]
async fn test_async_pipeline_processing() {
let result = timeout(Duration::from_secs(5), async {
tokio::time::sleep(Duration::from_millis(100)).await;
Ok::<_, TrustformersError>(())
})
.await;
assert!(
result.is_ok(),
"Async operation should complete within timeout"
);
assert!(
result.expect("operation failed in test").is_ok(),
"Async operation should succeed"
);
}
#[rstest]
#[tokio::test]
async fn test_concurrent_pipeline_operations() {
use tokio::task::JoinSet;
let mut set = JoinSet::new();
for i in 0..10 {
set.spawn(async move {
tokio::time::sleep(Duration::from_millis(10)).await;
i * 2
});
}
let mut results = Vec::new();
while let Some(res) = set.join_next().await {
results.push(res.expect("operation failed in test"));
}
assert_eq!(
results.len(),
10,
"All concurrent operations should complete"
);
}
}