augurs_ets/
model.rs

1//! A single model of the ETS family.
2//!
3//! This module contains the `ETSModel` struct, which represents a single model of the ETS family.
4
5use std::fmt::{self, Write};
6
7use augurs_core::{ForecastIntervals, Predict};
8use itertools::Itertools;
9use nalgebra::{DMatrix, DVector};
10use rand_distr::{Distribution, Normal};
11use tracing::instrument;
12
13use crate::{
14    ets::{Ets, FitState},
15    stat::VarExt,
16    Error,
17};
18
19/// The type of error component used by the model.
20#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
21pub enum ErrorComponent {
22    /// Additive error component.
23    Additive,
24    /// Multiplicative error component.
25    Multiplicative,
26}
27
28impl fmt::Display for ErrorComponent {
29    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30        match self {
31            Self::Additive => f.write_char('A'),
32            Self::Multiplicative => f.write_char('M'),
33        }
34    }
35}
36
37/// The type of trend component included in the model.
38#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
39pub enum TrendComponent {
40    /// No trend component.
41    None,
42    /// Additive trend component.
43    Additive,
44    /// Multiplicative trend component.
45    Multiplicative,
46}
47
48impl TrendComponent {
49    /// Whether this component will be included in a model.
50    pub fn included(&self) -> bool {
51        *self != TrendComponent::None
52    }
53}
54
55impl fmt::Display for TrendComponent {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        match self {
58            Self::None => f.write_char('N'),
59            Self::Additive => f.write_char('A'),
60            Self::Multiplicative => f.write_char('M'),
61        }
62    }
63}
64
65/// The type of trend component included in the model.
66#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
67pub enum SeasonalComponent {
68    /// No seasonal component.
69    None,
70    /// Additive seasonal component.
71    Additive {
72        /// The number of observations in a seasonal cycle.
73        ///
74        /// This was called `m` in the original `ets` R code.
75        season_length: usize,
76    },
77    /// Multiplicative seasonal component.
78    Multiplicative {
79        /// The number of observations in a seasonal cycle.
80        ///
81        /// This was called `m` in the original `ets` R code.
82        season_length: usize,
83    },
84}
85
86impl SeasonalComponent {
87    /// Whether this component will be included in a model.
88    pub fn included(&self) -> bool {
89        *self != SeasonalComponent::None
90    }
91
92    /// The number of observations in a seasonal cycle.
93    ///
94    /// This will be `1` if the component is `None`, otherwise it will be the
95    /// `season_length` of the variant.
96    pub fn season_length(&self) -> usize {
97        match self {
98            SeasonalComponent::None => 1,
99            SeasonalComponent::Additive { season_length } => *season_length,
100            SeasonalComponent::Multiplicative { season_length } => *season_length,
101        }
102    }
103}
104
105impl fmt::Display for SeasonalComponent {
106    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107        match self {
108            Self::None => f.write_char('N'),
109            Self::Additive { .. } => f.write_char('A'),
110            Self::Multiplicative { .. } => f.write_char('M'),
111        }
112    }
113}
114
115/// The upper and lower bounds to use with [`Bounds::Usual`] and [`Bounds::Both`].
116#[derive(Clone, PartialEq, Debug)]
117pub struct UpperLowerBounds {
118    lower: [f64; 4],
119    upper: [f64; 4],
120}
121
122impl UpperLowerBounds {
123    /// Create a new set of bounds.
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if any of the lower bounds are greater than the
128    /// corresponding upper bounds.
129    pub fn new(lower: [f64; 4], upper: [f64; 4]) -> Result<Self, Error> {
130        if lower.iter().zip(&upper).any(|(l, u)| l > u) {
131            Err(Error::InconsistentBounds)
132        } else {
133            Ok(Self { lower, upper })
134        }
135    }
136}
137
138impl Default for UpperLowerBounds {
139    fn default() -> Self {
140        Self {
141            lower: [0.0001, 0.0001, 0.0001, 0.8],
142            upper: [0.9999, 0.9999, 0.9999, 0.98],
143        }
144    }
145}
146
147/// The type of parameter space to impose.
148#[derive(Clone, Debug)]
149pub enum Bounds {
150    /// All parameters must lie in the admissible space.
151    Admissible,
152    /// All parameters must lie between specified lower and upper bounds.
153    Usual(UpperLowerBounds),
154    /// The intersection of `Admissible` and `Usual`. This is the default.
155    Both(UpperLowerBounds),
156}
157
158impl Bounds {
159    fn for_optimizer(
160        &self,
161        opt_params: &OptimizeParams,
162        n_states: usize,
163    ) -> Option<(Vec<f64>, Vec<f64>)> {
164        match self {
165            Self::Admissible => None,
166            Self::Usual(bounds) | Self::Both(bounds) => {
167                let n_params = opt_params.n_included();
168                let mut lower = Vec::with_capacity(n_params + n_states);
169                let mut upper = Vec::with_capacity(n_params + n_states);
170                if opt_params.alpha {
171                    lower.push(bounds.lower[0]);
172                    upper.push(bounds.upper[0]);
173                }
174                if opt_params.beta {
175                    lower.push(bounds.lower[1]);
176                    upper.push(bounds.upper[1]);
177                }
178                if opt_params.gamma {
179                    lower.push(bounds.lower[2]);
180                    upper.push(bounds.upper[2]);
181                }
182                if opt_params.phi {
183                    lower.push(bounds.lower[3]);
184                    upper.push(bounds.upper[3]);
185                }
186                for _ in 0..n_states {
187                    lower.push(f64::NEG_INFINITY);
188                    upper.push(f64::INFINITY);
189                }
190                Some((lower, upper))
191            }
192        }
193    }
194}
195
196impl Default for Bounds {
197    fn default() -> Self {
198        Self::Both(UpperLowerBounds::default())
199    }
200}
201
202/// The optimization criterion to use when fitting the model.
203///
204/// Defaults to [`OptimizationCriteria::Likelihood`].
205#[derive(Debug, Copy, Clone, Default)]
206pub enum OptimizationCriteria {
207    /// Log-likelihood.
208    #[default]
209    Likelihood,
210    /// Mean squared error.
211    MSE,
212    /// Average mean squared error over the first `nmse` forecast horizons.
213    AMSE,
214    /// Standard deviation of the residuals.
215    Sigma,
216    /// Mean absolute error.
217    MAE,
218}
219
220/// The type of ETS model.
221///
222/// ETS models are defined by the type of error, trend, and seasonal components
223/// included in the model. These components can be excluded, included additively,
224/// or included multiplicatively. Some combinations of components are not
225/// allowed due to identifiability issues; these will be excluded
226/// from the search space of [`crate::AutoETS`].
227#[derive(Debug, Clone, Copy)]
228pub struct ModelType {
229    /// The type of error component.
230    pub error: ErrorComponent,
231    /// The type of trend component.
232    pub trend: TrendComponent,
233    /// The type of seasonal component.
234    pub season: SeasonalComponent,
235}
236
237impl fmt::Display for ModelType {
238    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239        self.error.fmt(f)?;
240        self.trend.fmt(f)?;
241        self.season.fmt(f)?;
242        Ok(())
243    }
244}
245
246/// The parameters of an ETS model.
247#[derive(Debug, Clone)]
248pub struct Params {
249    /// The value of the smoothing parameter for the level.
250    ///
251    /// If `alpha = 0`, the level will not change over time.
252    /// Conversely, if `alpha = 1` the level will update similarly to a random walk process.
253    pub alpha: f64,
254    /// The value of the smoothing parameter for the slope.
255    ///
256    /// If `beta = 0`, the slope will not change over time.
257    /// Conversely, if `beta = 1` the slope will have no memory of past slopes.
258    pub beta: f64,
259    /// The value of the smoothing parameter for the seasonal pattern.
260    /// If `gamma = 0`, the seasonal pattern will not change over time.
261    /// Conversely, if `gamma = 1` the seasonality will have no memory of past seasonal periods.
262    pub gamma: f64,
263    /// The value of the dampening parameter for the slope.
264    /// If `phi = 0`, the slope will be dampened immediately (no slope).
265    /// Conversely, if `phi = 1` the slope will not be dampened.
266    pub phi: f64,
267}
268
269impl Default for Params {
270    fn default() -> Self {
271        Self {
272            alpha: f64::NAN,
273            beta: f64::NAN,
274            gamma: f64::NAN,
275            phi: f64::NAN,
276        }
277    }
278}
279
280/// Parameters to be optimized by the optimizer.
281///
282/// If parameters are explicitly specified, they won't be included
283/// in the Nelder-Mead optimization, and the specified values will be used.
284/// Otherwise the parameters will be optimized.
285///
286/// By default, all parameters relevant to the model are optimized
287/// (i.e. `gamma` is only included for seasonal models; `phi` is
288/// only included for damped trend models; etc).
289#[derive(Debug, Default, Clone)]
290pub(crate) struct OptimizeParams {
291    /// Optimize `alpha`.
292    pub alpha: bool,
293    /// Optimize `beta`.
294    pub beta: bool,
295    /// Optimize `gamma`.
296    pub gamma: bool,
297    /// Optimize `phi`.
298    pub phi: bool,
299}
300
301impl OptimizeParams {
302    pub(crate) fn n_included(&self) -> usize {
303        self.alpha as usize + self.beta as usize + self.gamma as usize + self.phi as usize
304    }
305}
306
307/// Returns `x` if `x` is not NaN, otherwise returns `default`.
308fn not_nan_or(x: f64, default: f64) -> f64 {
309    if x.is_nan() {
310        default
311    } else {
312        x
313    }
314}
315
316/// An ETS model that has not been fit.
317#[derive(Debug, Clone)]
318pub struct Unfit {
319    /// The type of model to be used.
320    model_type: ModelType,
321
322    /// Whether or not the model uses a damped trend.
323    ///
324    /// Defaults to `false`.
325    damped: bool,
326
327    /// Number of steps over which to calculate the average MSE.
328    ///
329    /// Will be constrained to the range [1, 30].
330    ///
331    /// Defaults to 3.
332    nmse: usize,
333
334    /// The bounds on parameters.
335    ///
336    /// Defaults to [`Bounds::Both`] with lower limits of
337    /// `[0.0001, 0.0001, 0.0001, 0.8]` and upper limits of
338    /// `[0.9999, 0.9999, 0.9999, 0.98]`.
339    bounds: Bounds,
340
341    /// The parameters of the model.
342    ///
343    /// Defaults to [`Params::default()`], meaning the parameters will be
344    /// determined and optimized by the optimizer.
345    params: Params,
346
347    /// Optimization criteria to use.
348    ///
349    /// Defaults to [`OptimizationCriteria::Likelihood`].
350    opt_crit: OptimizationCriteria,
351
352    /// Maximum number of iterations to use in the optimizer.
353    ///
354    /// Defaults to 2,000.
355    max_iter: usize,
356}
357
358impl Unfit {
359    /// Creates a new ETS model with the given type.
360    pub fn new(model_type: ModelType) -> Self {
361        Self {
362            model_type,
363            damped: false,
364            bounds: Bounds::default(),
365            nmse: 3,
366            params: Params::default(),
367            opt_crit: OptimizationCriteria::default(),
368            max_iter: 2_000,
369        }
370    }
371
372    /// Set the parameters of the model.
373    ///
374    /// To leave parameters unspecified, leave them set to `f64::NAN`.
375    pub fn params(self, params: Params) -> Self {
376        Self { params, ..self }
377    }
378
379    /// Set the number of steps over which to calculate the average MSE.
380    pub fn nmse(self, nmse: usize) -> Self {
381        Self { nmse, ..self }
382    }
383
384    /// Set the optimization criteria to use.
385    pub fn opt_crit(self, opt_crit: OptimizationCriteria) -> Self {
386        Self { opt_crit, ..self }
387    }
388
389    /// Set the maximum number of iterations to use in the optimizer.
390    pub fn max_iterations(self, max_iterations: usize) -> Self {
391        Self {
392            max_iter: max_iterations,
393            ..self
394        }
395    }
396
397    /// Set the model to use a damped trend or not.
398    pub fn damped(self, damped: bool) -> Self {
399        Self { damped, ..self }
400    }
401
402    /// Select a sensible initial value for the `alpha` parameter.
403    fn select_alpha(lower: &[f64; 4], upper: &[f64; 4], alpha: f64, m: usize) -> f64 {
404        if alpha.is_nan() {
405            let mut alpha = lower[0] + 0.2 * (upper[0] - lower[0]) / m as f64;
406            if !(0.0..=1.0).contains(&alpha) {
407                alpha = lower[0] + 2e-3;
408            }
409            alpha
410        } else {
411            alpha
412        }
413    }
414
415    /// Select a sensible initial value for the `beta` parameter.
416    fn select_beta(
417        lower: &[f64; 4],
418        upper: &mut [f64; 4],
419        trend: TrendComponent,
420        alpha: f64,
421        beta: f64,
422    ) -> f64 {
423        if trend != TrendComponent::None && beta.is_nan() {
424            // Ensure beta < alpha.
425            upper[1] = upper[1].min(alpha);
426            let mut beta = lower[1] + 0.1 * (upper[1] - lower[1]);
427            if beta < 0.0 || beta > alpha {
428                beta = alpha - 1e-3;
429            }
430            beta
431        } else {
432            beta
433        }
434    }
435
436    /// Select a sensible initial value for the `gamma` parameter.
437    fn select_gamma(
438        lower: &[f64; 4],
439        upper: &mut [f64; 4],
440        season: SeasonalComponent,
441        alpha: f64,
442        gamma: f64,
443    ) -> f64 {
444        if season != SeasonalComponent::None && gamma.is_nan() {
445            upper[2] = upper[2].min(1.0 - alpha);
446            let mut gamma = lower[2] + 0.05 * (upper[2] - lower[2]);
447            if gamma < 0.0 || gamma > 1.0 - alpha {
448                gamma = 1.0 - alpha - 1e-3;
449            }
450            gamma
451        } else {
452            gamma
453        }
454    }
455
456    /// Select a sensible initial value for the `phi` parameter.
457    fn select_phi(lower: &[f64; 4], upper: &[f64; 4], damped: bool, phi: f64) -> f64 {
458        if damped && phi.is_nan() {
459            let mut phi = lower[3] + 0.99 * (upper[3] - lower[3]);
460            if !(0.0..=1.0).contains(&phi) {
461                phi = upper[3] - 1e-3;
462            }
463            phi
464        } else {
465            phi
466        }
467    }
468
469    /// Initialize the parameters for the model.
470    fn initial_params(&mut self) -> Params {
471        // These dummy parameters aren't used, they're just here to placate the borrow checker.
472        let (mut dummy_lower, mut dummy_upper) = ([0.0; 4], [1e-3; 4]);
473        let (lower, upper) = match &mut self.bounds {
474            Bounds::Admissible => (&mut dummy_lower, &mut dummy_upper),
475            Bounds::Usual(UpperLowerBounds { lower, upper }) => (lower, upper),
476            Bounds::Both(UpperLowerBounds { lower, upper }) => (lower, upper),
477        };
478        let alpha = Self::select_alpha(
479            lower,
480            upper,
481            self.params.alpha,
482            self.model_type.season.season_length(),
483        );
484        let beta = Self::select_beta(lower, upper, self.model_type.trend, alpha, self.params.beta);
485        let gamma = Self::select_gamma(
486            lower,
487            upper,
488            self.model_type.season,
489            alpha,
490            self.params.gamma,
491        );
492        let phi = Self::select_phi(lower, upper, self.damped, self.params.phi);
493        Params {
494            alpha,
495            beta,
496            gamma,
497            phi,
498        }
499    }
500
501    /// Initialize the state for the model.
502    fn initial_state(&self, y: &[f64]) -> Result<Vec<f64>, Error> {
503        let n = y.len();
504        let (m, y_sa) = if self.model_type.season == SeasonalComponent::None {
505            (1, y.to_vec())
506        } else {
507            unimplemented!("seasonal component not implemented yet")
508            // if n < 4 {
509            //     return Err(Error::NotEnoughData);
510            // }
511            // let y_d = if n < 3 * self.m {
512            //     let fourier_y = fourier(self.y, &[self.m], &[1]);
513            //     // TODO: remove these copies.
514            //     let mut fourier_X = DMatrix::from_element(n, 4, f64::NAN);
515            //     fourier_X.set_column(0, &DVector::from_element(n, 1.0));
516            //     fourier_X.set_column(1, &DVector::from_iterator(n, (0..n).map(|x| x as f64)));
517            //     fourier_X.set_column(2, &fourier_y.column(0));
518            //     fourier_X.set_column(3, &fourier_y.column(1));
519            //     let coefs = lstsq(&fourier_X, &self.y, 1e-6)?;
520            //     if self.season == ComponentSpec::Additive {
521            //         let mut y_d = self.y.clone();
522            //         for (i, &x) in fourier_X.column(2).iter().enumerate() {
523            //             y_d[i] -= coefs[2] * x;
524            //         }
525            //         for (i, &x) in fourier_X.column(3).iter().enumerate() {
526            //             y_d[i] -= coefs[3] * x;
527            //         }
528            //         y_d
529            //     } else {
530            //         let mut y_d = self.y.clone();
531            //         for (i, &x) in fourier_X.column(2).iter().enumerate() {
532            //             y_d[i] /= coefs[2] * x;
533            //         }
534            //         for (i, &x) in fourier_X.column(3).iter().enumerate() {
535            //             y_d[i] /= coefs[3] * x;
536            //         }
537            //         y_d
538            //     }
539            // } else {
540            //     seasonal_decompose(
541            //         self.y,
542            //         self.m,
543            //         if self.season == ComponentSpec::Additive {
544            //             ModelType::Additive
545            //         } else {
546            //             ModelType::Multiplicative
547            //         },
548            //     )
549            // };
550        };
551        let max_n = 10.clamp(m, n);
552        match self.model_type.trend {
553            TrendComponent::None => {
554                let l0 = y_sa.iter().take(max_n).sum::<f64>() / max_n as f64;
555                Ok(vec![l0])
556            }
557            _ => {
558                #[allow(non_snake_case)]
559                let X = DMatrix::from_iterator(
560                    max_n,
561                    2,
562                    std::iter::repeat_n(1.0, max_n)
563                        .take(max_n)
564                        .chain((1..(max_n + 1)).map(|x| x as f64)),
565                );
566                let y = DVector::from_row_slice(&y_sa[..max_n]);
567                let lstsq = lstsq::lstsq(&X, &y, f64::EPSILON).map_err(Error::LeastSquares)?;
568                let (l, b) = (lstsq.solution[0], lstsq.solution[1]);
569                if self.model_type.trend == TrendComponent::Additive {
570                    let (mut l0, mut b0) = (l, b);
571                    if (l0 + b0).abs() < 1e-8 {
572                        l0 *= 1.0 + 1e-3;
573                        b0 *= 1.0 + 1e-3;
574                    }
575                    Ok(vec![l0, b0])
576                } else {
577                    let mut l0 = l + b;
578                    if l0.abs() < 1e-8 {
579                        l0 *= 1.0 + 1e-3;
580                    }
581                    let mut b0: f64 = (l + 2.0 * b) / l0;
582                    let div = if b0.abs() < 1e-8 { 1e-8 } else { b0 };
583                    l0 /= div;
584                    if b0.abs() > 1e10 {
585                        b0 = b0.signum() * 1e10;
586                    }
587                    if l0 < 1e-8 || b0 < 1e-8 {
588                        // simple linear approximation didn't work
589                        l0 = y_sa[0].max(1e-3);
590                        let div = if y_sa[0].abs() < 1e-8 { 1e-8 } else { y_sa[0] };
591                        b0 = (y_sa[1] / div).max(1e-3);
592                    }
593                    Ok(vec![l0, b0])
594                }
595            }
596        }
597    }
598
599    /// Fit the ETS model to the data, returning a fitted [`Model`].
600    #[instrument(skip_all)]
601    pub fn fit(mut self, y: &[f64]) -> Result<Model, Error> {
602        self.nmse = self.nmse.min(30);
603        let season_length = self.model_type.season.season_length();
604
605        let n_states = season_length * self.model_type.season.included() as usize
606            + 1
607            + self.model_type.trend.included() as usize;
608
609        // Store the original parameters.
610        let par_noopt = self.params.clone();
611        let par_ = self.initial_params();
612        let alpha = not_nan_or(par_.alpha, par_noopt.alpha);
613        let beta = not_nan_or(par_.beta, par_noopt.beta);
614        let gamma = not_nan_or(par_.gamma, par_noopt.gamma);
615        let phi = not_nan_or(par_.phi, par_noopt.phi);
616        if !check_params(
617            &self.bounds,
618            season_length,
619            Params {
620                alpha,
621                beta,
622                gamma,
623                phi,
624            },
625        ) {
626            return Err(Error::ParamsOutOfRange);
627        }
628
629        let initial_state = self.initial_state(y)?;
630        let param_arr = [alpha, beta, gamma, phi];
631
632        let x0: Vec<_> = param_arr
633            .iter()
634            .copied()
635            .filter(|&x| !x.is_nan())
636            .chain(initial_state.iter().copied())
637            .collect();
638        let np_ = x0.len();
639        if np_ >= y.len() - 1 {
640            return Err(Error::NotEnoughData);
641        }
642        let opt_params = OptimizeParams {
643            alpha: !alpha.is_nan(),
644            beta: !beta.is_nan(),
645            gamma: !gamma.is_nan(),
646            phi: !phi.is_nan(),
647        };
648
649        let params = Params {
650            alpha,
651            beta: if self.model_type.trend.included() {
652                beta
653            } else {
654                0.0
655            },
656            phi: if self.damped { phi } else { 1.0 },
657            gamma: if self.model_type.season.included() {
658                gamma
659            } else {
660                0.0
661            },
662        };
663
664        let opt_bounds = self.bounds.for_optimizer(&opt_params, n_states);
665        // Construct the problem.
666        let ets = Ets::new(
667            self.model_type,
668            self.damped,
669            self.nmse,
670            n_states,
671            params,
672            opt_params,
673            self.opt_crit,
674        );
675        let mut problem = ETSProblem::new(y, ets);
676        // Set up the input simplex for Nelder-Mead.
677        let simplex = self.param_vecs(x0, opt_bounds.as_ref());
678        // Run Nelder-Mead.
679        let best_params = self.nelder_mead(&mut problem, simplex, opt_bounds.as_ref());
680
681        // Rerun the model with the best parameters.
682        problem.amse.fill(0.0);
683        problem.denom.fill(0.0);
684        let fit = problem.ets.pegels_resid_in(
685            y,
686            &best_params,
687            problem.x,
688            problem.ets.params.clone(),
689            problem.residuals,
690            problem.forecasts,
691            problem.amse,
692            problem.denom,
693        );
694        let sigma_squared = y
695            .iter()
696            .zip(fit.fitted())
697            .map(|(y, f)| (y - f).powi(2))
698            .sum::<f64>()
699            / (y.len() - fit.n_params() - 1) as f64;
700        Ok(Model::new(problem.ets, fit, sigma_squared.sqrt()))
701    }
702
703    /// Generate the initial simplex.
704    ///
705    /// The original article suggested a simplex where an initial point is given
706    /// as x0 with the others generated a fixed step along each dimension in turn.
707    #[instrument(skip_all)]
708    fn param_vecs(&self, mut x0: Vec<f64>, bounds: Option<&(Vec<f64>, Vec<f64>)>) -> Vec<Vec<f64>> {
709        if let Some((lower, upper)) = bounds {
710            Self::restrict_to_bounds(&mut x0, lower, upper);
711        }
712        let n = x0.len();
713
714        let mut simplex = vec![x0; n + 1];
715        let diag = simplex
716            .iter_mut()
717            .take(n)
718            .enumerate()
719            .map(|(i, row)| &mut row[i]);
720        for el in diag {
721            if el.abs() < 1e-8 {
722                *el = 1e-4;
723            } else {
724                *el *= 1.05;
725            }
726        }
727        if let Some((lower, upper)) = bounds {
728            for row in simplex.iter_mut() {
729                Self::restrict_to_bounds(row, lower, upper)
730            }
731        }
732        simplex
733    }
734
735    const TOL_STD: f64 = 1e-4;
736
737    /// Run the Nelder-Mead algorithm.
738    ///
739    /// This is a custom implementation of the Nelder-Mead algorithm, which is
740    /// based on the implementation in the `statsforecast` Python package.
741    /// It implements bounds checks and a custom stopping criterion.
742    ///
743    /// It could be generalised by making `problem` a generic type but I can't
744    /// see that being needed.
745    #[instrument(skip_all)]
746    fn nelder_mead(
747        &self,
748        problem: &mut ETSProblem<'_>,
749        mut simplex: Vec<Vec<f64>>,
750        bounds: Option<&(Vec<f64>, Vec<f64>)>,
751    ) -> Vec<f64> {
752        let n_u = simplex[0].len();
753        let n = simplex[0].len() as f64;
754
755        let alpha = 1.0;
756        let gamma = 1.0 + 2.0 / n;
757        let rho = 0.75 - 1.0 / (2.0 * n);
758        let sigma = 1.0 - 1.0 / n;
759
760        let mut f_simplex: Vec<_> = simplex.iter().map(|x| problem.cost(x)).collect();
761        let mut costs_sorted: Vec<_> = f_simplex.iter().copied().enumerate().collect();
762        let mut order_f: Vec<_> = costs_sorted.iter().map(|(i, _)| *i).collect();
763        let mut best_idx = order_f[0];
764        let mut x_o: Vec<_>;
765        let mut x_r: Vec<_>;
766        let mut x_e: Vec<_>;
767        let mut x_oc: Vec<_>;
768        let mut x_ic: Vec<_>;
769        for _ in 0..self.max_iter {
770            costs_sorted.clear();
771            costs_sorted.extend(f_simplex.iter().copied().enumerate());
772            costs_sorted.sort_unstable_by(|(_, a), (_, b)| {
773                a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
774            });
775            order_f.clear();
776            order_f.extend(costs_sorted.iter().map(|(i, _)| *i));
777
778            best_idx = order_f[0];
779            let worst_idx = order_f[order_f.len() - 1];
780            let second_worst_idx = order_f[order_f.len() - 2];
781
782            // Check stopping criteria.
783            if f_simplex.std(0) < Self::TOL_STD {
784                break;
785            }
786
787            // Calculate centroid except argmax f_simplex.
788            x_o = vec![0.0; n_u];
789            for x in simplex
790                .iter()
791                .enumerate()
792                .filter_map(|(i, x)| (i != worst_idx).then_some(x))
793            {
794                for (i, el) in x.iter().enumerate() {
795                    x_o[i] += el;
796                }
797            }
798            for x in x_o.iter_mut() {
799                *x /= n;
800            }
801
802            // Step 2: Reflection, Compute reflected point
803            x_r = x_o
804                .iter()
805                .zip(&simplex[worst_idx])
806                .map(|(x_0, x)| x_0 + alpha * (x_0 - x))
807                .collect();
808            if let Some((lower, upper)) = &bounds {
809                Self::restrict_to_bounds(&mut x_r, lower, upper);
810            }
811            let f_r = problem.cost(&x_r);
812            if f_simplex[best_idx] <= f_r && f_r < f_simplex[second_worst_idx] {
813                simplex[worst_idx] = x_r;
814                f_simplex[worst_idx] = f_r;
815                continue;
816            }
817
818            // Step 3: Expansion, reflected point is the best point so far
819            if f_r < f_simplex[best_idx] {
820                x_e = x_o
821                    .iter()
822                    .zip(&x_r)
823                    .map(|(x_o, x_r)| x_o + gamma * (x_r - x_o))
824                    .collect();
825                if let Some((lower, upper)) = &bounds {
826                    Self::restrict_to_bounds(&mut x_e, lower, upper);
827                }
828                let f_e = problem.cost(&x_e);
829                if f_e < f_r {
830                    simplex[worst_idx] = x_e;
831                    f_simplex[worst_idx] = f_e;
832                } else {
833                    simplex[worst_idx] = x_r;
834                    f_simplex[worst_idx] = f_r;
835                }
836                continue;
837            }
838
839            // Step 4: outside Contraction
840            if f_simplex[second_worst_idx] <= f_r && f_r < f_simplex[worst_idx] {
841                x_oc = x_o
842                    .iter()
843                    .zip(&x_r)
844                    .map(|(x_o, x_r)| x_o + rho * (x_r - x_o))
845                    .collect();
846                if let Some((lower, upper)) = &bounds {
847                    Self::restrict_to_bounds(&mut x_oc, lower, upper);
848                }
849                let f_oc = problem.cost(&x_oc);
850                if f_oc <= f_r {
851                    simplex[worst_idx] = x_oc;
852                    f_simplex[worst_idx] = f_oc;
853                    continue;
854                }
855            } else {
856                // Step 5: inside contraction
857                x_ic = x_o
858                    .iter()
859                    .zip(&x_r)
860                    .map(|(x_o, x_r)| x_o - rho * (x_r - x_o))
861                    .collect();
862                if let Some((lower, upper)) = &bounds {
863                    Self::restrict_to_bounds(&mut x_ic, lower, upper);
864                }
865                let f_ic = problem.cost(&x_ic);
866                if f_ic < f_simplex[worst_idx] {
867                    simplex[worst_idx] = x_ic;
868                    f_simplex[worst_idx] = f_ic;
869                    continue;
870                }
871            }
872
873            // Step 6: shrink
874            let best = simplex[best_idx].clone();
875            simplex.iter_mut().enumerate().for_each(|(i, x)| {
876                if i != best_idx {
877                    x.iter_mut()
878                        .zip(&best)
879                        .for_each(|(x, x_best)| *x = x_best + sigma * (*x - x_best));
880                    if let Some((lower, upper)) = &bounds {
881                        Self::restrict_to_bounds(&mut x_r, lower, upper);
882                    }
883                    f_simplex[i] = problem.cost(x);
884                }
885            });
886        }
887        simplex[best_idx].clone()
888    }
889
890    /// Restrict `x0` to the bounds given by `lower` and `upper`.
891    fn restrict_to_bounds(x0: &mut [f64], lower: &[f64], upper: &[f64]) {
892        x0.iter_mut()
893            .zip(lower)
894            .zip(upper)
895            .for_each(|((x, &l), &u)| {
896                *x = x.clamp(l, u);
897            });
898    }
899}
900
901// This was generated by ChatGPT, we should probably check it...
902// In particular the `roots` part is unclear since the `roots` crate only returns real roots,
903// but the R/Python implementations reference complex roots too.
904fn admissible(alpha: f64, mut beta: f64, gamma: f64, mut phi: f64, m: usize) -> bool {
905    const EPSILON: f64 = 1e-8;
906    if phi.is_nan() {
907        phi = 1.0;
908    }
909    if !(0.0..=1.0 + EPSILON).contains(&phi) {
910        return false;
911    }
912    if gamma.is_nan() {
913        if alpha < 1.0 - 1.0 / phi || alpha > 1.0 + 1.0 / phi {
914            return false;
915        }
916        if !beta.is_nan() && (beta < alpha * (phi - 1.0) || beta > (1.0 + phi) * (2.0 - alpha)) {
917            return false;
918        }
919    } else if m > 1 {
920        if beta.is_nan() {
921            beta = 0.0;
922        }
923        if gamma < f64::max(1.0 - 1.0 / phi - alpha, 0.0) || gamma > 1.0 + 1.0 / phi - alpha {
924            return false;
925        }
926        if alpha
927            < 1.0
928                - 1.0 / phi
929                - gamma * (1.0 - m as f64 + phi + phi * m as f64) / (2.0 * phi * m as f64)
930        {
931            return false;
932        }
933        if beta < -(1.0 - phi) * (gamma / m as f64 + alpha) {
934            return false;
935        }
936        let mut p: Vec<f64> = vec![f64::NAN; 2 + m];
937        p[0] = phi * (1.0 - alpha - gamma);
938        p[1] = alpha + beta - alpha * phi + gamma - 1.0;
939        p[2..m].fill(alpha + beta - alpha * phi);
940        p[m..].fill(alpha + beta - phi);
941        p[m + 1] = 1.0;
942        let roots = roots::find_roots_eigen(p);
943        let max_ = roots
944            .into_iter()
945            .fold(f64::NEG_INFINITY, |max_, r| r.abs().max(max_));
946        if max_ > 1.0 + 1e-10 {
947            return false;
948        }
949    }
950    true
951}
952
953/// A 'problem' for the Nelder-Mead algorithm.
954///
955/// This just groups together and holds several pieces of data that are used in the
956/// cost function called by the Nelder-Mead algorithm. It saves us from having to
957/// pass around a bunch of arguments to the Nelder-Mead function.
958pub(crate) struct ETSProblem<'a> {
959    y: &'a [f64],
960    ets: Ets,
961    x: Vec<f64>,
962    residuals: Vec<f64>,
963    forecasts: Vec<f64>,
964    amse: Vec<f64>,
965    denom: Vec<f64>,
966}
967
968impl<'a> ETSProblem<'a> {
969    /// Create a new problem.
970    ///
971    /// The `y` argument is the time series to fit.
972    /// The `ets` argument is the ETS model to fit.
973    ///
974    /// The returned problem is ready to be passed to the Nelder-Mead algorithm.
975    /// Each of the vectors in the problem is pre-allocated to the correct size.
976    pub(crate) fn new(y: &'a [f64], ets: Ets) -> Self {
977        let nmse = ets.nmse;
978        let x_len = ets.n_states * (y.len() + 1);
979        Self {
980            y,
981            ets,
982            x: vec![0.0; x_len],
983            residuals: vec![0.0; y.len()],
984            forecasts: vec![0.0; nmse],
985            amse: vec![0.0; nmse],
986            denom: vec![0.0; nmse],
987        }
988    }
989
990    /// Calculate the cost function.
991    ///
992    /// The first `self.n_states` elements of `param` are the initial values of the parameters.
993    /// The remaining elements are the initial state.
994    fn cost(&mut self, inputs: &[f64]) -> f64 {
995        let Ets {
996            params,
997            opt_params,
998            opt_crit,
999            n_states,
1000            ..
1001        } = &self.ets;
1002        let mut params = params.clone();
1003
1004        // If we're optimizing params, they'll be included the inputs to the
1005        // optimizer, so use them to override the defaults.
1006        let mut i = 0;
1007        if opt_params.alpha {
1008            params.alpha = inputs[i];
1009            i += 1;
1010        }
1011        if opt_params.beta {
1012            params.beta = inputs[i];
1013            i += 1;
1014        }
1015        if opt_params.gamma {
1016            params.gamma = inputs[i];
1017            i += 1;
1018        }
1019        if opt_params.phi {
1020            params.phi = inputs[i];
1021            i += 1;
1022        }
1023
1024        // The remaining parameters are the initial state.
1025        let state_inputs = &inputs[i..];
1026        self.x.truncate(state_inputs.len());
1027        self.x.copy_from_slice(state_inputs);
1028        self.x.resize(n_states * (self.y.len() + 1), 0.0);
1029        // TODO: add extra state for seasonality?
1030
1031        // Calculate the cost.
1032        let fit = self.ets.etscalc_in(
1033            self.y,
1034            &mut self.x,
1035            params,
1036            &mut self.residuals,
1037            &mut self.forecasts,
1038            &mut self.amse,
1039            &mut self.denom,
1040            // We only need to update the AMSE if we're optimizing using
1041            // AMSE-based criteria.
1042            matches!(
1043                opt_crit,
1044                OptimizationCriteria::MSE | OptimizationCriteria::AMSE
1045            ),
1046        );
1047        match opt_crit {
1048            OptimizationCriteria::Likelihood => fit.likelihood(),
1049            OptimizationCriteria::MSE => fit.mse(),
1050            OptimizationCriteria::AMSE => fit.amse(),
1051            OptimizationCriteria::Sigma => fit.sigma_squared(),
1052            OptimizationCriteria::MAE => fit.mae(),
1053        }
1054    }
1055}
1056
1057/// Check that the parameters are within the bounds.
1058fn check_params(bounds: &Bounds, season_length: usize, params: Params) -> bool {
1059    let Params {
1060        alpha,
1061        beta,
1062        gamma,
1063        phi,
1064    } = params;
1065    if let Bounds::Usual(UpperLowerBounds {
1066        lower: [lower_a, lower_b, lower_g, lower_p],
1067        upper: [upper_a, upper_b, upper_g, upper_p],
1068    })
1069    | Bounds::Both(UpperLowerBounds {
1070        lower: [lower_a, lower_b, lower_g, lower_p],
1071        upper: [upper_a, upper_b, upper_g, upper_p],
1072    }) = bounds
1073    {
1074        if !(alpha.is_nan() || alpha >= *lower_a && alpha <= *upper_a) {
1075            return false;
1076        }
1077        if !(beta.is_nan() || beta >= *lower_b && beta <= alpha && beta <= *upper_b) {
1078            return false;
1079        }
1080        if !(gamma.is_nan() || gamma >= *lower_g && gamma <= 1.0 - alpha && gamma <= *upper_g) {
1081            return false;
1082        }
1083        if !(phi.is_nan() || phi >= *lower_p && phi <= *upper_p) {
1084            return false;
1085        }
1086    }
1087    if !matches!(bounds, Bounds::Usual(_)) {
1088        return admissible(alpha, beta, gamma, phi, season_length);
1089    }
1090    true
1091}
1092
1093/// A fitted ETS model.
1094#[derive(Debug, Clone)]
1095pub struct Model {
1096    /// The original model.
1097    ets: Ets,
1098
1099    /// The fitted model state, parameters and likelihood.
1100    model_fit: FitState,
1101
1102    /// The standard error of the residuals.
1103    ///
1104    /// This is used when calculating prediction intervals for in-sample
1105    /// predictions.
1106    sigma: f64,
1107}
1108
1109impl Model {
1110    fn new(ets: Ets, fit: FitState, sigma: f64) -> Model {
1111        Self {
1112            ets,
1113            model_fit: fit,
1114            sigma,
1115        }
1116    }
1117
1118    fn pegels_forecast(&self, horizon: usize) -> Vec<f64> {
1119        let mut forecasts = vec![0.0; horizon];
1120        let states = self.model_fit.states().last().unwrap();
1121        let phi = if self.ets.damped {
1122            self.model_fit.params().phi
1123        } else {
1124            1.0
1125        };
1126        let b = if self.ets.model_type.trend.included() {
1127            Some(states[1])
1128        } else {
1129            None
1130        };
1131        self.ets
1132            .forecast(phi, states[0], b, &mut forecasts, horizon);
1133        forecasts
1134    }
1135
1136    /// The log-likelihood of the model.
1137    pub fn log_likelihood(&self) -> f64 {
1138        -0.5 * self.model_fit.likelihood()
1139    }
1140
1141    /// The Akaike Information Criterion (AIC) of the model.
1142    pub fn aic(&self) -> f64 {
1143        self.model_fit.likelihood() + 2.0 * self.model_fit.n_params() as f64
1144    }
1145
1146    /// The corrected Akaike Information Criterion (AICC) of the model.
1147    pub fn aicc(&self) -> f64 {
1148        let n_y = self.model_fit.residuals().len();
1149        let n_params = self.model_fit.n_params() + 1;
1150        let aic = self.aic();
1151        let denom = n_y - n_params - 1;
1152        if denom != 0 {
1153            aic + 2.0 * n_params as f64 * (n_params as f64 + 1.0) / denom as f64
1154        } else {
1155            f64::INFINITY
1156        }
1157    }
1158
1159    /// The Bayesian Information Criterion (BIC) of the model.
1160    pub fn bic(&self) -> f64 {
1161        self.model_fit.likelihood()
1162            + (self.model_fit.n_params() as f64 + 1.0)
1163                * ((self.model_fit.residuals().len() as f64).ln())
1164    }
1165
1166    /// The mean squared error (MSE) of the model.
1167    pub fn mse(&self) -> f64 {
1168        self.model_fit.mse()
1169    }
1170
1171    /// The average mean squared error (AMSE) of the model.
1172    ///
1173    /// This is the average of the MSE over the number of forecasting horizons (`nmse`).
1174    pub fn amse(&self) -> f64 {
1175        self.model_fit.amse()
1176    }
1177
1178    /// The model type.
1179    pub fn model_type(&self) -> ModelType {
1180        self.ets.model_type
1181    }
1182
1183    /// Whether the model uses damped trend.
1184    pub fn damped(&self) -> bool {
1185        self.ets.damped
1186    }
1187}
1188
1189impl Predict for Model {
1190    type Error = Error;
1191
1192    fn predict_in_sample_inplace(
1193        &self,
1194        level: Option<f64>,
1195        forecast: &mut augurs_core::Forecast,
1196    ) -> Result<(), Self::Error> {
1197        forecast.point = self.model_fit.fitted().to_vec();
1198        if let Some(level) = level {
1199            Forecast(forecast).calculate_in_sample_intervals(self.sigma, level);
1200        }
1201        Ok(())
1202    }
1203
1204    fn predict_inplace(
1205        &self,
1206        horizon: usize,
1207        level: Option<f64>,
1208        forecast: &mut augurs_core::Forecast,
1209    ) -> Result<(), Self::Error> {
1210        // Short-circuit if horizon is zero.
1211        if horizon == 0 {
1212            return Ok(());
1213        }
1214        forecast.point = self.pegels_forecast(horizon);
1215        if let Some(level) = level {
1216            Forecast(forecast).calculate_intervals(&self.ets, &self.model_fit, horizon, level);
1217        }
1218        Ok(())
1219    }
1220
1221    fn training_data_size(&self) -> usize {
1222        self.model_fit.residuals().len()
1223    }
1224}
1225
1226struct Forecast<'a>(&'a mut augurs_core::Forecast);
1227
1228impl Forecast<'_> {
1229    /// Calculate the prediction intervals for the forecast.
1230    fn calculate_intervals(&mut self, ets: &Ets, fit: &FitState, horizon: usize, level: f64) {
1231        let sigma = fit.sigma_squared();
1232        let season_length = ets.model_type.season.season_length();
1233        let season_length_f = season_length as f64;
1234
1235        let ModelType {
1236            error,
1237            trend,
1238            season,
1239        } = ets.model_type;
1240        let steps: Vec<_> = (1..(horizon + 1)).map(|x| x as f64).collect();
1241        let hm = ((horizon - 1) as f64 / season_length_f).floor();
1242
1243        let Params {
1244            alpha,
1245            beta,
1246            gamma,
1247            phi,
1248        } = fit.params();
1249
1250        let alpha_2 = alpha.powi(2);
1251        let phi_2 = phi.powi(2);
1252
1253        let exp3 = 2.0 * alpha * (1.0 - phi) + beta * phi;
1254        let (exp1, exp2, exp4, exp5): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = steps
1255            .iter()
1256            .copied()
1257            .map(|s| {
1258                let phi_s = phi.powi(s as i32);
1259                (
1260                    alpha_2 + alpha * beta * s + (1.0 / 6.0) * beta.powi(2) * s * (2.0 * s - 1.0),
1261                    (beta * phi * s) / (1.0 - phi).powi(2),
1262                    (beta * phi * (1.0 - phi_s)) / ((1.0 - phi).powi(2) * (1.0 - phi_2)),
1263                    2.0 * alpha * (1.0 - phi_2) + beta * phi * (1.0 + 2.0 * phi - phi_s),
1264                )
1265            })
1266            .multiunzip();
1267
1268        use {ErrorComponent as EC, SeasonalComponent as SC, TrendComponent as TC};
1269        let (lower, upper) =
1270            match (error, trend, season, ets.damped) {
1271                // Class 1 models.
1272                // ANN
1273                (EC::Additive, TC::None, SC::None, false) => {
1274                    let sigma_h = steps
1275                        .iter()
1276                        .map(|s| (((s - 1.0) * alpha.powi(2) + 1.0) * sigma).sqrt());
1277                    self.compute_intervals(level, sigma_h)
1278                }
1279                // AAN
1280                (EC::Additive, TC::Additive, SC::None, false) => {
1281                    let sigma_h = steps
1282                        .iter()
1283                        .zip(&exp1)
1284                        .map(|(s, e)| ((1.0 + (s - 1.0) * e) * sigma).sqrt());
1285                    self.compute_intervals(level, sigma_h)
1286                }
1287                // AAdN
1288                (EC::Additive, TC::Additive, SC::None, true) => {
1289                    let sigma_h =
1290                        steps
1291                            .iter()
1292                            .zip(&exp2)
1293                            .zip(&exp4)
1294                            .zip(&exp5)
1295                            .map(|(((s, e2), e4), e5)| {
1296                                ((1.0 + alpha_2 * (s - 1.0) + e2 * exp3 - e4 * e5) * sigma).sqrt()
1297                            });
1298                    self.compute_intervals(level, sigma_h)
1299                }
1300                // ANA
1301                (EC::Additive, TC::None, SC::Additive { .. }, false) => {
1302                    let sigma_h = steps.iter().map(|s| {
1303                        ((1.0 + alpha_2 * (s - 1.0) + gamma * hm * (2.0 * alpha * gamma)) * sigma)
1304                            .sqrt()
1305                    });
1306                    self.compute_intervals(level, sigma_h)
1307                }
1308                // AAA
1309                (EC::Additive, TC::Additive, SC::Additive { .. }, false) => {
1310                    let sigma_h = steps.iter().zip(&exp1).map(|(s, e1)| {
1311                        let e6 = 2.0 * alpha + gamma + beta * season_length_f * (hm + 1.0);
1312                        ((1.0 + (s - 1.0) * e1 * gamma * hm * e6) * sigma).sqrt()
1313                    });
1314                    self.compute_intervals(level, sigma_h)
1315                }
1316                // AAdA
1317                (EC::Additive, TC::Additive, SC::Additive { season_length }, true) => {
1318                    let sigma_h = steps.iter().zip(&exp2).zip(&exp4).zip(&exp5).map(
1319                        |(((&s, e2), e4), e5)| {
1320                            let phi_s = phi.powi(s as i32);
1321                            let e7 = (2.0 * beta * gamma * phi) / ((1.0 - phi) * (1.0 - phi_s));
1322                            let e8 = hm * (1.0 - phi_s)
1323                                - phi_s * (1.0 - phi.powi(season_length as i32 * hm as i32));
1324                            ((1.0 + alpha_2 * (s - 1.0) + e2 * exp3 - e4 * e5
1325                                + gamma * hm * (2.0 * alpha + gamma)
1326                                + e7 * e8)
1327                                * sigma)
1328                                .sqrt()
1329                        },
1330                    );
1331                    self.compute_intervals(level, sigma_h)
1332                }
1333                // Class 2 models.
1334                // MNN
1335                (EC::Multiplicative, TC::None, SC::None, false) => {
1336                    let cvals = std::iter::repeat_n(*alpha, horizon);
1337                    let sigma_h = self.compute_sigma_h(sigma, cvals, horizon);
1338                    self.compute_intervals(level, sigma_h.into_iter())
1339                }
1340                // MAN
1341                (EC::Multiplicative, TC::Additive, SC::None, false) => {
1342                    let cvals = steps.iter().map(|s| alpha + beta * s);
1343                    let sigma_h = self.compute_sigma_h(sigma, cvals, horizon);
1344                    self.compute_intervals(level, sigma_h.into_iter())
1345                }
1346                // MAdN
1347                (EC::Multiplicative, TC::Additive, SC::None, true) => {
1348                    let mut cvals: Vec<_> = vec![f64::NAN; horizon];
1349                    for k in 1..(horizon + 1) {
1350                        let sum_phi = (1..(k + 1)).map(|j| phi.powi(j as i32)).sum::<f64>();
1351                        cvals[k - 1] = alpha + beta * sum_phi;
1352                    }
1353                    let sigma_h = self.compute_sigma_h(sigma, cvals.into_iter(), horizon);
1354                    self.compute_intervals(level, sigma_h.into_iter())
1355                }
1356                // TODO: all below models, once we do seasonality.
1357                // MNA
1358                (EC::Multiplicative, TC::None, SC::Additive { .. }, false) => todo!(),
1359                // MAA
1360                (EC::Multiplicative, TC::Additive, SC::Additive { .. }, false) => todo!(),
1361                // MAdA
1362                (EC::Multiplicative, TC::Additive, SC::Additive { .. }, true) => todo!(),
1363                // Class 3 models.
1364                // Anything with multiplicative error and seasonality?
1365                (EC::Multiplicative, _, SC::Multiplicative { .. }, _) => {
1366                    unimplemented!(
1367                        "Prediction intervals for class 3 models are not implemented yet"
1368                    )
1369                }
1370                // Class 4 or 5 models without seasonality.
1371                // In future we should also handle those with seasonality.
1372                (_, _, SC::None, _) => {
1373                    // Simulate.
1374                    self.simulate(ets, fit, horizon, level)
1375                }
1376                // Any other models aren't yet implemented.
1377                _ => unimplemented!("Prediction intervals for this model are not implemented yet"),
1378            };
1379        self.0.intervals = Some(ForecastIntervals {
1380            level,
1381            lower,
1382            upper,
1383        });
1384    }
1385
1386    /// Compute the prediction intervals for a given level.
1387    ///
1388    /// `level` should be a number between 0 and 1.
1389    /// `sigma_h` is the standard deviation of the residuals.
1390    fn compute_intervals(
1391        &self,
1392        level: f64,
1393        sigma_h: impl Iterator<Item = f64>,
1394    ) -> (Vec<f64>, Vec<f64>) {
1395        let z = distrs::Normal::ppf(0.5 + level / 2.0, 0.0, 1.0);
1396        self.0
1397            .point
1398            .iter()
1399            .zip(sigma_h)
1400            .map(|(p, s)| (p - z * s, p + z * s))
1401            .unzip()
1402    }
1403
1404    /// Compute the standard deviations of the residuals given the model's
1405    /// overall standard deviation and some critical values.
1406    fn compute_sigma_h(
1407        &self,
1408        sigma: f64,
1409        cvals: impl Iterator<Item = f64>,
1410        horizon: usize,
1411    ) -> Vec<f64> {
1412        let cvals_squared: Vec<_> = cvals.map(|c| c.powi(2)).collect();
1413        let theta =
1414            // Iterate over each point estimate, up to `horizon`.
1415            &self
1416                .0
1417                .point
1418                .iter()
1419                // `point` should always have length == horizon, but `take` just in case
1420                .take(horizon)
1421                .fold(Vec::with_capacity(horizon), |mut acc, p| {
1422                    // For each point estimate, accumulate a vec of
1423                    // errors so far, by iterating the current accumulator,
1424                    // zipping with the reversed critical values, and multiplying.
1425                    // Sum the totals up until this point, then multiply with sigma
1426                    // and add that onto the accumulator.
1427                    let t = p.powi(2)
1428                        + acc
1429                            .iter()
1430                            .rev()
1431                            .zip(&cvals_squared)
1432                            .map(|(t, c)| t * c)
1433                            .sum::<f64>()
1434                            * sigma;
1435                    acc.push(t);
1436                    acc
1437                });
1438        theta
1439            .iter()
1440            .zip(&self.0.point)
1441            .map(|(t, p)| ((1.0 + sigma) * t - p.powi(2)).sqrt())
1442            .collect()
1443    }
1444
1445    fn simulate(
1446        &self,
1447        ets: &Ets,
1448        fit: &FitState,
1449        horizon: usize,
1450        level: f64,
1451    ) -> (Vec<f64>, Vec<f64>) {
1452        let n_sim = 5000;
1453        let last_state = fit.last_state();
1454        let mut y_path = vec![vec![0.0; horizon]; n_sim];
1455        let params = fit.params();
1456        let beta = if params.beta.is_nan() {
1457            0.0
1458        } else {
1459            params.beta
1460        };
1461        let gamma = if params.gamma.is_nan() {
1462            0.0
1463        } else {
1464            params.gamma
1465        };
1466        let phi = if params.phi.is_nan() { 0.0 } else { params.phi };
1467        let rng = &mut rand::thread_rng();
1468        let normal = Normal::new(0.0, fit.sigma_squared().sqrt()).unwrap();
1469        // Use the same `f` vector for each simulation to avoid re-allocating.
1470        // For some reason statsforecast uses a length of 10 for `f`?
1471        let mut f = vec![0.0; 10];
1472        for y_path_k in &mut y_path {
1473            let e: Vec<_> = (0..horizon).map(|_| normal.sample(rng)).collect();
1474            ets.etssimulate(
1475                last_state,
1476                Params {
1477                    alpha: params.alpha,
1478                    beta,
1479                    gamma,
1480                    phi,
1481                },
1482                &e,
1483                &mut f,
1484                y_path_k,
1485            );
1486            f.iter_mut().for_each(|f| *f = 0.0);
1487        }
1488        y_path
1489            .into_iter()
1490            .map(|mut yhat| {
1491                yhat.sort_by(|a, b| a.partial_cmp(b).unwrap());
1492                (
1493                    percentile_of_sorted(&yhat, 0.5 - level / 2.0),
1494                    percentile_of_sorted(&yhat, 0.5 + level / 2.0),
1495                )
1496            })
1497            .unzip()
1498    }
1499
1500    fn calculate_in_sample_intervals(&mut self, sigma: f64, level: f64) {
1501        let (lower, upper) = self.compute_intervals(level, std::iter::repeat(sigma));
1502        self.0.intervals = Some(ForecastIntervals {
1503            level,
1504            lower,
1505            upper,
1506        });
1507    }
1508}
1509
1510// Taken from the Rust compiler's test suite:
1511// https://github.com/rust-lang/rust/blob/917b0b6c70f078cb08bbb0080c9379e4487353c3/library/test/src/stats.rs#L258-L280.
1512fn percentile_of_sorted(sorted_samples: &[f64], pct: f64) -> f64 {
1513    assert!(!sorted_samples.is_empty());
1514    if sorted_samples.len() == 1 {
1515        return sorted_samples[0];
1516    }
1517    let zero: f64 = 0.0;
1518    assert!(zero <= pct);
1519    let hundred = 100_f64;
1520    assert!(pct <= hundred);
1521    if pct == hundred {
1522        return sorted_samples[sorted_samples.len() - 1];
1523    }
1524    let length = (sorted_samples.len() - 1) as f64;
1525    let rank = (pct / hundred) * length;
1526    let lrank = rank.floor();
1527    let d = rank - lrank;
1528    let n = lrank as usize;
1529    let lo = sorted_samples[n];
1530    let hi = sorted_samples[n + 1];
1531    lo + (hi - lo) * d
1532}
1533
1534#[cfg(test)]
1535mod test {
1536    use augurs_core::prelude::*;
1537    use augurs_testing::{assert_approx_eq, assert_within_pct, data::AIR_PASSENGERS as AP};
1538
1539    use crate::model::{
1540        ErrorComponent, ForecastIntervals, ModelType, SeasonalComponent, TrendComponent, Unfit,
1541    };
1542
1543    #[test]
1544    fn initial_params() {
1545        let mut unfit = Unfit::new(ModelType {
1546            error: ErrorComponent::Additive,
1547            trend: TrendComponent::None,
1548            season: SeasonalComponent::None,
1549        });
1550        let initial_params = unfit.initial_params();
1551        assert_approx_eq!(initial_params.alpha, 0.20006);
1552        assert!(initial_params.beta.is_nan());
1553        assert!(initial_params.gamma.is_nan());
1554        assert!(initial_params.phi.is_nan());
1555    }
1556
1557    #[test]
1558    fn air_passengers_fit_aan() {
1559        let unfit = Unfit::new(ModelType {
1560            error: ErrorComponent::Additive,
1561            trend: TrendComponent::Additive,
1562            season: SeasonalComponent::None,
1563        })
1564        .damped(true);
1565        let model = unfit.fit(&AP[AP.len() - 20..]).unwrap();
1566        assert_within_pct!(model.log_likelihood(), -109.6248525790271, 0.01);
1567        assert_within_pct!(model.aic(), 231.2497051580542, 0.01);
1568        assert_within_pct!(model.bic(), 237.22409879937817, 0.01);
1569        assert_within_pct!(model.aicc(), 237.71124361959266, 0.01);
1570        assert_within_pct!(model.mse(), 2883.47944444736, 0.01);
1571        assert_within_pct!(model.amse(), 8292.71075580747, 0.01);
1572    }
1573
1574    #[test]
1575    fn air_passengers_fit_man() {
1576        let unfit = Unfit::new(ModelType {
1577            error: ErrorComponent::Multiplicative,
1578            trend: TrendComponent::Additive,
1579            season: SeasonalComponent::None,
1580        });
1581        let model = unfit.fit(AP).unwrap();
1582        assert_within_pct!(model.log_likelihood(), -831.4883541595792, 0.01);
1583        assert_within_pct!(model.aic(), 1672.9767083191584, 0.01);
1584        assert_within_pct!(model.bic(), 1687.8257748170383, 0.01);
1585        assert_within_pct!(model.aicc(), 1673.4114909278542, 0.01);
1586        assert_within_pct!(model.mse(), 1127.443938773091, 0.01);
1587        assert_within_pct!(model.amse(), 2888.3802507845635, 0.01);
1588    }
1589
1590    #[test]
1591    fn air_passengers_forecast_aan() {
1592        let unfit = Unfit::new(ModelType {
1593            error: ErrorComponent::Additive,
1594            trend: TrendComponent::Additive,
1595            season: SeasonalComponent::None,
1596        })
1597        .damped(true);
1598        let model = unfit.fit(&AP[AP.len() - 20..]).unwrap();
1599        let forecasts = model.predict(10, 0.95).unwrap();
1600        let expected_p = [
1601            432.26645246,
1602            432.53827337,
1603            432.75575609,
1604            432.92976307,
1605            433.0689853,
1606            433.18037639,
1607            433.26949992,
1608            433.34080727,
1609            433.39785997,
1610            433.44350758,
1611        ];
1612        assert_eq!(forecasts.point.len(), 10);
1613        for (actual, expected) in forecasts.point.iter().zip(expected_p.iter()) {
1614            assert_approx_eq!(actual, expected);
1615        }
1616
1617        let expected_l = [
1618            301.72457857,
1619            247.92511851,
1620            206.64496117,
1621            171.83062947,
1622            141.14177344,
1623            113.38060224,
1624            87.83698619,
1625            64.04903959,
1626            41.69638225,
1627            20.54598327,
1628        ];
1629        let ForecastIntervals { lower, upper, .. } = forecasts.intervals.unwrap();
1630        assert_eq!(lower.len(), 10);
1631        for (actual, expected) in lower.iter().zip(expected_l.iter()) {
1632            assert_approx_eq!(actual, expected);
1633        }
1634        let expected_u = [
1635            562.80832636,
1636            617.15142823,
1637            658.86655102,
1638            694.02889667,
1639            724.99619716,
1640            752.98015054,
1641            778.70201365,
1642            802.63257495,
1643            825.09933768,
1644            846.34103189,
1645        ];
1646        assert_eq!(upper.len(), 10);
1647        for (actual, expected) in upper.iter().zip(expected_u.iter()) {
1648            assert_approx_eq!(actual, expected);
1649        }
1650    }
1651
1652    #[test]
1653    fn air_passengers_forecast_man() {
1654        let unfit = Unfit::new(ModelType {
1655            error: ErrorComponent::Multiplicative,
1656            trend: TrendComponent::Additive,
1657            season: SeasonalComponent::None,
1658        });
1659        let model = unfit.fit(AP).unwrap();
1660        let forecasts = model.predict(10, 0.95).unwrap();
1661        let expected_p = [
1662            436.15668239,
1663            440.31714837,
1664            444.47761434,
1665            448.63808031,
1666            452.79854629,
1667            456.95901226,
1668            461.11947823,
1669            465.27994421,
1670            469.44041018,
1671            473.60087615,
1672        ];
1673        assert_eq!(forecasts.point.len(), 10);
1674        for (actual, expected) in forecasts.point.iter().zip(expected_p.iter()) {
1675            assert_approx_eq!(actual, expected);
1676        }
1677
1678        let expected_l = [
1679            345.14145884,
1680            310.62430297,
1681            284.42938026,
1682            262.42886479,
1683            243.03658151,
1684            225.44516176,
1685            209.1784846,
1686            193.92853297,
1687            179.48284058,
1688            165.68775958,
1689        ];
1690        let ForecastIntervals { lower, upper, .. } = forecasts.intervals.unwrap();
1691        assert_eq!(lower.len(), 10);
1692        for (actual, expected) in lower.iter().zip(expected_l.iter()) {
1693            assert_approx_eq!(actual, expected);
1694        }
1695        let expected_u = [
1696            527.17190595,
1697            570.00999376,
1698            604.52584842,
1699            634.84729584,
1700            662.56051106,
1701            688.47286276,
1702            713.06047187,
1703            736.63135545,
1704            759.39797978,
1705            781.51399273,
1706        ];
1707        assert_eq!(upper.len(), 10);
1708        for (actual, expected) in upper.iter().zip(expected_u.iter()) {
1709            assert_approx_eq!(actual, expected);
1710        }
1711
1712        // For in-sample data, just check that the first 10 values match.
1713        let in_sample = model.predict_in_sample(0.95).unwrap();
1714        let expected_p = [
1715            110.74681112,
1716            116.18804955,
1717            122.18817486,
1718            136.18835606,
1719            133.18933724,
1720            125.18861841,
1721            139.18739947,
1722            152.18838061,
1723            152.18926187,
1724            140.18884303,
1725        ];
1726        assert_eq!(in_sample.point.len(), AP.len());
1727        for (actual, expected) in in_sample.point.iter().zip(expected_p.iter()) {
1728            assert_approx_eq!(actual, expected);
1729        }
1730
1731        let ForecastIntervals { lower, upper, .. } = in_sample.intervals.unwrap();
1732        let expected_l = [
1733            43.76306764,
1734            49.20430607,
1735            55.20443139,
1736            69.20461258,
1737            66.20559377,
1738            58.20487493,
1739            72.203656,
1740            85.20463713,
1741            85.20551839,
1742            73.20509956,
1743        ];
1744        assert_eq!(lower.len(), AP.len());
1745        for (actual, expected) in lower.iter().zip(expected_l.iter()) {
1746            assert_approx_eq!(actual, expected);
1747        }
1748        let expected_u = [
1749            177.73055459,
1750            183.17179302,
1751            189.17191834,
1752            203.17209954,
1753            200.17308072,
1754            192.17236188,
1755            206.17114295,
1756            219.17212409,
1757            219.17300535,
1758            207.17258651,
1759        ];
1760        assert_eq!(upper.len(), AP.len());
1761        for (actual, expected) in upper.iter().zip(expected_u.iter()) {
1762            assert_approx_eq!(actual, expected);
1763        }
1764    }
1765
1766    #[test]
1767    fn predict_zero_horizon() {
1768        let unfit = Unfit::new(ModelType {
1769            error: ErrorComponent::Multiplicative,
1770            trend: TrendComponent::Additive,
1771            season: SeasonalComponent::None,
1772        });
1773        let model = unfit.fit(AP).unwrap();
1774        let forecasts = model.predict(0, 0.95).unwrap();
1775        assert!(forecasts.point.is_empty());
1776        let ForecastIntervals { lower, upper, .. } = forecasts.intervals.unwrap();
1777        assert!(lower.is_empty());
1778        assert!(upper.is_empty());
1779    }
1780}