#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Seasonality {
Additive,
Multiplicative,
}
#[derive(Debug, Clone)]
pub struct HoltWintersResult {
pub level: Vec<f64>,
pub trend: Vec<f64>,
pub seasonal: Vec<f64>,
pub fitted: Vec<f64>,
}
impl HoltWintersResult {
pub fn forecast(&self, h: usize, period: usize, seasonality: Seasonality) -> f64 {
let last_l = *self.level.last().expect("level must be non-empty");
let last_t = *self.trend.last().expect("trend must be non-empty");
let s_len = self.seasonal.len();
let idx = s_len - period + ((h - 1) % period);
let s = self.seasonal[idx];
match seasonality {
Seasonality::Additive => last_l + h as f64 * last_t + s,
Seasonality::Multiplicative => (last_l + h as f64 * last_t) * s,
}
}
}
pub struct HoltWinters {
alpha: f64,
beta: f64,
gamma: f64,
period: usize,
seasonality: Seasonality,
}
impl HoltWinters {
pub fn new(
alpha: f64,
beta: f64,
gamma: f64,
period: usize,
seasonality: Seasonality,
) -> Option<Self> {
if !alpha.is_finite() || alpha <= 0.0 || alpha >= 1.0 {
return None;
}
if !beta.is_finite() || beta <= 0.0 || beta >= 1.0 {
return None;
}
if !gamma.is_finite() || gamma <= 0.0 || gamma >= 1.0 {
return None;
}
if period < 2 {
return None;
}
Some(Self {
alpha,
beta,
gamma,
period,
seasonality,
})
}
pub fn period(&self) -> usize {
self.period
}
pub fn seasonality(&self) -> Seasonality {
self.seasonality
}
pub fn smooth(&self, data: &[f64]) -> Option<HoltWintersResult> {
let m = self.period;
let n = data.len();
if n < 2 * m {
return None;
}
if self.seasonality == Seasonality::Multiplicative && data.iter().any(|&x| x <= 0.0) {
return None;
}
let l0: f64 = data[..m].iter().sum::<f64>() / m as f64;
let t0: f64 = (0..m)
.map(|i| (data[m + i] - data[i]) / m as f64)
.sum::<f64>()
/ m as f64;
let mut seasonal = vec![0.0; n];
match self.seasonality {
Seasonality::Additive => {
for i in 0..m {
seasonal[i] = data[i] - l0;
}
}
Seasonality::Multiplicative => {
for i in 0..m {
seasonal[i] = data[i] / l0;
}
}
}
let mut level = vec![0.0; n];
let mut trend = vec![0.0; n];
let mut fitted = vec![0.0; n];
for i in 0..m {
level[i] = l0;
trend[i] = t0;
fitted[i] = match self.seasonality {
Seasonality::Additive => l0 + seasonal[i],
Seasonality::Multiplicative => l0 * seasonal[i],
};
}
for t in m..n {
let s_prev = seasonal[t - m];
let l = match self.seasonality {
Seasonality::Additive => {
self.alpha * (data[t] - s_prev)
+ (1.0 - self.alpha) * (level[t - 1] + trend[t - 1])
}
Seasonality::Multiplicative => {
self.alpha * (data[t] / s_prev)
+ (1.0 - self.alpha) * (level[t - 1] + trend[t - 1])
}
};
let b = self.beta * (l - level[t - 1]) + (1.0 - self.beta) * trend[t - 1];
let s = match self.seasonality {
Seasonality::Additive => self.gamma * (data[t] - l) + (1.0 - self.gamma) * s_prev,
Seasonality::Multiplicative => {
self.gamma * (data[t] / l) + (1.0 - self.gamma) * s_prev
}
};
level[t] = l;
trend[t] = b;
seasonal[t] = s;
fitted[t] = match self.seasonality {
Seasonality::Additive => level[t - 1] + trend[t - 1] + s_prev,
Seasonality::Multiplicative => (level[t - 1] + trend[t - 1]) * s_prev,
};
}
Some(HoltWintersResult {
level,
trend,
seasonal,
fitted,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn seasonal_additive_data() -> Vec<f64> {
let pattern = [10.0, -5.0, -5.0, 0.0];
(0..24)
.map(|t| 100.0 + 2.0 * t as f64 + pattern[t % 4])
.collect()
}
fn seasonal_multiplicative_data() -> Vec<f64> {
let pattern = [1.2, 0.8, 0.9, 1.1];
(0..24)
.map(|t| (100.0 + 2.0 * t as f64) * pattern[t % 4])
.collect()
}
#[test]
fn test_hw_additive_basic() {
let data = seasonal_additive_data();
let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
let result = hw.smooth(&data).unwrap();
assert_eq!(result.level.len(), 24);
assert_eq!(result.trend.len(), 24);
assert_eq!(result.seasonal.len(), 24);
assert_eq!(result.fitted.len(), 24);
}
#[test]
fn test_hw_additive_forecast() {
let data = seasonal_additive_data();
let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
let result = hw.smooth(&data).unwrap();
let f1 = result.forecast(1, 4, Seasonality::Additive);
let f4 = result.forecast(4, 4, Seasonality::Additive);
assert!(f1 > 100.0, "forecast(1) = {f1}");
assert!(f4 > f1 - 20.0, "forecast(4) = {f4}");
}
#[test]
fn test_hw_multiplicative_basic() {
let data = seasonal_multiplicative_data();
let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Multiplicative).unwrap();
let result = hw.smooth(&data).unwrap();
assert_eq!(result.level.len(), 24);
assert_eq!(result.fitted.len(), 24);
}
#[test]
fn test_hw_fitted_approximates_data() {
let data = seasonal_additive_data();
let hw = HoltWinters::new(0.5, 0.3, 0.5, 4, Seasonality::Additive).unwrap();
let result = hw.smooth(&data).unwrap();
let mape: f64 = (8..24)
.map(|i| ((result.fitted[i] - data[i]) / data[i]).abs())
.sum::<f64>()
/ 16.0;
assert!(
mape < 0.10,
"mean absolute percentage error = {mape}, expected < 10%"
);
}
#[test]
fn test_hw_seasonal_pattern_detected() {
let data = seasonal_additive_data();
let hw = HoltWinters::new(0.3, 0.1, 0.5, 4, Seasonality::Additive).unwrap();
let result = hw.smooth(&data).unwrap();
let last_cycle: Vec<f64> = (20..24).map(|i| result.seasonal[i]).collect();
let max_idx = last_cycle
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.unwrap()
.0;
assert_eq!(max_idx, 0, "highest seasonal at wrong position");
}
#[test]
fn test_hw_insufficient_data() {
let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
assert!(hw.smooth(&[1.0; 7]).is_none());
assert!(hw.smooth(&[1.0; 8]).is_some());
}
#[test]
fn test_hw_multiplicative_rejects_negative() {
let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Multiplicative).unwrap();
let data = vec![1.0, 2.0, -1.0, 4.0, 5.0, 6.0, 7.0, 8.0];
assert!(hw.smooth(&data).is_none());
}
#[test]
fn test_hw_invalid_params() {
assert!(HoltWinters::new(0.0, 0.5, 0.5, 4, Seasonality::Additive).is_none());
assert!(HoltWinters::new(0.5, 1.0, 0.5, 4, Seasonality::Additive).is_none());
assert!(HoltWinters::new(0.5, 0.5, 0.0, 4, Seasonality::Additive).is_none());
assert!(HoltWinters::new(0.5, 0.5, 0.5, 1, Seasonality::Additive).is_none());
}
#[test]
fn test_hw_trend_detected() {
let data = seasonal_additive_data();
let hw = HoltWinters::new(0.3, 0.3, 0.3, 4, Seasonality::Additive).unwrap();
let result = hw.smooth(&data).unwrap();
let last_trend = result.trend[23];
assert!(
last_trend > 1.0 && last_trend < 4.0,
"trend = {last_trend}, expected ~2.0"
);
}
#[test]
fn test_hw_level_tracks_mean() {
let data = seasonal_additive_data();
let hw = HoltWinters::new(0.3, 0.1, 0.3, 4, Seasonality::Additive).unwrap();
let result = hw.smooth(&data).unwrap();
let last_level = result.level[23];
assert!(
(last_level - 146.0).abs() < 10.0,
"level = {last_level}, expected ~146"
);
}
}