use anyhow::Result;
use tracing::debug;
use super::ets_types::{ETSErrorType, ETSModel, ETSSeasonalType, ETSState, ETSTrendType};
impl super::engine::ForecastingEngine {
pub(super) fn forecast_ets(
&self,
model: &ETSModel,
state: &ETSState,
horizon: usize,
) -> Result<Vec<f64>> {
let mut forecasts = Vec::with_capacity(horizon);
for h in 1..=horizon {
let trend_component = match model.trend {
ETSTrendType::None => 0.0,
ETSTrendType::Additive => state.trend * h as f64,
ETSTrendType::AdditiveDamped => {
let mut phi_sum = 0.0;
for i in 1..=h {
phi_sum += model.phi.powi(i as i32);
}
state.trend * phi_sum
}
};
let seasonal_component =
if !state.seasonal.is_empty() && !matches!(model.seasonal, ETSSeasonalType::None) {
let seasonal_index = (state.n_obs + h - 1) % state.seasonal.len();
state.seasonal[seasonal_index]
} else {
0.0
};
let forecast = match (model.error, model.seasonal) {
(ETSErrorType::Additive, ETSSeasonalType::Additive) => {
state.level + trend_component + seasonal_component
}
(ETSErrorType::Additive, ETSSeasonalType::Multiplicative) => {
(state.level + trend_component) * (1.0 + seasonal_component)
}
(ETSErrorType::Multiplicative, ETSSeasonalType::Additive) => {
(state.level + trend_component) + seasonal_component
}
(ETSErrorType::Multiplicative, ETSSeasonalType::Multiplicative) => {
(state.level + trend_component) * (1.0 + seasonal_component)
}
_ => state.level + trend_component,
};
forecasts.push(forecast);
}
debug!("Generated {} forecasts", forecasts.len());
Ok(forecasts)
}
pub(super) fn calculate_confidence_intervals(
&self,
_model: &ETSModel,
forecasts: &[f64],
_state: &ETSState,
_confidence: f64,
) -> (Vec<f64>, Vec<f64>) {
let residual_std = self.estimate_residual_variance(_state).sqrt();
let z_score = 1.96;
let lower_bounds: Vec<f64> = forecasts
.iter()
.enumerate()
.map(|(i, &f)| f - z_score * residual_std * ((i + 1) as f64).sqrt())
.collect();
let upper_bounds: Vec<f64> = forecasts
.iter()
.enumerate()
.map(|(i, &f)| f + z_score * residual_std * ((i + 1) as f64).sqrt())
.collect();
(lower_bounds, upper_bounds)
}
pub(super) fn estimate_residual_variance(&self, _state: &ETSState) -> f64 {
0.1
}
pub(super) fn calculate_ets_fit_quality(
&self,
_actual: &[f64],
_fitted: &[f64],
_model: &ETSModel,
) -> f64 {
let n = _actual.len() as f64;
let k = 3.0;
let sse: f64 = _actual
.iter()
.zip(_fitted.iter())
.map(|(a, f)| (a - f).powi(2))
.sum();
let aic = n * (sse / n).ln() + 2.0 * k;
-aic }
pub(super) fn calculate_model_metrics(
&self,
_actual: &[f64],
_fitted: &[f64],
) -> (f64, f64, f64) {
let n = _actual.len() as f64;
let mae: f64 = _actual
.iter()
.zip(_fitted.iter())
.map(|(a, f)| (a - f).abs())
.sum::<f64>()
/ n;
let mse: f64 = _actual
.iter()
.zip(_fitted.iter())
.map(|(a, f)| (a - f).powi(2))
.sum::<f64>()
/ n;
let rmse = mse.sqrt();
let mape: f64 = _actual
.iter()
.zip(_fitted.iter())
.filter(|(a, _)| a.abs() > 1e-10)
.map(|(a, f)| ((a - f).abs() / a.abs()) * 100.0)
.sum::<f64>()
/ n;
(mae, rmse, mape)
}
}