Skip to main content

do_memory_mcp/patterns/predictive/forecasting/
ets_fitting.rs

1use anyhow::Result;
2
3use super::ets_types::{
4    ETSErrorType, ETSForecastResult, ETSModel, ETSModelSpec, ETSSeasonalType, ETSState,
5    ETSTrendType,
6};
7
8impl super::engine::ForecastingEngine {
9    pub(super) fn generate_model_combinations(&self, period: usize) -> Vec<ETSModelSpec> {
10        let mut models = Vec::new();
11
12        let error_types = [ETSErrorType::Additive, ETSErrorType::Multiplicative];
13        let trend_types = [
14            ETSTrendType::None,
15            ETSTrendType::Additive,
16            ETSTrendType::AdditiveDamped,
17        ];
18        let seasonal_types = if period > 0 {
19            vec![
20                ETSSeasonalType::None,
21                ETSSeasonalType::Additive,
22                ETSSeasonalType::Multiplicative,
23            ]
24        } else {
25            vec![ETSSeasonalType::None]
26        };
27
28        for error in &error_types {
29            for trend in &trend_types {
30                for seasonal in &seasonal_types {
31                    models.push(ETSModelSpec {
32                        error: *error,
33                        trend: *trend,
34                        seasonal: *seasonal,
35                    });
36                }
37            }
38        }
39
40        models
41    }
42
43    /// Fit ETS model using Maximum Likelihood Estimation
44    pub(super) fn fit_ets_model(
45        &self,
46        series: &[f64],
47        model_spec: &ETSModelSpec,
48    ) -> Result<ETSForecastResult> {
49        // Initialize parameters
50        let mut model = self.initialize_parameters(series, model_spec)?;
51        let mut state = self.initialize_state(series, &model)?;
52
53        // Optimize parameters using MLE
54        model = self.optimize_parameters(series, model_spec, &state)?;
55
56        // Refit with optimized parameters
57        state = self.refit_with_parameters(series, &model)?;
58
59        // Calculate fitted values
60        let fitted: Vec<f64> = (0..series.len())
61            .map(|i| {
62                let obs_state = ETSState {
63                    level: state.level,
64                    trend: state.trend,
65                    seasonal: state.seasonal.clone(),
66                    last_observation: if i > 0 { series[i - 1] } else { series[0] },
67                    n_obs: i,
68                };
69                self.calculate_fitted_value(&obs_state, &model)
70            })
71            .collect();
72
73        // Calculate fit quality and model metrics
74        let (_, rmse, _mape) = self.calculate_model_metrics(series, &fitted);
75        let fit_quality = self.calculate_ets_fit_quality(series, &fitted, &model);
76
77        // Calculate simplified log-likelihood and AIC
78        let log_likelihood = -rmse * series.len() as f64; // Simplified
79        let aic = series.len() as f64 * (rmse.ln() + 1.0) + 6.0; // Simplified AIC
80
81        Ok(ETSForecastResult {
82            model,
83            state,
84            forecasts: Vec::new(), // Will be filled by caller
85            lower_bounds: Vec::new(),
86            upper_bounds: Vec::new(),
87            fit_quality,
88            aic,
89            log_likelihood,
90        })
91    }
92
93    /// Initialize ETS parameters with heuristics
94    fn initialize_parameters(&self, series: &[f64], model_spec: &ETSModelSpec) -> Result<ETSModel> {
95        let n = series.len();
96
97        // Simple heuristics for initial parameter values
98        let alpha = 0.2;
99        let beta = if matches!(model_spec.trend, ETSTrendType::None) {
100            0.0
101        } else {
102            0.1
103        };
104        let gamma = if matches!(model_spec.seasonal, ETSSeasonalType::None) {
105            0.0
106        } else {
107            0.1
108        };
109        let phi = 0.98;
110
111        // Calculate initial level and trend
112        let initial_level = series[0];
113        let initial_trend = if n > 1 {
114            (series[n - 1] - series[0]) / (n - 1) as f64
115        } else {
116            0.0
117        };
118
119        // Calculate initial seasonal components
120        let mut initial_seasonal = Vec::new();
121        if !matches!(model_spec.seasonal, ETSSeasonalType::None) {
122            let period = self.estimate_period(series);
123            if period > 0 {
124                for i in 0..period {
125                    let indices: Vec<usize> = (i..n).step_by(period).collect();
126                    if !indices.is_empty() {
127                        let seasonal_mean: f64 =
128                            indices.iter().map(|&idx| series[idx]).sum::<f64>()
129                                / indices.len() as f64;
130                        initial_seasonal.push(seasonal_mean - initial_level);
131                    } else {
132                        initial_seasonal.push(0.0);
133                    }
134                }
135            } else {
136                initial_seasonal = vec![0.0];
137            }
138        }
139
140        Ok(ETSModel {
141            error: model_spec.error,
142            trend: model_spec.trend,
143            seasonal: model_spec.seasonal,
144            alpha,
145            beta,
146            gamma,
147            phi,
148            initial_level,
149            initial_trend,
150            initial_seasonal,
151        })
152    }
153
154    /// Initialize ETS state from data
155    fn initialize_state(&self, series: &[f64], model: &ETSModel) -> Result<ETSState> {
156        let n = series.len();
157        let level = model.initial_level;
158        let trend = model.initial_trend;
159
160        let mut seasonal = model.initial_seasonal.clone();
161        if seasonal.is_empty() {
162            seasonal = vec![0.0];
163        }
164
165        Ok(ETSState {
166            level,
167            trend,
168            seasonal,
169            last_observation: if n > 0 { series[n - 1] } else { 0.0 },
170            n_obs: n,
171        })
172    }
173
174    /// Optimize ETS parameters using a simplified BFGS-like approach
175    fn optimize_parameters(
176        &self,
177        series: &[f64],
178        model_spec: &ETSModelSpec,
179        _initial_state: &ETSState,
180    ) -> Result<ETSModel> {
181        // Simplified parameter optimization - in practice, use proper optimization library
182        let mut best_model = self.initialize_parameters(series, model_spec)?;
183        let mut best_log_likelihood = f64::NEG_INFINITY;
184
185        // Grid search over reasonable parameter values
186        let alpha_values = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9];
187        let beta_values = if matches!(model_spec.trend, ETSTrendType::None) {
188            vec![0.0]
189        } else {
190            // Ensure beta stays strictly positive when trend is enabled.
191            vec![0.1, 0.2, 0.3, 0.5]
192        };
193        let gamma_values = if matches!(model_spec.seasonal, ETSSeasonalType::None) {
194            vec![0.0]
195        } else {
196            vec![0.0, 0.1, 0.2, 0.3, 0.5]
197        };
198
199        for &alpha in &alpha_values {
200            for &beta in &beta_values {
201                for &gamma in &gamma_values {
202                    let mut test_model = best_model.clone();
203                    test_model.alpha = alpha;
204                    test_model.beta = beta;
205                    test_model.gamma = gamma;
206
207                    if let Ok(test_state) = self.refit_with_parameters(series, &test_model) {
208                        // Calculate fitted values for this test model
209                        let fitted: Vec<f64> = (0..series.len())
210                            .map(|i| {
211                                let obs_state = ETSState {
212                                    level: test_state.level,
213                                    trend: test_state.trend,
214                                    seasonal: test_state.seasonal.clone(),
215                                    last_observation: if i > 0 { series[i - 1] } else { series[0] },
216                                    n_obs: i,
217                                };
218                                self.calculate_fitted_value(&obs_state, &test_model)
219                            })
220                            .collect();
221                        let (_, rmse, _) = self.calculate_model_metrics(series, &fitted);
222                        let log_likelihood = -rmse * series.len() as f64;
223
224                        if log_likelihood > best_log_likelihood {
225                            best_log_likelihood = log_likelihood;
226                            best_model = test_model;
227                        }
228                    }
229                }
230            }
231        }
232
233        Ok(best_model)
234    }
235
236    /// Refit ETS model with given parameters
237    fn refit_with_parameters(&self, series: &[f64], model: &ETSModel) -> Result<ETSState> {
238        let mut state = self.initialize_state(series, model)?;
239
240        for &observation in series.iter().skip(1) {
241            state = self.update_ets_state(&state, observation, model)?;
242        }
243
244        Ok(state)
245    }
246
247    /// Update ETS state with new observation (for incremental updates)
248    fn update_ets_state(
249        &self,
250        current_state: &ETSState,
251        new_observation: f64,
252        model: &ETSModel,
253    ) -> Result<ETSState> {
254        let mut new_state = current_state.clone();
255
256        // Calculate fitted value
257        let fitted = self.calculate_fitted_value(current_state, model);
258
259        // Calculate residual
260        let residual = match model.error {
261            ETSErrorType::Additive => new_observation - fitted,
262            ETSErrorType::Multiplicative => {
263                if fitted != 0.0 {
264                    new_observation / fitted
265                } else {
266                    0.0
267                }
268            }
269        };
270
271        // Update components
272        new_state.level = model.alpha * residual * self.get_error_multiplier(model)
273            + (1.0 - model.alpha) * (current_state.level + current_state.trend);
274
275        new_state.trend = model.beta * (new_state.level - current_state.level)
276            + (1.0 - model.beta) * self.get_damped_trend(current_state.trend, model.phi);
277
278        if !new_state.seasonal.is_empty() && !matches!(model.seasonal, ETSSeasonalType::None) {
279            let seasonal_index = (new_state.n_obs + 1) % new_state.seasonal.len();
280            let seasonal_factor = match model.seasonal {
281                ETSSeasonalType::Additive => residual * self.get_error_multiplier(model),
282                ETSSeasonalType::Multiplicative => residual,
283                ETSSeasonalType::None => 0.0,
284            };
285
286            new_state.seasonal[seasonal_index] = model.gamma * seasonal_factor
287                + (1.0 - model.gamma) * current_state.seasonal[seasonal_index];
288        }
289
290        new_state.last_observation = new_observation;
291        new_state.n_obs += 1;
292
293        Ok(new_state)
294    }
295}
296
297impl super::engine::ForecastingEngine {
298    pub(super) fn estimate_period(&self, series: &[f64]) -> usize {
299        if series.len() < 20 {
300            return 0;
301        }
302
303        let max_period = series.len() / 4;
304        let mut best_period = 0;
305        let mut best_acf = 0.0;
306
307        for period in 2..=max_period.min(24) {
308            if let Some(acf) = self.calculate_autocorrelation(series, period) {
309                if acf.abs() > best_acf {
310                    best_acf = acf.abs();
311                    best_period = period;
312                }
313            }
314        }
315        best_period
316    }
317
318    pub(super) fn calculate_autocorrelation(&self, series: &[f64], lag: usize) -> Option<f64> {
319        if series.len() <= lag {
320            return None;
321        }
322
323        let n = series.len() - lag;
324        let mean: f64 = series.iter().sum::<f64>() / series.len() as f64;
325
326        let mut numerator = 0.0;
327        let mut denominator = 0.0;
328
329        for i in 0..n {
330            numerator += (series[i] - mean) * (series[i + lag] - mean);
331            denominator += (series[i] - mean).powi(2);
332        }
333
334        if denominator > 0.0 {
335            Some(numerator / denominator)
336        } else {
337            Some(0.0)
338        }
339    }
340
341    pub(super) fn get_error_multiplier(&self, model: &ETSModel) -> f64 {
342        match model.error {
343            ETSErrorType::Additive => 1.0,
344            ETSErrorType::Multiplicative => model.alpha,
345        }
346    }
347
348    pub(super) fn get_damped_trend(&self, trend: f64, phi: f64) -> f64 {
349        trend * phi
350    }
351
352    pub(super) fn calculate_fitted_value(&self, state: &ETSState, model: &ETSModel) -> f64 {
353        let trend_component = match model.trend {
354            ETSTrendType::None => 0.0,
355            ETSTrendType::Additive => state.trend,
356            ETSTrendType::AdditiveDamped => state.trend * model.phi,
357        };
358
359        let seasonal_component =
360            if !state.seasonal.is_empty() && !matches!(model.seasonal, ETSSeasonalType::None) {
361                let seasonal_index = state.n_obs % state.seasonal.len();
362                match model.seasonal {
363                    ETSSeasonalType::Additive => state.seasonal[seasonal_index],
364                    ETSSeasonalType::Multiplicative => 1.0 + state.seasonal[seasonal_index],
365                    ETSSeasonalType::None => 0.0,
366                }
367            } else {
368                1.0
369            };
370
371        match (model.error, model.seasonal) {
372            (ETSErrorType::Additive, ETSSeasonalType::Additive) => {
373                state.level + trend_component + seasonal_component
374            }
375            (ETSErrorType::Additive, ETSSeasonalType::Multiplicative) => {
376                (state.level + trend_component) * seasonal_component
377            }
378            (ETSErrorType::Multiplicative, ETSSeasonalType::Additive) => {
379                (state.level + trend_component) + seasonal_component
380            }
381            (ETSErrorType::Multiplicative, ETSSeasonalType::Multiplicative) => {
382                (state.level + trend_component) * seasonal_component
383            }
384            _ => state.level + trend_component,
385        }
386    }
387}