quantrs2_device/ml_optimization/
validation.rs

1//! ML Validation Configuration Types
2
3use serde::{Deserialize, Serialize};
4
5/// ML validation configuration
6#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct MLValidationConfig {
8    /// Validation methods
9    pub validation_methods: Vec<ValidationMethod>,
10    /// Performance metrics
11    pub performance_metrics: Vec<PerformanceMetric>,
12    /// Statistical significance testing
13    pub statistical_testing: bool,
14    /// Robustness testing
15    pub robustness_testing: RobustnessTestingConfig,
16    /// Fairness evaluation
17    pub fairness_evaluation: bool,
18}
19
20/// Validation methods
21#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub enum ValidationMethod {
23    CrossValidation,
24    HoldoutValidation,
25    BootstrapValidation,
26    TimeSeriesValidation,
27    WalkForwardValidation,
28}
29
30/// Performance metrics
31#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32pub enum PerformanceMetric {
33    Accuracy,
34    Precision,
35    Recall,
36    F1Score,
37    AUC,
38    MAE,
39    MSE,
40    RMSE,
41    R2Score,
42    LogLoss,
43}
44
45/// Robustness testing configuration
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RobustnessTestingConfig {
48    /// Enable robustness testing
49    pub enable_testing: bool,
50    /// Adversarial testing
51    pub adversarial_testing: bool,
52    /// Distribution shift testing
53    pub distribution_shift_testing: bool,
54    /// Noise sensitivity testing
55    pub noise_sensitivity_testing: bool,
56    /// Fairness testing
57    pub fairness_testing: bool,
58}
59
60// Additional config types for QEC compatibility
61
62/// Validation configuration (alias for compatibility)
63pub type ValidationConfig = MLValidationConfig;
64
65/// Inference configuration
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct InferenceConfig {
68    /// Inference batch size
69    pub batch_size: usize,
70    /// Inference timeout
71    pub timeout: std::time::Duration,
72    /// Enable GPU acceleration
73    pub use_gpu: bool,
74    /// Inference precision
75    pub precision: InferencePrecision,
76    /// Caching configuration
77    pub caching: CachingConfig,
78}
79
80/// Model management configuration
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ModelManagementConfig {
83    /// Model versioning
84    pub versioning: bool,
85    /// Model storage path
86    pub storage_path: String,
87    /// Model lifecycle policy
88    pub lifecycle_policy: ModelLifecyclePolicy,
89    /// Model monitoring
90    pub monitoring: ModelMonitoringConfig,
91    /// Model deployment strategy
92    pub deployment_strategy: DeploymentStrategy,
93}
94
95/// Inference precision options
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub enum InferencePrecision {
98    Float32,
99    Float64,
100    Mixed,
101}
102
103/// Caching configuration
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct CachingConfig {
106    /// Enable caching
107    pub enable: bool,
108    /// Cache size limit (MB)
109    pub size_limit_mb: usize,
110    /// Cache expiration time
111    pub expiration: std::time::Duration,
112}
113
114/// Model lifecycle policy
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ModelLifecyclePolicy {
117    /// Maximum model age
118    pub max_age: std::time::Duration,
119    /// Auto-retirement threshold
120    pub retirement_threshold: f64,
121    /// Backup strategy
122    pub backup_strategy: BackupStrategy,
123}
124
125/// Model monitoring configuration
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct ModelMonitoringConfig {
128    /// Enable performance monitoring
129    pub performance_monitoring: bool,
130    /// Enable drift detection
131    pub drift_detection: bool,
132    /// Monitoring frequency
133    pub frequency: std::time::Duration,
134}
135
136/// Deployment strategy
137#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
138pub enum DeploymentStrategy {
139    BlueGreen,
140    Canary,
141    Rolling,
142    Immediate,
143}
144
145/// Backup strategy
146#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
147pub enum BackupStrategy {
148    Daily,
149    Weekly,
150    OnDemand,
151    Never,
152}
153
154// Default implementations
155
156impl Default for ValidationConfig {
157    fn default() -> Self {
158        Self {
159            validation_methods: vec![ValidationMethod::CrossValidation],
160            performance_metrics: vec![PerformanceMetric::Accuracy],
161            statistical_testing: true,
162            robustness_testing: RobustnessTestingConfig {
163                enable_testing: true,
164                adversarial_testing: false,
165                distribution_shift_testing: true,
166                noise_sensitivity_testing: true,
167                fairness_testing: false,
168            },
169            fairness_evaluation: false,
170        }
171    }
172}
173
174impl Default for InferenceConfig {
175    fn default() -> Self {
176        Self {
177            batch_size: 32,
178            timeout: std::time::Duration::from_secs(30),
179            use_gpu: false,
180            precision: InferencePrecision::Float32,
181            caching: CachingConfig {
182                enable: true,
183                size_limit_mb: 1024,
184                expiration: std::time::Duration::from_secs(3600),
185            },
186        }
187    }
188}
189
190impl Default for ModelManagementConfig {
191    fn default() -> Self {
192        Self {
193            versioning: true,
194            storage_path: "/tmp/models".to_string(),
195            lifecycle_policy: ModelLifecyclePolicy {
196                max_age: std::time::Duration::from_secs(30 * 24 * 3600), // 30 days
197                retirement_threshold: 0.8,
198                backup_strategy: BackupStrategy::Daily,
199            },
200            monitoring: ModelMonitoringConfig {
201                performance_monitoring: true,
202                drift_detection: true,
203                frequency: std::time::Duration::from_secs(3600),
204            },
205            deployment_strategy: DeploymentStrategy::Rolling,
206        }
207    }
208}