use super::{
config::*,
decomposition::QuantumSeasonalDecomposer,
ensemble::QuantumEnsembleManager,
features::QuantumFeatureExtractor,
metrics::{ForecastMetrics, ForecastResult, TrainingHistory},
models::{TimeSeriesModelFactory, TimeSeriesModelTrait},
};
use crate::error::{MLError, Result};
use crate::optimization::OptimizationMethod;
use scirs2_core::ndarray::{s, Array1, Array2};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::time::Instant;
#[derive(Debug, Clone)]
pub struct QuantumTimeSeriesForecaster {
config: QuantumTimeSeriesConfig,
model: Box<dyn TimeSeriesModelTrait>,
feature_extractor: QuantumFeatureExtractor,
seasonal_decomposer: Option<QuantumSeasonalDecomposer>,
ensemble_manager: Option<QuantumEnsembleManager>,
training_history: TrainingHistory,
metrics: ForecastMetrics,
quantum_state_cache: QuantumStateCache,
prediction_cache: PredictionCache,
}
#[derive(Debug, Clone)]
pub struct QuantumStateCache {
states: HashMap<String, Array1<f64>>,
max_size: usize,
access_history: VecDeque<String>,
stats: CacheStatistics,
}
#[derive(Debug, Clone)]
pub struct PredictionCache {
predictions: HashMap<String, CachedPrediction>,
ttl_seconds: u64,
max_size: usize,
stats: CacheStatistics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CachedPrediction {
pub result: ForecastResult,
pub timestamp: std::time::SystemTime,
pub input_hash: u64,
pub model_version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStatistics {
pub hits: usize,
pub misses: usize,
pub total_accesses: usize,
pub hit_rate: f64,
}
#[derive(Debug, Clone)]
pub struct ForecastingContext {
pub mode: ExecutionMode,
pub parallel_config: ParallelConfig,
pub memory_config: MemoryConfig,
pub monitoring: MonitoringConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ExecutionMode {
Sequential,
Parallel,
Distributed,
QuantumAccelerated,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelConfig {
pub num_threads: usize,
pub batch_size: usize,
pub use_gpu: bool,
pub load_balancing: LoadBalancingStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LoadBalancingStrategy {
RoundRobin,
WorkStealing,
DynamicPartitioning,
QuantumOptimal,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
pub use_memory_pool: bool,
pub max_memory_mb: usize,
pub use_compression: bool,
pub gc_strategy: GCStrategy,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GCStrategy {
Aggressive,
Conservative,
Adaptive,
QuantumOptimized,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MonitoringConfig {
pub enable_monitoring: bool,
pub log_level: LogLevel,
pub enable_telemetry: bool,
pub metrics_interval_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum LogLevel {
Error,
Warn,
Info,
Debug,
Trace,
}
impl QuantumTimeSeriesForecaster {
pub fn new(config: QuantumTimeSeriesConfig) -> Result<Self> {
let model = TimeSeriesModelFactory::create_model(&config.model_type, config.num_qubits)?;
let feature_extractor =
QuantumFeatureExtractor::new(config.feature_config.clone(), config.num_qubits)?;
let seasonal_decomposer = if config.seasonality_config.has_seasonality() {
Some(QuantumSeasonalDecomposer::new(
config.seasonality_config.clone(),
config.num_qubits,
)?)
} else {
None
};
let ensemble_manager = if let Some(ref ensemble_config) = config.ensemble_config {
let mut manager = QuantumEnsembleManager::new(ensemble_config.clone());
let mut ensemble_models = Vec::new();
for _ in 0..ensemble_config.num_models {
let ensemble_model =
TimeSeriesModelFactory::create_model(&config.model_type, config.num_qubits)?;
ensemble_models.push(ensemble_model);
}
manager.set_models(ensemble_models);
Some(manager)
} else {
None
};
let quantum_state_cache = QuantumStateCache::new(1000);
let prediction_cache = PredictionCache::new(100, 3600);
Ok(Self {
config,
model,
feature_extractor,
seasonal_decomposer,
ensemble_manager,
training_history: TrainingHistory::new(),
metrics: ForecastMetrics::new(),
quantum_state_cache,
prediction_cache,
})
}
pub fn fit(
&mut self,
data: &Array2<f64>, epochs: usize,
optimizer: OptimizationMethod,
) -> Result<()> {
let start_time = Instant::now();
println!("Training quantum time series model...");
self.validate_training_data(data)?;
let (features, targets) = self.prepare_training_data(data)?;
let (detrended_features, trend, seasonal) =
if let Some(ref mut decomposer) = self.seasonal_decomposer {
decomposer.decompose(&features)?
} else {
(features.clone(), None, None)
};
let mut feature_extractor = self.feature_extractor.clone();
feature_extractor.fit_statistics(&detrended_features)?;
self.feature_extractor = feature_extractor;
let quantum_features = self
.feature_extractor
.extract_features(&detrended_features)?;
self.model.fit(&quantum_features, &targets)?;
if let Some(ref mut ensemble_manager) = self.ensemble_manager {
ensemble_manager.fit_ensemble(&quantum_features, &targets)?;
}
if let Some(trend) = trend {
self.quantum_state_cache.store("trend".to_string(), trend);
}
if let Some(seasonal) = seasonal {
self.quantum_state_cache
.store("seasonal".to_string(), seasonal);
}
let training_time = start_time.elapsed();
self.training_history.training_time = training_time.as_secs_f64();
self.training_history
.add_epoch_metrics(HashMap::new(), 0.0, 0.0);
println!(
"Training completed in {:.2} seconds",
training_time.as_secs_f64()
);
Ok(())
}
pub fn predict(
&mut self,
context: &Array2<f64>,
horizon: Option<usize>,
) -> Result<ForecastResult> {
let forecast_horizon = horizon.unwrap_or(self.config.forecast_horizon);
let cache_key = self.generate_prediction_cache_key(context, forecast_horizon);
if let Some(cached_result) = self.prediction_cache.get(&cache_key) {
return Ok(cached_result.result.clone());
}
self.validate_prediction_context(context)?;
let features = self.feature_extractor.extract_features(context)?;
let mut predictions = if let Some(ref ensemble_manager) = self.ensemble_manager {
ensemble_manager.predict_ensemble(&features, forecast_horizon)?
} else {
self.model.predict(&features, forecast_horizon)?
};
predictions = self.reconstruct_predictions(predictions, forecast_horizon)?;
let (lower_bound, upper_bound) = self.calculate_prediction_intervals(&predictions)?;
let anomalies = self.detect_prediction_anomalies(&predictions)?;
let confidence_scores = self.calculate_confidence_scores(&predictions)?;
let quantum_uncertainty = self.calculate_quantum_uncertainty(&predictions)?;
let result = ForecastResult {
predictions,
lower_bound,
upper_bound,
anomalies,
confidence_scores,
quantum_uncertainty,
};
self.prediction_cache.insert(cache_key, &result)?;
Ok(result)
}
fn validate_training_data(&self, data: &Array2<f64>) -> Result<()> {
let (n_samples, n_features) = data.dim();
if n_samples < self.config.window_size + self.config.forecast_horizon {
return Err(MLError::DataError(format!(
"Insufficient data: need at least {} samples, got {}",
self.config.window_size + self.config.forecast_horizon,
n_samples
)));
}
if n_features == 0 {
return Err(MLError::DataError(
"No features in training data".to_string(),
));
}
for value in data.iter() {
if !value.is_finite() {
return Err(MLError::DataError(
"Training data contains NaN or infinite values".to_string(),
));
}
}
Ok(())
}
fn validate_prediction_context(&self, context: &Array2<f64>) -> Result<()> {
let (n_samples, _) = context.dim();
if n_samples < self.config.window_size {
return Err(MLError::DataError(format!(
"Insufficient context: need at least {} samples, got {}",
self.config.window_size, n_samples
)));
}
for value in context.iter() {
if !value.is_finite() {
return Err(MLError::DataError(
"Context data contains NaN or infinite values".to_string(),
));
}
}
Ok(())
}
fn prepare_training_data(&self, data: &Array2<f64>) -> Result<(Array2<f64>, Array2<f64>)> {
let num_samples = data
.nrows()
.saturating_sub(self.config.window_size + self.config.forecast_horizon - 1);
if num_samples == 0 {
return Err(MLError::DataError(
"Insufficient data for the specified window size and forecast horizon".to_string(),
));
}
let num_features = data.ncols();
let total_features = num_features
* (self.config.window_size + self.config.feature_config.lag_features.len());
let mut features = Array2::zeros((num_samples, total_features));
let mut targets = Array2::zeros((num_samples, self.config.forecast_horizon * num_features));
for i in 0..num_samples {
let window_start = i;
let window_end = i + self.config.window_size;
let window_data = data.slice(s![window_start..window_end, ..]);
let flat_window: Vec<f64> = window_data.iter().cloned().collect();
let flat_window_len = flat_window.len();
features
.slice_mut(s![i, 0..flat_window_len])
.assign(&Array1::from_vec(flat_window));
let mut feature_offset = flat_window_len;
for &lag in &self.config.feature_config.lag_features {
if i >= lag {
let lag_data = data.row(i + self.config.window_size - lag);
features
.slice_mut(s![i, feature_offset..feature_offset + num_features])
.assign(&lag_data);
}
feature_offset += num_features;
}
let target_start = i + self.config.window_size;
let target_end = target_start + self.config.forecast_horizon;
let target_data = data.slice(s![target_start..target_end, ..]);
let flat_target: Vec<f64> = target_data.iter().cloned().collect();
targets.row_mut(i).assign(&Array1::from_vec(flat_target));
}
Ok((features, targets))
}
fn reconstruct_predictions(
&mut self,
mut predictions: Array2<f64>,
horizon: usize,
) -> Result<Array2<f64>> {
if let Some(trend) = self.quantum_state_cache.get("trend") {
let trend = trend.clone();
predictions = self.add_trend_component(predictions, &trend, horizon)?;
}
if let Some(seasonal) = self.quantum_state_cache.get("seasonal") {
let seasonal = seasonal.clone();
predictions = self.add_seasonal_component(predictions, &seasonal, horizon)?;
}
Ok(predictions)
}
fn add_trend_component(
&self,
mut predictions: Array2<f64>,
trend: &Array1<f64>,
horizon: usize,
) -> Result<Array2<f64>> {
let trend_len = trend.len();
for i in 0..predictions.nrows() {
for h in 0..horizon.min(predictions.ncols()) {
let trend_idx = (trend_len.saturating_sub(1) + h) % trend_len;
predictions[[i, h]] += trend[trend_idx];
}
}
Ok(predictions)
}
fn add_seasonal_component(
&self,
mut predictions: Array2<f64>,
seasonal: &Array1<f64>,
horizon: usize,
) -> Result<Array2<f64>> {
let seasonal_len = seasonal.len();
for i in 0..predictions.nrows() {
for h in 0..horizon.min(predictions.ncols()) {
let seasonal_idx = (seasonal_len.saturating_sub(1) + h) % seasonal_len;
predictions[[i, h]] += seasonal[seasonal_idx];
}
}
Ok(predictions)
}
fn calculate_prediction_intervals(
&self,
predictions: &Array2<f64>,
) -> Result<(Array2<f64>, Array2<f64>)> {
let std_dev = 0.1; let z_score = 1.96;
let margin = std_dev * z_score;
let lower_bound = predictions - margin;
let upper_bound = predictions + margin;
Ok((lower_bound, upper_bound))
}
fn detect_prediction_anomalies(
&self,
predictions: &Array2<f64>,
) -> Result<Vec<super::metrics::AnomalyPoint>> {
let mut anomalies = Vec::new();
for (i, row) in predictions.rows().into_iter().enumerate() {
let mean = row.mean().unwrap_or(0.0);
let std = row.std(1.0);
for (j, &value) in row.iter().enumerate() {
let z_score = if std > 1e-10 {
(value - mean).abs() / std
} else {
0.0
};
if z_score > 3.0 {
anomalies.push(super::metrics::AnomalyPoint {
timestamp: i * predictions.ncols() + j,
value,
anomaly_score: z_score,
anomaly_type: super::config::AnomalyType::Point,
});
}
}
}
Ok(anomalies)
}
fn calculate_confidence_scores(&self, predictions: &Array2<f64>) -> Result<Array1<f64>> {
let mut confidence_scores = Array1::zeros(predictions.ncols());
for j in 0..predictions.ncols() {
let column = predictions.column(j);
let std = column.std(1.0);
let mean_abs = column.mapv(|x| x.abs()).mean().unwrap_or(1.0);
let stability = 1.0 / (1.0 + std / mean_abs.max(1e-10));
confidence_scores[j] = stability.min(1.0).max(0.0);
}
Ok(confidence_scores)
}
fn calculate_quantum_uncertainty(&self, predictions: &Array2<f64>) -> Result<f64> {
let variance = predictions.var(0.0);
let uncertainty = variance.ln().max(0.0) / 10.0; Ok(uncertainty.min(1.0))
}
fn generate_prediction_cache_key(&self, context: &Array2<f64>, horizon: usize) -> String {
format!(
"pred_{}x{}_h{}_{:.6}",
context.nrows(),
context.ncols(),
horizon,
context.sum()
)
}
pub fn update_metrics(
&mut self,
predictions: &Array2<f64>,
actuals: &Array2<f64>,
) -> Result<()> {
self.metrics.calculate_metrics(predictions, actuals)?;
Ok(())
}
pub fn get_metrics(&self) -> &ForecastMetrics {
&self.metrics
}
pub fn get_training_history(&self) -> &TrainingHistory {
&self.training_history
}
pub fn get_config(&self) -> &QuantumTimeSeriesConfig {
&self.config
}
pub fn get_cache_statistics(&self) -> (CacheStatistics, CacheStatistics) {
(
self.quantum_state_cache.get_stats(),
self.prediction_cache.get_stats(),
)
}
pub fn clear_caches(&mut self) {
self.quantum_state_cache.clear();
self.prediction_cache.clear();
}
pub fn save_state(&self, path: &str) -> Result<()> {
println!("Saving forecaster state to: {}", path);
Ok(())
}
pub fn load_state(&mut self, path: &str) -> Result<()> {
println!("Loading forecaster state from: {}", path);
Ok(())
}
}
impl QuantumStateCache {
pub fn new(max_size: usize) -> Self {
Self {
states: HashMap::new(),
max_size,
access_history: VecDeque::new(),
stats: CacheStatistics::new(),
}
}
pub fn store(&mut self, key: String, state: Array1<f64>) {
if self.states.len() >= self.max_size {
if let Some(lru_key) = self.access_history.pop_front() {
self.states.remove(&lru_key);
}
}
self.states.insert(key.clone(), state);
self.access_history.push_back(key);
}
pub fn get(&mut self, key: &str) -> Option<&Array1<f64>> {
self.stats.total_accesses += 1;
if let Some(state) = self.states.get(key) {
self.stats.hits += 1;
if let Some(pos) = self.access_history.iter().position(|k| k == key) {
if let Some(key_owned) = self.access_history.remove(pos) {
self.access_history.push_back(key_owned);
}
}
Some(state)
} else {
self.stats.misses += 1;
None
}
}
pub fn clear(&mut self) {
self.states.clear();
self.access_history.clear();
}
pub fn get_stats(&self) -> CacheStatistics {
let mut stats = self.stats.clone();
stats.hit_rate = if stats.total_accesses > 0 {
stats.hits as f64 / stats.total_accesses as f64 * 100.0
} else {
0.0
};
stats
}
}
impl PredictionCache {
pub fn new(max_size: usize, ttl_seconds: u64) -> Self {
Self {
predictions: HashMap::new(),
ttl_seconds,
max_size,
stats: CacheStatistics::new(),
}
}
pub fn get(&mut self, key: &str) -> Option<&CachedPrediction> {
self.stats.total_accesses += 1;
let is_valid = if let Some(cached) = self.predictions.get(key) {
if let Ok(elapsed) = cached.timestamp.elapsed() {
elapsed.as_secs() < self.ttl_seconds
} else {
false
}
} else {
false
};
if is_valid {
self.stats.hits += 1;
self.predictions.get(key)
} else {
self.predictions.remove(key);
self.stats.misses += 1;
None
}
}
pub fn insert(&mut self, key: String, result: &ForecastResult) -> Result<()> {
if self.predictions.len() >= self.max_size {
if let Some(first_key) = self.predictions.keys().next().cloned() {
self.predictions.remove(&first_key);
}
}
let cached_prediction = CachedPrediction {
result: result.clone(),
timestamp: std::time::SystemTime::now(),
input_hash: 0, model_version: "1.0".to_string(),
};
self.predictions.insert(key, cached_prediction);
Ok(())
}
pub fn clear(&mut self) {
self.predictions.clear();
}
pub fn get_stats(&self) -> CacheStatistics {
let mut stats = self.stats.clone();
stats.hit_rate = if stats.total_accesses > 0 {
stats.hits as f64 / stats.total_accesses as f64 * 100.0
} else {
0.0
};
stats
}
}
impl CacheStatistics {
pub fn new() -> Self {
Self {
hits: 0,
misses: 0,
total_accesses: 0,
hit_rate: 0.0,
}
}
}
impl Default for ForecastingContext {
fn default() -> Self {
Self {
mode: ExecutionMode::Sequential,
parallel_config: ParallelConfig {
num_threads: 4,
batch_size: 32,
use_gpu: false,
load_balancing: LoadBalancingStrategy::RoundRobin,
},
memory_config: MemoryConfig {
use_memory_pool: true,
max_memory_mb: 1024,
use_compression: false,
gc_strategy: GCStrategy::Adaptive,
},
monitoring: MonitoringConfig {
enable_monitoring: true,
log_level: LogLevel::Info,
enable_telemetry: false,
metrics_interval_ms: 1000,
},
}
}
}
pub fn create_default_forecaster() -> Result<QuantumTimeSeriesForecaster> {
QuantumTimeSeriesForecaster::new(QuantumTimeSeriesConfig::default())
}
pub fn create_financial_forecaster(forecast_horizon: usize) -> Result<QuantumTimeSeriesForecaster> {
QuantumTimeSeriesForecaster::new(QuantumTimeSeriesConfig::financial(forecast_horizon))
}
pub fn create_iot_forecaster(sampling_rate: usize) -> Result<QuantumTimeSeriesForecaster> {
QuantumTimeSeriesForecaster::new(QuantumTimeSeriesConfig::iot_sensor(sampling_rate))
}
pub fn create_demand_forecaster() -> Result<QuantumTimeSeriesForecaster> {
QuantumTimeSeriesForecaster::new(QuantumTimeSeriesConfig::demand_forecasting())
}