1use std::{
27 fmt::{self, Write},
28 str::FromStr,
29};
30
31use augurs_core::{Fit, Forecast, Predict};
32
33use crate::{
34 model::{self, Model, OptimizationCriteria, Params, Unfit},
35 Error, Result,
36};
37
38#[derive(Debug, Clone, Copy, Eq, PartialEq)]
40pub enum ErrorSpec {
41 Additive,
43 Multiplicative,
45 Auto,
47}
48
49impl ErrorSpec {
50 fn candidates(&self) -> &[model::ErrorComponent] {
52 match self {
53 Self::Additive => &[model::ErrorComponent::Additive],
54 Self::Multiplicative => &[model::ErrorComponent::Multiplicative],
55 Self::Auto => &[
56 model::ErrorComponent::Additive,
57 model::ErrorComponent::Multiplicative,
58 ],
59 }
60 }
61}
62
63impl fmt::Display for ErrorSpec {
64 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65 match self {
66 Self::Additive => f.write_char('A'),
67 Self::Multiplicative => f.write_char('M'),
68 Self::Auto => f.write_char('Z'),
69 }
70 }
71}
72
73impl TryFrom<char> for ErrorSpec {
74 type Error = Error;
75
76 fn try_from(c: char) -> Result<Self> {
77 match c {
78 'A' => Ok(Self::Additive),
79 'M' => Ok(Self::Multiplicative),
80 'Z' => Ok(Self::Auto),
81 _ => Err(Error::InvalidErrorComponentString(c)),
82 }
83 }
84}
85
86#[derive(Debug, Clone, Copy, Eq, PartialEq)]
88pub enum ComponentSpec {
89 None,
91 Additive,
93 Multiplicative,
95 Auto,
97}
98
99impl ComponentSpec {
100 fn is_specified(&self) -> bool {
102 matches!(self, Self::Additive | Self::Multiplicative)
103 }
104
105 fn trend_candidates(&self, auto_multiplicative: bool) -> &[model::TrendComponent] {
107 match (self, auto_multiplicative) {
108 (Self::None, _) => &[],
109 (Self::Additive, _) => &[model::TrendComponent::Additive],
110 (Self::Multiplicative, _) => &[model::TrendComponent::Multiplicative],
111 (Self::Auto, false) => &[model::TrendComponent::None, model::TrendComponent::Additive],
112 (Self::Auto, true) => &[
113 model::TrendComponent::None,
114 model::TrendComponent::Additive,
115 model::TrendComponent::Multiplicative,
116 ],
117 }
118 }
119
120 fn seasonal_candidates(&self, season_length: usize) -> Vec<model::SeasonalComponent> {
122 match self {
123 ComponentSpec::None => vec![model::SeasonalComponent::None],
124 ComponentSpec::Additive => {
125 vec![model::SeasonalComponent::Additive { season_length }]
126 }
127 ComponentSpec::Multiplicative => {
128 vec![model::SeasonalComponent::Multiplicative { season_length }]
129 }
130 ComponentSpec::Auto => vec![
131 model::SeasonalComponent::None,
132 model::SeasonalComponent::Additive { season_length },
133 model::SeasonalComponent::Multiplicative { season_length },
134 ],
135 }
136 }
137}
138
139impl fmt::Display for ComponentSpec {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 match self {
142 Self::None => f.write_char('N'),
143 Self::Additive => f.write_char('A'),
144 Self::Multiplicative => f.write_char('M'),
145 Self::Auto => f.write_char('Z'),
146 }
147 }
148}
149
150impl TryFrom<char> for ComponentSpec {
151 type Error = Error;
152
153 fn try_from(c: char) -> Result<Self> {
154 match c {
155 'N' => Ok(Self::None),
156 'A' => Ok(Self::Additive),
157 'M' => Ok(Self::Multiplicative),
158 'Z' => Ok(Self::Auto),
159 _ => Err(Error::InvalidComponentString(c)),
160 }
161 }
162}
163
164#[derive(Debug, Clone, Eq, PartialEq)]
165enum Damped {
166 Auto,
167 Fixed(bool),
168}
169
170impl Damped {
171 fn candidates(&self) -> &[bool] {
172 match self {
173 Self::Auto => &[true, false],
174 Self::Fixed(x) => std::slice::from_ref(x),
175 }
176 }
177}
178
179#[derive(Debug, Clone, Copy)]
181pub struct AutoSpec {
182 pub error: ErrorSpec,
184 pub trend: ComponentSpec,
186 pub seasonal: ComponentSpec,
188}
189
190impl fmt::Display for AutoSpec {
191 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192 self.error.fmt(f)?;
193 self.trend.fmt(f)?;
194 self.seasonal.fmt(f)?;
195 Ok(())
196 }
197}
198
199impl FromStr for AutoSpec {
200 type Err = Error;
201
202 fn from_str(s: &str) -> Result<Self> {
203 if s.len() != 3 {
204 return Err(Error::InvalidModelSpec(s.to_owned()));
205 }
206 let mut iter = s.chars();
207 let spec = Self {
208 error: ErrorSpec::try_from(iter.next().unwrap())?,
209 trend: ComponentSpec::try_from(iter.next().unwrap())?,
210 seasonal: ComponentSpec::try_from(iter.next().unwrap())?,
211 };
212 use ComponentSpec::*;
213 match spec {
214 Self {
215 error: ErrorSpec::Additive,
216 trend: _,
217 seasonal: Multiplicative,
218 }
219 | Self {
220 error: ErrorSpec::Additive,
221 trend: Multiplicative,
222 seasonal: _,
223 }
224 | Self {
225 error: ErrorSpec::Multiplicative,
226 trend: Multiplicative,
227 seasonal: Multiplicative,
228 } => Err(Error::InvalidModelSpec(s.to_owned())),
229 other => Ok(other),
230 }
231 }
232}
233
234impl TryFrom<&str> for AutoSpec {
235 type Error = Error;
236
237 fn try_from(s: &str) -> Result<Self> {
238 s.parse()
239 }
240}
241
242#[derive(Debug, Clone)]
244pub struct AutoETS {
245 spec: AutoSpec,
247 season_length: usize,
249
250 params: Params,
255
256 damped: Damped,
260
261 allow_multiplicative_trend: bool,
265
266 nmse: usize,
272
273 opt_crit: OptimizationCriteria,
277
278 max_iterations: usize,
282}
283
284impl AutoETS {
285 pub fn new(season_length: usize, spec: impl TryInto<AutoSpec, Error = Error>) -> Result<Self> {
305 let spec = spec.try_into()?;
306 Ok(Self::from_spec(season_length, spec))
307 }
308
309 pub fn from_spec(season_length: usize, spec: AutoSpec) -> Self {
311 let params = Params {
312 alpha: f64::NAN,
313 beta: f64::NAN,
314 gamma: f64::NAN,
315 phi: f64::NAN,
316 };
317 Self {
318 season_length,
319 spec,
320 params,
321 damped: Damped::Auto,
322 allow_multiplicative_trend: false,
323 nmse: 3,
324 opt_crit: OptimizationCriteria::Likelihood,
325 max_iterations: 2_000,
326 }
327 }
328
329 pub fn season_length(&self) -> usize {
331 self.season_length
332 }
333
334 pub fn spec(&self) -> AutoSpec {
336 self.spec
337 }
338
339 pub fn non_seasonal() -> Self {
343 Self::new(1, "ZZN").unwrap()
344 }
345
346 pub fn damped(mut self, damped: bool) -> Result<Self> {
348 if damped && self.spec.trend == ComponentSpec::None {
349 return Err(Error::InvalidModelSpec(format!(
350 "damped trend not allowed for model spec '{}'",
351 self.spec
352 )));
353 }
354 self.damped = Damped::Fixed(damped);
355 Ok(self)
356 }
357
358 pub fn alpha(mut self, alpha: f64) -> Self {
362 self.params.alpha = alpha;
363 self
364 }
365
366 pub fn beta(mut self, beta: f64) -> Self {
370 self.params.beta = beta;
371 self
372 }
373
374 pub fn gamma(mut self, gamma: f64) -> Self {
378 self.params.gamma = gamma;
379 self
380 }
381
382 pub fn phi(mut self, phi: f64) -> Self {
386 self.params.phi = phi;
387 self
388 }
389
390 pub fn allow_multiplicative_trend(mut self, allow: bool) -> Self {
394 self.allow_multiplicative_trend = allow;
395 self
396 }
397
398 fn valid_combination(
403 &self,
404 error: model::ErrorComponent,
405 trend: model::TrendComponent,
406 seasonal: model::SeasonalComponent,
407 damped: bool,
408 data_positive: bool,
409 ) -> bool {
410 use model::{ErrorComponent as EC, SeasonalComponent as SC, TrendComponent as TC};
411 match (error, trend, seasonal, damped) {
412 (_, TC::None, _, true) => false,
414 (EC::Additive, TC::Multiplicative, SC::Multiplicative { .. }, _) => false,
416 (EC::Multiplicative, TC::Multiplicative, SC::Additive { .. }, _) => false,
418 (EC::Multiplicative, _, _, _) if !data_positive => false,
419 (_, _, SC::Multiplicative { .. }, _) if !data_positive => false,
420 (
421 _,
422 _,
423 SC::Additive { season_length: 1 } | SC::Multiplicative { season_length: 1 },
424 _,
425 ) => false,
426 _ => true,
427 }
428 }
429
430 fn candidates(
435 &self,
436 ) -> impl Iterator<
437 Item = (
438 &model::ErrorComponent,
439 &model::TrendComponent,
440 model::SeasonalComponent,
441 &bool,
442 ),
443 > {
444 let error_candidates = self.spec.error.candidates();
445 let trend_candidates = self
446 .spec
447 .trend
448 .trend_candidates(self.allow_multiplicative_trend);
449 let season_candidates = self.spec.seasonal.seasonal_candidates(self.season_length);
450 let damped_candidates = self.damped.candidates();
451
452 itertools::iproduct!(
453 error_candidates,
454 trend_candidates,
455 season_candidates,
456 damped_candidates
457 )
458 }
459}
460
461impl Fit for AutoETS {
462 type Fitted = FittedAutoETS;
463 type Error = Error;
464 fn fit(&self, y: &[f64]) -> Result<Self::Fitted> {
474 let data_positive = y.iter().fold(f64::INFINITY, |a, &b| a.min(b)) > 0.0;
475 if self.spec.error == ErrorSpec::Multiplicative && !data_positive {
476 return Err(Error::InvalidModelSpec(format!(
477 "multiplicative error not allowed for model spec '{}' with non-positive data",
478 self.spec
479 )));
480 }
481
482 let n = y.len();
483 let mut npars = 2; if self.spec.trend.is_specified() {
485 npars += 2; }
487 if self.spec.seasonal.is_specified() {
488 npars += 2; }
490 if n <= npars + 4 {
491 return Err(Error::NotEnoughData);
492 }
493
494 let model = self
495 .candidates()
496 .filter_map(|(&error, &trend, season, &damped)| {
497 if self.valid_combination(error, trend, season, damped, data_positive) {
498 let model = Unfit::new(model::ModelType {
499 error,
500 trend,
501 season,
502 })
503 .damped(damped)
504 .params(self.params.clone())
505 .nmse(self.nmse)
506 .opt_crit(self.opt_crit)
507 .max_iterations(self.max_iterations)
508 .fit(y)
509 .ok()?;
510 if model.aicc().is_nan() {
511 None
512 } else {
513 Some(model)
514 }
515 } else {
516 None
517 }
518 })
519 .min_by(|a, b| {
520 a.aicc()
521 .partial_cmp(&b.aicc())
522 .expect("NaNs have already been filtered from the iterator")
523 })
524 .ok_or(Error::NoModelFound)?;
525 Ok(FittedAutoETS {
526 model,
527 training_data_size: n,
528 })
529 }
530}
531
532#[derive(Debug, Clone)]
536pub struct FittedAutoETS {
537 model: Model,
539
540 training_data_size: usize,
542}
543
544impl FittedAutoETS {
545 pub fn model(&self) -> &Model {
547 &self.model
548 }
549}
550
551impl Predict for FittedAutoETS {
552 type Error = Error;
553
554 fn training_data_size(&self) -> usize {
555 self.training_data_size
556 }
557
558 fn predict_inplace(&self, h: usize, level: Option<f64>, forecast: &mut Forecast) -> Result<()> {
567 self.model.predict_inplace(h, level, forecast)?;
568 Ok(())
569 }
570
571 fn predict_in_sample_inplace(&self, level: Option<f64>, forecast: &mut Forecast) -> Result<()> {
576 self.model.predict_in_sample_inplace(level, forecast)?;
577 Ok(())
578 }
579}
580
581#[cfg(test)]
582mod test {
583 use augurs_core::Fit;
584 use augurs_testing::{assert_within_pct, data::AIR_PASSENGERS};
585
586 use super::{AutoETS, AutoSpec};
587 use crate::{
588 model::{ErrorComponent, SeasonalComponent, TrendComponent},
589 Error,
590 };
591
592 #[test]
593 fn spec_from_str() {
594 let cases = [
595 "NNN", "NAN", "NAM", "NAZ", "NMN", "NMA", "NMM", "NMZ", "ANN", "AAN", "AAM", "AAZ",
596 "AMN", "AMA", "AMM", "AMZ", "MNN", "MAN", "MAM", "MAZ", "MMN", "MMA", "MMM", "MMZ",
597 "ZNN", "ZAN", "ZAM", "ZAZ", "ZMN", "ZMA", "ZMM", "ZMZ",
598 ];
599 for case in cases {
600 let spec: Result<AutoSpec, Error> = case.try_into();
601 let (error, rest) = case.split_at(1);
602 let (trend, seasonal) = rest.split_at(1);
603 match (error, trend, seasonal) {
604 ("N", _, _) => {
605 assert!(
606 matches!(spec, Err(Error::InvalidErrorComponentString(_))),
607 "{spec:?}, case {case}"
608 );
609 }
610 ("A", "M", _) | ("A", _, "M") | ("M", "M", "M") => {
611 assert!(
612 matches!(spec, Err(Error::InvalidModelSpec(_))),
613 "{spec:?}, case {case}"
614 );
615 }
616 _ => {
617 assert!(spec.is_ok());
618 }
619 }
620 }
621 }
622
623 #[test]
624 fn air_passengers_fit() {
625 let auto = AutoETS::new(1, "ZZN").unwrap();
626 let fit = auto.fit(AIR_PASSENGERS).expect("fit failed");
627 assert_eq!(fit.model.model_type().error, ErrorComponent::Multiplicative);
628 assert_eq!(fit.model.model_type().trend, TrendComponent::Additive);
629 assert_eq!(fit.model.model_type().season, SeasonalComponent::None);
630 assert_within_pct!(fit.model.log_likelihood(), -831.4883541595792, 0.01);
631 assert_within_pct!(fit.model.aic(), 1672.9767083191584, 0.01);
632 }
633}