1use super::{
4 config::*,
5 decomposition::QuantumSeasonalDecomposer,
6 ensemble::QuantumEnsembleManager,
7 features::QuantumFeatureExtractor,
8 metrics::{ForecastMetrics, ForecastResult, TrainingHistory},
9 models::{TimeSeriesModelFactory, TimeSeriesModelTrait},
10};
11use crate::error::{MLError, Result};
12use crate::optimization::OptimizationMethod;
13use ndarray::{s, Array1, Array2};
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, VecDeque};
16use std::time::Instant;
17
18#[derive(Debug, Clone)]
20pub struct QuantumTimeSeriesForecaster {
21 config: QuantumTimeSeriesConfig,
23
24 model: Box<dyn TimeSeriesModelTrait>,
26
27 feature_extractor: QuantumFeatureExtractor,
29
30 seasonal_decomposer: Option<QuantumSeasonalDecomposer>,
32
33 ensemble_manager: Option<QuantumEnsembleManager>,
35
36 training_history: TrainingHistory,
38
39 metrics: ForecastMetrics,
41
42 quantum_state_cache: QuantumStateCache,
44
45 prediction_cache: PredictionCache,
47}
48
49#[derive(Debug, Clone)]
51pub struct QuantumStateCache {
52 states: HashMap<String, Array1<f64>>,
54
55 max_size: usize,
57
58 access_history: VecDeque<String>,
60
61 stats: CacheStatistics,
63}
64
65#[derive(Debug, Clone)]
67pub struct PredictionCache {
68 predictions: HashMap<String, CachedPrediction>,
70
71 ttl_seconds: u64,
73
74 max_size: usize,
76
77 stats: CacheStatistics,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct CachedPrediction {
84 pub result: ForecastResult,
86
87 pub timestamp: std::time::SystemTime,
89
90 pub input_hash: u64,
92
93 pub model_version: String,
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct CacheStatistics {
100 pub hits: usize,
102
103 pub misses: usize,
105
106 pub total_accesses: usize,
108
109 pub hit_rate: f64,
111}
112
113#[derive(Debug, Clone)]
115pub struct ForecastingContext {
116 pub mode: ExecutionMode,
118
119 pub parallel_config: ParallelConfig,
121
122 pub memory_config: MemoryConfig,
124
125 pub monitoring: MonitoringConfig,
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
131pub enum ExecutionMode {
132 Sequential,
134
135 Parallel,
137
138 Distributed,
140
141 QuantumAccelerated,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct ParallelConfig {
148 pub num_threads: usize,
150
151 pub batch_size: usize,
153
154 pub use_gpu: bool,
156
157 pub load_balancing: LoadBalancingStrategy,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163pub enum LoadBalancingStrategy {
164 RoundRobin,
165 WorkStealing,
166 DynamicPartitioning,
167 QuantumOptimal,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
172pub struct MemoryConfig {
173 pub use_memory_pool: bool,
175
176 pub max_memory_mb: usize,
178
179 pub use_compression: bool,
181
182 pub gc_strategy: GCStrategy,
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
188pub enum GCStrategy {
189 Aggressive,
190 Conservative,
191 Adaptive,
192 QuantumOptimized,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
197pub struct MonitoringConfig {
198 pub enable_monitoring: bool,
200
201 pub log_level: LogLevel,
203
204 pub enable_telemetry: bool,
206
207 pub metrics_interval_ms: u64,
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize)]
213pub enum LogLevel {
214 Error,
215 Warn,
216 Info,
217 Debug,
218 Trace,
219}
220
221impl QuantumTimeSeriesForecaster {
222 pub fn new(config: QuantumTimeSeriesConfig) -> Result<Self> {
224 let model = TimeSeriesModelFactory::create_model(&config.model_type, config.num_qubits)?;
226
227 let feature_extractor =
229 QuantumFeatureExtractor::new(config.feature_config.clone(), config.num_qubits)?;
230
231 let seasonal_decomposer = if config.seasonality_config.has_seasonality() {
233 Some(QuantumSeasonalDecomposer::new(
234 config.seasonality_config.clone(),
235 config.num_qubits,
236 )?)
237 } else {
238 None
239 };
240
241 let ensemble_manager = if let Some(ref ensemble_config) = config.ensemble_config {
243 let mut manager = QuantumEnsembleManager::new(ensemble_config.clone());
244
245 let mut ensemble_models = Vec::new();
247 for _ in 0..ensemble_config.num_models {
248 let ensemble_model =
249 TimeSeriesModelFactory::create_model(&config.model_type, config.num_qubits)?;
250 ensemble_models.push(ensemble_model);
251 }
252
253 manager.set_models(ensemble_models);
254 Some(manager)
255 } else {
256 None
257 };
258
259 let quantum_state_cache = QuantumStateCache::new(1000);
261 let prediction_cache = PredictionCache::new(100, 3600); Ok(Self {
264 config,
265 model,
266 feature_extractor,
267 seasonal_decomposer,
268 ensemble_manager,
269 training_history: TrainingHistory::new(),
270 metrics: ForecastMetrics::new(),
271 quantum_state_cache,
272 prediction_cache,
273 })
274 }
275
276 pub fn fit(
278 &mut self,
279 data: &Array2<f64>, epochs: usize,
281 optimizer: OptimizationMethod,
282 ) -> Result<()> {
283 let start_time = Instant::now();
284 println!("Training quantum time series model...");
285
286 self.validate_training_data(data)?;
288
289 let (features, targets) = self.prepare_training_data(data)?;
291
292 let (detrended_features, trend, seasonal) =
294 if let Some(ref mut decomposer) = self.seasonal_decomposer {
295 decomposer.decompose(&features)?
296 } else {
297 (features.clone(), None, None)
298 };
299
300 let mut feature_extractor = self.feature_extractor.clone();
302 feature_extractor.fit_statistics(&detrended_features)?;
303 self.feature_extractor = feature_extractor;
304
305 let quantum_features = self
307 .feature_extractor
308 .extract_features(&detrended_features)?;
309
310 self.model.fit(&quantum_features, &targets)?;
312
313 if let Some(ref mut ensemble_manager) = self.ensemble_manager {
315 ensemble_manager.fit_ensemble(&quantum_features, &targets)?;
316 }
317
318 if let Some(trend) = trend {
320 self.quantum_state_cache.store("trend".to_string(), trend);
321 }
322 if let Some(seasonal) = seasonal {
323 self.quantum_state_cache
324 .store("seasonal".to_string(), seasonal);
325 }
326
327 let training_time = start_time.elapsed();
329 self.training_history.training_time = training_time.as_secs_f64();
330 self.training_history
331 .add_epoch_metrics(HashMap::new(), 0.0, 0.0);
332
333 println!(
334 "Training completed in {:.2} seconds",
335 training_time.as_secs_f64()
336 );
337 Ok(())
338 }
339
340 pub fn predict(
342 &mut self,
343 context: &Array2<f64>,
344 horizon: Option<usize>,
345 ) -> Result<ForecastResult> {
346 let forecast_horizon = horizon.unwrap_or(self.config.forecast_horizon);
347
348 let cache_key = self.generate_prediction_cache_key(context, forecast_horizon);
350 if let Some(cached_result) = self.prediction_cache.get(&cache_key) {
351 return Ok(cached_result.result.clone());
352 }
353
354 self.validate_prediction_context(context)?;
356
357 let features = self.feature_extractor.extract_features(context)?;
359
360 let mut predictions = if let Some(ref ensemble_manager) = self.ensemble_manager {
362 ensemble_manager.predict_ensemble(&features, forecast_horizon)?
364 } else {
365 self.model.predict(&features, forecast_horizon)?
367 };
368
369 predictions = self.reconstruct_predictions(predictions, forecast_horizon)?;
371
372 let (lower_bound, upper_bound) = self.calculate_prediction_intervals(&predictions)?;
374
375 let anomalies = self.detect_prediction_anomalies(&predictions)?;
377
378 let confidence_scores = self.calculate_confidence_scores(&predictions)?;
380
381 let quantum_uncertainty = self.calculate_quantum_uncertainty(&predictions)?;
383
384 let result = ForecastResult {
385 predictions,
386 lower_bound,
387 upper_bound,
388 anomalies,
389 confidence_scores,
390 quantum_uncertainty,
391 };
392
393 self.prediction_cache.insert(cache_key, &result)?;
395
396 Ok(result)
397 }
398
399 fn validate_training_data(&self, data: &Array2<f64>) -> Result<()> {
401 let (n_samples, n_features) = data.dim();
402
403 if n_samples < self.config.window_size + self.config.forecast_horizon {
404 return Err(MLError::DataError(format!(
405 "Insufficient data: need at least {} samples, got {}",
406 self.config.window_size + self.config.forecast_horizon,
407 n_samples
408 )));
409 }
410
411 if n_features == 0 {
412 return Err(MLError::DataError(
413 "No features in training data".to_string(),
414 ));
415 }
416
417 for value in data.iter() {
419 if !value.is_finite() {
420 return Err(MLError::DataError(
421 "Training data contains NaN or infinite values".to_string(),
422 ));
423 }
424 }
425
426 Ok(())
427 }
428
429 fn validate_prediction_context(&self, context: &Array2<f64>) -> Result<()> {
431 let (n_samples, _) = context.dim();
432
433 if n_samples < self.config.window_size {
434 return Err(MLError::DataError(format!(
435 "Insufficient context: need at least {} samples, got {}",
436 self.config.window_size, n_samples
437 )));
438 }
439
440 for value in context.iter() {
442 if !value.is_finite() {
443 return Err(MLError::DataError(
444 "Context data contains NaN or infinite values".to_string(),
445 ));
446 }
447 }
448
449 Ok(())
450 }
451
452 fn prepare_training_data(&self, data: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
454 let num_samples = data
455 .nrows()
456 .saturating_sub(self.config.window_size + self.config.forecast_horizon - 1);
457
458 if num_samples == 0 {
459 return Err(MLError::DataError(
460 "Insufficient data for the specified window size and forecast horizon".to_string(),
461 ));
462 }
463
464 let num_features = data.ncols();
465 let total_features = num_features
466 * (self.config.window_size + self.config.feature_config.lag_features.len());
467
468 let mut features = Array2::zeros((num_samples, total_features));
469 let mut targets = Array2::zeros((num_samples, self.config.forecast_horizon * num_features));
470
471 for i in 0..num_samples {
472 let window_start = i;
474 let window_end = i + self.config.window_size;
475 let window_data = data.slice(s![window_start..window_end, ..]);
476
477 let flat_window: Vec<f64> = window_data.iter().cloned().collect();
479 let flat_window_len = flat_window.len();
480 features
481 .slice_mut(s![i, 0..flat_window_len])
482 .assign(&Array1::from_vec(flat_window));
483
484 let mut feature_offset = flat_window_len;
486 for &lag in &self.config.feature_config.lag_features {
487 if i >= lag {
488 let lag_data = data.row(i + self.config.window_size - lag);
489 features
490 .slice_mut(s![i, feature_offset..feature_offset + num_features])
491 .assign(&lag_data);
492 }
493 feature_offset += num_features;
494 }
495
496 let target_start = i + self.config.window_size;
498 let target_end = target_start + self.config.forecast_horizon;
499 let target_data = data.slice(s![target_start..target_end, ..]);
500 let flat_target: Vec<f64> = target_data.iter().cloned().collect();
501 targets.row_mut(i).assign(&Array1::from_vec(flat_target));
502 }
503
504 Ok((features, targets))
505 }
506
507 fn reconstruct_predictions(
509 &mut self,
510 mut predictions: Array2<f64>,
511 horizon: usize,
512 ) -> Result<Array2<f64>> {
513 if let Some(trend) = self.quantum_state_cache.get("trend") {
515 let trend = trend.clone();
516 predictions = self.add_trend_component(predictions, &trend, horizon)?;
517 }
518
519 if let Some(seasonal) = self.quantum_state_cache.get("seasonal") {
521 let seasonal = seasonal.clone();
522 predictions = self.add_seasonal_component(predictions, &seasonal, horizon)?;
523 }
524
525 Ok(predictions)
526 }
527
528 fn add_trend_component(
530 &self,
531 mut predictions: Array2<f64>,
532 trend: &Array1<f64>,
533 horizon: usize,
534 ) -> Result<Array2<f64>> {
535 let trend_len = trend.len();
536
537 for i in 0..predictions.nrows() {
538 for h in 0..horizon.min(predictions.ncols()) {
539 let trend_idx = (trend_len.saturating_sub(1) + h) % trend_len;
540 predictions[[i, h]] += trend[trend_idx];
541 }
542 }
543
544 Ok(predictions)
545 }
546
547 fn add_seasonal_component(
549 &self,
550 mut predictions: Array2<f64>,
551 seasonal: &Array1<f64>,
552 horizon: usize,
553 ) -> Result<Array2<f64>> {
554 let seasonal_len = seasonal.len();
555
556 for i in 0..predictions.nrows() {
557 for h in 0..horizon.min(predictions.ncols()) {
558 let seasonal_idx = (seasonal_len.saturating_sub(1) + h) % seasonal_len;
559 predictions[[i, h]] += seasonal[seasonal_idx];
560 }
561 }
562
563 Ok(predictions)
564 }
565
566 fn calculate_prediction_intervals(
568 &self,
569 predictions: &Array2<f64>,
570 ) -> Result<(Array2<f64>, Array2<f64>)> {
571 let std_dev = 0.1; let z_score = 1.96; let margin = std_dev * z_score;
576 let lower_bound = predictions - margin;
577 let upper_bound = predictions + margin;
578
579 Ok((lower_bound, upper_bound))
580 }
581
582 fn detect_prediction_anomalies(
584 &self,
585 predictions: &Array2<f64>,
586 ) -> Result<Vec<super::metrics::AnomalyPoint>> {
587 let mut anomalies = Vec::new();
588
589 for (i, row) in predictions.rows().into_iter().enumerate() {
591 let mean = row.mean().unwrap_or(0.0);
592 let std = row.std(1.0);
593
594 for (j, &value) in row.iter().enumerate() {
595 let z_score = if std > 1e-10 {
596 (value - mean).abs() / std
597 } else {
598 0.0
599 };
600
601 if z_score > 3.0 {
602 anomalies.push(super::metrics::AnomalyPoint {
603 timestamp: i * predictions.ncols() + j,
604 value,
605 anomaly_score: z_score,
606 anomaly_type: super::config::AnomalyType::Point,
607 });
608 }
609 }
610 }
611
612 Ok(anomalies)
613 }
614
615 fn calculate_confidence_scores(&self, predictions: &Array2<f64>) -> Result<Array1<f64>> {
617 let mut confidence_scores = Array1::zeros(predictions.ncols());
618
619 for j in 0..predictions.ncols() {
621 let column = predictions.column(j);
622 let std = column.std(1.0);
623 let mean_abs = column.mapv(|x| x.abs()).mean().unwrap_or(1.0);
624
625 let stability = 1.0 / (1.0 + std / mean_abs.max(1e-10));
627 confidence_scores[j] = stability.min(1.0).max(0.0);
628 }
629
630 Ok(confidence_scores)
631 }
632
633 fn calculate_quantum_uncertainty(&self, predictions: &Array2<f64>) -> Result<f64> {
635 let variance = predictions.var(0.0);
637 let uncertainty = variance.ln().max(0.0) / 10.0; Ok(uncertainty.min(1.0))
639 }
640
641 fn generate_prediction_cache_key(&self, context: &Array2<f64>, horizon: usize) -> String {
643 format!(
645 "pred_{}x{}_h{}_{:.6}",
646 context.nrows(),
647 context.ncols(),
648 horizon,
649 context.sum()
650 )
651 }
652
653 pub fn update_metrics(
655 &mut self,
656 predictions: &Array2<f64>,
657 actuals: &Array2<f64>,
658 ) -> Result<()> {
659 self.metrics.calculate_metrics(predictions, actuals)?;
660 Ok(())
661 }
662
663 pub fn get_metrics(&self) -> &ForecastMetrics {
665 &self.metrics
666 }
667
668 pub fn get_training_history(&self) -> &TrainingHistory {
670 &self.training_history
671 }
672
673 pub fn get_config(&self) -> &QuantumTimeSeriesConfig {
675 &self.config
676 }
677
678 pub fn get_cache_statistics(&self) -> (CacheStatistics, CacheStatistics) {
680 (
681 self.quantum_state_cache.get_stats(),
682 self.prediction_cache.get_stats(),
683 )
684 }
685
686 pub fn clear_caches(&mut self) {
688 self.quantum_state_cache.clear();
689 self.prediction_cache.clear();
690 }
691
692 pub fn save_state(&self, path: &str) -> Result<()> {
694 println!("Saving forecaster state to: {}", path);
696 Ok(())
697 }
698
699 pub fn load_state(&mut self, path: &str) -> Result<()> {
701 println!("Loading forecaster state from: {}", path);
703 Ok(())
704 }
705}
706
707impl QuantumStateCache {
708 pub fn new(max_size: usize) -> Self {
710 Self {
711 states: HashMap::new(),
712 max_size,
713 access_history: VecDeque::new(),
714 stats: CacheStatistics::new(),
715 }
716 }
717
718 pub fn store(&mut self, key: String, state: Array1<f64>) {
720 if self.states.len() >= self.max_size {
722 if let Some(lru_key) = self.access_history.pop_front() {
723 self.states.remove(&lru_key);
724 }
725 }
726
727 self.states.insert(key.clone(), state);
728 self.access_history.push_back(key);
729 }
730
731 pub fn get(&mut self, key: &str) -> Option<&Array1<f64>> {
733 self.stats.total_accesses += 1;
734
735 if let Some(state) = self.states.get(key) {
736 self.stats.hits += 1;
737
738 if let Some(pos) = self.access_history.iter().position(|k| k == key) {
740 if let Some(key_owned) = self.access_history.remove(pos) {
741 self.access_history.push_back(key_owned);
742 }
743 }
744
745 Some(state)
746 } else {
747 self.stats.misses += 1;
748 None
749 }
750 }
751
752 pub fn clear(&mut self) {
754 self.states.clear();
755 self.access_history.clear();
756 }
757
758 pub fn get_stats(&self) -> CacheStatistics {
760 let mut stats = self.stats.clone();
761 stats.hit_rate = if stats.total_accesses > 0 {
762 stats.hits as f64 / stats.total_accesses as f64 * 100.0
763 } else {
764 0.0
765 };
766 stats
767 }
768}
769
770impl PredictionCache {
771 pub fn new(max_size: usize, ttl_seconds: u64) -> Self {
773 Self {
774 predictions: HashMap::new(),
775 ttl_seconds,
776 max_size,
777 stats: CacheStatistics::new(),
778 }
779 }
780
781 pub fn get(&mut self, key: &str) -> Option<&CachedPrediction> {
783 self.stats.total_accesses += 1;
784
785 let is_valid = if let Some(cached) = self.predictions.get(key) {
787 if let Ok(elapsed) = cached.timestamp.elapsed() {
788 elapsed.as_secs() < self.ttl_seconds
789 } else {
790 false
791 }
792 } else {
793 false
794 };
795
796 if is_valid {
797 self.stats.hits += 1;
798 self.predictions.get(key)
799 } else {
800 self.predictions.remove(key);
802 self.stats.misses += 1;
803 None
804 }
805 }
806
807 pub fn insert(&mut self, key: String, result: &ForecastResult) -> Result<()> {
809 if self.predictions.len() >= self.max_size {
811 if let Some(first_key) = self.predictions.keys().next().cloned() {
813 self.predictions.remove(&first_key);
814 }
815 }
816
817 let cached_prediction = CachedPrediction {
818 result: result.clone(),
819 timestamp: std::time::SystemTime::now(),
820 input_hash: 0, model_version: "1.0".to_string(),
822 };
823
824 self.predictions.insert(key, cached_prediction);
825 Ok(())
826 }
827
828 pub fn clear(&mut self) {
830 self.predictions.clear();
831 }
832
833 pub fn get_stats(&self) -> CacheStatistics {
835 let mut stats = self.stats.clone();
836 stats.hit_rate = if stats.total_accesses > 0 {
837 stats.hits as f64 / stats.total_accesses as f64 * 100.0
838 } else {
839 0.0
840 };
841 stats
842 }
843}
844
845impl CacheStatistics {
846 pub fn new() -> Self {
848 Self {
849 hits: 0,
850 misses: 0,
851 total_accesses: 0,
852 hit_rate: 0.0,
853 }
854 }
855}
856
857impl Default for ForecastingContext {
858 fn default() -> Self {
859 Self {
860 mode: ExecutionMode::Sequential,
861 parallel_config: ParallelConfig {
862 num_threads: 4,
863 batch_size: 32,
864 use_gpu: false,
865 load_balancing: LoadBalancingStrategy::RoundRobin,
866 },
867 memory_config: MemoryConfig {
868 use_memory_pool: true,
869 max_memory_mb: 1024,
870 use_compression: false,
871 gc_strategy: GCStrategy::Adaptive,
872 },
873 monitoring: MonitoringConfig {
874 enable_monitoring: true,
875 log_level: LogLevel::Info,
876 enable_telemetry: false,
877 metrics_interval_ms: 1000,
878 },
879 }
880 }
881}
882
883pub fn create_default_forecaster() -> Result<QuantumTimeSeriesForecaster> {
887 QuantumTimeSeriesForecaster::new(QuantumTimeSeriesConfig::default())
888}
889
890pub fn create_financial_forecaster(forecast_horizon: usize) -> Result<QuantumTimeSeriesForecaster> {
892 QuantumTimeSeriesForecaster::new(QuantumTimeSeriesConfig::financial(forecast_horizon))
893}
894
895pub fn create_iot_forecaster(sampling_rate: usize) -> Result<QuantumTimeSeriesForecaster> {
897 QuantumTimeSeriesForecaster::new(QuantumTimeSeriesConfig::iot_sensor(sampling_rate))
898}
899
900pub fn create_demand_forecaster() -> Result<QuantumTimeSeriesForecaster> {
902 QuantumTimeSeriesForecaster::new(QuantumTimeSeriesConfig::demand_forecasting())
903}