do_memory_mcp/patterns/predictive/forecasting/
engine.rs1use anyhow::{Result, anyhow};
2use std::collections::HashMap;
3use tracing::{debug, info, instrument, warn};
4
5use super::ets_types::{ETSForecastResult, SeasonalityResult};
6use super::types::{ForecastResult, PredictiveConfig};
7
8pub struct ForecastingEngine {
9 config: PredictiveConfig,
10}
11
12impl ForecastingEngine {
13 pub fn new() -> Result<Self> {
15 Self::with_config(PredictiveConfig::default())
16 }
17
18 pub fn with_config(config: PredictiveConfig) -> Result<Self> {
20 Ok(Self { config })
21 }
22
23 #[instrument(skip(self, data))]
25 pub fn forecast(&mut self, data: &HashMap<String, Vec<f64>>) -> Result<Vec<ForecastResult>> {
26 let mut results = Vec::new();
27
28 info!("Generating forecasts for {} variables", data.len());
29
30 for (var_name, series) in data {
31 if series.len() < 5 {
32 warn!(
33 "Skipping forecast for {}: insufficient data points",
34 var_name
35 );
36 continue;
37 }
38
39 let forecast_result = self.forecast_variable(var_name, series)?;
40 results.push(forecast_result);
41 }
42
43 debug!("Generated {} forecasts", results.len());
44 Ok(results)
45 }
46
47 fn forecast_variable(&mut self, variable: &str, series: &[f64]) -> Result<ForecastResult> {
49 let data = if series.len() > self.config.reservoir_size {
51 series
52 .iter()
53 .take(self.config.reservoir_size)
54 .copied()
55 .collect()
56 } else {
57 series.to_vec()
58 };
59
60 if data.len() < 2 {
61 return Err(anyhow!("Insufficient data for ETS forecasting"));
62 }
63
64 let seasonality = self.detect_seasonality(&data)?;
66 let period = seasonality.period;
67
68 let best_result = self.select_and_fit_ets_model(&data, period)?;
70
71 let forecasts = self.forecast_ets(
73 &best_result.model,
74 &best_result.state,
75 self.config.forecast_horizon,
76 )?;
77
78 let (lower_bounds, upper_bounds) = self.calculate_confidence_intervals(
80 &best_result.model,
81 &forecasts,
82 &best_result.state,
83 0.95, );
85
86 Ok(ForecastResult {
87 variable: variable.to_string(),
88 point_forecasts: forecasts,
89 lower_bounds,
90 upper_bounds,
91 fit_quality: best_result.fit_quality,
92 method: format!(
93 "ETS-{}{}{}",
94 best_result.model.error.as_str(),
95 best_result.model.trend.as_str(),
96 best_result.model.seasonal.as_str()
97 ),
98 })
99 }
100
101 #[allow(dead_code)] fn calculate_fit_quality(&self, actual: &[f64], forecast: &[f64]) -> f64 {
104 if actual.len() < 2 || forecast.is_empty() {
105 return 0.0;
106 }
107
108 let n = actual.len().min(forecast.len().min(10));
110 let start_idx = actual.len().saturating_sub(n);
111
112 let mape: f64 = actual[start_idx..]
113 .iter()
114 .zip(&forecast[..n])
115 .map(|(&a, &f)| {
116 if a != 0.0 {
117 (a - f).abs() / a.abs()
118 } else {
119 0.0
120 }
121 })
122 .sum::<f64>()
123 / n as f64;
124
125 (1.0 - mape.min(1.0)).max(0.0)
127 }
128
129 fn detect_seasonality(&self, series: &[f64]) -> Result<SeasonalityResult> {
131 if series.len() < 10 {
132 return Ok(SeasonalityResult {
133 period: 0,
134 strength: 0.0,
135 });
136 }
137
138 let max_period = (series.len() / 2).min(12); let mut strengths: Vec<(usize, f64)> = Vec::new();
142 for period in 2..=max_period {
143 if let Some(strength) = self.calculate_seasonal_strength(series, period) {
144 strengths.push((period, strength));
145 }
146 }
147
148 let Some((_, max_strength)) = strengths
149 .iter()
150 .cloned()
151 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
152 else {
153 return Ok(SeasonalityResult {
154 period: 0,
155 strength: 0.0,
156 });
157 };
158
159 let tolerance = 0.02;
162 let mut candidates: Vec<(usize, f64)> = strengths
163 .into_iter()
164 .filter(|(_, s)| *s >= max_strength - tolerance)
165 .collect();
166 candidates.sort_by_key(|a| a.0);
167
168 let (best_period, best_strength) = if let Some((p, s)) = candidates
169 .iter()
170 .find(|(p, _)| (3..=5).contains(p))
171 .copied()
172 {
173 (p, s)
174 } else {
175 candidates[0]
177 };
178
179 Ok(SeasonalityResult {
180 period: if best_strength > 0.1 { best_period } else { 0 },
181 strength: best_strength,
182 })
183 }
184
185 fn calculate_seasonal_strength(&self, series: &[f64], period: usize) -> Option<f64> {
187 if series.len() < period * 2 {
188 return None;
189 }
190
191 let mut seasonal_means = vec![0.0f64; period];
192 let mut counts = vec![0usize; period];
193
194 for (i, &value) in series.iter().enumerate() {
195 seasonal_means[i % period] += value;
196 counts[i % period] += 1;
197 }
198
199 for (i, &count) in counts.iter().enumerate() {
200 if count > 0 {
201 seasonal_means[i] /= count as f64;
202 }
203 }
204
205 let overall_mean: f64 = series.iter().sum::<f64>() / series.len() as f64;
206 let variance: f64 = series
207 .iter()
208 .map(|&x| (x - overall_mean).powi(2))
209 .sum::<f64>()
210 / series.len() as f64;
211
212 let seasonal_variance: f64 = seasonal_means
213 .iter()
214 .enumerate()
215 .map(|(i, &mean)| {
216 let count = counts[i] as f64;
217 count * (mean - overall_mean).powi(2)
218 })
219 .sum::<f64>()
220 / series.len() as f64;
221
222 if variance > 0.0 {
223 Some((seasonal_variance / variance).sqrt())
224 } else {
225 Some(0.0)
226 }
227 }
228
229 fn select_and_fit_ets_model(&self, series: &[f64], period: usize) -> Result<ETSForecastResult> {
231 if series.len() < 2 {
232 return Err(anyhow!("ETS requires at least 2 observations"));
233 }
234 let models_to_try = self.generate_model_combinations(period);
235 let mut best_result = None;
236 let mut best_aic = f64::INFINITY;
237
238 for model_spec in models_to_try {
239 if let Ok(result) = self.fit_ets_model(series, &model_spec) {
240 if result.aic < best_aic {
241 best_aic = result.aic;
242 best_result = Some(result);
243 }
244 }
245 }
246
247 best_result.ok_or_else(|| anyhow!("Failed to fit any ETS model"))
248 }
249}