1use std::collections::HashMap;
2
3use augurs_core::FloatIterExt;
4use itertools::{izip, Itertools};
5use rand::{distributions::Uniform, thread_rng, Rng};
6use statrs::distribution::{Laplace, Normal, Poisson};
7
8use crate::{optimizer::OptimizedParams, Error, GrowthType, Prophet, TimestampSeconds};
9
10use super::prep::{ComponentName, Features, FeaturesFrame, Modes, ProcessedData};
11
12#[derive(Debug, Default, Clone)]
19pub struct FeaturePrediction {
20 pub point: Vec<f64>,
22 pub lower: Option<Vec<f64>>,
27 pub upper: Option<Vec<f64>>,
32}
33
34#[derive(Debug, Default)]
35pub(super) struct FeaturePredictions {
36 pub(super) additive: FeaturePrediction,
40 pub(super) multiplicative: FeaturePrediction,
44 pub(super) holidays: HashMap<String, FeaturePrediction>,
46 pub(super) regressors: HashMap<String, FeaturePrediction>,
48 pub(super) seasonalities: HashMap<String, FeaturePrediction>,
50}
51
52#[derive(Debug, Clone)]
62pub struct Predictions {
63 pub ds: Vec<TimestampSeconds>,
65
66 pub yhat: FeaturePrediction,
68
69 pub trend: FeaturePrediction,
71
72 pub cap: Option<Vec<f64>>,
76 pub floor: Option<Vec<f64>>,
81
82 pub additive: FeaturePrediction,
87
88 pub multiplicative: FeaturePrediction,
93
94 pub holidays: HashMap<String, FeaturePrediction>,
96
97 pub seasonalities: HashMap<String, FeaturePrediction>,
99
100 pub regressors: HashMap<String, FeaturePrediction>,
102}
103
104#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum IncludeHistory {
107 Yes,
109 No,
111}
112
113#[derive(Debug)]
114pub(super) struct PosteriorPredictiveSamples {
115 pub(super) yhat: Vec<Vec<f64>>,
116 pub(super) trend: Vec<Vec<f64>>,
117}
118
119impl<O> Prophet<O> {
120 pub(super) fn predict_trend(
122 &self,
123 t: &[f64],
124 cap: &Option<Vec<f64>>,
125 floor: &[f64],
126 changepoints_t: &[f64],
127 params: &OptimizedParams,
128 y_scale: f64,
129 ) -> Result<FeaturePrediction, Error> {
130 let point = match (self.opts.growth, cap) {
131 (GrowthType::Linear, _) => {
132 Self::piecewise_linear(t, ¶ms.delta, params.k, params.m, changepoints_t)
133 .zip(floor)
134 .map(|(trend, flr)| trend * y_scale + flr)
135 .collect_vec()
136 }
137 (GrowthType::Logistic, Some(cap)) => {
138 Self::piecewise_logistic(t, cap, ¶ms.delta, params.k, params.m, changepoints_t)
139 .zip(floor)
140 .map(|(trend, flr)| trend * y_scale + flr)
141 .collect_vec()
142 }
143 (GrowthType::Logistic, None) => return Err(Error::MissingCap),
144 (GrowthType::Flat, _) => Self::flat_trend(t, params.m)
145 .zip(floor)
146 .map(|(trend, flr)| trend * y_scale + flr)
147 .collect_vec(),
148 };
149 Ok(FeaturePrediction {
150 point,
151 lower: None,
152 upper: None,
153 })
154 }
155
156 fn piecewise_linear<'a>(
157 t: &'a [f64],
158 deltas: &'a [f64],
159 k: f64,
160 m: f64,
161 changepoints_t: &'a [f64],
162 ) -> impl Iterator<Item = f64> + 'a {
163 let cp_zipped = deltas.iter().zip(changepoints_t);
167 let deltas_t = cp_zipped
168 .cartesian_product(t)
169 .map(|((delta, cp_t), t)| if cp_t <= t { *delta } else { 0.0 });
170
171 let changepoints_repeated = changepoints_t
173 .iter()
174 .flat_map(|x| std::iter::repeat_n(*x, t.len()));
175 let indexes = (0..t.len()).cycle();
176 let k_m_t = izip!(deltas_t, changepoints_repeated, indexes).fold(
179 vec![(k, m); t.len()],
180 |mut acc, (delta, cp_t, idx)| {
181 acc[idx].0 += delta;
183 acc[idx].1 += -cp_t * delta;
185 acc
186 },
187 );
188
189 izip!(t, k_m_t).map(|(t, (k, m))| t * k + m)
190 }
191
192 fn piecewise_logistic<'a>(
193 t: &'a [f64],
194 cap: &'a [f64],
195 deltas: &'a [f64],
196 k: f64,
197 m: f64,
198 changepoints_t: &'a [f64],
199 ) -> impl Iterator<Item = f64> + 'a {
200 let k_cum = std::iter::once(k)
202 .chain(deltas.iter().scan(k, |state, delta| {
203 *state += delta;
204 Some(*state)
205 }))
206 .collect_vec();
207 let mut gammas = vec![0.0; changepoints_t.len()];
208 let mut gammas_sum = 0.0;
209 for (i, t_s) in changepoints_t.iter().enumerate() {
210 gammas[i] = (t_s - m - gammas_sum) * (1.0 - k_cum[i] / k_cum[i + 1]);
211 gammas_sum += gammas[i];
212 }
213
214 let mut k_t = vec![k; t.len()];
216 let mut m_t = vec![m; t.len()];
217 for (s, t_s) in changepoints_t.iter().enumerate() {
218 for (i, t_i) in t.iter().enumerate() {
219 if t_i >= t_s {
220 k_t[i] += deltas[s];
221 m_t[i] += gammas[s];
222 }
223 }
224 }
225
226 izip!(cap, t, k_t, m_t).map(|(cap, t, k, m)| cap / (1.0 + (-k * (t - m)).exp()))
227 }
228
229 fn flat_trend(t: &[f64], m: f64) -> impl Iterator<Item = f64> {
231 std::iter::repeat_n(m, t.len())
232 }
233
234 pub(super) fn predict_features(
236 &self,
237 features: &Features,
238 params: &OptimizedParams,
239 y_scale: f64,
240 ) -> Result<FeaturePredictions, Error> {
241 let Features {
242 features,
243 component_columns,
244 modes,
245 ..
246 } = features;
247 let predict_feature = |col, f: fn(String) -> ComponentName| {
248 Self::predict_components(col, &features.data, ¶ms.beta, y_scale, modes, f)
249 };
250 Ok(FeaturePredictions {
251 additive: Self::predict_feature(
252 &component_columns.additive,
253 &features.data,
254 ¶ms.beta,
255 y_scale,
256 true,
257 ),
258 multiplicative: Self::predict_feature(
259 &component_columns.multiplicative,
260 &features.data,
261 ¶ms.beta,
262 y_scale,
263 false,
264 ),
265 holidays: predict_feature(&component_columns.holidays, ComponentName::Holiday),
266 seasonalities: predict_feature(
267 &component_columns.seasonalities,
268 ComponentName::Seasonality,
269 ),
270 regressors: predict_feature(&component_columns.regressors, ComponentName::Regressor),
271 })
272 }
273
274 fn predict_components(
275 component_columns: &HashMap<String, Vec<i32>>,
276 #[allow(non_snake_case)] X: &[Vec<f64>],
277 beta: &[f64],
278 y_scale: f64,
279 modes: &Modes,
280 make_mode: impl Fn(String) -> ComponentName,
281 ) -> HashMap<String, FeaturePrediction> {
282 component_columns
283 .iter()
284 .map(|(name, component_col)| {
285 (
286 name.clone(),
287 Self::predict_feature(
288 component_col,
289 X,
290 beta,
291 y_scale,
292 modes.additive.contains(&make_mode(name.clone())),
293 ),
294 )
295 })
296 .collect()
297 }
298
299 pub(super) fn predict_feature(
300 component_col: &[i32],
301 #[allow(non_snake_case)] X: &[Vec<f64>],
302 beta: &[f64],
303 y_scale: f64,
304 is_additive: bool,
305 ) -> FeaturePrediction {
306 let beta_c = component_col
307 .iter()
308 .copied()
309 .zip(beta)
310 .map(|(x, b)| x as f64 * b)
311 .collect_vec();
312 let mut point = vec![0.0; X[0].len()];
314 for (feature, b) in izip!(X, beta_c) {
315 for (p, x) in izip!(point.iter_mut(), feature) {
316 *p += b * x;
317 }
318 }
319 if is_additive {
320 point.iter_mut().for_each(|x| *x *= y_scale);
321 }
322 FeaturePrediction {
323 point,
324 lower: None,
325 upper: None,
326 }
327 }
328
329 #[allow(clippy::too_many_arguments)]
330 pub(super) fn predict_uncertainty(
331 &self,
332 df: &ProcessedData,
333 features: &Features,
334 params: &OptimizedParams,
335 changepoints_t: &[f64],
336 yhat: &mut FeaturePrediction,
337 trend: &mut FeaturePrediction,
338 y_scale: f64,
339 ) -> Result<(), Error> {
340 let mut sim_values =
341 self.sample_posterior_predictive(df, features, params, changepoints_t, y_scale)?;
342 let lower_p = 100.0 * (1.0 - *self.opts.interval_width) / 2.0;
343 let upper_p = 100.0 * (1.0 + *self.opts.interval_width) / 2.0;
344
345 let mut yhat_lower = Vec::with_capacity(df.ds.len());
346 let mut yhat_upper = Vec::with_capacity(df.ds.len());
347 let mut trend_lower = Vec::with_capacity(df.ds.len());
348 let mut trend_upper = Vec::with_capacity(df.ds.len());
349
350 for (yhat_samples, trend_samples) in
351 sim_values.yhat.iter_mut().zip(sim_values.trend.iter_mut())
352 {
353 yhat_samples
355 .sort_unstable_by(|a, b| a.partial_cmp(b).expect("found NaN in yhat sample"));
356 trend_samples
357 .sort_unstable_by(|a, b| a.partial_cmp(b).expect("found NaN in yhat sample"));
358 yhat_lower.push(percentile_of_sorted(yhat_samples, lower_p));
359 yhat_upper.push(percentile_of_sorted(yhat_samples, upper_p));
360 trend_lower.push(percentile_of_sorted(trend_samples, lower_p));
361 trend_upper.push(percentile_of_sorted(trend_samples, upper_p));
362 }
363 yhat.lower = Some(yhat_lower);
364 yhat.upper = Some(yhat_upper);
365 trend.lower = Some(trend_lower);
366 trend.upper = Some(trend_upper);
367 Ok(())
368 }
369
370 pub(super) fn sample_posterior_predictive(
372 &self,
373 df: &ProcessedData,
374 features: &Features,
375 params: &OptimizedParams,
376 changepoints_t: &[f64],
377 y_scale: f64,
378 ) -> Result<PosteriorPredictiveSamples, Error> {
379 let n_iterations = 1;
381 let samples_per_iter = usize::max(
382 1,
383 (self.opts.uncertainty_samples as f64 / n_iterations as f64).ceil() as usize,
384 );
385 let Features {
386 features,
387 component_columns,
388 ..
389 } = features;
390 let n_timestamps = df.ds.len();
397 let n_samples = samples_per_iter * n_iterations;
398 let mut sim_values = PosteriorPredictiveSamples {
399 yhat: std::iter::repeat_with(|| Vec::with_capacity(n_samples))
400 .take(n_timestamps)
401 .collect_vec(),
402 trend: std::iter::repeat_with(|| Vec::with_capacity(n_samples))
403 .take(n_timestamps)
404 .collect_vec(),
405 };
406 let (mut yhat, mut trend) = (
409 Vec::with_capacity(n_timestamps),
410 Vec::with_capacity(n_timestamps),
411 );
412 for i in 0..n_iterations {
413 for _ in 0..samples_per_iter {
414 self.sample_model(
415 df,
416 features,
417 params,
418 changepoints_t,
419 &component_columns.additive,
420 &component_columns.multiplicative,
421 y_scale,
422 i,
423 &mut yhat,
424 &mut trend,
425 )?;
426 for ((i, yhat), trend) in yhat.iter().enumerate().zip(&trend) {
428 sim_values.yhat[i].push(*yhat);
429 sim_values.trend[i].push(*trend);
430 }
431 }
432 }
433 debug_assert_eq!(sim_values.yhat.len(), n_timestamps);
434 debug_assert_eq!(sim_values.trend.len(), n_timestamps);
435 Ok(sim_values)
436 }
437
438 #[allow(clippy::too_many_arguments)]
440 fn sample_model(
441 &self,
442 df: &ProcessedData,
443 features: &FeaturesFrame,
444 params: &OptimizedParams,
445 changepoints_t: &[f64],
446 additive: &[i32],
447 multiplicative: &[i32],
448 y_scale: f64,
449 iteration: usize,
450 yhat_tmp: &mut Vec<f64>,
451 trend_tmp: &mut Vec<f64>,
452 ) -> Result<(), Error> {
453 yhat_tmp.clear();
454 trend_tmp.clear();
455 let n = df.ds.len();
456 *trend_tmp =
457 self.sample_predictive_trend(df, params, changepoints_t, y_scale, iteration)?;
458 let beta = ¶ms.beta;
459 let mut xb_a = vec![0.0; n];
460 for (feature, b, a) in izip!(&features.data, beta, additive) {
461 for (p, x) in izip!(&mut xb_a, feature) {
462 *p += x * b * *a as f64;
463 }
464 }
465 xb_a.iter_mut().for_each(|x| *x *= y_scale);
466 let mut xb_m = vec![0.0; n];
467 for (feature, b, m) in izip!(&features.data, beta, multiplicative) {
468 for (p, x) in izip!(&mut xb_m, feature) {
469 *p += x * b * *m as f64;
470 }
471 }
472
473 let sigma = params.sigma_obs;
474 let dist = Normal::new(0.0, *sigma).expect("sigma must be non-negative");
475 let mut rng = thread_rng();
476 let noise = (&mut rng).sample_iter(dist).take(n).map(|x| x * y_scale);
477
478 for yhat in izip!(trend_tmp, &xb_a, &xb_m, noise).map(|(t, a, m, n)| *t * (1.0 + m) + a + n)
479 {
480 yhat_tmp.push(yhat);
481 }
482
483 Ok(())
484 }
485
486 fn sample_predictive_trend(
487 &self,
488 df: &ProcessedData,
489 params: &OptimizedParams,
490 changepoints_t: &[f64],
491 y_scale: f64,
492 _iteration: usize, ) -> Result<Vec<f64>, Error> {
494 let deltas = ¶ms.delta;
495
496 let t_max = df.t.iter().copied().nanmax(true);
497
498 let mut rng = thread_rng();
499
500 let n_changes = if t_max > 1.0 {
501 let n_cp = changepoints_t.len() as i32;
503 let lambda = n_cp as f64 * (t_max - 1.0);
504 let dist = Poisson::new(lambda).expect("Valid Poisson distribution");
506 rng.sample::<f64, _>(dist).round() as usize
507 } else {
508 0
509 };
510 let changepoints_t_new = if n_changes > 0 {
511 let mut cp_t_new = (&mut rng)
512 .sample_iter(Uniform::new(0.0, t_max - 1.0))
513 .take(n_changes)
514 .map(|x| x + 1.0)
515 .collect_vec();
516 cp_t_new.sort_unstable_by(|a, b| {
517 a.partial_cmp(b)
518 .expect("uniform distribution should not sample NaNs")
519 });
520 cp_t_new
521 } else {
522 vec![]
523 };
524
525 let mut lambda = deltas.iter().map(|x| x.abs()).nanmean(false) + 1e-8;
527 if lambda.is_nan() {
528 lambda = 1e-8;
529 }
530 let dist = Laplace::new(0.0, lambda).expect("Valid Laplace distribution");
533 let deltas_new = rng.sample_iter(dist).take(n_changes);
534
535 let all_changepoints_t = changepoints_t
537 .iter()
538 .copied()
539 .chain(changepoints_t_new)
540 .collect_vec();
541 let all_deltas = deltas.iter().copied().chain(deltas_new).collect_vec();
542
543 let new_params = OptimizedParams {
545 delta: all_deltas,
546 ..params.clone()
547 };
548 let trend = self.predict_trend(
549 &df.t,
550 &df.cap_scaled,
551 &df.floor,
552 &all_changepoints_t,
553 &new_params,
554 y_scale,
555 )?;
556 Ok(trend.point)
557 }
558}
559
560fn percentile_of_sorted(sorted_samples: &[f64], pct: f64) -> f64 {
563 assert!(!sorted_samples.is_empty());
564 if sorted_samples.len() == 1 {
565 return sorted_samples[0];
566 }
567 let zero: f64 = 0.0;
568 assert!(zero <= pct);
569 let hundred = 100_f64;
570 assert!(pct <= hundred);
571 if pct == hundred {
572 return sorted_samples[sorted_samples.len() - 1];
573 }
574 let length = (sorted_samples.len() - 1) as f64;
575 let rank = (pct / hundred) * length;
576 let lrank = rank.floor();
577 let d = rank - lrank;
578 let n = lrank as usize;
579 let lo = sorted_samples[n];
580 let hi = sorted_samples[n + 1];
581 lo + (hi - lo) * d
582}
583
584#[cfg(test)]
585mod test {
586 use augurs_testing::{assert_all_close, assert_approx_eq};
587 use itertools::Itertools;
588
589 use crate::{
590 optimizer::{mock_optimizer::MockOptimizer, OptimizedParams},
591 testdata::{daily_univariate_ts, train_test_splitn},
592 IncludeHistory, Prophet, ProphetOptions,
593 };
594
595 #[test]
596 fn piecewise_linear() {
597 let t = (0..11).map(f64::from).collect_vec();
598 let m = 0.0;
599 let k = 1.0;
600 let deltas = vec![0.5];
601 let changepoints_t = vec![5.0];
602 let y = Prophet::<()>::piecewise_linear(&t, &deltas, k, m, &changepoints_t).collect_vec();
603 let y_true = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.5, 8.0, 9.5, 11.0, 12.5];
604 assert_eq!(y, y_true);
605
606 let y =
607 Prophet::<()>::piecewise_linear(&t[8..], &deltas, k, m, &changepoints_t).collect_vec();
608 assert_eq!(y, y_true[8..]);
609
610 let deltas = vec![0.4, 0.5];
613 let changepoints_t = vec![4.0, 8.0];
614 let y = Prophet::<()>::piecewise_linear(&t, &deltas, k, m, &changepoints_t).collect_vec();
615 let y_true = &[0.0, 1.0, 2.0, 3.0, 4.0, 5.4, 6.8, 8.2, 9.6, 11.5, 13.4];
616 for (a, b) in y.iter().zip(y_true) {
617 assert_approx_eq!(a, b);
618 }
619 }
620
621 #[test]
622 fn piecewise_logistic() {
623 let t = (0..11).map(f64::from).collect_vec();
624 let cap = vec![10.0; 11];
625 let m = 0.0;
626 let k = 1.0;
627 let deltas = vec![0.5];
628 let changepoints_t = vec![5.0];
629 let y = Prophet::<()>::piecewise_logistic(&t, &cap, &deltas, k, m, &changepoints_t)
630 .collect_vec();
631 let y_true = &[
632 5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071, 9.984988, 9.996646,
633 9.999252, 9.999833, 9.999963,
634 ];
635 for (a, b) in y.iter().zip(y_true) {
636 assert_approx_eq!(a, b);
637 }
638
639 let y =
640 Prophet::<()>::piecewise_logistic(&t[8..], &cap[8..], &deltas, k, m, &changepoints_t)
641 .collect_vec();
642 for (a, b) in y.iter().zip(&y_true[8..]) {
643 assert_approx_eq!(a, b);
644 }
645
646 let deltas = vec![0.4, 0.5];
649 let changepoints_t = vec![4.0, 8.0];
650 let y = Prophet::<()>::piecewise_logistic(&t, &cap, &deltas, k, m, &changepoints_t)
651 .collect_vec();
652 let y_true = &[
653 5., 7.31058579, 8.80797078, 9.52574127, 9.8201379, 9.95503727, 9.98887464, 9.99725422,
654 9.99932276, 9.9998987, 9.99998485,
655 ];
656 for (a, b) in y.iter().zip(y_true) {
657 assert_approx_eq!(a, b);
658 }
659 }
660
661 #[test]
662 fn flat_trend() {
663 let t = (0..11).map(f64::from).collect_vec();
664 let m = 0.5;
665 let y = Prophet::<()>::flat_trend(&t, m).collect_vec();
666 assert_all_close(&y, &[0.5; 11]);
667
668 let y = Prophet::<()>::flat_trend(&t[8..], m).collect_vec();
669 assert_all_close(&y, &[0.5; 3]);
670 }
671
672 #[test]
680 fn predict_absmax() {
681 let test_days = 30;
682 let (train, test) = train_test_splitn(daily_univariate_ts(), test_days);
683 let opts = ProphetOptions {
684 scaling: crate::Scaling::AbsMax,
685 ..Default::default()
686 };
687 let opt = MockOptimizer::new();
688 let mut prophet = Prophet::new(opts, opt);
689 prophet.fit(train.clone(), Default::default()).unwrap();
690
691 prophet.optimized = Some(OptimizedParams {
694 k: -1.01136,
695 m: 0.460947,
696 sigma_obs: 0.0451108.try_into().unwrap(),
697 beta: vec![
698 0.0205064,
699 -0.0129451,
700 -0.0164735,
701 -0.00275837,
702 0.00333371,
703 0.00599414,
704 ],
705 delta: vec![
706 3.51708e-08,
707 1.17925e-09,
708 -2.91421e-09,
709 2.06189e-01,
710 9.06870e-01,
711 4.49113e-01,
712 1.94664e-03,
713 -1.16088e-09,
714 -5.75394e-08,
715 -7.90284e-06,
716 -6.74530e-01,
717 -5.70814e-02,
718 -4.91360e-08,
719 -3.53111e-09,
720 1.42645e-08,
721 4.50809e-05,
722 8.86286e-01,
723 1.14535e+00,
724 4.40539e-02,
725 8.17306e-09,
726 -1.57715e-07,
727 -5.15430e-01,
728 -3.15001e-01,
729 1.14429e-08,
730 -2.56863e-09,
731 ],
732 trend: vec![
733 0.460947, 0.4566, 0.455151, 0.453703, 0.452254, 0.450805, 0.445009, 0.44356,
734 0.442111, 0.440662, 0.436315, 0.434866, 0.433417, 0.431968, 0.430519, 0.426173,
735 0.424724, 0.423275, 0.421826, 0.420377, 0.41603, 0.414581, 0.413132, 0.411683,
736 0.410234, 0.405887, 0.404438, 0.402989, 0.40154, 0.400092, 0.395745, 0.394296,
737 0.391398, 0.389949, 0.385602, 0.384153, 0.382704, 0.381255, 0.379806, 0.375459,
738 0.374011, 0.372562, 0.371113, 0.369664, 0.365317, 0.363868, 0.362419, 0.36097,
739 0.359521, 0.355174, 0.353725, 0.352276, 0.350827, 0.349378, 0.345032, 0.343583,
740 0.342134, 0.340685, 0.339236, 0.334889, 0.33344, 0.331991, 0.330838, 0.329684,
741 0.326223, 0.32507, 0.323916, 0.322763, 0.321609, 0.318149, 0.316995, 0.315841,
742 0.314688, 0.313534, 0.30892, 0.307767, 0.306613, 0.30546, 0.305897, 0.306042,
743 0.306188, 0.306334, 0.306479, 0.306916, 0.307062, 0.307208, 0.307354, 0.307499,
744 0.307936, 0.308082, 0.308228, 0.308373, 0.308519, 0.310886, 0.311676, 0.312465,
745 0.313254, 0.314043, 0.31641, 0.317199, 0.317989, 0.318778, 0.319567, 0.321934,
746 0.322723, 0.323512, 0.324302, 0.325091, 0.327466, 0.328258, 0.32905, 0.329842,
747 0.330634, 0.334594, 0.335386, 0.336177, 0.338553, 0.339345, 0.340137, 0.340929,
748 0.341721, 0.344097, 0.344888, 0.34568, 0.346472, 0.347264, 0.34964, 0.350432,
749 0.351224, 0.352808, 0.355183, 0.355975, 0.356767, 0.357559, 0.358351, 0.360727,
750 0.361519, 0.362311, 0.363102, 0.363894, 0.36627, 0.367062, 0.367854, 0.368646,
751 0.369438, 0.371813, 0.372605, 0.373397, 0.374189, 0.374981, 0.377357, 0.378941,
752 0.379733, 0.380524, 0.3829, 0.384484, 0.385276, 0.386068, 0.388443, 0.389235,
753 0.390027, 0.390819, 0.391611, 0.393987, 0.394779, 0.395571, 0.396362, 0.397154,
754 0.400322, 0.401114, 0.400939, 0.400765, 0.400242, 0.400067, 0.399893, 0.399718,
755 0.399544, 0.39902, 0.398846, 0.398671, 0.398497, 0.398322, 0.397799, 0.397624,
756 0.39745, 0.397194, 0.396937, 0.395912, 0.395656, 0.3954, 0.395144, 0.394375,
757 0.394119, 0.393862, 0.393606, 0.39335, 0.392581, 0.392325, 0.392069, 0.391812,
758 0.391556, 0.390787, 0.390531, 0.390275, 0.390019, 0.389762, 0.388994, 0.388737,
759 0.388481, 0.388225, 0.387968, 0.3872, 0.386943, 0.386687, 0.386431, 0.385406,
760 0.38515, 0.384893, 0.384637, 0.384381, 0.383612, 0.383356, 0.3831, 0.382843,
761 0.382587, 0.381818, 0.381562, 0.381306, 0.38105, 0.380793, 0.380025, 0.379768,
762 0.379512, 0.379256, 0.379, 0.378231, 0.377975, 0.377718, 0.377462, 0.377206,
763 0.376437, 0.376181, 0.375925, 0.375668, 0.375412, 0.374643, 0.374387, 0.374131,
764 0.373875, 0.373619, 0.37285, 0.372594, 0.372338, 0.372081, 0.371825, 0.3708,
765 0.370544, 0.370288, 0.370032, 0.369263, 0.369007, 0.370021, 0.371034, 0.372048,
766 0.375088, 0.376102, 0.377116, 0.378129, 0.379143, 0.382183, 0.383197, 0.384211,
767 0.385224, 0.386238, 0.389278, 0.390292, 0.391305, 0.39396, 0.396614, 0.404578,
768 0.407232, 0.409887, 0.415196, 0.423159, 0.425813, 0.428468, 0.431122, 0.433777,
769 0.44174, 0.444395, 0.447049, 0.449704, 0.452421, 0.460574, 0.463291, 0.466009,
770 0.468727, 0.471444, 0.479597, 0.482314, 0.485032, 0.48775, 0.490467, 0.49862,
771 0.501337, 0.504055, 0.506773, 0.50949, 0.517643, 0.520361, 0.523078, 0.525796,
772 0.528513, 0.536666, 0.539384, 0.542101, 0.544819, 0.547536, 0.555689, 0.558407,
773 0.561124, 0.563842, 0.566559, 0.57743, 0.580147, 0.582865, 0.585582, 0.593735,
774 0.596453, 0.59917, 0.601888, 0.604605, 0.612758, 0.615476, 0.618193, 0.620911,
775 0.623628, 0.631781, 0.63376, 0.635739, 0.637719, 0.639698, 0.645635, 0.647614,
776 0.649593, 0.651572, 0.653552, 0.659489, 0.661468, 0.663447, 0.665426, 0.667406,
777 0.673343, 0.674871, 0.676399, 0.677926, 0.679454, 0.684038, 0.685566, 0.687094,
778 0.688621, 0.690149, 0.694733, 0.696261, 0.697788, 0.699316, 0.700844, 0.705428,
779 0.706956, 0.708483, 0.710011, 0.711539, 0.716123, 0.71765, 0.719178, 0.720706,
780 0.722234, 0.726818, 0.728345, 0.729873, 0.731401, 0.732929, 0.737512, 0.73904,
781 0.740568, 0.743624, 0.748207, 0.749735, 0.751263, 0.752791, 0.754319, 0.758902,
782 0.76043, 0.761958, 0.763486, 0.765014, 0.769597, 0.771125, 0.772653, 0.774181,
783 0.775709, 0.780292, 0.78182, 0.784876, 0.786404, 0.790987, 0.792515, 0.795571,
784 0.797098, 0.801682, 0.80321, 0.804738, 0.806265, 0.807793, 0.812377, 0.813905,
785 0.815433, 0.81696, 0.818488, 0.8246, 0.826127, 0.827655, 0.829183, 0.833767,
786 0.835295, 0.836822, 0.83835, 0.839878, 0.844462, 0.845989, 0.847517, 0.849045,
787 0.850573, 0.855157, 0.856684, 0.858212, 0.85974, 0.861268, 0.867379, 0.868907,
788 0.870435, 0.871963, 0.876546, 0.878074, 0.879602, 0.88113, 0.882658, 0.887241,
789 0.888769, 0.890297, 0.891825, 0.893353, 0.897936, 0.899464, 0.900992, 0.90252,
790 0.904048, 0.908631, 0.910159, 0.911687, 0.913215, 0.914743, 0.919326, 0.920854,
791 0.922382, 0.92391, 0.925437, 0.930021, 0.931549, 0.933077, 0.934604, 0.936132,
792 0.940716, 0.942244, 0.943772, 0.945299, 0.946827, 0.951411, 0.952939, 0.954466,
793 ],
794 });
795 let future = prophet
796 .make_future_dataframe((test_days as u32).try_into().unwrap(), IncludeHistory::No)
797 .unwrap();
798 let predictions = prophet.predict(future).unwrap();
799 assert_eq!(predictions.yhat.point.len(), test_days);
800 let rmse = (predictions
801 .yhat
802 .point
803 .iter()
804 .zip(&test.y)
805 .map(|(a, b)| (a - b).powi(2))
806 .sum::<f64>()
807 / test.y.len() as f64)
808 .sqrt();
809 assert_approx_eq!(rmse, 10.64, 1e-1);
810
811 let lower = predictions.yhat.lower.as_ref().unwrap();
812 let upper = predictions.yhat.upper.as_ref().unwrap();
813 assert_eq!(lower.len(), predictions.yhat.point.len());
814 for (lower_bound, point_estimate) in lower.iter().zip(&predictions.yhat.point) {
815 assert!(
816 lower_bound <= point_estimate,
817 "Lower bound should be less than the point estimate"
818 );
819 }
820 for (upper_bound, point_estimate) in upper.iter().zip(&predictions.yhat.point) {
821 assert!(
822 upper_bound >= point_estimate,
823 "Upper bound should be greater than the point estimate"
824 );
825 }
826 }
827}