1use crate::learner::StreamingLearner;
8
9#[derive(Debug, Clone, Copy, PartialEq)]
15pub enum Seasonality {
16 Additive,
18 Multiplicative,
20}
21
22#[derive(Debug, Clone)]
44pub struct HoltWintersConfig {
45 pub alpha: f64,
47 pub beta: f64,
49 pub gamma: f64,
51 pub period: usize,
53 pub seasonality: Seasonality,
55}
56
57impl HoltWintersConfig {
58 pub fn builder(period: usize) -> HoltWintersConfigBuilder {
60 HoltWintersConfigBuilder {
61 alpha: 0.3,
62 beta: 0.1,
63 gamma: 0.1,
64 period,
65 seasonality: Seasonality::Additive,
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
76pub struct HoltWintersConfigBuilder {
77 alpha: f64,
78 beta: f64,
79 gamma: f64,
80 period: usize,
81 seasonality: Seasonality,
82}
83
84impl HoltWintersConfigBuilder {
85 pub fn alpha(mut self, alpha: f64) -> Self {
87 self.alpha = alpha;
88 self
89 }
90
91 pub fn beta(mut self, beta: f64) -> Self {
93 self.beta = beta;
94 self
95 }
96
97 pub fn gamma(mut self, gamma: f64) -> Self {
99 self.gamma = gamma;
100 self
101 }
102
103 pub fn seasonality(mut self, seasonality: Seasonality) -> Self {
105 self.seasonality = seasonality;
106 self
107 }
108
109 pub fn build(self) -> Result<HoltWintersConfig, String> {
113 if self.alpha <= 0.0 || self.alpha >= 1.0 {
114 return Err(format!("alpha must be in (0, 1), got {}", self.alpha));
115 }
116 if self.beta <= 0.0 || self.beta >= 1.0 {
117 return Err(format!("beta must be in (0, 1), got {}", self.beta));
118 }
119 if self.gamma <= 0.0 || self.gamma >= 1.0 {
120 return Err(format!("gamma must be in (0, 1), got {}", self.gamma));
121 }
122 if self.period < 2 {
123 return Err(format!("period must be >= 2, got {}", self.period));
124 }
125 Ok(HoltWintersConfig {
126 alpha: self.alpha,
127 beta: self.beta,
128 gamma: self.gamma,
129 period: self.period,
130 seasonality: self.seasonality,
131 })
132 }
133}
134
135#[derive(Debug, Clone)]
175pub struct HoltWinters {
176 config: HoltWintersConfig,
177 level: f64,
178 trend: f64,
179 seasonal: Vec<f64>,
180 season_idx: usize,
181 n_samples: u64,
182 initialized: bool,
183 init_buffer: Vec<f64>,
184}
185
186impl HoltWinters {
187 pub fn new(config: HoltWintersConfig) -> Self {
189 let period = config.period;
190 let init_seasonal = match config.seasonality {
191 Seasonality::Additive => vec![0.0; period],
192 Seasonality::Multiplicative => vec![1.0; period],
193 };
194 Self {
195 config,
196 level: 0.0,
197 trend: 0.0,
198 seasonal: init_seasonal,
199 season_idx: 0,
200 n_samples: 0,
201 initialized: false,
202 init_buffer: Vec::with_capacity(period),
203 }
204 }
205
206 pub fn train_one(&mut self, y: f64) {
212 self.n_samples += 1;
213
214 if !self.initialized {
215 self.init_buffer.push(y);
216 if self.init_buffer.len() == self.config.period {
217 self.initialize();
218 }
219 return;
220 }
221
222 self.update(y);
223 }
224
225 pub fn predict_one(&self) -> f64 {
229 if !self.initialized {
230 return 0.0;
231 }
232 self.forecast_step(1)
233 }
234
235 pub fn forecast(&self, horizon: usize) -> Vec<f64> {
240 if !self.initialized || horizon == 0 {
241 return vec![0.0; horizon];
242 }
243 (1..=horizon).map(|h| self.forecast_step(h)).collect()
244 }
245
246 pub fn level(&self) -> f64 {
248 self.level
249 }
250
251 pub fn trend(&self) -> f64 {
253 self.trend
254 }
255
256 pub fn seasonal_factors(&self) -> &[f64] {
258 &self.seasonal
259 }
260
261 pub fn is_initialized(&self) -> bool {
263 self.initialized
264 }
265
266 pub fn n_samples_seen(&self) -> u64 {
268 self.n_samples
269 }
270
271 pub fn reset(&mut self) {
273 let period = self.config.period;
274 self.level = 0.0;
275 self.trend = 0.0;
276 self.seasonal = match self.config.seasonality {
277 Seasonality::Additive => vec![0.0; period],
278 Seasonality::Multiplicative => vec![1.0; period],
279 };
280 self.season_idx = 0;
281 self.n_samples = 0;
282 self.initialized = false;
283 self.init_buffer.clear();
284 }
285
286 fn initialize(&mut self) {
292 let m = self.config.period;
293 let buf = &self.init_buffer;
294
295 let mean: f64 = buf.iter().sum::<f64>() / m as f64;
297 self.level = mean;
298
299 self.trend = 0.0;
301
302 match self.config.seasonality {
304 Seasonality::Additive => {
305 for (i, &b) in buf.iter().enumerate().take(m) {
306 self.seasonal[i] = b - mean;
307 }
308 }
309 Seasonality::Multiplicative => {
310 for (i, &b) in buf.iter().enumerate().take(m) {
311 if mean.abs() < f64::EPSILON {
313 self.seasonal[i] = 1.0;
314 } else {
315 self.seasonal[i] = b / mean;
316 }
317 }
318 }
319 }
320
321 self.initialized = true;
322 self.season_idx = 0;
323
324 let replay: Vec<f64> = buf.clone();
327 for &y in &replay {
328 self.update(y);
329 }
330 }
331
332 fn update(&mut self, y: f64) {
334 let m = self.config.period;
335 let alpha = self.config.alpha;
336 let beta = self.config.beta;
337 let gamma = self.config.gamma;
338
339 let prev_level = self.level;
340 let prev_trend = self.trend;
341 let prev_seasonal = self.seasonal[self.season_idx];
342
343 match self.config.seasonality {
344 Seasonality::Additive => {
345 self.level =
347 alpha * (y - prev_seasonal) + (1.0 - alpha) * (prev_level + prev_trend);
348
349 self.trend = beta * (self.level - prev_level) + (1.0 - beta) * prev_trend;
351
352 self.seasonal[self.season_idx] =
354 gamma * (y - self.level) + (1.0 - gamma) * prev_seasonal;
355 }
356 Seasonality::Multiplicative => {
357 let safe_seasonal = if prev_seasonal.abs() < f64::EPSILON {
359 1.0
360 } else {
361 prev_seasonal
362 };
363
364 self.level =
366 alpha * (y / safe_seasonal) + (1.0 - alpha) * (prev_level + prev_trend);
367
368 self.trend = beta * (self.level - prev_level) + (1.0 - beta) * prev_trend;
370
371 let safe_level = if self.level.abs() < f64::EPSILON {
373 1.0
374 } else {
375 self.level
376 };
377 self.seasonal[self.season_idx] =
378 gamma * (y / safe_level) + (1.0 - gamma) * prev_seasonal;
379 }
380 }
381
382 self.season_idx = (self.season_idx + 1) % m;
384 }
385
386 fn forecast_step(&self, h: usize) -> f64 {
388 let m = self.config.period;
389 let idx = (self.season_idx + (h - 1) % m) % m;
392
393 match self.config.seasonality {
394 Seasonality::Additive => self.level + (h as f64) * self.trend + self.seasonal[idx],
395 Seasonality::Multiplicative => {
396 (self.level + (h as f64) * self.trend) * self.seasonal[idx]
397 }
398 }
399 }
400}
401
402impl StreamingLearner for HoltWinters {
407 fn train_one(&mut self, _features: &[f64], target: f64, _weight: f64) {
408 HoltWinters::train_one(self, target);
409 }
410
411 fn predict(&self, _features: &[f64]) -> f64 {
412 self.predict_one()
413 }
414
415 fn n_samples_seen(&self) -> u64 {
416 self.n_samples
417 }
418
419 fn reset(&mut self) {
420 HoltWinters::reset(self);
421 }
422}
423
424impl crate::automl::DiagnosticSource for HoltWinters {
429 fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
430 None
431 }
432}
433
434#[cfg(test)]
439mod tests {
440 use super::*;
441 use std::f64::consts::PI;
442
443 const EPS: f64 = 1e-6;
444
445 fn default_config(period: usize) -> HoltWintersConfig {
446 HoltWintersConfig::builder(period)
447 .alpha(0.3)
448 .beta(0.1)
449 .gamma(0.1)
450 .build()
451 .unwrap()
452 }
453
454 #[test]
455 fn constant_series_converges() {
456 let mut hw = HoltWinters::new(default_config(4));
457 let val = 42.0;
458
459 for _ in 0..100 {
461 hw.train_one(val);
462 }
463
464 assert!(
465 hw.is_initialized(),
466 "should be initialized after 100 samples"
467 );
468 assert!(
469 (hw.level() - val).abs() < 1.0,
470 "level should converge to {}, got {}",
471 val,
472 hw.level()
473 );
474 assert!(
475 hw.trend().abs() < 1.0,
476 "trend should converge to 0, got {}",
477 hw.trend()
478 );
479 }
480
481 #[test]
482 fn linear_trend_captured() {
483 let mut hw = HoltWinters::new(default_config(4));
484
485 for t in 0..200 {
487 hw.train_one(2.0 * t as f64);
488 }
489
490 assert!(hw.is_initialized());
491 assert!(
492 hw.trend() > 0.0,
493 "trend should be positive for increasing series, got {}",
494 hw.trend()
495 );
496 }
497
498 #[test]
499 fn additive_seasonal_captured() {
500 let period = 12;
501 let config = HoltWintersConfig::builder(period)
502 .alpha(0.3)
503 .beta(0.1)
504 .gamma(0.3)
505 .build()
506 .unwrap();
507 let mut hw = HoltWinters::new(config);
508
509 for t in 0..120 {
511 let y = 100.0 + 10.0 * (2.0 * PI * t as f64 / period as f64).sin();
512 hw.train_one(y);
513 }
514
515 assert!(hw.is_initialized());
516
517 let factors = hw.seasonal_factors();
519 let has_nonzero = factors.iter().any(|s| s.abs() > EPS);
520 assert!(
521 has_nonzero,
522 "additive seasonal factors should be nonzero, got {:?}",
523 factors
524 );
525 }
526
527 #[test]
528 fn multiplicative_seasonal_captured() {
529 let period = 12;
530 let config = HoltWintersConfig::builder(period)
531 .alpha(0.3)
532 .beta(0.1)
533 .gamma(0.3)
534 .seasonality(Seasonality::Multiplicative)
535 .build()
536 .unwrap();
537 let mut hw = HoltWinters::new(config);
538
539 for t in 0..120 {
541 let y = 100.0 * (1.0 + 0.1 * (2.0 * PI * t as f64 / period as f64).sin());
542 hw.train_one(y);
543 }
544
545 assert!(hw.is_initialized());
546
547 let factors = hw.seasonal_factors();
549 let has_deviation = factors.iter().any(|s| (s - 1.0).abs() > EPS);
550 assert!(
551 has_deviation,
552 "multiplicative seasonal factors should deviate from 1.0, got {:?}",
553 factors
554 );
555 }
556
557 #[test]
558 fn forecast_returns_correct_length() {
559 let mut hw = HoltWinters::new(default_config(4));
560
561 let f0 = hw.forecast(5);
563 assert_eq!(f0.len(), 5, "forecast length should match horizon");
564
565 for t in 0..20 {
567 hw.train_one(100.0 + (t % 4) as f64 * 10.0);
568 }
569
570 let f1 = hw.forecast(10);
571 assert_eq!(f1.len(), 10, "forecast length should match horizon");
572
573 let f_empty = hw.forecast(0);
574 assert_eq!(f_empty.len(), 0, "forecast(0) should return empty vec");
575 }
576
577 #[test]
578 fn forecast_uses_seasonal() {
579 let period = 4;
580 let config = HoltWintersConfig::builder(period)
581 .alpha(0.3)
582 .beta(0.01)
583 .gamma(0.3)
584 .build()
585 .unwrap();
586 let mut hw = HoltWinters::new(config);
587
588 let pattern = [10.0, 20.0, 30.0, 15.0];
590 for cycle in 0..50 {
591 for &v in &pattern {
592 hw.train_one(100.0 + v + cycle as f64 * 0.1);
593 }
594 }
595
596 let fc = hw.forecast(period);
598 assert_eq!(fc.len(), period);
599
600 let all_same = fc.windows(2).all(|w| (w[0] - w[1]).abs() < EPS);
602 assert!(!all_same, "forecast should show periodicity, got {:?}", fc);
603 }
604
605 #[test]
606 fn initialization_buffers_first_period() {
607 let period = 7;
608 let mut hw = HoltWinters::new(default_config(period));
609
610 for t in 0..period - 1 {
612 hw.train_one(t as f64);
613 assert!(
614 !hw.is_initialized(),
615 "should not be initialized after {} samples",
616 t + 1
617 );
618 }
619
620 hw.train_one((period - 1) as f64);
622 assert!(
623 hw.is_initialized(),
624 "should be initialized after {} samples",
625 period
626 );
627 }
628
629 #[test]
630 fn streaming_learner_trait() {
631 let config = default_config(4);
632 let mut hw = HoltWinters::new(config);
633
634 let learner: &mut dyn StreamingLearner = &mut hw;
636
637 for t in 0..20 {
639 learner.train_one(&[], 100.0 + (t % 4) as f64 * 10.0, 1.0);
640 }
641
642 assert_eq!(learner.n_samples_seen(), 20);
643
644 let pred = learner.predict(&[]);
646 assert!(
647 pred.is_finite(),
648 "prediction should be finite, got {}",
649 pred
650 );
651 assert!(
652 pred > 0.0,
653 "prediction should be positive for positive series, got {}",
654 pred
655 );
656
657 learner.reset();
659 assert_eq!(learner.n_samples_seen(), 0);
660 }
661
662 #[test]
663 fn reset_clears_state() {
664 let mut hw = HoltWinters::new(default_config(4));
665
666 for t in 0..20 {
668 hw.train_one(50.0 + t as f64);
669 }
670
671 assert!(hw.is_initialized());
672 assert!(hw.n_samples_seen() > 0);
673
674 hw.reset();
676
677 assert!(
678 !hw.is_initialized(),
679 "should not be initialized after reset"
680 );
681 assert_eq!(hw.n_samples_seen(), 0, "n_samples should be 0 after reset");
682 assert_eq!(hw.level(), 0.0, "level should be 0 after reset");
683 assert_eq!(hw.trend(), 0.0, "trend should be 0 after reset");
684
685 for t in 0..10 {
687 hw.train_one(t as f64 * 5.0);
688 }
689 assert!(hw.is_initialized());
690 }
691
692 #[test]
693 fn config_validates() {
694 let ok = HoltWintersConfig::builder(4)
696 .alpha(0.5)
697 .beta(0.5)
698 .gamma(0.5)
699 .build();
700 assert!(ok.is_ok(), "valid config should succeed");
701
702 let err = HoltWintersConfig::builder(4).alpha(0.0).build();
704 assert!(err.is_err(), "alpha=0 should fail");
705
706 let err = HoltWintersConfig::builder(4).alpha(1.0).build();
707 assert!(err.is_err(), "alpha=1 should fail");
708
709 let err = HoltWintersConfig::builder(4).alpha(-0.1).build();
710 assert!(err.is_err(), "alpha<0 should fail");
711
712 let err = HoltWintersConfig::builder(4).alpha(1.5).build();
713 assert!(err.is_err(), "alpha>1 should fail");
714
715 let err = HoltWintersConfig::builder(4).beta(0.0).build();
717 assert!(err.is_err(), "beta=0 should fail");
718
719 let err = HoltWintersConfig::builder(4).gamma(0.0).build();
721 assert!(err.is_err(), "gamma=0 should fail");
722
723 let err = HoltWintersConfig::builder(1).build();
725 assert!(err.is_err(), "period=1 should fail");
726
727 let err = HoltWintersConfig::builder(0).build();
728 assert!(err.is_err(), "period=0 should fail");
729 }
730}