use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MLValidationConfig {
pub validation_methods: Vec<ValidationMethod>,
pub performance_metrics: Vec<PerformanceMetric>,
pub statistical_testing: bool,
pub robustness_testing: RobustnessTestingConfig,
pub fairness_evaluation: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ValidationMethod {
CrossValidation,
HoldoutValidation,
BootstrapValidation,
TimeSeriesValidation,
WalkForwardValidation,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum PerformanceMetric {
Accuracy,
Precision,
Recall,
F1Score,
AUC,
MAE,
MSE,
RMSE,
R2Score,
LogLoss,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RobustnessTestingConfig {
pub enable_testing: bool,
pub adversarial_testing: bool,
pub distribution_shift_testing: bool,
pub noise_sensitivity_testing: bool,
pub fairness_testing: bool,
}
pub type ValidationConfig = MLValidationConfig;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InferenceConfig {
pub batch_size: usize,
pub timeout: std::time::Duration,
pub use_gpu: bool,
pub precision: InferencePrecision,
pub caching: CachingConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelManagementConfig {
pub versioning: bool,
pub storage_path: String,
pub lifecycle_policy: ModelLifecyclePolicy,
pub monitoring: ModelMonitoringConfig,
pub deployment_strategy: DeploymentStrategy,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum InferencePrecision {
Float32,
Float64,
Mixed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachingConfig {
pub enable: bool,
pub size_limit_mb: usize,
pub expiration: std::time::Duration,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLifecyclePolicy {
pub max_age: std::time::Duration,
pub retirement_threshold: f64,
pub backup_strategy: BackupStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMonitoringConfig {
pub performance_monitoring: bool,
pub drift_detection: bool,
pub frequency: std::time::Duration,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum DeploymentStrategy {
BlueGreen,
Canary,
Rolling,
Immediate,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum BackupStrategy {
Daily,
Weekly,
OnDemand,
Never,
}
impl Default for ValidationConfig {
fn default() -> Self {
Self {
validation_methods: vec![ValidationMethod::CrossValidation],
performance_metrics: vec![PerformanceMetric::Accuracy],
statistical_testing: true,
robustness_testing: RobustnessTestingConfig {
enable_testing: true,
adversarial_testing: false,
distribution_shift_testing: true,
noise_sensitivity_testing: true,
fairness_testing: false,
},
fairness_evaluation: false,
}
}
}
impl Default for InferenceConfig {
fn default() -> Self {
Self {
batch_size: 32,
timeout: std::time::Duration::from_secs(30),
use_gpu: false,
precision: InferencePrecision::Float32,
caching: CachingConfig {
enable: true,
size_limit_mb: 1024,
expiration: std::time::Duration::from_secs(3600),
},
}
}
}
impl Default for ModelManagementConfig {
fn default() -> Self {
Self {
versioning: true,
storage_path: "/tmp/models".to_string(),
lifecycle_policy: ModelLifecyclePolicy {
max_age: std::time::Duration::from_secs(30 * 24 * 3600), retirement_threshold: 0.8,
backup_strategy: BackupStrategy::Daily,
},
monitoring: ModelMonitoringConfig {
performance_monitoring: true,
drift_detection: true,
frequency: std::time::Duration::from_secs(3600),
},
deployment_strategy: DeploymentStrategy::Rolling,
}
}
}