do_memory_mcp/patterns/predictive/forecasting/
ets_fitting.rs1use anyhow::Result;
2
3use super::ets_types::{
4 ETSErrorType, ETSForecastResult, ETSModel, ETSModelSpec, ETSSeasonalType, ETSState,
5 ETSTrendType,
6};
7
8impl super::engine::ForecastingEngine {
9 pub(super) fn generate_model_combinations(&self, period: usize) -> Vec<ETSModelSpec> {
10 let mut models = Vec::new();
11
12 let error_types = [ETSErrorType::Additive, ETSErrorType::Multiplicative];
13 let trend_types = [
14 ETSTrendType::None,
15 ETSTrendType::Additive,
16 ETSTrendType::AdditiveDamped,
17 ];
18 let seasonal_types = if period > 0 {
19 vec![
20 ETSSeasonalType::None,
21 ETSSeasonalType::Additive,
22 ETSSeasonalType::Multiplicative,
23 ]
24 } else {
25 vec![ETSSeasonalType::None]
26 };
27
28 for error in &error_types {
29 for trend in &trend_types {
30 for seasonal in &seasonal_types {
31 models.push(ETSModelSpec {
32 error: *error,
33 trend: *trend,
34 seasonal: *seasonal,
35 });
36 }
37 }
38 }
39
40 models
41 }
42
43 pub(super) fn fit_ets_model(
45 &self,
46 series: &[f64],
47 model_spec: &ETSModelSpec,
48 ) -> Result<ETSForecastResult> {
49 let mut model = self.initialize_parameters(series, model_spec)?;
51 let mut state = self.initialize_state(series, &model)?;
52
53 model = self.optimize_parameters(series, model_spec, &state)?;
55
56 state = self.refit_with_parameters(series, &model)?;
58
59 let fitted: Vec<f64> = (0..series.len())
61 .map(|i| {
62 let obs_state = ETSState {
63 level: state.level,
64 trend: state.trend,
65 seasonal: state.seasonal.clone(),
66 last_observation: if i > 0 { series[i - 1] } else { series[0] },
67 n_obs: i,
68 };
69 self.calculate_fitted_value(&obs_state, &model)
70 })
71 .collect();
72
73 let (_, rmse, _mape) = self.calculate_model_metrics(series, &fitted);
75 let fit_quality = self.calculate_ets_fit_quality(series, &fitted, &model);
76
77 let log_likelihood = -rmse * series.len() as f64; let aic = series.len() as f64 * (rmse.ln() + 1.0) + 6.0; Ok(ETSForecastResult {
82 model,
83 state,
84 forecasts: Vec::new(), lower_bounds: Vec::new(),
86 upper_bounds: Vec::new(),
87 fit_quality,
88 aic,
89 log_likelihood,
90 })
91 }
92
93 fn initialize_parameters(&self, series: &[f64], model_spec: &ETSModelSpec) -> Result<ETSModel> {
95 let n = series.len();
96
97 let alpha = 0.2;
99 let beta = if matches!(model_spec.trend, ETSTrendType::None) {
100 0.0
101 } else {
102 0.1
103 };
104 let gamma = if matches!(model_spec.seasonal, ETSSeasonalType::None) {
105 0.0
106 } else {
107 0.1
108 };
109 let phi = 0.98;
110
111 let initial_level = series[0];
113 let initial_trend = if n > 1 {
114 (series[n - 1] - series[0]) / (n - 1) as f64
115 } else {
116 0.0
117 };
118
119 let mut initial_seasonal = Vec::new();
121 if !matches!(model_spec.seasonal, ETSSeasonalType::None) {
122 let period = self.estimate_period(series);
123 if period > 0 {
124 for i in 0..period {
125 let indices: Vec<usize> = (i..n).step_by(period).collect();
126 if !indices.is_empty() {
127 let seasonal_mean: f64 =
128 indices.iter().map(|&idx| series[idx]).sum::<f64>()
129 / indices.len() as f64;
130 initial_seasonal.push(seasonal_mean - initial_level);
131 } else {
132 initial_seasonal.push(0.0);
133 }
134 }
135 } else {
136 initial_seasonal = vec![0.0];
137 }
138 }
139
140 Ok(ETSModel {
141 error: model_spec.error,
142 trend: model_spec.trend,
143 seasonal: model_spec.seasonal,
144 alpha,
145 beta,
146 gamma,
147 phi,
148 initial_level,
149 initial_trend,
150 initial_seasonal,
151 })
152 }
153
154 fn initialize_state(&self, series: &[f64], model: &ETSModel) -> Result<ETSState> {
156 let n = series.len();
157 let level = model.initial_level;
158 let trend = model.initial_trend;
159
160 let mut seasonal = model.initial_seasonal.clone();
161 if seasonal.is_empty() {
162 seasonal = vec![0.0];
163 }
164
165 Ok(ETSState {
166 level,
167 trend,
168 seasonal,
169 last_observation: if n > 0 { series[n - 1] } else { 0.0 },
170 n_obs: n,
171 })
172 }
173
174 fn optimize_parameters(
176 &self,
177 series: &[f64],
178 model_spec: &ETSModelSpec,
179 _initial_state: &ETSState,
180 ) -> Result<ETSModel> {
181 let mut best_model = self.initialize_parameters(series, model_spec)?;
183 let mut best_log_likelihood = f64::NEG_INFINITY;
184
185 let alpha_values = [0.1, 0.2, 0.3, 0.5, 0.7, 0.9];
187 let beta_values = if matches!(model_spec.trend, ETSTrendType::None) {
188 vec![0.0]
189 } else {
190 vec![0.1, 0.2, 0.3, 0.5]
192 };
193 let gamma_values = if matches!(model_spec.seasonal, ETSSeasonalType::None) {
194 vec![0.0]
195 } else {
196 vec![0.0, 0.1, 0.2, 0.3, 0.5]
197 };
198
199 for &alpha in &alpha_values {
200 for &beta in &beta_values {
201 for &gamma in &gamma_values {
202 let mut test_model = best_model.clone();
203 test_model.alpha = alpha;
204 test_model.beta = beta;
205 test_model.gamma = gamma;
206
207 if let Ok(test_state) = self.refit_with_parameters(series, &test_model) {
208 let fitted: Vec<f64> = (0..series.len())
210 .map(|i| {
211 let obs_state = ETSState {
212 level: test_state.level,
213 trend: test_state.trend,
214 seasonal: test_state.seasonal.clone(),
215 last_observation: if i > 0 { series[i - 1] } else { series[0] },
216 n_obs: i,
217 };
218 self.calculate_fitted_value(&obs_state, &test_model)
219 })
220 .collect();
221 let (_, rmse, _) = self.calculate_model_metrics(series, &fitted);
222 let log_likelihood = -rmse * series.len() as f64;
223
224 if log_likelihood > best_log_likelihood {
225 best_log_likelihood = log_likelihood;
226 best_model = test_model;
227 }
228 }
229 }
230 }
231 }
232
233 Ok(best_model)
234 }
235
236 fn refit_with_parameters(&self, series: &[f64], model: &ETSModel) -> Result<ETSState> {
238 let mut state = self.initialize_state(series, model)?;
239
240 for &observation in series.iter().skip(1) {
241 state = self.update_ets_state(&state, observation, model)?;
242 }
243
244 Ok(state)
245 }
246
247 fn update_ets_state(
249 &self,
250 current_state: &ETSState,
251 new_observation: f64,
252 model: &ETSModel,
253 ) -> Result<ETSState> {
254 let mut new_state = current_state.clone();
255
256 let fitted = self.calculate_fitted_value(current_state, model);
258
259 let residual = match model.error {
261 ETSErrorType::Additive => new_observation - fitted,
262 ETSErrorType::Multiplicative => {
263 if fitted != 0.0 {
264 new_observation / fitted
265 } else {
266 0.0
267 }
268 }
269 };
270
271 new_state.level = model.alpha * residual * self.get_error_multiplier(model)
273 + (1.0 - model.alpha) * (current_state.level + current_state.trend);
274
275 new_state.trend = model.beta * (new_state.level - current_state.level)
276 + (1.0 - model.beta) * self.get_damped_trend(current_state.trend, model.phi);
277
278 if !new_state.seasonal.is_empty() && !matches!(model.seasonal, ETSSeasonalType::None) {
279 let seasonal_index = (new_state.n_obs + 1) % new_state.seasonal.len();
280 let seasonal_factor = match model.seasonal {
281 ETSSeasonalType::Additive => residual * self.get_error_multiplier(model),
282 ETSSeasonalType::Multiplicative => residual,
283 ETSSeasonalType::None => 0.0,
284 };
285
286 new_state.seasonal[seasonal_index] = model.gamma * seasonal_factor
287 + (1.0 - model.gamma) * current_state.seasonal[seasonal_index];
288 }
289
290 new_state.last_observation = new_observation;
291 new_state.n_obs += 1;
292
293 Ok(new_state)
294 }
295}
296
297impl super::engine::ForecastingEngine {
298 pub(super) fn estimate_period(&self, series: &[f64]) -> usize {
299 if series.len() < 20 {
300 return 0;
301 }
302
303 let max_period = series.len() / 4;
304 let mut best_period = 0;
305 let mut best_acf = 0.0;
306
307 for period in 2..=max_period.min(24) {
308 if let Some(acf) = self.calculate_autocorrelation(series, period) {
309 if acf.abs() > best_acf {
310 best_acf = acf.abs();
311 best_period = period;
312 }
313 }
314 }
315 best_period
316 }
317
318 pub(super) fn calculate_autocorrelation(&self, series: &[f64], lag: usize) -> Option<f64> {
319 if series.len() <= lag {
320 return None;
321 }
322
323 let n = series.len() - lag;
324 let mean: f64 = series.iter().sum::<f64>() / series.len() as f64;
325
326 let mut numerator = 0.0;
327 let mut denominator = 0.0;
328
329 for i in 0..n {
330 numerator += (series[i] - mean) * (series[i + lag] - mean);
331 denominator += (series[i] - mean).powi(2);
332 }
333
334 if denominator > 0.0 {
335 Some(numerator / denominator)
336 } else {
337 Some(0.0)
338 }
339 }
340
341 pub(super) fn get_error_multiplier(&self, model: &ETSModel) -> f64 {
342 match model.error {
343 ETSErrorType::Additive => 1.0,
344 ETSErrorType::Multiplicative => model.alpha,
345 }
346 }
347
348 pub(super) fn get_damped_trend(&self, trend: f64, phi: f64) -> f64 {
349 trend * phi
350 }
351
352 pub(super) fn calculate_fitted_value(&self, state: &ETSState, model: &ETSModel) -> f64 {
353 let trend_component = match model.trend {
354 ETSTrendType::None => 0.0,
355 ETSTrendType::Additive => state.trend,
356 ETSTrendType::AdditiveDamped => state.trend * model.phi,
357 };
358
359 let seasonal_component =
360 if !state.seasonal.is_empty() && !matches!(model.seasonal, ETSSeasonalType::None) {
361 let seasonal_index = state.n_obs % state.seasonal.len();
362 match model.seasonal {
363 ETSSeasonalType::Additive => state.seasonal[seasonal_index],
364 ETSSeasonalType::Multiplicative => 1.0 + state.seasonal[seasonal_index],
365 ETSSeasonalType::None => 0.0,
366 }
367 } else {
368 1.0
369 };
370
371 match (model.error, model.seasonal) {
372 (ETSErrorType::Additive, ETSSeasonalType::Additive) => {
373 state.level + trend_component + seasonal_component
374 }
375 (ETSErrorType::Additive, ETSSeasonalType::Multiplicative) => {
376 (state.level + trend_component) * seasonal_component
377 }
378 (ETSErrorType::Multiplicative, ETSSeasonalType::Additive) => {
379 (state.level + trend_component) + seasonal_component
380 }
381 (ETSErrorType::Multiplicative, ETSSeasonalType::Multiplicative) => {
382 (state.level + trend_component) * seasonal_component
383 }
384 _ => state.level + trend_component,
385 }
386 }
387}