quantrs2_device/ml_optimization/
validation.rs1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
7pub struct MLValidationConfig {
8 pub validation_methods: Vec<ValidationMethod>,
10 pub performance_metrics: Vec<PerformanceMetric>,
12 pub statistical_testing: bool,
14 pub robustness_testing: RobustnessTestingConfig,
16 pub fairness_evaluation: bool,
18}
19
20#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
22pub enum ValidationMethod {
23 CrossValidation,
24 HoldoutValidation,
25 BootstrapValidation,
26 TimeSeriesValidation,
27 WalkForwardValidation,
28}
29
30#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
32pub enum PerformanceMetric {
33 Accuracy,
34 Precision,
35 Recall,
36 F1Score,
37 AUC,
38 MAE,
39 MSE,
40 RMSE,
41 R2Score,
42 LogLoss,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct RobustnessTestingConfig {
48 pub enable_testing: bool,
50 pub adversarial_testing: bool,
52 pub distribution_shift_testing: bool,
54 pub noise_sensitivity_testing: bool,
56 pub fairness_testing: bool,
58}
59
60pub type ValidationConfig = MLValidationConfig;
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct InferenceConfig {
68 pub batch_size: usize,
70 pub timeout: std::time::Duration,
72 pub use_gpu: bool,
74 pub precision: InferencePrecision,
76 pub caching: CachingConfig,
78}
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct ModelManagementConfig {
83 pub versioning: bool,
85 pub storage_path: String,
87 pub lifecycle_policy: ModelLifecyclePolicy,
89 pub monitoring: ModelMonitoringConfig,
91 pub deployment_strategy: DeploymentStrategy,
93}
94
95#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub enum InferencePrecision {
98 Float32,
99 Float64,
100 Mixed,
101}
102
103#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct CachingConfig {
106 pub enable: bool,
108 pub size_limit_mb: usize,
110 pub expiration: std::time::Duration,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct ModelLifecyclePolicy {
117 pub max_age: std::time::Duration,
119 pub retirement_threshold: f64,
121 pub backup_strategy: BackupStrategy,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct ModelMonitoringConfig {
128 pub performance_monitoring: bool,
130 pub drift_detection: bool,
132 pub frequency: std::time::Duration,
134}
135
136#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
138pub enum DeploymentStrategy {
139 BlueGreen,
140 Canary,
141 Rolling,
142 Immediate,
143}
144
145#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
147pub enum BackupStrategy {
148 Daily,
149 Weekly,
150 OnDemand,
151 Never,
152}
153
154impl Default for ValidationConfig {
157 fn default() -> Self {
158 Self {
159 validation_methods: vec![ValidationMethod::CrossValidation],
160 performance_metrics: vec![PerformanceMetric::Accuracy],
161 statistical_testing: true,
162 robustness_testing: RobustnessTestingConfig {
163 enable_testing: true,
164 adversarial_testing: false,
165 distribution_shift_testing: true,
166 noise_sensitivity_testing: true,
167 fairness_testing: false,
168 },
169 fairness_evaluation: false,
170 }
171 }
172}
173
174impl Default for InferenceConfig {
175 fn default() -> Self {
176 Self {
177 batch_size: 32,
178 timeout: std::time::Duration::from_secs(30),
179 use_gpu: false,
180 precision: InferencePrecision::Float32,
181 caching: CachingConfig {
182 enable: true,
183 size_limit_mb: 1024,
184 expiration: std::time::Duration::from_secs(3600),
185 },
186 }
187 }
188}
189
190impl Default for ModelManagementConfig {
191 fn default() -> Self {
192 Self {
193 versioning: true,
194 storage_path: "/tmp/models".to_string(),
195 lifecycle_policy: ModelLifecyclePolicy {
196 max_age: std::time::Duration::from_secs(30 * 24 * 3600), retirement_threshold: 0.8,
198 backup_strategy: BackupStrategy::Daily,
199 },
200 monitoring: ModelMonitoringConfig {
201 performance_monitoring: true,
202 drift_detection: true,
203 frequency: std::time::Duration::from_secs(3600),
204 },
205 deployment_strategy: DeploymentStrategy::Rolling,
206 }
207 }
208}