Skip to main content

do_memory_mcp/patterns/predictive/forecasting/
engine.rs

1use anyhow::{Result, anyhow};
2use std::collections::HashMap;
3use tracing::{debug, info, instrument, warn};
4
5use super::ets_types::{ETSForecastResult, SeasonalityResult};
6use super::types::{ForecastResult, PredictiveConfig};
7
8#[allow(dead_code)]
9pub struct ForecastingEngine {
10    config: PredictiveConfig,
11}
12
13impl ForecastingEngine {
14    /// Create a new forecasting engine
15    pub fn new() -> Result<Self> {
16        Self::with_config(PredictiveConfig::default())
17    }
18
19    /// Create a new forecasting engine with custom config
20    pub fn with_config(config: PredictiveConfig) -> Result<Self> {
21        Ok(Self { config })
22    }
23
24    /// Generate forecasts for time series data
25    #[instrument(skip(self, data))]
26    pub fn forecast(&mut self, data: &HashMap<String, Vec<f64>>) -> Result<Vec<ForecastResult>> {
27        let mut results = Vec::new();
28
29        info!("Generating forecasts for {} variables", data.len());
30
31        for (var_name, series) in data {
32            if series.len() < 5 {
33                warn!(
34                    "Skipping forecast for {}: insufficient data points",
35                    var_name
36                );
37                continue;
38            }
39
40            let forecast_result = self.forecast_variable(var_name, series)?;
41            results.push(forecast_result);
42        }
43
44        debug!("Generated {} forecasts", results.len());
45        Ok(results)
46    }
47
48    /// Main ETS forecasting function - replaces placeholder
49    fn forecast_variable(&mut self, variable: &str, series: &[f64]) -> Result<ForecastResult> {
50        // Sample data if too large
51        let data = if series.len() > self.config.reservoir_size {
52            series
53                .iter()
54                .take(self.config.reservoir_size)
55                .copied()
56                .collect()
57        } else {
58            series.to_vec()
59        };
60
61        if data.len() < 2 {
62            return Err(anyhow!("Insufficient data for ETS forecasting"));
63        }
64
65        // Detect seasonality
66        let seasonality = self.detect_seasonality(&data)?;
67        let period = seasonality.period;
68
69        // Try all ETS model combinations and select best
70        let best_result = self.select_and_fit_ets_model(&data, period)?;
71
72        // Generate multi-step forecasts
73        let forecasts = self.forecast_ets(
74            &best_result.model,
75            &best_result.state,
76            self.config.forecast_horizon,
77        )?;
78
79        // Calculate confidence intervals
80        let (lower_bounds, upper_bounds) = self.calculate_confidence_intervals(
81            &best_result.model,
82            &forecasts,
83            &best_result.state,
84            0.95, // 95% confidence
85        );
86
87        Ok(ForecastResult {
88            variable: variable.to_string(),
89            point_forecasts: forecasts,
90            lower_bounds,
91            upper_bounds,
92            fit_quality: best_result.fit_quality,
93            method: format!(
94                "ETS-{}{}{}",
95                best_result.model.error.as_str(),
96                best_result.model.trend.as_str(),
97                best_result.model.seasonal.as_str()
98            ),
99        })
100    }
101
102    /// Calculate forecast fit quality
103    #[allow(dead_code)]
104    fn calculate_fit_quality(&self, actual: &[f64], forecast: &[f64]) -> f64 {
105        if actual.len() < 2 || forecast.is_empty() {
106            return 0.0;
107        }
108
109        // Simple MAPE calculation for last few points
110        let n = actual.len().min(forecast.len().min(10));
111        let start_idx = actual.len().saturating_sub(n);
112
113        let mape: f64 = actual[start_idx..]
114            .iter()
115            .zip(&forecast[..n])
116            .map(|(&a, &f)| {
117                if a != 0.0 {
118                    (a - f).abs() / a.abs()
119                } else {
120                    0.0
121                }
122            })
123            .sum::<f64>()
124            / n as f64;
125
126        // Convert MAPE to quality score (lower MAPE = higher quality)
127        (1.0 - mape.min(1.0)).max(0.0)
128    }
129
130    /// Automatic seasonality detection using autocorrelation
131    fn detect_seasonality(&self, series: &[f64]) -> Result<SeasonalityResult> {
132        if series.len() < 10 {
133            return Ok(SeasonalityResult {
134                period: 0,
135                strength: 0.0,
136            });
137        }
138
139        let max_period = (series.len() / 2).min(12); // Limit seasonal periods
140
141        // Collect strengths for each candidate period.
142        let mut strengths: Vec<(usize, f64)> = Vec::new();
143        for period in 2..=max_period {
144            if let Some(strength) = self.calculate_seasonal_strength(series, period) {
145                strengths.push((period, strength));
146            }
147        }
148
149        let Some((_, max_strength)) = strengths
150            .iter()
151            .cloned()
152            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
153        else {
154            return Ok(SeasonalityResult {
155                period: 0,
156                strength: 0.0,
157            });
158        };
159
160        // Prefer realistic short seasonal periods if they are close to the best score.
161        // This reduces autocorrelation artifacts picking period=2 on small synthetic series.
162        let tolerance = 0.02;
163        let mut candidates: Vec<(usize, f64)> = strengths
164            .into_iter()
165            .filter(|(_, s)| *s >= max_strength - tolerance)
166            .collect();
167        candidates.sort_by(|a, b| a.0.cmp(&b.0));
168
169        let (best_period, best_strength) = if let Some((p, s)) = candidates
170            .iter()
171            .find(|(p, _)| (3..=5).contains(p))
172            .copied()
173        {
174            (p, s)
175        } else {
176            // Otherwise choose the smallest period among near-best candidates.
177            candidates[0]
178        };
179
180        Ok(SeasonalityResult {
181            period: if best_strength > 0.1 { best_period } else { 0 },
182            strength: best_strength,
183        })
184    }
185
186    /// Calculate seasonal strength for a given period
187    fn calculate_seasonal_strength(&self, series: &[f64], period: usize) -> Option<f64> {
188        if series.len() < period * 2 {
189            return None;
190        }
191
192        let mut seasonal_means = vec![0.0f64; period];
193        let mut counts = vec![0usize; period];
194
195        for (i, &value) in series.iter().enumerate() {
196            seasonal_means[i % period] += value;
197            counts[i % period] += 1;
198        }
199
200        for (i, &count) in counts.iter().enumerate() {
201            if count > 0 {
202                seasonal_means[i] /= count as f64;
203            }
204        }
205
206        let overall_mean: f64 = series.iter().sum::<f64>() / series.len() as f64;
207        let variance: f64 = series
208            .iter()
209            .map(|&x| (x - overall_mean).powi(2))
210            .sum::<f64>()
211            / series.len() as f64;
212
213        let seasonal_variance: f64 = seasonal_means
214            .iter()
215            .enumerate()
216            .map(|(i, &mean)| {
217                let count = counts[i] as f64;
218                count * (mean - overall_mean).powi(2)
219            })
220            .sum::<f64>()
221            / series.len() as f64;
222
223        if variance > 0.0 {
224            Some((seasonal_variance / variance).sqrt())
225        } else {
226            Some(0.0)
227        }
228    }
229
230    /// Select and fit the best ETS model using information criteria
231    fn select_and_fit_ets_model(&self, series: &[f64], period: usize) -> Result<ETSForecastResult> {
232        if series.len() < 2 {
233            return Err(anyhow!("ETS requires at least 2 observations"));
234        }
235        let models_to_try = self.generate_model_combinations(period);
236        let mut best_result = None;
237        let mut best_aic = f64::INFINITY;
238
239        for model_spec in models_to_try {
240            if let Ok(result) = self.fit_ets_model(series, &model_spec) {
241                if result.aic < best_aic {
242                    best_aic = result.aic;
243                    best_result = Some(result);
244                }
245            }
246        }
247
248        best_result.ok_or_else(|| anyhow!("Failed to fit any ETS model"))
249    }
250}