1use std::fmt::{self, Write};
6
7use augurs_core::{ForecastIntervals, Predict};
8use itertools::Itertools;
9use nalgebra::{DMatrix, DVector};
10use rand_distr::{Distribution, Normal};
11use tracing::instrument;
12
13use crate::{
14 ets::{Ets, FitState},
15 stat::VarExt,
16 Error,
17};
18
19#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
21pub enum ErrorComponent {
22 Additive,
24 Multiplicative,
26}
27
28impl fmt::Display for ErrorComponent {
29 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
30 match self {
31 Self::Additive => f.write_char('A'),
32 Self::Multiplicative => f.write_char('M'),
33 }
34 }
35}
36
37#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
39pub enum TrendComponent {
40 None,
42 Additive,
44 Multiplicative,
46}
47
48impl TrendComponent {
49 pub fn included(&self) -> bool {
51 *self != TrendComponent::None
52 }
53}
54
55impl fmt::Display for TrendComponent {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 match self {
58 Self::None => f.write_char('N'),
59 Self::Additive => f.write_char('A'),
60 Self::Multiplicative => f.write_char('M'),
61 }
62 }
63}
64
65#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
67pub enum SeasonalComponent {
68 None,
70 Additive {
72 season_length: usize,
76 },
77 Multiplicative {
79 season_length: usize,
83 },
84}
85
86impl SeasonalComponent {
87 pub fn included(&self) -> bool {
89 *self != SeasonalComponent::None
90 }
91
92 pub fn season_length(&self) -> usize {
97 match self {
98 SeasonalComponent::None => 1,
99 SeasonalComponent::Additive { season_length } => *season_length,
100 SeasonalComponent::Multiplicative { season_length } => *season_length,
101 }
102 }
103}
104
105impl fmt::Display for SeasonalComponent {
106 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
107 match self {
108 Self::None => f.write_char('N'),
109 Self::Additive { .. } => f.write_char('A'),
110 Self::Multiplicative { .. } => f.write_char('M'),
111 }
112 }
113}
114
115#[derive(Clone, PartialEq, Debug)]
117pub struct UpperLowerBounds {
118 lower: [f64; 4],
119 upper: [f64; 4],
120}
121
122impl UpperLowerBounds {
123 pub fn new(lower: [f64; 4], upper: [f64; 4]) -> Result<Self, Error> {
130 if lower.iter().zip(&upper).any(|(l, u)| l > u) {
131 Err(Error::InconsistentBounds)
132 } else {
133 Ok(Self { lower, upper })
134 }
135 }
136}
137
138impl Default for UpperLowerBounds {
139 fn default() -> Self {
140 Self {
141 lower: [0.0001, 0.0001, 0.0001, 0.8],
142 upper: [0.9999, 0.9999, 0.9999, 0.98],
143 }
144 }
145}
146
147#[derive(Clone, Debug)]
149pub enum Bounds {
150 Admissible,
152 Usual(UpperLowerBounds),
154 Both(UpperLowerBounds),
156}
157
158impl Bounds {
159 fn for_optimizer(
160 &self,
161 opt_params: &OptimizeParams,
162 n_states: usize,
163 ) -> Option<(Vec<f64>, Vec<f64>)> {
164 match self {
165 Self::Admissible => None,
166 Self::Usual(bounds) | Self::Both(bounds) => {
167 let n_params = opt_params.n_included();
168 let mut lower = Vec::with_capacity(n_params + n_states);
169 let mut upper = Vec::with_capacity(n_params + n_states);
170 if opt_params.alpha {
171 lower.push(bounds.lower[0]);
172 upper.push(bounds.upper[0]);
173 }
174 if opt_params.beta {
175 lower.push(bounds.lower[1]);
176 upper.push(bounds.upper[1]);
177 }
178 if opt_params.gamma {
179 lower.push(bounds.lower[2]);
180 upper.push(bounds.upper[2]);
181 }
182 if opt_params.phi {
183 lower.push(bounds.lower[3]);
184 upper.push(bounds.upper[3]);
185 }
186 for _ in 0..n_states {
187 lower.push(f64::NEG_INFINITY);
188 upper.push(f64::INFINITY);
189 }
190 Some((lower, upper))
191 }
192 }
193 }
194}
195
196impl Default for Bounds {
197 fn default() -> Self {
198 Self::Both(UpperLowerBounds::default())
199 }
200}
201
202#[derive(Debug, Copy, Clone, Default)]
206pub enum OptimizationCriteria {
207 #[default]
209 Likelihood,
210 MSE,
212 AMSE,
214 Sigma,
216 MAE,
218}
219
220#[derive(Debug, Clone, Copy)]
228pub struct ModelType {
229 pub error: ErrorComponent,
231 pub trend: TrendComponent,
233 pub season: SeasonalComponent,
235}
236
237impl fmt::Display for ModelType {
238 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
239 self.error.fmt(f)?;
240 self.trend.fmt(f)?;
241 self.season.fmt(f)?;
242 Ok(())
243 }
244}
245
246#[derive(Debug, Clone)]
248pub struct Params {
249 pub alpha: f64,
254 pub beta: f64,
259 pub gamma: f64,
263 pub phi: f64,
267}
268
269impl Default for Params {
270 fn default() -> Self {
271 Self {
272 alpha: f64::NAN,
273 beta: f64::NAN,
274 gamma: f64::NAN,
275 phi: f64::NAN,
276 }
277 }
278}
279
280#[derive(Debug, Default, Clone)]
290pub(crate) struct OptimizeParams {
291 pub alpha: bool,
293 pub beta: bool,
295 pub gamma: bool,
297 pub phi: bool,
299}
300
301impl OptimizeParams {
302 pub(crate) fn n_included(&self) -> usize {
303 self.alpha as usize + self.beta as usize + self.gamma as usize + self.phi as usize
304 }
305}
306
307fn not_nan_or(x: f64, default: f64) -> f64 {
309 if x.is_nan() {
310 default
311 } else {
312 x
313 }
314}
315
316#[derive(Debug, Clone)]
318pub struct Unfit {
319 model_type: ModelType,
321
322 damped: bool,
326
327 nmse: usize,
333
334 bounds: Bounds,
340
341 params: Params,
346
347 opt_crit: OptimizationCriteria,
351
352 max_iter: usize,
356}
357
358impl Unfit {
359 pub fn new(model_type: ModelType) -> Self {
361 Self {
362 model_type,
363 damped: false,
364 bounds: Bounds::default(),
365 nmse: 3,
366 params: Params::default(),
367 opt_crit: OptimizationCriteria::default(),
368 max_iter: 2_000,
369 }
370 }
371
372 pub fn params(self, params: Params) -> Self {
376 Self { params, ..self }
377 }
378
379 pub fn nmse(self, nmse: usize) -> Self {
381 Self { nmse, ..self }
382 }
383
384 pub fn opt_crit(self, opt_crit: OptimizationCriteria) -> Self {
386 Self { opt_crit, ..self }
387 }
388
389 pub fn max_iterations(self, max_iterations: usize) -> Self {
391 Self {
392 max_iter: max_iterations,
393 ..self
394 }
395 }
396
397 pub fn damped(self, damped: bool) -> Self {
399 Self { damped, ..self }
400 }
401
402 fn select_alpha(lower: &[f64; 4], upper: &[f64; 4], alpha: f64, m: usize) -> f64 {
404 if alpha.is_nan() {
405 let mut alpha = lower[0] + 0.2 * (upper[0] - lower[0]) / m as f64;
406 if !(0.0..=1.0).contains(&alpha) {
407 alpha = lower[0] + 2e-3;
408 }
409 alpha
410 } else {
411 alpha
412 }
413 }
414
415 fn select_beta(
417 lower: &[f64; 4],
418 upper: &mut [f64; 4],
419 trend: TrendComponent,
420 alpha: f64,
421 beta: f64,
422 ) -> f64 {
423 if trend != TrendComponent::None && beta.is_nan() {
424 upper[1] = upper[1].min(alpha);
426 let mut beta = lower[1] + 0.1 * (upper[1] - lower[1]);
427 if beta < 0.0 || beta > alpha {
428 beta = alpha - 1e-3;
429 }
430 beta
431 } else {
432 beta
433 }
434 }
435
436 fn select_gamma(
438 lower: &[f64; 4],
439 upper: &mut [f64; 4],
440 season: SeasonalComponent,
441 alpha: f64,
442 gamma: f64,
443 ) -> f64 {
444 if season != SeasonalComponent::None && gamma.is_nan() {
445 upper[2] = upper[2].min(1.0 - alpha);
446 let mut gamma = lower[2] + 0.05 * (upper[2] - lower[2]);
447 if gamma < 0.0 || gamma > 1.0 - alpha {
448 gamma = 1.0 - alpha - 1e-3;
449 }
450 gamma
451 } else {
452 gamma
453 }
454 }
455
456 fn select_phi(lower: &[f64; 4], upper: &[f64; 4], damped: bool, phi: f64) -> f64 {
458 if damped && phi.is_nan() {
459 let mut phi = lower[3] + 0.99 * (upper[3] - lower[3]);
460 if !(0.0..=1.0).contains(&phi) {
461 phi = upper[3] - 1e-3;
462 }
463 phi
464 } else {
465 phi
466 }
467 }
468
469 fn initial_params(&mut self) -> Params {
471 let (mut dummy_lower, mut dummy_upper) = ([0.0; 4], [1e-3; 4]);
473 let (lower, upper) = match &mut self.bounds {
474 Bounds::Admissible => (&mut dummy_lower, &mut dummy_upper),
475 Bounds::Usual(UpperLowerBounds { lower, upper }) => (lower, upper),
476 Bounds::Both(UpperLowerBounds { lower, upper }) => (lower, upper),
477 };
478 let alpha = Self::select_alpha(
479 lower,
480 upper,
481 self.params.alpha,
482 self.model_type.season.season_length(),
483 );
484 let beta = Self::select_beta(lower, upper, self.model_type.trend, alpha, self.params.beta);
485 let gamma = Self::select_gamma(
486 lower,
487 upper,
488 self.model_type.season,
489 alpha,
490 self.params.gamma,
491 );
492 let phi = Self::select_phi(lower, upper, self.damped, self.params.phi);
493 Params {
494 alpha,
495 beta,
496 gamma,
497 phi,
498 }
499 }
500
501 fn initial_state(&self, y: &[f64]) -> Result<Vec<f64>, Error> {
503 let n = y.len();
504 let (m, y_sa) = if self.model_type.season == SeasonalComponent::None {
505 (1, y.to_vec())
506 } else {
507 unimplemented!("seasonal component not implemented yet")
508 };
551 let max_n = 10.clamp(m, n);
552 match self.model_type.trend {
553 TrendComponent::None => {
554 let l0 = y_sa.iter().take(max_n).sum::<f64>() / max_n as f64;
555 Ok(vec![l0])
556 }
557 _ => {
558 #[allow(non_snake_case)]
559 let X = DMatrix::from_iterator(
560 max_n,
561 2,
562 std::iter::repeat_n(1.0, max_n)
563 .take(max_n)
564 .chain((1..(max_n + 1)).map(|x| x as f64)),
565 );
566 let y = DVector::from_row_slice(&y_sa[..max_n]);
567 let lstsq = lstsq::lstsq(&X, &y, f64::EPSILON).map_err(Error::LeastSquares)?;
568 let (l, b) = (lstsq.solution[0], lstsq.solution[1]);
569 if self.model_type.trend == TrendComponent::Additive {
570 let (mut l0, mut b0) = (l, b);
571 if (l0 + b0).abs() < 1e-8 {
572 l0 *= 1.0 + 1e-3;
573 b0 *= 1.0 + 1e-3;
574 }
575 Ok(vec![l0, b0])
576 } else {
577 let mut l0 = l + b;
578 if l0.abs() < 1e-8 {
579 l0 *= 1.0 + 1e-3;
580 }
581 let mut b0: f64 = (l + 2.0 * b) / l0;
582 let div = if b0.abs() < 1e-8 { 1e-8 } else { b0 };
583 l0 /= div;
584 if b0.abs() > 1e10 {
585 b0 = b0.signum() * 1e10;
586 }
587 if l0 < 1e-8 || b0 < 1e-8 {
588 l0 = y_sa[0].max(1e-3);
590 let div = if y_sa[0].abs() < 1e-8 { 1e-8 } else { y_sa[0] };
591 b0 = (y_sa[1] / div).max(1e-3);
592 }
593 Ok(vec![l0, b0])
594 }
595 }
596 }
597 }
598
599 #[instrument(skip_all)]
601 pub fn fit(mut self, y: &[f64]) -> Result<Model, Error> {
602 self.nmse = self.nmse.min(30);
603 let season_length = self.model_type.season.season_length();
604
605 let n_states = season_length * self.model_type.season.included() as usize
606 + 1
607 + self.model_type.trend.included() as usize;
608
609 let par_noopt = self.params.clone();
611 let par_ = self.initial_params();
612 let alpha = not_nan_or(par_.alpha, par_noopt.alpha);
613 let beta = not_nan_or(par_.beta, par_noopt.beta);
614 let gamma = not_nan_or(par_.gamma, par_noopt.gamma);
615 let phi = not_nan_or(par_.phi, par_noopt.phi);
616 if !check_params(
617 &self.bounds,
618 season_length,
619 Params {
620 alpha,
621 beta,
622 gamma,
623 phi,
624 },
625 ) {
626 return Err(Error::ParamsOutOfRange);
627 }
628
629 let initial_state = self.initial_state(y)?;
630 let param_arr = [alpha, beta, gamma, phi];
631
632 let x0: Vec<_> = param_arr
633 .iter()
634 .copied()
635 .filter(|&x| !x.is_nan())
636 .chain(initial_state.iter().copied())
637 .collect();
638 let np_ = x0.len();
639 if np_ >= y.len() - 1 {
640 return Err(Error::NotEnoughData);
641 }
642 let opt_params = OptimizeParams {
643 alpha: !alpha.is_nan(),
644 beta: !beta.is_nan(),
645 gamma: !gamma.is_nan(),
646 phi: !phi.is_nan(),
647 };
648
649 let params = Params {
650 alpha,
651 beta: if self.model_type.trend.included() {
652 beta
653 } else {
654 0.0
655 },
656 phi: if self.damped { phi } else { 1.0 },
657 gamma: if self.model_type.season.included() {
658 gamma
659 } else {
660 0.0
661 },
662 };
663
664 let opt_bounds = self.bounds.for_optimizer(&opt_params, n_states);
665 let ets = Ets::new(
667 self.model_type,
668 self.damped,
669 self.nmse,
670 n_states,
671 params,
672 opt_params,
673 self.opt_crit,
674 );
675 let mut problem = ETSProblem::new(y, ets);
676 let simplex = self.param_vecs(x0, opt_bounds.as_ref());
678 let best_params = self.nelder_mead(&mut problem, simplex, opt_bounds.as_ref());
680
681 problem.amse.fill(0.0);
683 problem.denom.fill(0.0);
684 let fit = problem.ets.pegels_resid_in(
685 y,
686 &best_params,
687 problem.x,
688 problem.ets.params.clone(),
689 problem.residuals,
690 problem.forecasts,
691 problem.amse,
692 problem.denom,
693 );
694 let sigma_squared = y
695 .iter()
696 .zip(fit.fitted())
697 .map(|(y, f)| (y - f).powi(2))
698 .sum::<f64>()
699 / (y.len() - fit.n_params() - 1) as f64;
700 Ok(Model::new(problem.ets, fit, sigma_squared.sqrt()))
701 }
702
703 #[instrument(skip_all)]
708 fn param_vecs(&self, mut x0: Vec<f64>, bounds: Option<&(Vec<f64>, Vec<f64>)>) -> Vec<Vec<f64>> {
709 if let Some((lower, upper)) = bounds {
710 Self::restrict_to_bounds(&mut x0, lower, upper);
711 }
712 let n = x0.len();
713
714 let mut simplex = vec![x0; n + 1];
715 let diag = simplex
716 .iter_mut()
717 .take(n)
718 .enumerate()
719 .map(|(i, row)| &mut row[i]);
720 for el in diag {
721 if el.abs() < 1e-8 {
722 *el = 1e-4;
723 } else {
724 *el *= 1.05;
725 }
726 }
727 if let Some((lower, upper)) = bounds {
728 for row in simplex.iter_mut() {
729 Self::restrict_to_bounds(row, lower, upper)
730 }
731 }
732 simplex
733 }
734
735 const TOL_STD: f64 = 1e-4;
736
737 #[instrument(skip_all)]
746 fn nelder_mead(
747 &self,
748 problem: &mut ETSProblem<'_>,
749 mut simplex: Vec<Vec<f64>>,
750 bounds: Option<&(Vec<f64>, Vec<f64>)>,
751 ) -> Vec<f64> {
752 let n_u = simplex[0].len();
753 let n = simplex[0].len() as f64;
754
755 let alpha = 1.0;
756 let gamma = 1.0 + 2.0 / n;
757 let rho = 0.75 - 1.0 / (2.0 * n);
758 let sigma = 1.0 - 1.0 / n;
759
760 let mut f_simplex: Vec<_> = simplex.iter().map(|x| problem.cost(x)).collect();
761 let mut costs_sorted: Vec<_> = f_simplex.iter().copied().enumerate().collect();
762 let mut order_f: Vec<_> = costs_sorted.iter().map(|(i, _)| *i).collect();
763 let mut best_idx = order_f[0];
764 let mut x_o: Vec<_>;
765 let mut x_r: Vec<_>;
766 let mut x_e: Vec<_>;
767 let mut x_oc: Vec<_>;
768 let mut x_ic: Vec<_>;
769 for _ in 0..self.max_iter {
770 costs_sorted.clear();
771 costs_sorted.extend(f_simplex.iter().copied().enumerate());
772 costs_sorted.sort_unstable_by(|(_, a), (_, b)| {
773 a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)
774 });
775 order_f.clear();
776 order_f.extend(costs_sorted.iter().map(|(i, _)| *i));
777
778 best_idx = order_f[0];
779 let worst_idx = order_f[order_f.len() - 1];
780 let second_worst_idx = order_f[order_f.len() - 2];
781
782 if f_simplex.std(0) < Self::TOL_STD {
784 break;
785 }
786
787 x_o = vec![0.0; n_u];
789 for x in simplex
790 .iter()
791 .enumerate()
792 .filter_map(|(i, x)| (i != worst_idx).then_some(x))
793 {
794 for (i, el) in x.iter().enumerate() {
795 x_o[i] += el;
796 }
797 }
798 for x in x_o.iter_mut() {
799 *x /= n;
800 }
801
802 x_r = x_o
804 .iter()
805 .zip(&simplex[worst_idx])
806 .map(|(x_0, x)| x_0 + alpha * (x_0 - x))
807 .collect();
808 if let Some((lower, upper)) = &bounds {
809 Self::restrict_to_bounds(&mut x_r, lower, upper);
810 }
811 let f_r = problem.cost(&x_r);
812 if f_simplex[best_idx] <= f_r && f_r < f_simplex[second_worst_idx] {
813 simplex[worst_idx] = x_r;
814 f_simplex[worst_idx] = f_r;
815 continue;
816 }
817
818 if f_r < f_simplex[best_idx] {
820 x_e = x_o
821 .iter()
822 .zip(&x_r)
823 .map(|(x_o, x_r)| x_o + gamma * (x_r - x_o))
824 .collect();
825 if let Some((lower, upper)) = &bounds {
826 Self::restrict_to_bounds(&mut x_e, lower, upper);
827 }
828 let f_e = problem.cost(&x_e);
829 if f_e < f_r {
830 simplex[worst_idx] = x_e;
831 f_simplex[worst_idx] = f_e;
832 } else {
833 simplex[worst_idx] = x_r;
834 f_simplex[worst_idx] = f_r;
835 }
836 continue;
837 }
838
839 if f_simplex[second_worst_idx] <= f_r && f_r < f_simplex[worst_idx] {
841 x_oc = x_o
842 .iter()
843 .zip(&x_r)
844 .map(|(x_o, x_r)| x_o + rho * (x_r - x_o))
845 .collect();
846 if let Some((lower, upper)) = &bounds {
847 Self::restrict_to_bounds(&mut x_oc, lower, upper);
848 }
849 let f_oc = problem.cost(&x_oc);
850 if f_oc <= f_r {
851 simplex[worst_idx] = x_oc;
852 f_simplex[worst_idx] = f_oc;
853 continue;
854 }
855 } else {
856 x_ic = x_o
858 .iter()
859 .zip(&x_r)
860 .map(|(x_o, x_r)| x_o - rho * (x_r - x_o))
861 .collect();
862 if let Some((lower, upper)) = &bounds {
863 Self::restrict_to_bounds(&mut x_ic, lower, upper);
864 }
865 let f_ic = problem.cost(&x_ic);
866 if f_ic < f_simplex[worst_idx] {
867 simplex[worst_idx] = x_ic;
868 f_simplex[worst_idx] = f_ic;
869 continue;
870 }
871 }
872
873 let best = simplex[best_idx].clone();
875 simplex.iter_mut().enumerate().for_each(|(i, x)| {
876 if i != best_idx {
877 x.iter_mut()
878 .zip(&best)
879 .for_each(|(x, x_best)| *x = x_best + sigma * (*x - x_best));
880 if let Some((lower, upper)) = &bounds {
881 Self::restrict_to_bounds(&mut x_r, lower, upper);
882 }
883 f_simplex[i] = problem.cost(x);
884 }
885 });
886 }
887 simplex[best_idx].clone()
888 }
889
890 fn restrict_to_bounds(x0: &mut [f64], lower: &[f64], upper: &[f64]) {
892 x0.iter_mut()
893 .zip(lower)
894 .zip(upper)
895 .for_each(|((x, &l), &u)| {
896 *x = x.clamp(l, u);
897 });
898 }
899}
900
901fn admissible(alpha: f64, mut beta: f64, gamma: f64, mut phi: f64, m: usize) -> bool {
905 const EPSILON: f64 = 1e-8;
906 if phi.is_nan() {
907 phi = 1.0;
908 }
909 if !(0.0..=1.0 + EPSILON).contains(&phi) {
910 return false;
911 }
912 if gamma.is_nan() {
913 if alpha < 1.0 - 1.0 / phi || alpha > 1.0 + 1.0 / phi {
914 return false;
915 }
916 if !beta.is_nan() && (beta < alpha * (phi - 1.0) || beta > (1.0 + phi) * (2.0 - alpha)) {
917 return false;
918 }
919 } else if m > 1 {
920 if beta.is_nan() {
921 beta = 0.0;
922 }
923 if gamma < f64::max(1.0 - 1.0 / phi - alpha, 0.0) || gamma > 1.0 + 1.0 / phi - alpha {
924 return false;
925 }
926 if alpha
927 < 1.0
928 - 1.0 / phi
929 - gamma * (1.0 - m as f64 + phi + phi * m as f64) / (2.0 * phi * m as f64)
930 {
931 return false;
932 }
933 if beta < -(1.0 - phi) * (gamma / m as f64 + alpha) {
934 return false;
935 }
936 let mut p: Vec<f64> = vec![f64::NAN; 2 + m];
937 p[0] = phi * (1.0 - alpha - gamma);
938 p[1] = alpha + beta - alpha * phi + gamma - 1.0;
939 p[2..m].fill(alpha + beta - alpha * phi);
940 p[m..].fill(alpha + beta - phi);
941 p[m + 1] = 1.0;
942 let roots = roots::find_roots_eigen(p);
943 let max_ = roots
944 .into_iter()
945 .fold(f64::NEG_INFINITY, |max_, r| r.abs().max(max_));
946 if max_ > 1.0 + 1e-10 {
947 return false;
948 }
949 }
950 true
951}
952
953pub(crate) struct ETSProblem<'a> {
959 y: &'a [f64],
960 ets: Ets,
961 x: Vec<f64>,
962 residuals: Vec<f64>,
963 forecasts: Vec<f64>,
964 amse: Vec<f64>,
965 denom: Vec<f64>,
966}
967
968impl<'a> ETSProblem<'a> {
969 pub(crate) fn new(y: &'a [f64], ets: Ets) -> Self {
977 let nmse = ets.nmse;
978 let x_len = ets.n_states * (y.len() + 1);
979 Self {
980 y,
981 ets,
982 x: vec![0.0; x_len],
983 residuals: vec![0.0; y.len()],
984 forecasts: vec![0.0; nmse],
985 amse: vec![0.0; nmse],
986 denom: vec![0.0; nmse],
987 }
988 }
989
990 fn cost(&mut self, inputs: &[f64]) -> f64 {
995 let Ets {
996 params,
997 opt_params,
998 opt_crit,
999 n_states,
1000 ..
1001 } = &self.ets;
1002 let mut params = params.clone();
1003
1004 let mut i = 0;
1007 if opt_params.alpha {
1008 params.alpha = inputs[i];
1009 i += 1;
1010 }
1011 if opt_params.beta {
1012 params.beta = inputs[i];
1013 i += 1;
1014 }
1015 if opt_params.gamma {
1016 params.gamma = inputs[i];
1017 i += 1;
1018 }
1019 if opt_params.phi {
1020 params.phi = inputs[i];
1021 i += 1;
1022 }
1023
1024 let state_inputs = &inputs[i..];
1026 self.x.truncate(state_inputs.len());
1027 self.x.copy_from_slice(state_inputs);
1028 self.x.resize(n_states * (self.y.len() + 1), 0.0);
1029 let fit = self.ets.etscalc_in(
1033 self.y,
1034 &mut self.x,
1035 params,
1036 &mut self.residuals,
1037 &mut self.forecasts,
1038 &mut self.amse,
1039 &mut self.denom,
1040 matches!(
1043 opt_crit,
1044 OptimizationCriteria::MSE | OptimizationCriteria::AMSE
1045 ),
1046 );
1047 match opt_crit {
1048 OptimizationCriteria::Likelihood => fit.likelihood(),
1049 OptimizationCriteria::MSE => fit.mse(),
1050 OptimizationCriteria::AMSE => fit.amse(),
1051 OptimizationCriteria::Sigma => fit.sigma_squared(),
1052 OptimizationCriteria::MAE => fit.mae(),
1053 }
1054 }
1055}
1056
1057fn check_params(bounds: &Bounds, season_length: usize, params: Params) -> bool {
1059 let Params {
1060 alpha,
1061 beta,
1062 gamma,
1063 phi,
1064 } = params;
1065 if let Bounds::Usual(UpperLowerBounds {
1066 lower: [lower_a, lower_b, lower_g, lower_p],
1067 upper: [upper_a, upper_b, upper_g, upper_p],
1068 })
1069 | Bounds::Both(UpperLowerBounds {
1070 lower: [lower_a, lower_b, lower_g, lower_p],
1071 upper: [upper_a, upper_b, upper_g, upper_p],
1072 }) = bounds
1073 {
1074 if !(alpha.is_nan() || alpha >= *lower_a && alpha <= *upper_a) {
1075 return false;
1076 }
1077 if !(beta.is_nan() || beta >= *lower_b && beta <= alpha && beta <= *upper_b) {
1078 return false;
1079 }
1080 if !(gamma.is_nan() || gamma >= *lower_g && gamma <= 1.0 - alpha && gamma <= *upper_g) {
1081 return false;
1082 }
1083 if !(phi.is_nan() || phi >= *lower_p && phi <= *upper_p) {
1084 return false;
1085 }
1086 }
1087 if !matches!(bounds, Bounds::Usual(_)) {
1088 return admissible(alpha, beta, gamma, phi, season_length);
1089 }
1090 true
1091}
1092
1093#[derive(Debug, Clone)]
1095pub struct Model {
1096 ets: Ets,
1098
1099 model_fit: FitState,
1101
1102 sigma: f64,
1107}
1108
1109impl Model {
1110 fn new(ets: Ets, fit: FitState, sigma: f64) -> Model {
1111 Self {
1112 ets,
1113 model_fit: fit,
1114 sigma,
1115 }
1116 }
1117
1118 fn pegels_forecast(&self, horizon: usize) -> Vec<f64> {
1119 let mut forecasts = vec![0.0; horizon];
1120 let states = self.model_fit.states().last().unwrap();
1121 let phi = if self.ets.damped {
1122 self.model_fit.params().phi
1123 } else {
1124 1.0
1125 };
1126 let b = if self.ets.model_type.trend.included() {
1127 Some(states[1])
1128 } else {
1129 None
1130 };
1131 self.ets
1132 .forecast(phi, states[0], b, &mut forecasts, horizon);
1133 forecasts
1134 }
1135
1136 pub fn log_likelihood(&self) -> f64 {
1138 -0.5 * self.model_fit.likelihood()
1139 }
1140
1141 pub fn aic(&self) -> f64 {
1143 self.model_fit.likelihood() + 2.0 * self.model_fit.n_params() as f64
1144 }
1145
1146 pub fn aicc(&self) -> f64 {
1148 let n_y = self.model_fit.residuals().len();
1149 let n_params = self.model_fit.n_params() + 1;
1150 let aic = self.aic();
1151 let denom = n_y - n_params - 1;
1152 if denom != 0 {
1153 aic + 2.0 * n_params as f64 * (n_params as f64 + 1.0) / denom as f64
1154 } else {
1155 f64::INFINITY
1156 }
1157 }
1158
1159 pub fn bic(&self) -> f64 {
1161 self.model_fit.likelihood()
1162 + (self.model_fit.n_params() as f64 + 1.0)
1163 * ((self.model_fit.residuals().len() as f64).ln())
1164 }
1165
1166 pub fn mse(&self) -> f64 {
1168 self.model_fit.mse()
1169 }
1170
1171 pub fn amse(&self) -> f64 {
1175 self.model_fit.amse()
1176 }
1177
1178 pub fn model_type(&self) -> ModelType {
1180 self.ets.model_type
1181 }
1182
1183 pub fn damped(&self) -> bool {
1185 self.ets.damped
1186 }
1187}
1188
1189impl Predict for Model {
1190 type Error = Error;
1191
1192 fn predict_in_sample_inplace(
1193 &self,
1194 level: Option<f64>,
1195 forecast: &mut augurs_core::Forecast,
1196 ) -> Result<(), Self::Error> {
1197 forecast.point = self.model_fit.fitted().to_vec();
1198 if let Some(level) = level {
1199 Forecast(forecast).calculate_in_sample_intervals(self.sigma, level);
1200 }
1201 Ok(())
1202 }
1203
1204 fn predict_inplace(
1205 &self,
1206 horizon: usize,
1207 level: Option<f64>,
1208 forecast: &mut augurs_core::Forecast,
1209 ) -> Result<(), Self::Error> {
1210 if horizon == 0 {
1212 return Ok(());
1213 }
1214 forecast.point = self.pegels_forecast(horizon);
1215 if let Some(level) = level {
1216 Forecast(forecast).calculate_intervals(&self.ets, &self.model_fit, horizon, level);
1217 }
1218 Ok(())
1219 }
1220
1221 fn training_data_size(&self) -> usize {
1222 self.model_fit.residuals().len()
1223 }
1224}
1225
1226struct Forecast<'a>(&'a mut augurs_core::Forecast);
1227
1228impl Forecast<'_> {
1229 fn calculate_intervals(&mut self, ets: &Ets, fit: &FitState, horizon: usize, level: f64) {
1231 let sigma = fit.sigma_squared();
1232 let season_length = ets.model_type.season.season_length();
1233 let season_length_f = season_length as f64;
1234
1235 let ModelType {
1236 error,
1237 trend,
1238 season,
1239 } = ets.model_type;
1240 let steps: Vec<_> = (1..(horizon + 1)).map(|x| x as f64).collect();
1241 let hm = ((horizon - 1) as f64 / season_length_f).floor();
1242
1243 let Params {
1244 alpha,
1245 beta,
1246 gamma,
1247 phi,
1248 } = fit.params();
1249
1250 let alpha_2 = alpha.powi(2);
1251 let phi_2 = phi.powi(2);
1252
1253 let exp3 = 2.0 * alpha * (1.0 - phi) + beta * phi;
1254 let (exp1, exp2, exp4, exp5): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = steps
1255 .iter()
1256 .copied()
1257 .map(|s| {
1258 let phi_s = phi.powi(s as i32);
1259 (
1260 alpha_2 + alpha * beta * s + (1.0 / 6.0) * beta.powi(2) * s * (2.0 * s - 1.0),
1261 (beta * phi * s) / (1.0 - phi).powi(2),
1262 (beta * phi * (1.0 - phi_s)) / ((1.0 - phi).powi(2) * (1.0 - phi_2)),
1263 2.0 * alpha * (1.0 - phi_2) + beta * phi * (1.0 + 2.0 * phi - phi_s),
1264 )
1265 })
1266 .multiunzip();
1267
1268 use {ErrorComponent as EC, SeasonalComponent as SC, TrendComponent as TC};
1269 let (lower, upper) =
1270 match (error, trend, season, ets.damped) {
1271 (EC::Additive, TC::None, SC::None, false) => {
1274 let sigma_h = steps
1275 .iter()
1276 .map(|s| (((s - 1.0) * alpha.powi(2) + 1.0) * sigma).sqrt());
1277 self.compute_intervals(level, sigma_h)
1278 }
1279 (EC::Additive, TC::Additive, SC::None, false) => {
1281 let sigma_h = steps
1282 .iter()
1283 .zip(&exp1)
1284 .map(|(s, e)| ((1.0 + (s - 1.0) * e) * sigma).sqrt());
1285 self.compute_intervals(level, sigma_h)
1286 }
1287 (EC::Additive, TC::Additive, SC::None, true) => {
1289 let sigma_h =
1290 steps
1291 .iter()
1292 .zip(&exp2)
1293 .zip(&exp4)
1294 .zip(&exp5)
1295 .map(|(((s, e2), e4), e5)| {
1296 ((1.0 + alpha_2 * (s - 1.0) + e2 * exp3 - e4 * e5) * sigma).sqrt()
1297 });
1298 self.compute_intervals(level, sigma_h)
1299 }
1300 (EC::Additive, TC::None, SC::Additive { .. }, false) => {
1302 let sigma_h = steps.iter().map(|s| {
1303 ((1.0 + alpha_2 * (s - 1.0) + gamma * hm * (2.0 * alpha * gamma)) * sigma)
1304 .sqrt()
1305 });
1306 self.compute_intervals(level, sigma_h)
1307 }
1308 (EC::Additive, TC::Additive, SC::Additive { .. }, false) => {
1310 let sigma_h = steps.iter().zip(&exp1).map(|(s, e1)| {
1311 let e6 = 2.0 * alpha + gamma + beta * season_length_f * (hm + 1.0);
1312 ((1.0 + (s - 1.0) * e1 * gamma * hm * e6) * sigma).sqrt()
1313 });
1314 self.compute_intervals(level, sigma_h)
1315 }
1316 (EC::Additive, TC::Additive, SC::Additive { season_length }, true) => {
1318 let sigma_h = steps.iter().zip(&exp2).zip(&exp4).zip(&exp5).map(
1319 |(((&s, e2), e4), e5)| {
1320 let phi_s = phi.powi(s as i32);
1321 let e7 = (2.0 * beta * gamma * phi) / ((1.0 - phi) * (1.0 - phi_s));
1322 let e8 = hm * (1.0 - phi_s)
1323 - phi_s * (1.0 - phi.powi(season_length as i32 * hm as i32));
1324 ((1.0 + alpha_2 * (s - 1.0) + e2 * exp3 - e4 * e5
1325 + gamma * hm * (2.0 * alpha + gamma)
1326 + e7 * e8)
1327 * sigma)
1328 .sqrt()
1329 },
1330 );
1331 self.compute_intervals(level, sigma_h)
1332 }
1333 (EC::Multiplicative, TC::None, SC::None, false) => {
1336 let cvals = std::iter::repeat_n(*alpha, horizon);
1337 let sigma_h = self.compute_sigma_h(sigma, cvals, horizon);
1338 self.compute_intervals(level, sigma_h.into_iter())
1339 }
1340 (EC::Multiplicative, TC::Additive, SC::None, false) => {
1342 let cvals = steps.iter().map(|s| alpha + beta * s);
1343 let sigma_h = self.compute_sigma_h(sigma, cvals, horizon);
1344 self.compute_intervals(level, sigma_h.into_iter())
1345 }
1346 (EC::Multiplicative, TC::Additive, SC::None, true) => {
1348 let mut cvals: Vec<_> = vec![f64::NAN; horizon];
1349 for k in 1..(horizon + 1) {
1350 let sum_phi = (1..(k + 1)).map(|j| phi.powi(j as i32)).sum::<f64>();
1351 cvals[k - 1] = alpha + beta * sum_phi;
1352 }
1353 let sigma_h = self.compute_sigma_h(sigma, cvals.into_iter(), horizon);
1354 self.compute_intervals(level, sigma_h.into_iter())
1355 }
1356 (EC::Multiplicative, TC::None, SC::Additive { .. }, false) => todo!(),
1359 (EC::Multiplicative, TC::Additive, SC::Additive { .. }, false) => todo!(),
1361 (EC::Multiplicative, TC::Additive, SC::Additive { .. }, true) => todo!(),
1363 (EC::Multiplicative, _, SC::Multiplicative { .. }, _) => {
1366 unimplemented!(
1367 "Prediction intervals for class 3 models are not implemented yet"
1368 )
1369 }
1370 (_, _, SC::None, _) => {
1373 self.simulate(ets, fit, horizon, level)
1375 }
1376 _ => unimplemented!("Prediction intervals for this model are not implemented yet"),
1378 };
1379 self.0.intervals = Some(ForecastIntervals {
1380 level,
1381 lower,
1382 upper,
1383 });
1384 }
1385
1386 fn compute_intervals(
1391 &self,
1392 level: f64,
1393 sigma_h: impl Iterator<Item = f64>,
1394 ) -> (Vec<f64>, Vec<f64>) {
1395 let z = distrs::Normal::ppf(0.5 + level / 2.0, 0.0, 1.0);
1396 self.0
1397 .point
1398 .iter()
1399 .zip(sigma_h)
1400 .map(|(p, s)| (p - z * s, p + z * s))
1401 .unzip()
1402 }
1403
1404 fn compute_sigma_h(
1407 &self,
1408 sigma: f64,
1409 cvals: impl Iterator<Item = f64>,
1410 horizon: usize,
1411 ) -> Vec<f64> {
1412 let cvals_squared: Vec<_> = cvals.map(|c| c.powi(2)).collect();
1413 let theta =
1414 &self
1416 .0
1417 .point
1418 .iter()
1419 .take(horizon)
1421 .fold(Vec::with_capacity(horizon), |mut acc, p| {
1422 let t = p.powi(2)
1428 + acc
1429 .iter()
1430 .rev()
1431 .zip(&cvals_squared)
1432 .map(|(t, c)| t * c)
1433 .sum::<f64>()
1434 * sigma;
1435 acc.push(t);
1436 acc
1437 });
1438 theta
1439 .iter()
1440 .zip(&self.0.point)
1441 .map(|(t, p)| ((1.0 + sigma) * t - p.powi(2)).sqrt())
1442 .collect()
1443 }
1444
1445 fn simulate(
1446 &self,
1447 ets: &Ets,
1448 fit: &FitState,
1449 horizon: usize,
1450 level: f64,
1451 ) -> (Vec<f64>, Vec<f64>) {
1452 let n_sim = 5000;
1453 let last_state = fit.last_state();
1454 let mut y_path = vec![vec![0.0; horizon]; n_sim];
1455 let params = fit.params();
1456 let beta = if params.beta.is_nan() {
1457 0.0
1458 } else {
1459 params.beta
1460 };
1461 let gamma = if params.gamma.is_nan() {
1462 0.0
1463 } else {
1464 params.gamma
1465 };
1466 let phi = if params.phi.is_nan() { 0.0 } else { params.phi };
1467 let rng = &mut rand::thread_rng();
1468 let normal = Normal::new(0.0, fit.sigma_squared().sqrt()).unwrap();
1469 let mut f = vec![0.0; 10];
1472 for y_path_k in &mut y_path {
1473 let e: Vec<_> = (0..horizon).map(|_| normal.sample(rng)).collect();
1474 ets.etssimulate(
1475 last_state,
1476 Params {
1477 alpha: params.alpha,
1478 beta,
1479 gamma,
1480 phi,
1481 },
1482 &e,
1483 &mut f,
1484 y_path_k,
1485 );
1486 f.iter_mut().for_each(|f| *f = 0.0);
1487 }
1488 y_path
1489 .into_iter()
1490 .map(|mut yhat| {
1491 yhat.sort_by(|a, b| a.partial_cmp(b).unwrap());
1492 (
1493 percentile_of_sorted(&yhat, 0.5 - level / 2.0),
1494 percentile_of_sorted(&yhat, 0.5 + level / 2.0),
1495 )
1496 })
1497 .unzip()
1498 }
1499
1500 fn calculate_in_sample_intervals(&mut self, sigma: f64, level: f64) {
1501 let (lower, upper) = self.compute_intervals(level, std::iter::repeat(sigma));
1502 self.0.intervals = Some(ForecastIntervals {
1503 level,
1504 lower,
1505 upper,
1506 });
1507 }
1508}
1509
1510fn percentile_of_sorted(sorted_samples: &[f64], pct: f64) -> f64 {
1513 assert!(!sorted_samples.is_empty());
1514 if sorted_samples.len() == 1 {
1515 return sorted_samples[0];
1516 }
1517 let zero: f64 = 0.0;
1518 assert!(zero <= pct);
1519 let hundred = 100_f64;
1520 assert!(pct <= hundred);
1521 if pct == hundred {
1522 return sorted_samples[sorted_samples.len() - 1];
1523 }
1524 let length = (sorted_samples.len() - 1) as f64;
1525 let rank = (pct / hundred) * length;
1526 let lrank = rank.floor();
1527 let d = rank - lrank;
1528 let n = lrank as usize;
1529 let lo = sorted_samples[n];
1530 let hi = sorted_samples[n + 1];
1531 lo + (hi - lo) * d
1532}
1533
1534#[cfg(test)]
1535mod test {
1536 use augurs_core::prelude::*;
1537 use augurs_testing::{assert_approx_eq, assert_within_pct, data::AIR_PASSENGERS as AP};
1538
1539 use crate::model::{
1540 ErrorComponent, ForecastIntervals, ModelType, SeasonalComponent, TrendComponent, Unfit,
1541 };
1542
1543 #[test]
1544 fn initial_params() {
1545 let mut unfit = Unfit::new(ModelType {
1546 error: ErrorComponent::Additive,
1547 trend: TrendComponent::None,
1548 season: SeasonalComponent::None,
1549 });
1550 let initial_params = unfit.initial_params();
1551 assert_approx_eq!(initial_params.alpha, 0.20006);
1552 assert!(initial_params.beta.is_nan());
1553 assert!(initial_params.gamma.is_nan());
1554 assert!(initial_params.phi.is_nan());
1555 }
1556
1557 #[test]
1558 fn air_passengers_fit_aan() {
1559 let unfit = Unfit::new(ModelType {
1560 error: ErrorComponent::Additive,
1561 trend: TrendComponent::Additive,
1562 season: SeasonalComponent::None,
1563 })
1564 .damped(true);
1565 let model = unfit.fit(&AP[AP.len() - 20..]).unwrap();
1566 assert_within_pct!(model.log_likelihood(), -109.6248525790271, 0.01);
1567 assert_within_pct!(model.aic(), 231.2497051580542, 0.01);
1568 assert_within_pct!(model.bic(), 237.22409879937817, 0.01);
1569 assert_within_pct!(model.aicc(), 237.71124361959266, 0.01);
1570 assert_within_pct!(model.mse(), 2883.47944444736, 0.01);
1571 assert_within_pct!(model.amse(), 8292.71075580747, 0.01);
1572 }
1573
1574 #[test]
1575 fn air_passengers_fit_man() {
1576 let unfit = Unfit::new(ModelType {
1577 error: ErrorComponent::Multiplicative,
1578 trend: TrendComponent::Additive,
1579 season: SeasonalComponent::None,
1580 });
1581 let model = unfit.fit(AP).unwrap();
1582 assert_within_pct!(model.log_likelihood(), -831.4883541595792, 0.01);
1583 assert_within_pct!(model.aic(), 1672.9767083191584, 0.01);
1584 assert_within_pct!(model.bic(), 1687.8257748170383, 0.01);
1585 assert_within_pct!(model.aicc(), 1673.4114909278542, 0.01);
1586 assert_within_pct!(model.mse(), 1127.443938773091, 0.01);
1587 assert_within_pct!(model.amse(), 2888.3802507845635, 0.01);
1588 }
1589
1590 #[test]
1591 fn air_passengers_forecast_aan() {
1592 let unfit = Unfit::new(ModelType {
1593 error: ErrorComponent::Additive,
1594 trend: TrendComponent::Additive,
1595 season: SeasonalComponent::None,
1596 })
1597 .damped(true);
1598 let model = unfit.fit(&AP[AP.len() - 20..]).unwrap();
1599 let forecasts = model.predict(10, 0.95).unwrap();
1600 let expected_p = [
1601 432.26645246,
1602 432.53827337,
1603 432.75575609,
1604 432.92976307,
1605 433.0689853,
1606 433.18037639,
1607 433.26949992,
1608 433.34080727,
1609 433.39785997,
1610 433.44350758,
1611 ];
1612 assert_eq!(forecasts.point.len(), 10);
1613 for (actual, expected) in forecasts.point.iter().zip(expected_p.iter()) {
1614 assert_approx_eq!(actual, expected);
1615 }
1616
1617 let expected_l = [
1618 301.72457857,
1619 247.92511851,
1620 206.64496117,
1621 171.83062947,
1622 141.14177344,
1623 113.38060224,
1624 87.83698619,
1625 64.04903959,
1626 41.69638225,
1627 20.54598327,
1628 ];
1629 let ForecastIntervals { lower, upper, .. } = forecasts.intervals.unwrap();
1630 assert_eq!(lower.len(), 10);
1631 for (actual, expected) in lower.iter().zip(expected_l.iter()) {
1632 assert_approx_eq!(actual, expected);
1633 }
1634 let expected_u = [
1635 562.80832636,
1636 617.15142823,
1637 658.86655102,
1638 694.02889667,
1639 724.99619716,
1640 752.98015054,
1641 778.70201365,
1642 802.63257495,
1643 825.09933768,
1644 846.34103189,
1645 ];
1646 assert_eq!(upper.len(), 10);
1647 for (actual, expected) in upper.iter().zip(expected_u.iter()) {
1648 assert_approx_eq!(actual, expected);
1649 }
1650 }
1651
1652 #[test]
1653 fn air_passengers_forecast_man() {
1654 let unfit = Unfit::new(ModelType {
1655 error: ErrorComponent::Multiplicative,
1656 trend: TrendComponent::Additive,
1657 season: SeasonalComponent::None,
1658 });
1659 let model = unfit.fit(AP).unwrap();
1660 let forecasts = model.predict(10, 0.95).unwrap();
1661 let expected_p = [
1662 436.15668239,
1663 440.31714837,
1664 444.47761434,
1665 448.63808031,
1666 452.79854629,
1667 456.95901226,
1668 461.11947823,
1669 465.27994421,
1670 469.44041018,
1671 473.60087615,
1672 ];
1673 assert_eq!(forecasts.point.len(), 10);
1674 for (actual, expected) in forecasts.point.iter().zip(expected_p.iter()) {
1675 assert_approx_eq!(actual, expected);
1676 }
1677
1678 let expected_l = [
1679 345.14145884,
1680 310.62430297,
1681 284.42938026,
1682 262.42886479,
1683 243.03658151,
1684 225.44516176,
1685 209.1784846,
1686 193.92853297,
1687 179.48284058,
1688 165.68775958,
1689 ];
1690 let ForecastIntervals { lower, upper, .. } = forecasts.intervals.unwrap();
1691 assert_eq!(lower.len(), 10);
1692 for (actual, expected) in lower.iter().zip(expected_l.iter()) {
1693 assert_approx_eq!(actual, expected);
1694 }
1695 let expected_u = [
1696 527.17190595,
1697 570.00999376,
1698 604.52584842,
1699 634.84729584,
1700 662.56051106,
1701 688.47286276,
1702 713.06047187,
1703 736.63135545,
1704 759.39797978,
1705 781.51399273,
1706 ];
1707 assert_eq!(upper.len(), 10);
1708 for (actual, expected) in upper.iter().zip(expected_u.iter()) {
1709 assert_approx_eq!(actual, expected);
1710 }
1711
1712 let in_sample = model.predict_in_sample(0.95).unwrap();
1714 let expected_p = [
1715 110.74681112,
1716 116.18804955,
1717 122.18817486,
1718 136.18835606,
1719 133.18933724,
1720 125.18861841,
1721 139.18739947,
1722 152.18838061,
1723 152.18926187,
1724 140.18884303,
1725 ];
1726 assert_eq!(in_sample.point.len(), AP.len());
1727 for (actual, expected) in in_sample.point.iter().zip(expected_p.iter()) {
1728 assert_approx_eq!(actual, expected);
1729 }
1730
1731 let ForecastIntervals { lower, upper, .. } = in_sample.intervals.unwrap();
1732 let expected_l = [
1733 43.76306764,
1734 49.20430607,
1735 55.20443139,
1736 69.20461258,
1737 66.20559377,
1738 58.20487493,
1739 72.203656,
1740 85.20463713,
1741 85.20551839,
1742 73.20509956,
1743 ];
1744 assert_eq!(lower.len(), AP.len());
1745 for (actual, expected) in lower.iter().zip(expected_l.iter()) {
1746 assert_approx_eq!(actual, expected);
1747 }
1748 let expected_u = [
1749 177.73055459,
1750 183.17179302,
1751 189.17191834,
1752 203.17209954,
1753 200.17308072,
1754 192.17236188,
1755 206.17114295,
1756 219.17212409,
1757 219.17300535,
1758 207.17258651,
1759 ];
1760 assert_eq!(upper.len(), AP.len());
1761 for (actual, expected) in upper.iter().zip(expected_u.iter()) {
1762 assert_approx_eq!(actual, expected);
1763 }
1764 }
1765
1766 #[test]
1767 fn predict_zero_horizon() {
1768 let unfit = Unfit::new(ModelType {
1769 error: ErrorComponent::Multiplicative,
1770 trend: TrendComponent::Additive,
1771 season: SeasonalComponent::None,
1772 });
1773 let model = unfit.fit(AP).unwrap();
1774 let forecasts = model.predict(0, 0.95).unwrap();
1775 assert!(forecasts.point.is_empty());
1776 let ForecastIntervals { lower, upper, .. } = forecasts.intervals.unwrap();
1777 assert!(lower.is_empty());
1778 assert!(upper.is_empty());
1779 }
1780}