Skip to main content

do_memory_mcp/patterns/predictive/forecasting/
engine.rs

1use anyhow::{Result, anyhow};
2use std::collections::HashMap;
3use tracing::{debug, info, instrument, warn};
4
5use super::ets_types::{ETSForecastResult, SeasonalityResult};
6use super::types::{ForecastResult, PredictiveConfig};
7
8pub struct ForecastingEngine {
9    config: PredictiveConfig,
10}
11
12impl ForecastingEngine {
13    /// Create a new forecasting engine
14    pub fn new() -> Result<Self> {
15        Self::with_config(PredictiveConfig::default())
16    }
17
18    /// Create a new forecasting engine with custom config
19    pub fn with_config(config: PredictiveConfig) -> Result<Self> {
20        Ok(Self { config })
21    }
22
23    /// Generate forecasts for time series data
24    #[instrument(skip(self, data))]
25    pub fn forecast(&mut self, data: &HashMap<String, Vec<f64>>) -> Result<Vec<ForecastResult>> {
26        let mut results = Vec::new();
27
28        info!("Generating forecasts for {} variables", data.len());
29
30        for (var_name, series) in data {
31            if series.len() < 5 {
32                warn!(
33                    "Skipping forecast for {}: insufficient data points",
34                    var_name
35                );
36                continue;
37            }
38
39            let forecast_result = self.forecast_variable(var_name, series)?;
40            results.push(forecast_result);
41        }
42
43        debug!("Generated {} forecasts", results.len());
44        Ok(results)
45    }
46
47    /// Main ETS forecasting function - replaces placeholder
48    fn forecast_variable(&mut self, variable: &str, series: &[f64]) -> Result<ForecastResult> {
49        // Sample data if too large
50        let data = if series.len() > self.config.reservoir_size {
51            series
52                .iter()
53                .take(self.config.reservoir_size)
54                .copied()
55                .collect()
56        } else {
57            series.to_vec()
58        };
59
60        if data.len() < 2 {
61            return Err(anyhow!("Insufficient data for ETS forecasting"));
62        }
63
64        // Detect seasonality
65        let seasonality = self.detect_seasonality(&data)?;
66        let period = seasonality.period;
67
68        // Try all ETS model combinations and select best
69        let best_result = self.select_and_fit_ets_model(&data, period)?;
70
71        // Generate multi-step forecasts
72        let forecasts = self.forecast_ets(
73            &best_result.model,
74            &best_result.state,
75            self.config.forecast_horizon,
76        )?;
77
78        // Calculate confidence intervals
79        let (lower_bounds, upper_bounds) = self.calculate_confidence_intervals(
80            &best_result.model,
81            &forecasts,
82            &best_result.state,
83            0.95, // 95% confidence
84        );
85
86        Ok(ForecastResult {
87            variable: variable.to_string(),
88            point_forecasts: forecasts,
89            lower_bounds,
90            upper_bounds,
91            fit_quality: best_result.fit_quality,
92            method: format!(
93                "ETS-{}{}{}",
94                best_result.model.error.as_str(),
95                best_result.model.trend.as_str(),
96                best_result.model.seasonal.as_str()
97            ),
98        })
99    }
100
101    /// Calculate forecast fit quality
102    #[allow(dead_code)] // Utility method for forecast quality metrics
103    fn calculate_fit_quality(&self, actual: &[f64], forecast: &[f64]) -> f64 {
104        if actual.len() < 2 || forecast.is_empty() {
105            return 0.0;
106        }
107
108        // Simple MAPE calculation for last few points
109        let n = actual.len().min(forecast.len().min(10));
110        let start_idx = actual.len().saturating_sub(n);
111
112        let mape: f64 = actual[start_idx..]
113            .iter()
114            .zip(&forecast[..n])
115            .map(|(&a, &f)| {
116                if a != 0.0 {
117                    (a - f).abs() / a.abs()
118                } else {
119                    0.0
120                }
121            })
122            .sum::<f64>()
123            / n as f64;
124
125        // Convert MAPE to quality score (lower MAPE = higher quality)
126        (1.0 - mape.min(1.0)).max(0.0)
127    }
128
129    /// Automatic seasonality detection using autocorrelation
130    fn detect_seasonality(&self, series: &[f64]) -> Result<SeasonalityResult> {
131        if series.len() < 10 {
132            return Ok(SeasonalityResult {
133                period: 0,
134                strength: 0.0,
135            });
136        }
137
138        let max_period = (series.len() / 2).min(12); // Limit seasonal periods
139
140        // Collect strengths for each candidate period.
141        let mut strengths: Vec<(usize, f64)> = Vec::new();
142        for period in 2..=max_period {
143            if let Some(strength) = self.calculate_seasonal_strength(series, period) {
144                strengths.push((period, strength));
145            }
146        }
147
148        let Some((_, max_strength)) = strengths
149            .iter()
150            .cloned()
151            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
152        else {
153            return Ok(SeasonalityResult {
154                period: 0,
155                strength: 0.0,
156            });
157        };
158
159        // Prefer realistic short seasonal periods if they are close to the best score.
160        // This reduces autocorrelation artifacts picking period=2 on small synthetic series.
161        let tolerance = 0.02;
162        let mut candidates: Vec<(usize, f64)> = strengths
163            .into_iter()
164            .filter(|(_, s)| *s >= max_strength - tolerance)
165            .collect();
166        candidates.sort_by_key(|a| a.0);
167
168        let (best_period, best_strength) = if let Some((p, s)) = candidates
169            .iter()
170            .find(|(p, _)| (3..=5).contains(p))
171            .copied()
172        {
173            (p, s)
174        } else {
175            // Otherwise choose the smallest period among near-best candidates.
176            candidates[0]
177        };
178
179        Ok(SeasonalityResult {
180            period: if best_strength > 0.1 { best_period } else { 0 },
181            strength: best_strength,
182        })
183    }
184
185    /// Calculate seasonal strength for a given period
186    fn calculate_seasonal_strength(&self, series: &[f64], period: usize) -> Option<f64> {
187        if series.len() < period * 2 {
188            return None;
189        }
190
191        let mut seasonal_means = vec![0.0f64; period];
192        let mut counts = vec![0usize; period];
193
194        for (i, &value) in series.iter().enumerate() {
195            seasonal_means[i % period] += value;
196            counts[i % period] += 1;
197        }
198
199        for (i, &count) in counts.iter().enumerate() {
200            if count > 0 {
201                seasonal_means[i] /= count as f64;
202            }
203        }
204
205        let overall_mean: f64 = series.iter().sum::<f64>() / series.len() as f64;
206        let variance: f64 = series
207            .iter()
208            .map(|&x| (x - overall_mean).powi(2))
209            .sum::<f64>()
210            / series.len() as f64;
211
212        let seasonal_variance: f64 = seasonal_means
213            .iter()
214            .enumerate()
215            .map(|(i, &mean)| {
216                let count = counts[i] as f64;
217                count * (mean - overall_mean).powi(2)
218            })
219            .sum::<f64>()
220            / series.len() as f64;
221
222        if variance > 0.0 {
223            Some((seasonal_variance / variance).sqrt())
224        } else {
225            Some(0.0)
226        }
227    }
228
229    /// Select and fit the best ETS model using information criteria
230    fn select_and_fit_ets_model(&self, series: &[f64], period: usize) -> Result<ETSForecastResult> {
231        if series.len() < 2 {
232            return Err(anyhow!("ETS requires at least 2 observations"));
233        }
234        let models_to_try = self.generate_model_combinations(period);
235        let mut best_result = None;
236        let mut best_aic = f64::INFINITY;
237
238        for model_spec in models_to_try {
239            if let Ok(result) = self.fit_ets_model(series, &model_spec) {
240                if result.aic < best_aic {
241                    best_aic = result.aic;
242                    best_result = Some(result);
243                }
244            }
245        }
246
247        best_result.ok_or_else(|| anyhow!("Failed to fit any ETS model"))
248    }
249}