#![allow(clippy::result_large_err)]
use pandrs::ml::serving::{
BatchPredictionRequest, DeploymentConfig, HealthCheckConfig, ModelMetadata, ModelServer,
ModelServing, MonitoringConfig, PredictionRequest, ResourceConfig, ScalingConfig, ServerConfig,
};
use pandrs::ml::serving::serialization::{
GenericServingModel, JsonModelSerializer, ModelSerializer, SerializableModel,
SerializationFormat,
};
use pandrs::ml::serving::registry::{
FileSystemModelRegistry, InMemoryModelRegistry, ModelRegistry,
};
use pandrs::ml::serving::deployment::{
DeployedModel, DeploymentManager, DeploymentMetrics, DeploymentStatus,
};
use pandrs::ml::serving::monitoring::{
AlertConfig, AlertSeverity, ComparisonOperator, DefaultMetricsCollector, MetricsCollector,
ModelMonitor,
};
use pandrs::ml::serving::endpoints::{
BatchPredictionEndpoint, HealthEndpoint, ModelInfoEndpoint, PredictionEndpoint,
RequestValidator,
};
use pandrs::ml::serving::server::{HttpModelServer, RateLimiter, RequestContext};
use std::collections::HashMap;
use tempfile::{NamedTempFile, TempDir};
fn create_test_metadata() -> ModelMetadata {
let mut metadata = ModelMetadata {
name: "test_model".to_string(),
version: "1.0.0".to_string(),
model_type: "linear_regression".to_string(),
feature_names: ["feature1".to_string(), "feature2".to_string()].to_vec(),
target_name: Some("target".to_string()),
description: "Test model for serving".to_string(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
metrics: HashMap::new(),
metadata: HashMap::new(),
};
metadata.metrics.insert("r2_score".to_string(), 0.85);
metadata.metrics.insert("mse".to_string(), 0.15);
metadata
}
fn create_test_serializable_model() -> SerializableModel {
let metadata = create_test_metadata();
let mut parameters = HashMap::new();
parameters.insert("coefficients".to_string(), serde_json::json!([1.5, -0.8]));
parameters.insert("intercept".to_string(), serde_json::json!(2.3));
SerializableModel {
metadata,
parameters,
model_data: serde_json::json!({"type": "linear_regression"}),
preprocessing: Some(serde_json::json!({"scaler": "standard"})),
config: HashMap::new(),
}
}
fn create_test_prediction_request() -> PredictionRequest {
let mut data = HashMap::new();
data.insert("feature1".to_string(), serde_json::json!(1.5));
data.insert("feature2".to_string(), serde_json::json!(2.0));
PredictionRequest {
data,
model_version: Some("1.0.0".to_string()),
options: None,
}
}
fn create_test_deployment_config() -> DeploymentConfig {
DeploymentConfig {
model_name: "test_model".to_string(),
model_version: "1.0.0".to_string(),
environment: "test".to_string(),
resources: ResourceConfig {
cpu_cores: 1.0,
memory_mb: 1024,
gpu_memory_mb: None,
max_concurrent_requests: 10,
},
scaling: ScalingConfig {
min_instances: 1,
max_instances: 3,
target_cpu_utilization: 0.7,
target_memory_utilization: 0.8,
scale_up_threshold: 0.8,
scale_down_threshold: 0.3,
},
health_check: HealthCheckConfig {
path: "/health".to_string(),
interval_seconds: 30,
timeout_seconds: 5,
failure_threshold: 3,
success_threshold: 2,
},
monitoring: MonitoringConfig {
enable_metrics: true,
enable_logging: true,
enable_tracing: false,
metrics_interval_seconds: 60,
log_level: "info".to_string(),
},
}
}
#[test]
fn test_model_serialization_json() {
let model = create_test_serializable_model();
let serializer = JsonModelSerializer;
let serialized = serializer.serialize(&model).unwrap();
let deserialized = serializer.deserialize(&serialized).unwrap();
assert_eq!(model.metadata.name, deserialized.metadata.name);
assert_eq!(model.metadata.version, deserialized.metadata.version);
assert_eq!(model.metadata.model_type, deserialized.metadata.model_type);
assert_eq!(model.parameters.len(), deserialized.parameters.len());
let temp_file = NamedTempFile::new().unwrap();
serializer.save(&model, temp_file.path()).unwrap();
let loaded_model = serializer.load(temp_file.path()).unwrap();
assert_eq!(model.metadata.name, loaded_model.get_metadata().name);
assert_eq!(model.metadata.version, loaded_model.get_metadata().version);
}
#[test]
fn test_serialization_factory() {
let formats = [
SerializationFormat::Json,
SerializationFormat::Yaml,
SerializationFormat::Toml,
SerializationFormat::Binary,
];
for format in &formats {
assert_eq!(
format.extension(),
match format {
SerializationFormat::Json => "json",
SerializationFormat::Yaml => "yaml",
SerializationFormat::Toml => "toml",
SerializationFormat::Binary => "bin",
}
);
}
assert_eq!(
SerializationFormat::from_extension("json"),
Some(SerializationFormat::Json)
);
assert_eq!(
SerializationFormat::from_extension("yaml"),
Some(SerializationFormat::Yaml)
);
assert_eq!(
SerializationFormat::from_extension("yml"),
Some(SerializationFormat::Yaml)
);
assert_eq!(
SerializationFormat::from_extension("toml"),
Some(SerializationFormat::Toml)
);
assert_eq!(
SerializationFormat::from_extension("bin"),
Some(SerializationFormat::Binary)
);
assert_eq!(SerializationFormat::from_extension("unknown"), None);
}
#[test]
fn test_generic_serving_model() {
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
let metadata = serving_model.get_metadata();
assert_eq!(metadata.name, "test_model");
assert_eq!(metadata.version, "1.0.0");
assert_eq!(metadata.model_type, "linear_regression");
let request = create_test_prediction_request();
let response = serving_model.predict(&request).unwrap();
assert!(response.prediction.is_object() || response.prediction.is_number());
assert_eq!(response.model_metadata.name, "test_model");
let batch_request = BatchPredictionRequest {
data: [request.data.clone(), request.data.clone()].to_vec(),
model_version: Some("1.0.0".to_string()),
options: None,
};
let batch_response = serving_model.predict_batch(&batch_request).unwrap();
assert_eq!(batch_response.predictions.len(), 2);
assert_eq!(batch_response.summary.total_predictions, 2);
assert!(batch_response.summary.successful_predictions > 0);
let health_status = serving_model.health_check().unwrap();
assert_eq!(health_status.status, "healthy");
let info = serving_model.info();
assert_eq!(info.metadata.name, "test_model");
}
#[test]
fn test_in_memory_model_registry() {
let mut registry = InMemoryModelRegistry::new();
assert!(registry.list_models().unwrap().is_empty());
assert!(!registry.exists("test_model", "1.0.0"));
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
let boxed_model: Box<dyn ModelServing> = Box::new(serving_model);
registry.register_model(boxed_model).unwrap();
assert!(registry.exists("test_model", "1.0.0"));
let models = registry.list_models().unwrap();
assert_eq!(models.len(), 1);
assert_eq!(models[0].name, "test_model");
assert_eq!(models[0].versions, ["1.0.0"]);
let versions = registry.list_versions("test_model").unwrap();
assert_eq!(versions, ["1.0.0"]);
let latest = registry.get_latest_version("test_model").unwrap();
assert_eq!(latest, "1.0.0");
let default = registry.get_default_version("test_model").unwrap();
assert_eq!(default, "1.0.0");
let metadata = registry.get_metadata("test_model", "1.0.0").unwrap();
assert_eq!(metadata.name, "test_model");
}
#[test]
fn test_filesystem_model_registry() {
let temp_dir = TempDir::new().unwrap();
let mut registry = FileSystemModelRegistry::new(temp_dir.path()).unwrap();
assert!(registry.list_models().unwrap().is_empty());
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
let boxed_model: Box<dyn ModelServing> = Box::new(serving_model);
registry.register_model(boxed_model).unwrap();
let model_dir = temp_dir.path().join("test_model");
assert!(model_dir.exists());
let model_file = model_dir.join("1.0.0.json");
assert!(model_file.exists());
let loaded_model = registry.load_model("test_model", "1.0.0").unwrap();
assert_eq!(loaded_model.get_metadata().name, "test_model");
registry.delete_model("test_model", "1.0.0").unwrap();
assert!(!registry.exists("test_model", "1.0.0"));
}
#[test]
fn test_model_deployment() {
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
let boxed_model: Box<dyn ModelServing> = Box::new(serving_model);
let config = create_test_deployment_config();
let deployed_model = DeployedModel::new(boxed_model, config).unwrap();
let metrics = deployed_model.get_metrics().unwrap();
assert_eq!(metrics.status, DeploymentStatus::Running);
assert_eq!(metrics.active_instances, 1);
let request = create_test_prediction_request();
let _response = deployed_model.predict(&request).unwrap();
assert!(!deployed_model.should_scale_up().unwrap()); assert!(!deployed_model.should_scale_down().unwrap());
deployed_model.scale_up().unwrap();
let metrics = deployed_model.get_metrics().unwrap();
assert_eq!(metrics.active_instances, 2);
deployed_model.scale_down().unwrap();
let metrics = deployed_model.get_metrics().unwrap();
assert_eq!(metrics.active_instances, 1);
let health = deployed_model.health_check().unwrap();
assert_eq!(health.status, "healthy");
}
#[test]
fn test_deployment_manager() {
let mut manager = DeploymentManager::new();
assert!(manager.list_deployments().is_empty());
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
let boxed_model: Box<dyn ModelServing> = Box::new(serving_model);
let config = create_test_deployment_config();
manager
.deploy("test_deployment".to_string(), boxed_model, config)
.unwrap();
let deployments = manager.list_deployments();
assert_eq!(deployments.len(), 1);
assert_eq!(deployments[0], "test_deployment");
let deployment = manager.get_deployment("test_deployment").unwrap();
assert_eq!(deployment.get_config().model_name, "test_model");
let metrics = manager.get_deployment_metrics("test_deployment").unwrap();
assert_eq!(metrics.status, DeploymentStatus::Running);
manager.scale_deployment("test_deployment", 2).unwrap();
let metrics = manager.get_deployment_metrics("test_deployment").unwrap();
assert_eq!(metrics.active_instances, 2);
let health_statuses = manager.health_check_all();
assert_eq!(health_statuses.len(), 1);
assert!(health_statuses.contains_key("test_deployment"));
manager.undeploy("test_deployment").unwrap();
assert!(manager.list_deployments().is_empty());
}
#[test]
fn test_model_monitoring() {
let metadata = create_test_metadata();
let mut monitor = ModelMonitor::new(metadata);
let deployment_metrics = DeploymentMetrics {
status: DeploymentStatus::Running,
active_instances: 1,
cpu_utilization: 0.5,
memory_utilization: 0.6,
request_rate: 10.0,
avg_response_time_ms: 100.0,
error_rate: 0.01,
total_requests: 1000,
successful_requests: 990,
failed_requests: 10,
last_health_check: chrono::Utc::now(),
started_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
};
let alert_config = AlertConfig {
name: "high_latency".to_string(),
description: "Alert when latency is too high".to_string(),
metric: "avg_latency_ms".to_string(),
threshold: 200.0,
operator: ComparisonOperator::GreaterThan,
severity: AlertSeverity::Warning,
evaluation_window_seconds: 300,
consecutive_evaluations: 1,
cooldown_seconds: 600,
enabled: true,
};
monitor.add_alert(alert_config);
monitor.collect_metrics(&deployment_metrics).unwrap();
let recent_metrics = monitor.get_recent_metrics(10);
if !recent_metrics.is_empty() {
assert_eq!(recent_metrics[0].model_name, "test_model");
assert!(recent_metrics[0].latency.avg_latency_ms > 0.0);
}
let summary = monitor.get_metrics_summary(60);
if let Some(summary) = summary {
assert_eq!(summary.window_minutes, 60);
assert!(summary.avg_latency_ms > 0.0);
}
let alert_configs = monitor.get_alert_configs();
assert_eq!(alert_configs.len(), 1);
assert_eq!(alert_configs[0].name, "high_latency");
}
#[test]
fn test_metrics_collector() {
let collector = DefaultMetricsCollector;
let system_metrics = collector.collect_system_metrics().unwrap();
assert!(system_metrics.cpu_usage >= 0.0 && system_metrics.cpu_usage <= 1.0);
assert!(system_metrics.memory_usage > 0);
assert!(system_metrics.memory_available > 0);
let model_metrics = collector.collect_model_metrics("test_model").unwrap();
assert!(model_metrics.model_memory_usage > 0);
assert!(model_metrics.cache_hit_rate >= 0.0 && model_metrics.cache_hit_rate <= 1.0);
}
#[test]
fn test_model_server() {
let config = ServerConfig::default();
let mut server = ModelServer::new(config);
assert!(server.list_models().is_empty());
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
let boxed_model: Box<dyn ModelServing> = Box::new(serving_model);
server
.register_model("test_model".to_string(), boxed_model)
.unwrap();
let models = server.list_models();
assert_eq!(models.len(), 1);
assert_eq!(models[0], "test_model");
let model = server.get_model("test_model").unwrap();
assert_eq!(model.get_metadata().name, "test_model");
server.unregister_model("test_model").unwrap();
assert!(server.list_models().is_empty());
}
#[test]
fn test_prediction_endpoints() {
let config = ServerConfig::default();
let mut server = ModelServer::new(config);
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
let boxed_model: Box<dyn ModelServing> = Box::new(serving_model);
server
.register_model("test_model".to_string(), boxed_model)
.unwrap();
let request = create_test_prediction_request();
let response =
PredictionEndpoint::predict(&server, "test_model", request, Some("req-123".to_string()));
assert!(response.success);
assert!(response.data.is_some());
assert_eq!(response.request_id, Some("req-123".to_string()));
let batch_request = BatchPredictionRequest {
data: vec![
create_test_prediction_request().data,
create_test_prediction_request().data,
],
model_version: Some("1.0.0".to_string()),
options: None,
};
let batch_response = BatchPredictionEndpoint::predict_batch(
&server,
"test_model",
batch_request,
Some("batch-456".to_string()),
);
assert!(batch_response.success);
assert!(batch_response.data.is_some());
let info_response =
ModelInfoEndpoint::get_model_info(&server, "test_model", Some("info-789".to_string()));
assert!(info_response.success);
assert!(info_response.data.is_some());
let health_response =
HealthEndpoint::health_check_model(&server, "test_model", Some("health-999".to_string()));
assert!(health_response.success);
assert!(health_response.data.is_some());
}
#[test]
fn test_request_validation() {
let request_id = RequestValidator::generate_request_id();
assert!(!request_id.is_empty());
assert!(RequestValidator::validate_api_key(None, None));
assert!(RequestValidator::validate_api_key(Some("key"), Some("key")));
assert!(!RequestValidator::validate_api_key(
Some("key1"),
Some("key2")
));
assert!(!RequestValidator::validate_api_key(None, Some("key")));
assert!(RequestValidator::validate_request_size(100, 200));
assert!(!RequestValidator::validate_request_size(300, 200));
let sanitized = RequestValidator::sanitize_model_name("test-model_123!@#$%");
assert_eq!(sanitized, "test-model_123");
}
#[test]
fn test_rate_limiter() {
let rate_limiter = RateLimiter::new(3, 1);
for i in 0..3 {
assert!(
rate_limiter.check_rate_limit("client1").unwrap(),
"Request {} should be allowed",
i + 1
);
}
assert!(!rate_limiter.check_rate_limit("client1").unwrap());
assert!(rate_limiter.check_rate_limit("client2").unwrap());
assert_eq!(rate_limiter.get_request_count("client1").unwrap(), 4); assert_eq!(rate_limiter.get_request_count("client2").unwrap(), 1);
}
#[test]
fn test_http_model_server() {
let config = ServerConfig {
host: "127.0.0.1".to_string(),
port: 8081,
max_request_size: 1024,
request_timeout_seconds: 30,
enable_cors: true,
enable_auth: false,
api_key: None,
};
let mut server = HttpModelServer::new(config);
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
let boxed_model: Box<dyn ModelServing> = Box::new(serving_model);
server
.register_model("test_model".to_string(), boxed_model)
.unwrap();
let request = create_test_prediction_request();
let context = RequestContext::with_id("test-request".to_string());
let response = server.handle_predict("test_model", request, context);
assert_eq!(response.status_code, 200);
assert!(response.body.success);
let context = RequestContext::with_id("info-request".to_string());
let info_response = server.handle_model_info("test_model", context);
assert_eq!(info_response.status_code, 200);
assert!(info_response.body.success);
let context = RequestContext::with_id("health-request".to_string());
let health_response = server.handle_health_check(Some("test_model"), context);
assert_eq!(health_response.status_code, 200);
assert!(health_response.body.success);
let stats = server.get_server_stats().unwrap();
assert!(stats.total_requests > 0);
assert_eq!(stats.active_models, 1);
let routes = server.get_routes();
assert!(!routes.is_empty());
let predict_route = routes
.iter()
.find(|route| route.path.contains("predict") && !route.path.contains("batch"))
.expect("Prediction route should exist");
assert_eq!(predict_route.method, "POST");
assert!(predict_route.body_required);
}
#[test]
fn test_model_serving_factory() {
let serializable_model = create_test_serializable_model();
let serving_model = GenericServingModel::from_serializable(serializable_model).unwrap();
assert_eq!(serving_model.get_metadata().name, "test_model");
assert_eq!(serving_model.get_metadata().version, "1.0.0");
let request = create_test_prediction_request();
let _response = serving_model.predict(&request).unwrap();
let health = serving_model.health_check().unwrap();
assert_eq!(health.status, "healthy");
}
#[test]
fn test_error_handling() {
let config = ServerConfig::default();
let server = ModelServer::new(config);
let request = create_test_prediction_request();
let response = PredictionEndpoint::predict(&server, "nonexistent_model", request, None);
assert!(!response.success);
assert!(response.error.is_some());
assert!(response.error.unwrap().contains("not found"));
let info_response = ModelInfoEndpoint::get_model_info(&server, "nonexistent_model", None);
assert!(!info_response.success);
assert!(info_response.error.is_some());
}
#[test]
fn test_comprehensive_workflow() {
let serializable_model = create_test_serializable_model();
let temp_dir = TempDir::new().unwrap();
let model_file = temp_dir.path().join("test_model.json");
let serializer = JsonModelSerializer;
serializer.save(&serializable_model, &model_file).unwrap();
let loaded_serving_model = serializer.load(&model_file).unwrap();
let config = create_test_deployment_config();
let deployed_model = DeployedModel::new(loaded_serving_model, config).unwrap();
let request = create_test_prediction_request();
let response = deployed_model.predict(&request).unwrap();
assert_eq!(response.model_metadata.name, "test_model");
let metrics = deployed_model.get_metrics().unwrap();
assert_eq!(metrics.status, DeploymentStatus::Running);
assert!(metrics.total_requests > 0);
let health = deployed_model.health_check().unwrap();
assert_eq!(health.status, "healthy");
println!("Comprehensive workflow test completed successfully!");
}
#[test]
#[cfg(feature = "serving")]
fn test_serving_feature_integration() {
let config = ServerConfig::default();
let server = ModelServer::new(config);
assert!(server.list_models().is_empty());
let registry = InMemoryModelRegistry::new();
assert!(registry.list_models().unwrap().is_empty());
let metadata = create_test_metadata();
let monitor = ModelMonitor::new(metadata);
assert_eq!(monitor.get_recent_metrics(10).len(), 0);
println!("Serving feature integration test passed!");
}