do_memory_mcp/patterns/predictive/forecasting/
engine.rs1use anyhow::{Result, anyhow};
2use std::collections::HashMap;
3use tracing::{debug, info, instrument, warn};
4
5use super::ets_types::{ETSForecastResult, SeasonalityResult};
6use super::types::{ForecastResult, PredictiveConfig};
7
8#[allow(dead_code)]
9pub struct ForecastingEngine {
10 config: PredictiveConfig,
11}
12
13impl ForecastingEngine {
14 pub fn new() -> Result<Self> {
16 Self::with_config(PredictiveConfig::default())
17 }
18
19 pub fn with_config(config: PredictiveConfig) -> Result<Self> {
21 Ok(Self { config })
22 }
23
24 #[instrument(skip(self, data))]
26 pub fn forecast(&mut self, data: &HashMap<String, Vec<f64>>) -> Result<Vec<ForecastResult>> {
27 let mut results = Vec::new();
28
29 info!("Generating forecasts for {} variables", data.len());
30
31 for (var_name, series) in data {
32 if series.len() < 5 {
33 warn!(
34 "Skipping forecast for {}: insufficient data points",
35 var_name
36 );
37 continue;
38 }
39
40 let forecast_result = self.forecast_variable(var_name, series)?;
41 results.push(forecast_result);
42 }
43
44 debug!("Generated {} forecasts", results.len());
45 Ok(results)
46 }
47
48 fn forecast_variable(&mut self, variable: &str, series: &[f64]) -> Result<ForecastResult> {
50 let data = if series.len() > self.config.reservoir_size {
52 series
53 .iter()
54 .take(self.config.reservoir_size)
55 .copied()
56 .collect()
57 } else {
58 series.to_vec()
59 };
60
61 if data.len() < 2 {
62 return Err(anyhow!("Insufficient data for ETS forecasting"));
63 }
64
65 let seasonality = self.detect_seasonality(&data)?;
67 let period = seasonality.period;
68
69 let best_result = self.select_and_fit_ets_model(&data, period)?;
71
72 let forecasts = self.forecast_ets(
74 &best_result.model,
75 &best_result.state,
76 self.config.forecast_horizon,
77 )?;
78
79 let (lower_bounds, upper_bounds) = self.calculate_confidence_intervals(
81 &best_result.model,
82 &forecasts,
83 &best_result.state,
84 0.95, );
86
87 Ok(ForecastResult {
88 variable: variable.to_string(),
89 point_forecasts: forecasts,
90 lower_bounds,
91 upper_bounds,
92 fit_quality: best_result.fit_quality,
93 method: format!(
94 "ETS-{}{}{}",
95 best_result.model.error.as_str(),
96 best_result.model.trend.as_str(),
97 best_result.model.seasonal.as_str()
98 ),
99 })
100 }
101
102 #[allow(dead_code)]
104 fn calculate_fit_quality(&self, actual: &[f64], forecast: &[f64]) -> f64 {
105 if actual.len() < 2 || forecast.is_empty() {
106 return 0.0;
107 }
108
109 let n = actual.len().min(forecast.len().min(10));
111 let start_idx = actual.len().saturating_sub(n);
112
113 let mape: f64 = actual[start_idx..]
114 .iter()
115 .zip(&forecast[..n])
116 .map(|(&a, &f)| {
117 if a != 0.0 {
118 (a - f).abs() / a.abs()
119 } else {
120 0.0
121 }
122 })
123 .sum::<f64>()
124 / n as f64;
125
126 (1.0 - mape.min(1.0)).max(0.0)
128 }
129
130 fn detect_seasonality(&self, series: &[f64]) -> Result<SeasonalityResult> {
132 if series.len() < 10 {
133 return Ok(SeasonalityResult {
134 period: 0,
135 strength: 0.0,
136 });
137 }
138
139 let max_period = (series.len() / 2).min(12); let mut strengths: Vec<(usize, f64)> = Vec::new();
143 for period in 2..=max_period {
144 if let Some(strength) = self.calculate_seasonal_strength(series, period) {
145 strengths.push((period, strength));
146 }
147 }
148
149 let Some((_, max_strength)) = strengths
150 .iter()
151 .cloned()
152 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
153 else {
154 return Ok(SeasonalityResult {
155 period: 0,
156 strength: 0.0,
157 });
158 };
159
160 let tolerance = 0.02;
163 let mut candidates: Vec<(usize, f64)> = strengths
164 .into_iter()
165 .filter(|(_, s)| *s >= max_strength - tolerance)
166 .collect();
167 candidates.sort_by(|a, b| a.0.cmp(&b.0));
168
169 let (best_period, best_strength) = if let Some((p, s)) = candidates
170 .iter()
171 .find(|(p, _)| (3..=5).contains(p))
172 .copied()
173 {
174 (p, s)
175 } else {
176 candidates[0]
178 };
179
180 Ok(SeasonalityResult {
181 period: if best_strength > 0.1 { best_period } else { 0 },
182 strength: best_strength,
183 })
184 }
185
186 fn calculate_seasonal_strength(&self, series: &[f64], period: usize) -> Option<f64> {
188 if series.len() < period * 2 {
189 return None;
190 }
191
192 let mut seasonal_means = vec![0.0f64; period];
193 let mut counts = vec![0usize; period];
194
195 for (i, &value) in series.iter().enumerate() {
196 seasonal_means[i % period] += value;
197 counts[i % period] += 1;
198 }
199
200 for (i, &count) in counts.iter().enumerate() {
201 if count > 0 {
202 seasonal_means[i] /= count as f64;
203 }
204 }
205
206 let overall_mean: f64 = series.iter().sum::<f64>() / series.len() as f64;
207 let variance: f64 = series
208 .iter()
209 .map(|&x| (x - overall_mean).powi(2))
210 .sum::<f64>()
211 / series.len() as f64;
212
213 let seasonal_variance: f64 = seasonal_means
214 .iter()
215 .enumerate()
216 .map(|(i, &mean)| {
217 let count = counts[i] as f64;
218 count * (mean - overall_mean).powi(2)
219 })
220 .sum::<f64>()
221 / series.len() as f64;
222
223 if variance > 0.0 {
224 Some((seasonal_variance / variance).sqrt())
225 } else {
226 Some(0.0)
227 }
228 }
229
230 fn select_and_fit_ets_model(&self, series: &[f64], period: usize) -> Result<ETSForecastResult> {
232 if series.len() < 2 {
233 return Err(anyhow!("ETS requires at least 2 observations"));
234 }
235 let models_to_try = self.generate_model_combinations(period);
236 let mut best_result = None;
237 let mut best_aic = f64::INFINITY;
238
239 for model_spec in models_to_try {
240 if let Ok(result) = self.fit_ets_model(series, &model_spec) {
241 if result.aic < best_aic {
242 best_aic = result.aic;
243 best_result = Some(result);
244 }
245 }
246 }
247
248 best_result.ok_or_else(|| anyhow!("Failed to fit any ETS model"))
249 }
250}