quantrs2_device/mid_circuit_measurements/ml/
mod.rs

1//! Machine learning optimization components
2
3pub mod adaptive;
4pub mod predictor;
5
6use super::config::MLOptimizationConfig;
7use super::results::*;
8use crate::DeviceResult;
9
10pub use adaptive::AdaptiveMeasurementManager;
11pub use predictor::MeasurementPredictor;
12
13/// ML-powered optimizer for mid-circuit measurements
14pub struct MLOptimizer {
15    config: MLOptimizationConfig,
16    model_cache: Option<OptimizationModel>,
17    training_history: Vec<TrainingEpoch>,
18}
19
20impl MLOptimizer {
21    /// Create new ML optimizer
22    pub fn new(config: &MLOptimizationConfig) -> Self {
23        Self {
24            config: config.clone(),
25            model_cache: None,
26            training_history: Vec::new(),
27        }
28    }
29
30    /// Optimize measurement parameters using ML
31    pub async fn optimize_parameters(
32        &mut self,
33        measurement_history: &[MeasurementEvent],
34        current_performance: &PerformanceMetrics,
35    ) -> DeviceResult<OptimizationResult> {
36        if !self.config.enable_ml_optimization {
37            return Ok(OptimizationResult::default());
38        }
39
40        // Extract features from measurement history
41        let features = self.extract_features(measurement_history)?;
42
43        // Train or update model if needed
44        if self.should_retrain(&features, current_performance)? {
45            self.train_model(&features, current_performance).await?;
46        }
47
48        // Generate optimization recommendations
49        let recommendations = self.generate_recommendations(&features)?;
50
51        // Predict performance improvements
52        let predicted_improvements = self.predict_improvements(&recommendations)?;
53
54        Ok(OptimizationResult {
55            recommendations,
56            predicted_improvements,
57            confidence: self.calculate_optimization_confidence()?,
58            model_version: self.get_model_version(),
59        })
60    }
61
62    /// Extract ML features from measurement data
63    fn extract_features(
64        &self,
65        measurement_history: &[MeasurementEvent],
66    ) -> DeviceResult<MLFeatures> {
67        if measurement_history.is_empty() {
68            return Ok(MLFeatures::default());
69        }
70
71        // Statistical features
72        let latencies: Vec<f64> = measurement_history.iter().map(|e| e.latency).collect();
73        let confidences: Vec<f64> = measurement_history.iter().map(|e| e.confidence).collect();
74        let timestamps: Vec<f64> = measurement_history.iter().map(|e| e.timestamp).collect();
75
76        let statistical_features = StatisticalFeatures {
77            mean_latency: latencies.iter().sum::<f64>() / latencies.len() as f64,
78            std_latency: {
79                let mean = latencies.iter().sum::<f64>() / latencies.len() as f64;
80                let variance = latencies.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
81                    / latencies.len() as f64;
82                variance.sqrt()
83            },
84            mean_confidence: confidences.iter().sum::<f64>() / confidences.len() as f64,
85            std_confidence: {
86                let mean = confidences.iter().sum::<f64>() / confidences.len() as f64;
87                let variance = confidences.iter().map(|&x| (x - mean).powi(2)).sum::<f64>()
88                    / confidences.len() as f64;
89                variance.sqrt()
90            },
91            skewness_latency: self.calculate_skewness(&latencies),
92            kurtosis_latency: self.calculate_kurtosis(&latencies),
93        };
94
95        // Temporal features
96        let temporal_features = TemporalFeatures {
97            measurement_rate: measurement_history.len() as f64
98                / (timestamps.last().unwrap_or(&1.0) - timestamps.first().unwrap_or(&0.0)),
99            temporal_autocorrelation: self.calculate_autocorrelation(&latencies, 1),
100            trend_slope: self.calculate_trend_slope(&timestamps, &latencies),
101            periodicity_strength: self.detect_periodicity(&latencies),
102        };
103
104        // Pattern features
105        let pattern_features = PatternFeatures {
106            latency_confidence_correlation: self.calculate_correlation(&latencies, &confidences),
107            measurement_consistency: self.calculate_consistency(&confidences),
108            outlier_ratio: self.calculate_outlier_ratio(&latencies),
109            pattern_complexity: self.calculate_pattern_complexity(&latencies),
110        };
111
112        Ok(MLFeatures {
113            statistical_features,
114            temporal_features,
115            pattern_features,
116            feature_importance: self.calculate_feature_importance()?,
117        })
118    }
119
120    /// Determine if model should be retrained
121    fn should_retrain(
122        &self,
123        features: &MLFeatures,
124        current_performance: &PerformanceMetrics,
125    ) -> DeviceResult<bool> {
126        // Retrain if no model exists
127        if self.model_cache.is_none() {
128            return Ok(true);
129        }
130
131        // Retrain if performance has degraded significantly
132        let performance_threshold = 0.05; // 5% degradation
133        if current_performance.measurement_success_rate < (1.0 - performance_threshold) {
134            return Ok(true);
135        }
136
137        // Retrain periodically based on data volume
138        let training_interval = 100; // Every 100 training epochs
139        if self.training_history.len() % training_interval == 0 {
140            return Ok(true);
141        }
142
143        // Retrain if feature distribution has shifted significantly
144        if let Some(ref model) = self.model_cache {
145            let feature_drift = self.detect_feature_drift(features, &model.training_features)?;
146            if feature_drift > 0.2 {
147                // 20% drift threshold
148                return Ok(true);
149            }
150        }
151
152        Ok(false)
153    }
154
155    /// Train or update the ML model
156    async fn train_model(
157        &mut self,
158        features: &MLFeatures,
159        target_performance: &PerformanceMetrics,
160    ) -> DeviceResult<()> {
161        let training_epoch = TrainingEpoch {
162            epoch_number: self.training_history.len() + 1,
163            features: features.clone(),
164            target_metrics: target_performance.clone(),
165            training_loss: 0.0, // Will be updated during training
166            validation_loss: 0.0,
167            learning_rate: match &self.config.training_config.learning_rate_schedule {
168                crate::mid_circuit_measurements::config::LearningRateSchedule::Constant {
169                    rate,
170                } => *rate,
171                _ => 0.001, // Default fallback
172            },
173        };
174
175        // Simplified training process
176        let model = OptimizationModel {
177            model_type: self
178                .config
179                .model_types
180                .first()
181                .map(|t| format!("{t:?}"))
182                .unwrap_or_else(|| "LinearRegression".to_string()),
183            parameters: self.initialize_model_parameters()?,
184            training_features: features.clone(),
185            model_performance: ModelPerformance {
186                training_accuracy: 0.95,
187                validation_accuracy: 0.92,
188                cross_validation_score: 0.93,
189                overfitting_score: 0.03,
190            },
191            last_updated: std::time::SystemTime::now(),
192        };
193
194        self.model_cache = Some(model);
195        self.training_history.push(training_epoch);
196
197        Ok(())
198    }
199
200    /// Generate optimization recommendations
201    fn generate_recommendations(
202        &self,
203        features: &MLFeatures,
204    ) -> DeviceResult<Vec<OptimizationRecommendation>> {
205        let mut recommendations = Vec::new();
206
207        // Latency optimization
208        if features.statistical_features.mean_latency > 1000.0 {
209            // > 1ms
210            recommendations.push(OptimizationRecommendation {
211                parameter: "measurement_timeout".to_string(),
212                current_value: features.statistical_features.mean_latency,
213                recommended_value: features.statistical_features.mean_latency * 0.8,
214                expected_improvement: 0.15,
215                confidence: 0.85,
216                rationale:
217                    "High average latency detected, reducing timeout may improve performance"
218                        .to_string(),
219            });
220        }
221
222        // Confidence optimization
223        if features.statistical_features.mean_confidence < 0.95 {
224            recommendations.push(OptimizationRecommendation {
225                parameter: "measurement_repetitions".to_string(),
226                current_value: 1.0,
227                recommended_value: 2.0,
228                expected_improvement: 0.1,
229                confidence: 0.75,
230                rationale:
231                    "Low confidence detected, increasing repetitions may improve reliability"
232                        .to_string(),
233            });
234        }
235
236        // Timing optimization
237        if features.temporal_features.measurement_rate < 100.0 {
238            // < 100 Hz
239            recommendations.push(OptimizationRecommendation {
240                parameter: "measurement_frequency".to_string(),
241                current_value: features.temporal_features.measurement_rate,
242                recommended_value: features.temporal_features.measurement_rate * 1.2,
243                expected_improvement: 0.05,
244                confidence: 0.65,
245                rationale: "Low measurement rate, increasing frequency may improve throughput"
246                    .to_string(),
247            });
248        }
249
250        Ok(recommendations)
251    }
252
253    /// Predict performance improvements
254    fn predict_improvements(
255        &self,
256        recommendations: &[OptimizationRecommendation],
257    ) -> DeviceResult<PerformanceImprovements> {
258        let overall_improvement = recommendations
259            .iter()
260            .map(|r| r.expected_improvement * r.confidence)
261            .sum::<f64>()
262            / recommendations.len() as f64;
263
264        Ok(PerformanceImprovements {
265            latency_reduction: overall_improvement * 0.4,
266            confidence_increase: overall_improvement * 0.3,
267            throughput_increase: overall_improvement * 0.2,
268            error_rate_reduction: overall_improvement * 0.1,
269            overall_score_improvement: overall_improvement,
270        })
271    }
272
273    /// Calculate optimization confidence
274    const fn calculate_optimization_confidence(&self) -> DeviceResult<f64> {
275        if let Some(ref model) = self.model_cache {
276            Ok(model.model_performance.validation_accuracy)
277        } else {
278            Ok(0.5) // No model, low confidence
279        }
280    }
281
282    /// Get current model version
283    fn get_model_version(&self) -> String {
284        format!(
285            "v{}.{}",
286            self.training_history.len() / 10,
287            self.training_history.len() % 10
288        )
289    }
290
291    // Helper methods for feature extraction
292    fn calculate_skewness(&self, values: &[f64]) -> f64 {
293        if values.len() < 3 {
294            return 0.0;
295        }
296
297        let mean = values.iter().sum::<f64>() / values.len() as f64;
298        let std_dev = {
299            let variance =
300                values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
301            variance.sqrt()
302        };
303
304        if std_dev > 1e-10 {
305            let skewness = values
306                .iter()
307                .map(|&x| ((x - mean) / std_dev).powi(3))
308                .sum::<f64>()
309                / values.len() as f64;
310            skewness
311        } else {
312            0.0
313        }
314    }
315
316    fn calculate_kurtosis(&self, values: &[f64]) -> f64 {
317        if values.len() < 4 {
318            return 0.0;
319        }
320
321        let mean = values.iter().sum::<f64>() / values.len() as f64;
322        let std_dev = {
323            let variance =
324                values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
325            variance.sqrt()
326        };
327
328        if std_dev > 1e-10 {
329            let kurtosis = values
330                .iter()
331                .map(|&x| ((x - mean) / std_dev).powi(4))
332                .sum::<f64>()
333                / values.len() as f64
334                - 3.0; // Excess kurtosis
335            kurtosis
336        } else {
337            0.0
338        }
339    }
340
341    fn calculate_autocorrelation(&self, values: &[f64], lag: usize) -> f64 {
342        if values.len() <= lag {
343            return 0.0;
344        }
345
346        let n = values.len() - lag;
347        let mean = values.iter().sum::<f64>() / values.len() as f64;
348
349        let numerator: f64 = (0..n)
350            .map(|i| (values[i] - mean) * (values[i + lag] - mean))
351            .sum();
352
353        let denominator: f64 = values.iter().map(|&x| (x - mean).powi(2)).sum();
354
355        if denominator > 1e-10 {
356            numerator / denominator
357        } else {
358            0.0
359        }
360    }
361
362    fn calculate_trend_slope(&self, x: &[f64], y: &[f64]) -> f64 {
363        if x.len() != y.len() || x.len() < 2 {
364            return 0.0;
365        }
366
367        let n = x.len() as f64;
368        let sum_x = x.iter().sum::<f64>();
369        let sum_y = y.iter().sum::<f64>();
370        let sum_xy = x
371            .iter()
372            .zip(y.iter())
373            .map(|(&xi, &yi)| xi * yi)
374            .sum::<f64>();
375        let sum_x2 = x.iter().map(|&xi| xi * xi).sum::<f64>();
376
377        let denominator = n.mul_add(sum_x2, -(sum_x * sum_x));
378        if denominator > 1e-10 {
379            n.mul_add(sum_xy, -(sum_x * sum_y)) / denominator
380        } else {
381            0.0
382        }
383    }
384
385    fn detect_periodicity(&self, values: &[f64]) -> f64 {
386        // Simplified periodicity detection using autocorrelation
387        let max_lag = values.len() / 4;
388        let mut max_autocorr = 0.0;
389
390        for lag in 1..max_lag {
391            let autocorr = self.calculate_autocorrelation(values, lag).abs();
392            max_autocorr = f64::max(max_autocorr, autocorr);
393        }
394
395        max_autocorr
396    }
397
398    fn calculate_correlation(&self, x: &[f64], y: &[f64]) -> f64 {
399        if x.len() != y.len() || x.len() < 2 {
400            return 0.0;
401        }
402
403        let n = x.len() as f64;
404        let mean_x = x.iter().sum::<f64>() / n;
405        let mean_y = y.iter().sum::<f64>() / n;
406
407        let numerator: f64 = x
408            .iter()
409            .zip(y.iter())
410            .map(|(&xi, &yi)| (xi - mean_x) * (yi - mean_y))
411            .sum();
412
413        let sum_sq_x: f64 = x.iter().map(|&xi| (xi - mean_x).powi(2)).sum();
414        let sum_sq_y: f64 = y.iter().map(|&yi| (yi - mean_y).powi(2)).sum();
415
416        let denominator = (sum_sq_x * sum_sq_y).sqrt();
417
418        if denominator > 1e-10 {
419            numerator / denominator
420        } else {
421            0.0
422        }
423    }
424
425    fn calculate_consistency(&self, values: &[f64]) -> f64 {
426        if values.is_empty() {
427            return 1.0;
428        }
429
430        let mean = values.iter().sum::<f64>() / values.len() as f64;
431        let variance =
432            values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
433
434        // Consistency is inverse of coefficient of variation
435        if mean > 1e-10 {
436            1.0 / (variance.sqrt() / mean + 1.0)
437        } else {
438            1.0
439        }
440    }
441
442    fn calculate_outlier_ratio(&self, values: &[f64]) -> f64 {
443        if values.len() < 4 {
444            return 0.0;
445        }
446
447        let mean = values.iter().sum::<f64>() / values.len() as f64;
448        let std_dev = {
449            let variance =
450                values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
451            variance.sqrt()
452        };
453
454        let outlier_count = values
455            .iter()
456            .filter(|&&x| (x - mean).abs() > 2.0 * std_dev)
457            .count();
458
459        outlier_count as f64 / values.len() as f64
460    }
461
462    fn calculate_pattern_complexity(&self, values: &[f64]) -> f64 {
463        // Simple complexity measure based on number of direction changes
464        if values.len() < 3 {
465            return 0.0;
466        }
467
468        let mut changes = 0;
469        for i in 1..(values.len() - 1) {
470            let prev_diff = values[i] - values[i - 1];
471            let curr_diff = values[i + 1] - values[i];
472            if prev_diff * curr_diff < 0.0 {
473                // Sign change
474                changes += 1;
475            }
476        }
477
478        changes as f64 / (values.len() - 2) as f64
479    }
480
481    fn calculate_feature_importance(&self) -> DeviceResult<Vec<FeatureImportance>> {
482        // Simplified feature importance calculation
483        Ok(vec![
484            FeatureImportance {
485                feature_name: "mean_latency".to_string(),
486                importance: 0.3,
487            },
488            FeatureImportance {
489                feature_name: "mean_confidence".to_string(),
490                importance: 0.25,
491            },
492            FeatureImportance {
493                feature_name: "temporal_autocorrelation".to_string(),
494                importance: 0.2,
495            },
496            FeatureImportance {
497                feature_name: "latency_confidence_correlation".to_string(),
498                importance: 0.15,
499            },
500            FeatureImportance {
501                feature_name: "measurement_rate".to_string(),
502                importance: 0.1,
503            },
504        ])
505    }
506
507    fn detect_feature_drift(
508        &self,
509        current_features: &MLFeatures,
510        training_features: &MLFeatures,
511    ) -> DeviceResult<f64> {
512        // Simple drift detection using statistical distance
513        let current_mean_latency = current_features.statistical_features.mean_latency;
514        let training_mean_latency = training_features.statistical_features.mean_latency;
515
516        let latency_drift = if training_mean_latency > 1e-10 {
517            (current_mean_latency - training_mean_latency).abs() / training_mean_latency
518        } else {
519            0.0
520        };
521
522        let current_mean_confidence = current_features.statistical_features.mean_confidence;
523        let training_mean_confidence = training_features.statistical_features.mean_confidence;
524
525        let confidence_drift = if training_mean_confidence > 1e-10 {
526            (current_mean_confidence - training_mean_confidence).abs() / training_mean_confidence
527        } else {
528            0.0
529        };
530
531        Ok(f64::midpoint(latency_drift, confidence_drift))
532    }
533
534    fn initialize_model_parameters(&self) -> DeviceResult<Vec<f64>> {
535        // Initialize simple linear model parameters
536        Ok(vec![1.0, 0.0, 0.5, 0.3, 0.2]) // weights for different features
537    }
538}