augurs_ets/
auto.rs

1//! Automated ETS model selection.
2//!
3//! This module contains the [`AutoETS`] type, which can be used to automatically
4//! select the best ETS model for a given time series.
5//!
6//! The search specification is controlled by the [`AutoSpec`] type. As a
7//! convenience, [`AutoSpec`] implements [`FromStr`], so it can be parsed from a
8//! string using the same framework as R's `ets` function.
9//!
10//! # Example
11//!
12//! ```
13//! use augurs_core::prelude::*;
14//! use augurs_ets::{AutoETS, AutoSpec};
15//!
16//! // Create an `AutoETS` instance from a specification string.
17//! // The `"ZZN"` specification means that the search should consider all
18//! // models with additive or multiplicative error and trend components, and
19//! // no seasonal component.
20//! let mut auto = AutoETS::new(1, "ZZN").expect("ZZN is a valid specification");
21//! let data = (1..10).map(|x| x as f64).collect::<Vec<_>>();
22//! let model = auto.fit(&data).expect("fit succeeds");
23//! assert_eq!(&model.model().model_type().to_string(), "AAN");
24//! ```
25
26use std::{
27    fmt::{self, Write},
28    str::FromStr,
29};
30
31use augurs_core::{Fit, Forecast, Predict};
32
33use crate::{
34    model::{self, Model, OptimizationCriteria, Params, Unfit},
35    Error, Result,
36};
37
38/// Error component search specification.
39#[derive(Debug, Clone, Copy, Eq, PartialEq)]
40pub enum ErrorSpec {
41    /// Only consider additive error models.
42    Additive,
43    /// Only consider multiplicative error models.
44    Multiplicative,
45    /// Consider both additive and multiplicative error models.
46    Auto,
47}
48
49impl ErrorSpec {
50    /// Returns the error component candidates for this specification.
51    fn candidates(&self) -> &[model::ErrorComponent] {
52        match self {
53            Self::Additive => &[model::ErrorComponent::Additive],
54            Self::Multiplicative => &[model::ErrorComponent::Multiplicative],
55            Self::Auto => &[
56                model::ErrorComponent::Additive,
57                model::ErrorComponent::Multiplicative,
58            ],
59        }
60    }
61}
62
63impl fmt::Display for ErrorSpec {
64    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65        match self {
66            Self::Additive => f.write_char('A'),
67            Self::Multiplicative => f.write_char('M'),
68            Self::Auto => f.write_char('Z'),
69        }
70    }
71}
72
73impl TryFrom<char> for ErrorSpec {
74    type Error = Error;
75
76    fn try_from(c: char) -> Result<Self> {
77        match c {
78            'A' => Ok(Self::Additive),
79            'M' => Ok(Self::Multiplicative),
80            'Z' => Ok(Self::Auto),
81            _ => Err(Error::InvalidErrorComponentString(c)),
82        }
83    }
84}
85
86/// Trend and seasonal component search specification.
87#[derive(Debug, Clone, Copy, Eq, PartialEq)]
88pub enum ComponentSpec {
89    /// Only consider models without this component.
90    None,
91    /// Only consider additive models.
92    Additive,
93    /// Only consider multiplicative models.
94    Multiplicative,
95    /// Consider both additive and multiplicative models.
96    Auto,
97}
98
99impl ComponentSpec {
100    /// Returns `true` if this specification is not `None`.
101    fn is_specified(&self) -> bool {
102        matches!(self, Self::Additive | Self::Multiplicative)
103    }
104
105    /// Returns the trend component candidates for this specification.
106    fn trend_candidates(&self, auto_multiplicative: bool) -> &[model::TrendComponent] {
107        match (self, auto_multiplicative) {
108            (Self::None, _) => &[],
109            (Self::Additive, _) => &[model::TrendComponent::Additive],
110            (Self::Multiplicative, _) => &[model::TrendComponent::Multiplicative],
111            (Self::Auto, false) => &[model::TrendComponent::None, model::TrendComponent::Additive],
112            (Self::Auto, true) => &[
113                model::TrendComponent::None,
114                model::TrendComponent::Additive,
115                model::TrendComponent::Multiplicative,
116            ],
117        }
118    }
119
120    /// Returns the seasonal component candidates for this specification.
121    fn seasonal_candidates(&self, season_length: usize) -> Vec<model::SeasonalComponent> {
122        match self {
123            ComponentSpec::None => vec![model::SeasonalComponent::None],
124            ComponentSpec::Additive => {
125                vec![model::SeasonalComponent::Additive { season_length }]
126            }
127            ComponentSpec::Multiplicative => {
128                vec![model::SeasonalComponent::Multiplicative { season_length }]
129            }
130            ComponentSpec::Auto => vec![
131                model::SeasonalComponent::None,
132                model::SeasonalComponent::Additive { season_length },
133                model::SeasonalComponent::Multiplicative { season_length },
134            ],
135        }
136    }
137}
138
139impl fmt::Display for ComponentSpec {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        match self {
142            Self::None => f.write_char('N'),
143            Self::Additive => f.write_char('A'),
144            Self::Multiplicative => f.write_char('M'),
145            Self::Auto => f.write_char('Z'),
146        }
147    }
148}
149
150impl TryFrom<char> for ComponentSpec {
151    type Error = Error;
152
153    fn try_from(c: char) -> Result<Self> {
154        match c {
155            'N' => Ok(Self::None),
156            'A' => Ok(Self::Additive),
157            'M' => Ok(Self::Multiplicative),
158            'Z' => Ok(Self::Auto),
159            _ => Err(Error::InvalidComponentString(c)),
160        }
161    }
162}
163
164#[derive(Debug, Clone, Eq, PartialEq)]
165enum Damped {
166    Auto,
167    Fixed(bool),
168}
169
170impl Damped {
171    fn candidates(&self) -> &[bool] {
172        match self {
173            Self::Auto => &[true, false],
174            Self::Fixed(x) => std::slice::from_ref(x),
175        }
176    }
177}
178
179/// Auto model search specification.
180#[derive(Debug, Clone, Copy)]
181pub struct AutoSpec {
182    /// The types of error components to consider.
183    pub error: ErrorSpec,
184    /// The types of trend components to consider.
185    pub trend: ComponentSpec,
186    /// The types of seasonal components to consider.
187    pub seasonal: ComponentSpec,
188}
189
190impl fmt::Display for AutoSpec {
191    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
192        self.error.fmt(f)?;
193        self.trend.fmt(f)?;
194        self.seasonal.fmt(f)?;
195        Ok(())
196    }
197}
198
199impl FromStr for AutoSpec {
200    type Err = Error;
201
202    fn from_str(s: &str) -> Result<Self> {
203        if s.len() != 3 {
204            return Err(Error::InvalidModelSpec(s.to_owned()));
205        }
206        let mut iter = s.chars();
207        let spec = Self {
208            error: ErrorSpec::try_from(iter.next().unwrap())?,
209            trend: ComponentSpec::try_from(iter.next().unwrap())?,
210            seasonal: ComponentSpec::try_from(iter.next().unwrap())?,
211        };
212        use ComponentSpec::*;
213        match spec {
214            Self {
215                error: ErrorSpec::Additive,
216                trend: _,
217                seasonal: Multiplicative,
218            }
219            | Self {
220                error: ErrorSpec::Additive,
221                trend: Multiplicative,
222                seasonal: _,
223            }
224            | Self {
225                error: ErrorSpec::Multiplicative,
226                trend: Multiplicative,
227                seasonal: Multiplicative,
228            } => Err(Error::InvalidModelSpec(s.to_owned())),
229            other => Ok(other),
230        }
231    }
232}
233
234impl TryFrom<&str> for AutoSpec {
235    type Error = Error;
236
237    fn try_from(s: &str) -> Result<Self> {
238        s.parse()
239    }
240}
241
242/// Automatic ETS model selection.
243#[derive(Debug, Clone)]
244pub struct AutoETS {
245    /// The model search specification.
246    spec: AutoSpec,
247    /// The number of observations per unit of time.
248    season_length: usize,
249
250    /// Explicit parameters to use when fitting the model.
251    ///
252    /// If any of these are `None` then they will be estimated
253    /// as part of the model fitting procedure.
254    params: Params,
255
256    /// Whether to use a damped trend.
257    ///
258    /// Defaults to trying both damped and non-damped trends.
259    damped: Damped,
260
261    /// Whether to allow multiplicative trend during automatic model selection.
262    ///
263    /// Defaults to `false`.
264    allow_multiplicative_trend: bool,
265
266    /// Number of steps over which to calculate the average MSE.
267    ///
268    /// Will be constrained to the range `[1, 30]`.
269    ///
270    /// Defaults to `3`.
271    nmse: usize,
272
273    /// The optimization criterion to use.
274    ///
275    /// Defaults to [`OptimizationCriteria::Likelihood`].
276    opt_crit: OptimizationCriteria,
277
278    /// The maximum number of iterations to use during optimization.
279    ///
280    /// Defaults to `2_000`.
281    max_iterations: usize,
282}
283
284impl AutoETS {
285    /// Create a new `AutoETS` model with the given period and model search specification string.
286    ///
287    /// The specification string should be of the form `XXX` where the first character is the error
288    /// component, the second is the trend component, and the third is the seasonal component. The
289    /// possible values for each component are:
290    ///
291    /// - `N` for no component
292    /// - `A` for additive
293    /// - `M` for multiplicative
294    /// - `Z` for automatic
295    ///
296    /// Using `Z` for any component will cause the model to try all possible values for that component.
297    /// For example, `ZAZ` will try all possible error and seasonal components, but only additive
298    /// trend components.
299    ///
300    /// # Errors
301    ///
302    /// An error will be returned if the specification string is not of the correct length or contains
303    /// invalid characters.
304    pub fn new(season_length: usize, spec: impl TryInto<AutoSpec, Error = Error>) -> Result<Self> {
305        let spec = spec.try_into()?;
306        Ok(Self::from_spec(season_length, spec))
307    }
308
309    /// Create a new `AutoETS` model with the given period and model search specification.
310    pub fn from_spec(season_length: usize, spec: AutoSpec) -> Self {
311        let params = Params {
312            alpha: f64::NAN,
313            beta: f64::NAN,
314            gamma: f64::NAN,
315            phi: f64::NAN,
316        };
317        Self {
318            season_length,
319            spec,
320            params,
321            damped: Damped::Auto,
322            allow_multiplicative_trend: false,
323            nmse: 3,
324            opt_crit: OptimizationCriteria::Likelihood,
325            max_iterations: 2_000,
326        }
327    }
328
329    /// Get the season length of the model.
330    pub fn season_length(&self) -> usize {
331        self.season_length
332    }
333
334    /// Get the search specification.
335    pub fn spec(&self) -> AutoSpec {
336        self.spec
337    }
338
339    /// Create a new `AutoETS` model search without any seasonal components.
340    ///
341    /// Equivalent to `AutoETS::new(1, "ZZN")`.
342    pub fn non_seasonal() -> Self {
343        Self::new(1, "ZZN").unwrap()
344    }
345
346    /// Fix the search to consider only damped or undamped trend.
347    pub fn damped(mut self, damped: bool) -> Result<Self> {
348        if damped && self.spec.trend == ComponentSpec::None {
349            return Err(Error::InvalidModelSpec(format!(
350                "damped trend not allowed for model spec '{}'",
351                self.spec
352            )));
353        }
354        self.damped = Damped::Fixed(damped);
355        Ok(self)
356    }
357
358    /// Set the value of `alpha` to use when fitting the model.
359    ///
360    /// See the docs for [`Params::alpha`] for more details on `alpha`.
361    pub fn alpha(mut self, alpha: f64) -> Self {
362        self.params.alpha = alpha;
363        self
364    }
365
366    /// Set the value of `beta` to use when fitting the model.
367    ///
368    /// See the docs for [`Params::beta`] for more details on `beta`.
369    pub fn beta(mut self, beta: f64) -> Self {
370        self.params.beta = beta;
371        self
372    }
373
374    /// Set the value of `gamma` to use when fitting the model.
375    ///
376    /// See the docs for [`Params::gamma`] for more details on `gamma`.
377    pub fn gamma(mut self, gamma: f64) -> Self {
378        self.params.gamma = gamma;
379        self
380    }
381
382    /// Set the value of `phi` to use when fitting the model.
383    ///
384    /// See the docs for [`Params::phi`] for more details on `phi`.
385    pub fn phi(mut self, phi: f64) -> Self {
386        self.params.phi = phi;
387        self
388    }
389
390    /// Include models with multiplicative trend during automatic model selection.
391    ///
392    /// By default, models with multiplicative trend are excluded from the search space.
393    pub fn allow_multiplicative_trend(mut self, allow: bool) -> Self {
394        self.allow_multiplicative_trend = allow;
395        self
396    }
397
398    /// Check whether a model combination is valid.
399    ///
400    /// Note that we currently enforce the 'restricted' mode of R's `ets` package
401    /// which disallows models with infinite variance.
402    fn valid_combination(
403        &self,
404        error: model::ErrorComponent,
405        trend: model::TrendComponent,
406        seasonal: model::SeasonalComponent,
407        damped: bool,
408        data_positive: bool,
409    ) -> bool {
410        use model::{ErrorComponent as EC, SeasonalComponent as SC, TrendComponent as TC};
411        match (error, trend, seasonal, damped) {
412            // Can't have no trend and damped trend.
413            (_, TC::None, _, true) => false,
414            // Restricted mode disallows additive error with multiplicative trend and seasonality.
415            (EC::Additive, TC::Multiplicative, SC::Multiplicative { .. }, _) => false,
416            // Restricted mode disallows multiplicative error with multiplicative trend and additive seasonality;
417            (EC::Multiplicative, TC::Multiplicative, SC::Additive { .. }, _) => false,
418            (EC::Multiplicative, _, _, _) if !data_positive => false,
419            (_, _, SC::Multiplicative { .. }, _) if !data_positive => false,
420            (
421                _,
422                _,
423                SC::Additive { season_length: 1 } | SC::Multiplicative { season_length: 1 },
424                _,
425            ) => false,
426            _ => true,
427        }
428    }
429
430    /// Return an iterator over all model combinations.
431    ///
432    /// Note that this does not check that the model combinations are valid;
433    /// some knowledge of the data is required for that.
434    fn candidates(
435        &self,
436    ) -> impl Iterator<
437        Item = (
438            &model::ErrorComponent,
439            &model::TrendComponent,
440            model::SeasonalComponent,
441            &bool,
442        ),
443    > {
444        let error_candidates = self.spec.error.candidates();
445        let trend_candidates = self
446            .spec
447            .trend
448            .trend_candidates(self.allow_multiplicative_trend);
449        let season_candidates = self.spec.seasonal.seasonal_candidates(self.season_length);
450        let damped_candidates = self.damped.candidates();
451
452        itertools::iproduct!(
453            error_candidates,
454            trend_candidates,
455            season_candidates,
456            damped_candidates
457        )
458    }
459}
460
461impl Fit for AutoETS {
462    type Fitted = FittedAutoETS;
463    type Error = Error;
464    /// Search for the best model, fitting it to the data.
465    ///
466    /// The model is stored on the `AutoETS` struct and can be retrieved with
467    /// the `model` method. It is also returned by this function.
468    ///
469    /// # Errors
470    ///
471    /// If no model can be found, or if any parameters are invalid, this function
472    /// returns an error.
473    fn fit(&self, y: &[f64]) -> Result<Self::Fitted> {
474        let data_positive = y.iter().fold(f64::INFINITY, |a, &b| a.min(b)) > 0.0;
475        if self.spec.error == ErrorSpec::Multiplicative && !data_positive {
476            return Err(Error::InvalidModelSpec(format!(
477                "multiplicative error not allowed for model spec '{}' with non-positive data",
478                self.spec
479            )));
480        }
481
482        let n = y.len();
483        let mut npars = 2; // alpha + l0
484        if self.spec.trend.is_specified() {
485            npars += 2; // beta + b0
486        }
487        if self.spec.seasonal.is_specified() {
488            npars += 2; // gamma + s
489        }
490        if n <= npars + 4 {
491            return Err(Error::NotEnoughData);
492        }
493
494        let model = self
495            .candidates()
496            .filter_map(|(&error, &trend, season, &damped)| {
497                if self.valid_combination(error, trend, season, damped, data_positive) {
498                    let model = Unfit::new(model::ModelType {
499                        error,
500                        trend,
501                        season,
502                    })
503                    .damped(damped)
504                    .params(self.params.clone())
505                    .nmse(self.nmse)
506                    .opt_crit(self.opt_crit)
507                    .max_iterations(self.max_iterations)
508                    .fit(y)
509                    .ok()?;
510                    if model.aicc().is_nan() {
511                        None
512                    } else {
513                        Some(model)
514                    }
515                } else {
516                    None
517                }
518            })
519            .min_by(|a, b| {
520                a.aicc()
521                    .partial_cmp(&b.aicc())
522                    .expect("NaNs have already been filtered from the iterator")
523            })
524            .ok_or(Error::NoModelFound)?;
525        Ok(FittedAutoETS {
526            model,
527            training_data_size: n,
528        })
529    }
530}
531
532/// A fitted [`AutoETS`] model.
533///
534/// This type can be used to obtain predictions using the [`Predict`] trait.
535#[derive(Debug, Clone)]
536pub struct FittedAutoETS {
537    /// The model that was selected.
538    model: Model,
539
540    /// The number of observations in the training data.
541    training_data_size: usize,
542}
543
544impl FittedAutoETS {
545    /// Get the model that was selected.
546    pub fn model(&self) -> &Model {
547        &self.model
548    }
549}
550
551impl Predict for FittedAutoETS {
552    type Error = Error;
553
554    fn training_data_size(&self) -> usize {
555        self.training_data_size
556    }
557
558    /// Predict the next `horizon` values using the best model, optionally including
559    /// prediction intervals at the specified level.
560    ///
561    /// `level` should be a float between 0 and 1 representing the confidence level.
562    ///
563    /// # Errors
564    ///
565    /// This function will return an error if no model has been fit yet (using [`AutoETS::fit`]).
566    fn predict_inplace(&self, h: usize, level: Option<f64>, forecast: &mut Forecast) -> Result<()> {
567        self.model.predict_inplace(h, level, forecast)?;
568        Ok(())
569    }
570
571    /// Return the in-sample predictions using the best model, optionally including
572    /// prediction intervals at the specified level.
573    ///
574    /// `level` should be a float between 0 and 1 representing the confidence level.`
575    fn predict_in_sample_inplace(&self, level: Option<f64>, forecast: &mut Forecast) -> Result<()> {
576        self.model.predict_in_sample_inplace(level, forecast)?;
577        Ok(())
578    }
579}
580
581#[cfg(test)]
582mod test {
583    use augurs_core::Fit;
584    use augurs_testing::{assert_within_pct, data::AIR_PASSENGERS};
585
586    use super::{AutoETS, AutoSpec};
587    use crate::{
588        model::{ErrorComponent, SeasonalComponent, TrendComponent},
589        Error,
590    };
591
592    #[test]
593    fn spec_from_str() {
594        let cases = [
595            "NNN", "NAN", "NAM", "NAZ", "NMN", "NMA", "NMM", "NMZ", "ANN", "AAN", "AAM", "AAZ",
596            "AMN", "AMA", "AMM", "AMZ", "MNN", "MAN", "MAM", "MAZ", "MMN", "MMA", "MMM", "MMZ",
597            "ZNN", "ZAN", "ZAM", "ZAZ", "ZMN", "ZMA", "ZMM", "ZMZ",
598        ];
599        for case in cases {
600            let spec: Result<AutoSpec, Error> = case.try_into();
601            let (error, rest) = case.split_at(1);
602            let (trend, seasonal) = rest.split_at(1);
603            match (error, trend, seasonal) {
604                ("N", _, _) => {
605                    assert!(
606                        matches!(spec, Err(Error::InvalidErrorComponentString(_))),
607                        "{spec:?}, case {case}"
608                    );
609                }
610                ("A", "M", _) | ("A", _, "M") | ("M", "M", "M") => {
611                    assert!(
612                        matches!(spec, Err(Error::InvalidModelSpec(_))),
613                        "{spec:?}, case {case}"
614                    );
615                }
616                _ => {
617                    assert!(spec.is_ok());
618                }
619            }
620        }
621    }
622
623    #[test]
624    fn air_passengers_fit() {
625        let auto = AutoETS::new(1, "ZZN").unwrap();
626        let fit = auto.fit(AIR_PASSENGERS).expect("fit failed");
627        assert_eq!(fit.model.model_type().error, ErrorComponent::Multiplicative);
628        assert_eq!(fit.model.model_type().trend, TrendComponent::Additive);
629        assert_eq!(fit.model.model_type().season, SeasonalComponent::None);
630        assert_within_pct!(fit.model.log_likelihood(), -831.4883541595792, 0.01);
631        assert_within_pct!(fit.model.aic(), 1672.9767083191584, 0.01);
632    }
633}