use crate::learner::StreamingLearner;
#[derive(Debug, Clone, Copy, PartialEq)]
#[non_exhaustive]
pub enum Seasonality {
Additive,
Multiplicative,
}
#[derive(Debug, Clone)]
pub struct HoltWintersConfig {
pub alpha: f64,
pub beta: f64,
pub gamma: f64,
pub period: usize,
pub seasonality: Seasonality,
}
impl HoltWintersConfig {
pub fn builder(period: usize) -> HoltWintersConfigBuilder {
HoltWintersConfigBuilder {
alpha: 0.3,
beta: 0.1,
gamma: 0.1,
period,
seasonality: Seasonality::Additive,
}
}
}
#[derive(Debug, Clone)]
pub struct HoltWintersConfigBuilder {
alpha: f64,
beta: f64,
gamma: f64,
period: usize,
seasonality: Seasonality,
}
impl HoltWintersConfigBuilder {
pub fn alpha(mut self, alpha: f64) -> Self {
self.alpha = alpha;
self
}
pub fn beta(mut self, beta: f64) -> Self {
self.beta = beta;
self
}
pub fn gamma(mut self, gamma: f64) -> Self {
self.gamma = gamma;
self
}
pub fn seasonality(mut self, seasonality: Seasonality) -> Self {
self.seasonality = seasonality;
self
}
pub fn build(self) -> Result<HoltWintersConfig, irithyll_core::error::ConfigError> {
use irithyll_core::error::ConfigError;
if self.alpha <= 0.0 || self.alpha >= 1.0 {
return Err(ConfigError::out_of_range(
"alpha",
"must be in (0, 1)",
self.alpha,
));
}
if self.beta <= 0.0 || self.beta >= 1.0 {
return Err(ConfigError::out_of_range(
"beta",
"must be in (0, 1)",
self.beta,
));
}
if self.gamma <= 0.0 || self.gamma >= 1.0 {
return Err(ConfigError::out_of_range(
"gamma",
"must be in (0, 1)",
self.gamma,
));
}
if self.period < 2 {
return Err(ConfigError::out_of_range(
"period",
"must be >= 2",
self.period,
));
}
Ok(HoltWintersConfig {
alpha: self.alpha,
beta: self.beta,
gamma: self.gamma,
period: self.period,
seasonality: self.seasonality,
})
}
}
#[derive(Debug, Clone)]
pub struct HoltWinters {
config: HoltWintersConfig,
level: f64,
trend: f64,
seasonal: Vec<f64>,
season_idx: usize,
n_samples: u64,
initialized: bool,
init_buffer: Vec<f64>,
}
impl HoltWinters {
pub fn new(config: HoltWintersConfig) -> Self {
let period = config.period;
let init_seasonal = match config.seasonality {
Seasonality::Additive => vec![0.0; period],
Seasonality::Multiplicative => vec![1.0; period],
};
Self {
config,
level: 0.0,
trend: 0.0,
seasonal: init_seasonal,
season_idx: 0,
n_samples: 0,
initialized: false,
init_buffer: Vec::with_capacity(period),
}
}
pub fn train_one(&mut self, y: f64) {
self.n_samples += 1;
if !self.initialized {
self.init_buffer.push(y);
if self.init_buffer.len() == self.config.period {
self.initialize();
}
return;
}
self.update(y);
}
pub fn predict_one(&self) -> f64 {
if !self.initialized {
return 0.0;
}
self.forecast_step(1)
}
pub fn forecast(&self, horizon: usize) -> Vec<f64> {
if !self.initialized || horizon == 0 {
return vec![0.0; horizon];
}
(1..=horizon).map(|h| self.forecast_step(h)).collect()
}
pub fn level(&self) -> f64 {
self.level
}
pub fn trend(&self) -> f64 {
self.trend
}
pub fn seasonal_factors(&self) -> &[f64] {
&self.seasonal
}
pub fn is_initialized(&self) -> bool {
self.initialized
}
pub fn n_samples_seen(&self) -> u64 {
self.n_samples
}
pub fn reset(&mut self) {
let period = self.config.period;
self.level = 0.0;
self.trend = 0.0;
self.seasonal = match self.config.seasonality {
Seasonality::Additive => vec![0.0; period],
Seasonality::Multiplicative => vec![1.0; period],
};
self.season_idx = 0;
self.n_samples = 0;
self.initialized = false;
self.init_buffer.clear();
}
fn initialize(&mut self) {
let m = self.config.period;
let buf = &self.init_buffer;
let mean: f64 = buf.iter().sum::<f64>() / m as f64;
self.level = mean;
self.trend = 0.0;
match self.config.seasonality {
Seasonality::Additive => {
for (i, &b) in buf.iter().enumerate().take(m) {
self.seasonal[i] = b - mean;
}
}
Seasonality::Multiplicative => {
for (i, &b) in buf.iter().enumerate().take(m) {
if mean.abs() < f64::EPSILON {
self.seasonal[i] = 1.0;
} else {
self.seasonal[i] = b / mean;
}
}
}
}
self.initialized = true;
self.season_idx = 0;
let replay: Vec<f64> = buf.clone();
for &y in &replay {
self.update(y);
}
}
fn update(&mut self, y: f64) {
let m = self.config.period;
let alpha = self.config.alpha;
let beta = self.config.beta;
let gamma = self.config.gamma;
let prev_level = self.level;
let prev_trend = self.trend;
let prev_seasonal = self.seasonal[self.season_idx];
match self.config.seasonality {
Seasonality::Additive => {
self.level =
alpha * (y - prev_seasonal) + (1.0 - alpha) * (prev_level + prev_trend);
self.trend = beta * (self.level - prev_level) + (1.0 - beta) * prev_trend;
self.seasonal[self.season_idx] =
gamma * (y - self.level) + (1.0 - gamma) * prev_seasonal;
}
Seasonality::Multiplicative => {
let safe_seasonal = if prev_seasonal.abs() < f64::EPSILON {
1.0
} else {
prev_seasonal
};
self.level =
alpha * (y / safe_seasonal) + (1.0 - alpha) * (prev_level + prev_trend);
self.trend = beta * (self.level - prev_level) + (1.0 - beta) * prev_trend;
let safe_level = if self.level.abs() < f64::EPSILON {
1.0
} else {
self.level
};
self.seasonal[self.season_idx] =
gamma * (y / safe_level) + (1.0 - gamma) * prev_seasonal;
}
}
self.season_idx = (self.season_idx + 1) % m;
}
fn forecast_step(&self, h: usize) -> f64 {
let m = self.config.period;
let idx = (self.season_idx + (h - 1) % m) % m;
match self.config.seasonality {
Seasonality::Additive => self.level + (h as f64) * self.trend + self.seasonal[idx],
Seasonality::Multiplicative => {
(self.level + (h as f64) * self.trend) * self.seasonal[idx]
}
}
}
}
impl StreamingLearner for HoltWinters {
fn train_one(&mut self, _features: &[f64], target: f64, _weight: f64) {
HoltWinters::train_one(self, target);
}
fn predict(&self, _features: &[f64]) -> f64 {
self.predict_one()
}
fn n_samples_seen(&self) -> u64 {
self.n_samples
}
fn reset(&mut self) {
HoltWinters::reset(self);
}
}
impl crate::automl::DiagnosticSource for HoltWinters {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
const EPS: f64 = 1e-6;
fn default_config(period: usize) -> HoltWintersConfig {
HoltWintersConfig::builder(period)
.alpha(0.3)
.beta(0.1)
.gamma(0.1)
.build()
.unwrap()
}
#[test]
fn constant_series_converges() {
let mut hw = HoltWinters::new(default_config(4));
let val = 42.0;
for _ in 0..100 {
hw.train_one(val);
}
assert!(
hw.is_initialized(),
"should be initialized after 100 samples"
);
assert!(
(hw.level() - val).abs() < 1.0,
"level should converge to {}, got {}",
val,
hw.level()
);
assert!(
hw.trend().abs() < 1.0,
"trend should converge to 0, got {}",
hw.trend()
);
}
#[test]
fn linear_trend_captured() {
let mut hw = HoltWinters::new(default_config(4));
for t in 0..200 {
hw.train_one(2.0 * t as f64);
}
assert!(hw.is_initialized());
assert!(
hw.trend() > 0.0,
"trend should be positive for increasing series, got {}",
hw.trend()
);
}
#[test]
fn additive_seasonal_captured() {
let period = 12;
let config = HoltWintersConfig::builder(period)
.alpha(0.3)
.beta(0.1)
.gamma(0.3)
.build()
.unwrap();
let mut hw = HoltWinters::new(config);
for t in 0..120 {
let y = 100.0 + 10.0 * (2.0 * PI * t as f64 / period as f64).sin();
hw.train_one(y);
}
assert!(hw.is_initialized());
let factors = hw.seasonal_factors();
let has_nonzero = factors.iter().any(|s| s.abs() > EPS);
assert!(
has_nonzero,
"additive seasonal factors should be nonzero, got {:?}",
factors
);
}
#[test]
fn multiplicative_seasonal_captured() {
let period = 12;
let config = HoltWintersConfig::builder(period)
.alpha(0.3)
.beta(0.1)
.gamma(0.3)
.seasonality(Seasonality::Multiplicative)
.build()
.unwrap();
let mut hw = HoltWinters::new(config);
for t in 0..120 {
let y = 100.0 * (1.0 + 0.1 * (2.0 * PI * t as f64 / period as f64).sin());
hw.train_one(y);
}
assert!(hw.is_initialized());
let factors = hw.seasonal_factors();
let has_deviation = factors.iter().any(|s| (s - 1.0).abs() > EPS);
assert!(
has_deviation,
"multiplicative seasonal factors should deviate from 1.0, got {:?}",
factors
);
}
#[test]
fn forecast_returns_correct_length() {
let mut hw = HoltWinters::new(default_config(4));
let f0 = hw.forecast(5);
assert_eq!(f0.len(), 5, "forecast length should match horizon");
for t in 0..20 {
hw.train_one(100.0 + (t % 4) as f64 * 10.0);
}
let f1 = hw.forecast(10);
assert_eq!(f1.len(), 10, "forecast length should match horizon");
let f_empty = hw.forecast(0);
assert_eq!(f_empty.len(), 0, "forecast(0) should return empty vec");
}
#[test]
fn forecast_uses_seasonal() {
let period = 4;
let config = HoltWintersConfig::builder(period)
.alpha(0.3)
.beta(0.01)
.gamma(0.3)
.build()
.unwrap();
let mut hw = HoltWinters::new(config);
let pattern = [10.0, 20.0, 30.0, 15.0];
for cycle in 0..50 {
for &v in &pattern {
hw.train_one(100.0 + v + cycle as f64 * 0.1);
}
}
let fc = hw.forecast(period);
assert_eq!(fc.len(), period);
let all_same = fc.windows(2).all(|w| (w[0] - w[1]).abs() < EPS);
assert!(!all_same, "forecast should show periodicity, got {:?}", fc);
}
#[test]
fn initialization_buffers_first_period() {
let period = 7;
let mut hw = HoltWinters::new(default_config(period));
for t in 0..period - 1 {
hw.train_one(t as f64);
assert!(
!hw.is_initialized(),
"should not be initialized after {} samples",
t + 1
);
}
hw.train_one((period - 1) as f64);
assert!(
hw.is_initialized(),
"should be initialized after {} samples",
period
);
}
#[test]
fn streaming_learner_trait() {
let config = default_config(4);
let mut hw = HoltWinters::new(config);
let learner: &mut dyn StreamingLearner = &mut hw;
for t in 0..20 {
learner.train_one(&[], 100.0 + (t % 4) as f64 * 10.0, 1.0);
}
assert_eq!(learner.n_samples_seen(), 20);
let pred = learner.predict(&[]);
assert!(
pred.is_finite(),
"prediction should be finite, got {}",
pred
);
assert!(
pred > 0.0,
"prediction should be positive for positive series, got {}",
pred
);
learner.reset();
assert_eq!(learner.n_samples_seen(), 0);
}
#[test]
fn reset_clears_state() {
let mut hw = HoltWinters::new(default_config(4));
for t in 0..20 {
hw.train_one(50.0 + t as f64);
}
assert!(hw.is_initialized());
assert!(hw.n_samples_seen() > 0);
hw.reset();
assert!(
!hw.is_initialized(),
"should not be initialized after reset"
);
assert_eq!(hw.n_samples_seen(), 0, "n_samples should be 0 after reset");
assert_eq!(hw.level(), 0.0, "level should be 0 after reset");
assert_eq!(hw.trend(), 0.0, "trend should be 0 after reset");
for t in 0..10 {
hw.train_one(t as f64 * 5.0);
}
assert!(hw.is_initialized());
}
#[test]
fn config_validates() {
let ok = HoltWintersConfig::builder(4)
.alpha(0.5)
.beta(0.5)
.gamma(0.5)
.build();
assert!(ok.is_ok(), "valid config should succeed");
let err = HoltWintersConfig::builder(4).alpha(0.0).build();
assert!(err.is_err(), "alpha=0 should fail");
let err = HoltWintersConfig::builder(4).alpha(1.0).build();
assert!(err.is_err(), "alpha=1 should fail");
let err = HoltWintersConfig::builder(4).alpha(-0.1).build();
assert!(err.is_err(), "alpha<0 should fail");
let err = HoltWintersConfig::builder(4).alpha(1.5).build();
assert!(err.is_err(), "alpha>1 should fail");
let err = HoltWintersConfig::builder(4).beta(0.0).build();
assert!(err.is_err(), "beta=0 should fail");
let err = HoltWintersConfig::builder(4).gamma(0.0).build();
assert!(err.is_err(), "gamma=0 should fail");
let err = HoltWintersConfig::builder(1).build();
assert!(err.is_err(), "period=1 should fail");
let err = HoltWintersConfig::builder(0).build();
assert!(err.is_err(), "period=0 should fail");
}
}