Skip to main content

irithyll/time_series/
holt_winters.rs

1//! Streaming Holt-Winters triple exponential smoothing.
2//!
3//! Online implementation supporting both additive and multiplicative
4//! seasonality. State is O(m) where m is the seasonal period. No past
5//! samples are stored after initialization.
6
7use crate::learner::StreamingLearner;
8
9// ---------------------------------------------------------------------------
10// Seasonality enum
11// ---------------------------------------------------------------------------
12
13/// Type of seasonal decomposition.
14#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum Seasonality {
16    /// Seasonal effects are added to the level+trend.
17    Additive,
18    /// Seasonal effects multiply the level+trend.
19    Multiplicative,
20}
21
22// ---------------------------------------------------------------------------
23// HoltWintersConfig
24// ---------------------------------------------------------------------------
25
26/// Configuration for the Holt-Winters exponential smoothing model.
27///
28/// Use [`HoltWintersConfig::builder`] to construct with validation.
29///
30/// # Example
31///
32/// ```
33/// use irithyll::time_series::HoltWintersConfig;
34///
35/// let config = HoltWintersConfig::builder(12)
36///     .alpha(0.2)
37///     .beta(0.05)
38///     .gamma(0.15)
39///     .build()
40///     .unwrap();
41/// assert_eq!(config.period, 12);
42/// ```
43#[derive(Debug, Clone)]
44pub struct HoltWintersConfig {
45    /// Smoothing parameter for level (0 < alpha < 1).
46    pub alpha: f64,
47    /// Smoothing parameter for trend (0 < beta < 1).
48    pub beta: f64,
49    /// Smoothing parameter for seasonality (0 < gamma < 1).
50    pub gamma: f64,
51    /// Seasonal period (e.g., 12 for monthly, 7 for daily).
52    pub period: usize,
53    /// Seasonality type.
54    pub seasonality: Seasonality,
55}
56
57impl HoltWintersConfig {
58    /// Create a builder for `HoltWintersConfig` with the given seasonal period.
59    pub fn builder(period: usize) -> HoltWintersConfigBuilder {
60        HoltWintersConfigBuilder {
61            alpha: 0.3,
62            beta: 0.1,
63            gamma: 0.1,
64            period,
65            seasonality: Seasonality::Additive,
66        }
67    }
68}
69
70// ---------------------------------------------------------------------------
71// HoltWintersConfigBuilder
72// ---------------------------------------------------------------------------
73
74/// Builder for [`HoltWintersConfig`] with parameter validation.
75#[derive(Debug, Clone)]
76pub struct HoltWintersConfigBuilder {
77    alpha: f64,
78    beta: f64,
79    gamma: f64,
80    period: usize,
81    seasonality: Seasonality,
82}
83
84impl HoltWintersConfigBuilder {
85    /// Set the level smoothing parameter (default 0.3).
86    pub fn alpha(mut self, alpha: f64) -> Self {
87        self.alpha = alpha;
88        self
89    }
90
91    /// Set the trend smoothing parameter (default 0.1).
92    pub fn beta(mut self, beta: f64) -> Self {
93        self.beta = beta;
94        self
95    }
96
97    /// Set the seasonality smoothing parameter (default 0.1).
98    pub fn gamma(mut self, gamma: f64) -> Self {
99        self.gamma = gamma;
100        self
101    }
102
103    /// Set the seasonality type (default [`Seasonality::Additive`]).
104    pub fn seasonality(mut self, seasonality: Seasonality) -> Self {
105        self.seasonality = seasonality;
106        self
107    }
108
109    /// Build the configuration, validating all parameters.
110    ///
111    /// Returns `Err` if any smoothing parameter is not in (0, 1) or period < 2.
112    pub fn build(self) -> Result<HoltWintersConfig, String> {
113        if self.alpha <= 0.0 || self.alpha >= 1.0 {
114            return Err(format!("alpha must be in (0, 1), got {}", self.alpha));
115        }
116        if self.beta <= 0.0 || self.beta >= 1.0 {
117            return Err(format!("beta must be in (0, 1), got {}", self.beta));
118        }
119        if self.gamma <= 0.0 || self.gamma >= 1.0 {
120            return Err(format!("gamma must be in (0, 1), got {}", self.gamma));
121        }
122        if self.period < 2 {
123            return Err(format!("period must be >= 2, got {}", self.period));
124        }
125        Ok(HoltWintersConfig {
126            alpha: self.alpha,
127            beta: self.beta,
128            gamma: self.gamma,
129            period: self.period,
130            seasonality: self.seasonality,
131        })
132    }
133}
134
135// ---------------------------------------------------------------------------
136// HoltWinters
137// ---------------------------------------------------------------------------
138
139/// Streaming Holt-Winters triple exponential smoothing.
140///
141/// Maintains level, trend, and seasonal components updated incrementally
142/// with each observation. Supports both additive and multiplicative
143/// seasonality modes.
144///
145/// The first `period` observations are buffered for initialization. Once a
146/// full season is seen, level is set to the season mean, trend to 0, and
147/// seasonal factors are estimated from the buffered values. The buffer is
148/// then replayed through the update equations.
149///
150/// # Example
151///
152/// ```
153/// use irithyll::time_series::{HoltWinters, HoltWintersConfig};
154///
155/// let config = HoltWintersConfig::builder(4)
156///     .alpha(0.3)
157///     .beta(0.1)
158///     .gamma(0.1)
159///     .build()
160///     .unwrap();
161///
162/// let mut hw = HoltWinters::new(config);
163///
164/// // Feed a few seasons of data
165/// for t in 0..20 {
166///     let seasonal = [10.0, 20.0, 30.0, 15.0][t % 4];
167///     hw.train_one(100.0 + seasonal);
168/// }
169///
170/// // Forecast the next 4 steps
171/// let forecast = hw.forecast(4);
172/// assert_eq!(forecast.len(), 4);
173/// ```
174#[derive(Debug, Clone)]
175pub struct HoltWinters {
176    config: HoltWintersConfig,
177    level: f64,
178    trend: f64,
179    seasonal: Vec<f64>,
180    season_idx: usize,
181    n_samples: u64,
182    initialized: bool,
183    init_buffer: Vec<f64>,
184}
185
186impl HoltWinters {
187    /// Create a new Holt-Winters model from the given configuration.
188    pub fn new(config: HoltWintersConfig) -> Self {
189        let period = config.period;
190        let init_seasonal = match config.seasonality {
191            Seasonality::Additive => vec![0.0; period],
192            Seasonality::Multiplicative => vec![1.0; period],
193        };
194        Self {
195            config,
196            level: 0.0,
197            trend: 0.0,
198            seasonal: init_seasonal,
199            season_idx: 0,
200            n_samples: 0,
201            initialized: false,
202            init_buffer: Vec::with_capacity(period),
203        }
204    }
205
206    /// Update the model with a single observation.
207    ///
208    /// During the initialization phase (first `period` observations), values
209    /// are buffered. Once a full season is collected, the model is initialized
210    /// and all buffered values are replayed.
211    pub fn train_one(&mut self, y: f64) {
212        self.n_samples += 1;
213
214        if !self.initialized {
215            self.init_buffer.push(y);
216            if self.init_buffer.len() == self.config.period {
217                self.initialize();
218            }
219            return;
220        }
221
222        self.update(y);
223    }
224
225    /// One-step-ahead forecast from the current state.
226    ///
227    /// Returns 0.0 if the model has not been initialized yet.
228    pub fn predict_one(&self) -> f64 {
229        if !self.initialized {
230            return 0.0;
231        }
232        self.forecast_step(1)
233    }
234
235    /// Multi-step forecast from the current state.
236    ///
237    /// Returns a `Vec` of length `horizon` with forecasts for steps 1..=horizon.
238    /// Returns an empty `Vec` if the model is not initialized or horizon is 0.
239    pub fn forecast(&self, horizon: usize) -> Vec<f64> {
240        if !self.initialized || horizon == 0 {
241            return vec![0.0; horizon];
242        }
243        (1..=horizon).map(|h| self.forecast_step(h)).collect()
244    }
245
246    /// Current level component.
247    pub fn level(&self) -> f64 {
248        self.level
249    }
250
251    /// Current trend component.
252    pub fn trend(&self) -> f64 {
253        self.trend
254    }
255
256    /// Seasonal factors (one per period position).
257    pub fn seasonal_factors(&self) -> &[f64] {
258        &self.seasonal
259    }
260
261    /// Whether the model has been initialized (a full season has been seen).
262    pub fn is_initialized(&self) -> bool {
263        self.initialized
264    }
265
266    /// Total number of observations processed (including buffered ones).
267    pub fn n_samples_seen(&self) -> u64 {
268        self.n_samples
269    }
270
271    /// Reset the model to its initial untrained state.
272    pub fn reset(&mut self) {
273        let period = self.config.period;
274        self.level = 0.0;
275        self.trend = 0.0;
276        self.seasonal = match self.config.seasonality {
277            Seasonality::Additive => vec![0.0; period],
278            Seasonality::Multiplicative => vec![1.0; period],
279        };
280        self.season_idx = 0;
281        self.n_samples = 0;
282        self.initialized = false;
283        self.init_buffer.clear();
284    }
285
286    // -----------------------------------------------------------------------
287    // Private helpers
288    // -----------------------------------------------------------------------
289
290    /// Initialize level, trend, and seasonal factors from the first season.
291    fn initialize(&mut self) {
292        let m = self.config.period;
293        let buf = &self.init_buffer;
294
295        // Level = mean of first season
296        let mean: f64 = buf.iter().sum::<f64>() / m as f64;
297        self.level = mean;
298
299        // Trend = 0 (single-season initialization)
300        self.trend = 0.0;
301
302        // Seasonal factors from first season
303        match self.config.seasonality {
304            Seasonality::Additive => {
305                for (i, &b) in buf.iter().enumerate().take(m) {
306                    self.seasonal[i] = b - mean;
307                }
308            }
309            Seasonality::Multiplicative => {
310                for (i, &b) in buf.iter().enumerate().take(m) {
311                    // Guard against zero mean to avoid division by zero.
312                    if mean.abs() < f64::EPSILON {
313                        self.seasonal[i] = 1.0;
314                    } else {
315                        self.seasonal[i] = b / mean;
316                    }
317                }
318            }
319        }
320
321        self.initialized = true;
322        self.season_idx = 0;
323
324        // Replay the buffered observations through the update equations.
325        // We clone the buffer since update() borrows &mut self.
326        let replay: Vec<f64> = buf.clone();
327        for &y in &replay {
328            self.update(y);
329        }
330    }
331
332    /// Update level, trend, and seasonal components with a single observation.
333    fn update(&mut self, y: f64) {
334        let m = self.config.period;
335        let alpha = self.config.alpha;
336        let beta = self.config.beta;
337        let gamma = self.config.gamma;
338
339        let prev_level = self.level;
340        let prev_trend = self.trend;
341        let prev_seasonal = self.seasonal[self.season_idx];
342
343        match self.config.seasonality {
344            Seasonality::Additive => {
345                // Level
346                self.level =
347                    alpha * (y - prev_seasonal) + (1.0 - alpha) * (prev_level + prev_trend);
348
349                // Trend
350                self.trend = beta * (self.level - prev_level) + (1.0 - beta) * prev_trend;
351
352                // Seasonal
353                self.seasonal[self.season_idx] =
354                    gamma * (y - self.level) + (1.0 - gamma) * prev_seasonal;
355            }
356            Seasonality::Multiplicative => {
357                // Guard against zero seasonal factor
358                let safe_seasonal = if prev_seasonal.abs() < f64::EPSILON {
359                    1.0
360                } else {
361                    prev_seasonal
362                };
363
364                // Level
365                self.level =
366                    alpha * (y / safe_seasonal) + (1.0 - alpha) * (prev_level + prev_trend);
367
368                // Trend
369                self.trend = beta * (self.level - prev_level) + (1.0 - beta) * prev_trend;
370
371                // Seasonal — guard against zero level
372                let safe_level = if self.level.abs() < f64::EPSILON {
373                    1.0
374                } else {
375                    self.level
376                };
377                self.seasonal[self.season_idx] =
378                    gamma * (y / safe_level) + (1.0 - gamma) * prev_seasonal;
379            }
380        }
381
382        // Advance seasonal index
383        self.season_idx = (self.season_idx + 1) % m;
384    }
385
386    /// Forecast h steps ahead from the current state.
387    fn forecast_step(&self, h: usize) -> f64 {
388        let m = self.config.period;
389        // Seasonal index for h steps ahead:
390        // s_{t-m+((h-1) mod m)+1}
391        let idx = (self.season_idx + (h - 1) % m) % m;
392
393        match self.config.seasonality {
394            Seasonality::Additive => self.level + (h as f64) * self.trend + self.seasonal[idx],
395            Seasonality::Multiplicative => {
396                (self.level + (h as f64) * self.trend) * self.seasonal[idx]
397            }
398        }
399    }
400}
401
402// ---------------------------------------------------------------------------
403// StreamingLearner impl
404// ---------------------------------------------------------------------------
405
406impl StreamingLearner for HoltWinters {
407    fn train_one(&mut self, _features: &[f64], target: f64, _weight: f64) {
408        HoltWinters::train_one(self, target);
409    }
410
411    fn predict(&self, _features: &[f64]) -> f64 {
412        self.predict_one()
413    }
414
415    fn n_samples_seen(&self) -> u64 {
416        self.n_samples
417    }
418
419    fn reset(&mut self) {
420        HoltWinters::reset(self);
421    }
422}
423
424// ---------------------------------------------------------------------------
425// DiagnosticSource impl
426// ---------------------------------------------------------------------------
427
428impl crate::automl::DiagnosticSource for HoltWinters {
429    fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
430        None
431    }
432}
433
434// ---------------------------------------------------------------------------
435// Tests
436// ---------------------------------------------------------------------------
437
438#[cfg(test)]
439mod tests {
440    use super::*;
441    use std::f64::consts::PI;
442
443    const EPS: f64 = 1e-6;
444
445    fn default_config(period: usize) -> HoltWintersConfig {
446        HoltWintersConfig::builder(period)
447            .alpha(0.3)
448            .beta(0.1)
449            .gamma(0.1)
450            .build()
451            .unwrap()
452    }
453
454    #[test]
455    fn constant_series_converges() {
456        let mut hw = HoltWinters::new(default_config(4));
457        let val = 42.0;
458
459        // Feed 100 observations of a constant
460        for _ in 0..100 {
461            hw.train_one(val);
462        }
463
464        assert!(
465            hw.is_initialized(),
466            "should be initialized after 100 samples"
467        );
468        assert!(
469            (hw.level() - val).abs() < 1.0,
470            "level should converge to {}, got {}",
471            val,
472            hw.level()
473        );
474        assert!(
475            hw.trend().abs() < 1.0,
476            "trend should converge to 0, got {}",
477            hw.trend()
478        );
479    }
480
481    #[test]
482    fn linear_trend_captured() {
483        let mut hw = HoltWinters::new(default_config(4));
484
485        // Feed y = 2*t
486        for t in 0..200 {
487            hw.train_one(2.0 * t as f64);
488        }
489
490        assert!(hw.is_initialized());
491        assert!(
492            hw.trend() > 0.0,
493            "trend should be positive for increasing series, got {}",
494            hw.trend()
495        );
496    }
497
498    #[test]
499    fn additive_seasonal_captured() {
500        let period = 12;
501        let config = HoltWintersConfig::builder(period)
502            .alpha(0.3)
503            .beta(0.1)
504            .gamma(0.3)
505            .build()
506            .unwrap();
507        let mut hw = HoltWinters::new(config);
508
509        // Feed y = 100 + 10*sin(2*pi*t/period)
510        for t in 0..120 {
511            let y = 100.0 + 10.0 * (2.0 * PI * t as f64 / period as f64).sin();
512            hw.train_one(y);
513        }
514
515        assert!(hw.is_initialized());
516
517        // Seasonal factors should be nonzero (not all zero)
518        let factors = hw.seasonal_factors();
519        let has_nonzero = factors.iter().any(|s| s.abs() > EPS);
520        assert!(
521            has_nonzero,
522            "additive seasonal factors should be nonzero, got {:?}",
523            factors
524        );
525    }
526
527    #[test]
528    fn multiplicative_seasonal_captured() {
529        let period = 12;
530        let config = HoltWintersConfig::builder(period)
531            .alpha(0.3)
532            .beta(0.1)
533            .gamma(0.3)
534            .seasonality(Seasonality::Multiplicative)
535            .build()
536            .unwrap();
537        let mut hw = HoltWinters::new(config);
538
539        // Feed y = 100 * (1 + 0.1*sin(2*pi*t/period))
540        for t in 0..120 {
541            let y = 100.0 * (1.0 + 0.1 * (2.0 * PI * t as f64 / period as f64).sin());
542            hw.train_one(y);
543        }
544
545        assert!(hw.is_initialized());
546
547        // Multiplicative factors should not all be 1.0
548        let factors = hw.seasonal_factors();
549        let has_deviation = factors.iter().any(|s| (s - 1.0).abs() > EPS);
550        assert!(
551            has_deviation,
552            "multiplicative seasonal factors should deviate from 1.0, got {:?}",
553            factors
554        );
555    }
556
557    #[test]
558    fn forecast_returns_correct_length() {
559        let mut hw = HoltWinters::new(default_config(4));
560
561        // Before init, forecast should return zeros of correct length
562        let f0 = hw.forecast(5);
563        assert_eq!(f0.len(), 5, "forecast length should match horizon");
564
565        // Feed enough to init
566        for t in 0..20 {
567            hw.train_one(100.0 + (t % 4) as f64 * 10.0);
568        }
569
570        let f1 = hw.forecast(10);
571        assert_eq!(f1.len(), 10, "forecast length should match horizon");
572
573        let f_empty = hw.forecast(0);
574        assert_eq!(f_empty.len(), 0, "forecast(0) should return empty vec");
575    }
576
577    #[test]
578    fn forecast_uses_seasonal() {
579        let period = 4;
580        let config = HoltWintersConfig::builder(period)
581            .alpha(0.3)
582            .beta(0.01)
583            .gamma(0.3)
584            .build()
585            .unwrap();
586        let mut hw = HoltWinters::new(config);
587
588        // Feed distinct seasonal pattern
589        let pattern = [10.0, 20.0, 30.0, 15.0];
590        for cycle in 0..50 {
591            for &v in &pattern {
592                hw.train_one(100.0 + v + cycle as f64 * 0.1);
593            }
594        }
595
596        // Forecast one full period
597        let fc = hw.forecast(period);
598        assert_eq!(fc.len(), period);
599
600        // Forecasted values should not all be the same (seasonality present)
601        let all_same = fc.windows(2).all(|w| (w[0] - w[1]).abs() < EPS);
602        assert!(!all_same, "forecast should show periodicity, got {:?}", fc);
603    }
604
605    #[test]
606    fn initialization_buffers_first_period() {
607        let period = 7;
608        let mut hw = HoltWinters::new(default_config(period));
609
610        // Feed period-1 samples -- should not be initialized yet
611        for t in 0..period - 1 {
612            hw.train_one(t as f64);
613            assert!(
614                !hw.is_initialized(),
615                "should not be initialized after {} samples",
616                t + 1
617            );
618        }
619
620        // One more to complete the period
621        hw.train_one((period - 1) as f64);
622        assert!(
623            hw.is_initialized(),
624            "should be initialized after {} samples",
625            period
626        );
627    }
628
629    #[test]
630    fn streaming_learner_trait() {
631        let config = default_config(4);
632        let mut hw = HoltWinters::new(config);
633
634        // Use through the StreamingLearner interface
635        let learner: &mut dyn StreamingLearner = &mut hw;
636
637        // Train
638        for t in 0..20 {
639            learner.train_one(&[], 100.0 + (t % 4) as f64 * 10.0, 1.0);
640        }
641
642        assert_eq!(learner.n_samples_seen(), 20);
643
644        // Predict (features ignored)
645        let pred = learner.predict(&[]);
646        assert!(
647            pred.is_finite(),
648            "prediction should be finite, got {}",
649            pred
650        );
651        assert!(
652            pred > 0.0,
653            "prediction should be positive for positive series, got {}",
654            pred
655        );
656
657        // Reset
658        learner.reset();
659        assert_eq!(learner.n_samples_seen(), 0);
660    }
661
662    #[test]
663    fn reset_clears_state() {
664        let mut hw = HoltWinters::new(default_config(4));
665
666        // Train
667        for t in 0..20 {
668            hw.train_one(50.0 + t as f64);
669        }
670
671        assert!(hw.is_initialized());
672        assert!(hw.n_samples_seen() > 0);
673
674        // Reset
675        hw.reset();
676
677        assert!(
678            !hw.is_initialized(),
679            "should not be initialized after reset"
680        );
681        assert_eq!(hw.n_samples_seen(), 0, "n_samples should be 0 after reset");
682        assert_eq!(hw.level(), 0.0, "level should be 0 after reset");
683        assert_eq!(hw.trend(), 0.0, "trend should be 0 after reset");
684
685        // Should be able to retrain
686        for t in 0..10 {
687            hw.train_one(t as f64 * 5.0);
688        }
689        assert!(hw.is_initialized());
690    }
691
692    #[test]
693    fn config_validates() {
694        // Valid config
695        let ok = HoltWintersConfig::builder(4)
696            .alpha(0.5)
697            .beta(0.5)
698            .gamma(0.5)
699            .build();
700        assert!(ok.is_ok(), "valid config should succeed");
701
702        // Alpha out of range
703        let err = HoltWintersConfig::builder(4).alpha(0.0).build();
704        assert!(err.is_err(), "alpha=0 should fail");
705
706        let err = HoltWintersConfig::builder(4).alpha(1.0).build();
707        assert!(err.is_err(), "alpha=1 should fail");
708
709        let err = HoltWintersConfig::builder(4).alpha(-0.1).build();
710        assert!(err.is_err(), "alpha<0 should fail");
711
712        let err = HoltWintersConfig::builder(4).alpha(1.5).build();
713        assert!(err.is_err(), "alpha>1 should fail");
714
715        // Beta out of range
716        let err = HoltWintersConfig::builder(4).beta(0.0).build();
717        assert!(err.is_err(), "beta=0 should fail");
718
719        // Gamma out of range
720        let err = HoltWintersConfig::builder(4).gamma(0.0).build();
721        assert!(err.is_err(), "gamma=0 should fail");
722
723        // Period too small
724        let err = HoltWintersConfig::builder(1).build();
725        assert!(err.is_err(), "period=1 should fail");
726
727        let err = HoltWintersConfig::builder(0).build();
728        assert!(err.is_err(), "period=0 should fail");
729    }
730}