Skip to main content

augurs_prophet/prophet/
predict.rs

1use std::collections::HashMap;
2
3use augurs_core::FloatIterExt;
4use itertools::{izip, Itertools};
5use rand::{distributions::Uniform, thread_rng, Rng};
6use statrs::distribution::{Laplace, Normal, Poisson};
7
8use crate::{optimizer::OptimizedParams, Error, GrowthType, Prophet, TimestampSeconds};
9
10use super::prep::{ComponentName, Features, FeaturesFrame, Modes, ProcessedData};
11
12/// The prediction for a feature.
13///
14/// 'Feature' could refer to the forecasts themselves (`yhat`)
15/// or any of the other component features which contribute to
16/// the final estimate, such as trend, seasonality, seasonalities,
17/// regressors or holidays.
18#[derive(Debug, Default, Clone)]
19pub struct FeaturePrediction {
20    /// The point estimate for this feature.
21    pub point: Vec<f64>,
22    /// The lower estimate for this feature.
23    ///
24    /// Only present if `uncertainty_samples` was greater than zero
25    /// when the model was created.
26    pub lower: Option<Vec<f64>>,
27    /// The upper estimate for this feature.
28    ///
29    /// Only present if `uncertainty_samples` was greater than zero
30    /// when the model was created.
31    pub upper: Option<Vec<f64>>,
32}
33
34#[derive(Debug, Default)]
35pub(super) struct FeaturePredictions {
36    /// Contribution of the additive terms in the model.
37    ///
38    /// This includes additive seasonalities, holidays and regressors.
39    pub(super) additive: FeaturePrediction,
40    /// Contribution of the multiplicative terms in the model.
41    ///
42    /// This includes multiplicative seasonalities, holidays and regressors.
43    pub(super) multiplicative: FeaturePrediction,
44    /// Mapping from holiday name to the contribution of that holiday.
45    pub(super) holidays: HashMap<String, FeaturePrediction>,
46    /// Mapping from regressor name to the contribution of that regressor.
47    pub(super) regressors: HashMap<String, FeaturePrediction>,
48    /// Mapping from seasonality name to the contribution of that seasonality.
49    pub(super) seasonalities: HashMap<String, FeaturePrediction>,
50}
51
52/// Predictions from a Prophet model.
53///
54/// The `yhat` field contains the forecasts for the input time series.
55/// All other fields contain individual components of the model which
56/// contribute towards the final `yhat` estimate.
57///
58/// Certain fields (such as `cap` and `floor`) may be `None` if the
59/// model did not use them (e.g. the model was not configured to use
60/// logistic trend).
61#[derive(Debug, Clone)]
62pub struct Predictions {
63    /// The timestamps of the forecasts.
64    pub ds: Vec<TimestampSeconds>,
65
66    /// Forecasts of the input time series `y`.
67    pub yhat: FeaturePrediction,
68
69    /// The trend contribution at each time point.
70    pub trend: FeaturePrediction,
71
72    /// The cap for the logistic growth.
73    ///
74    /// Will only be `Some` if the model used [`GrowthType::Logistic`](crate::GrowthType::Logistic).
75    pub cap: Option<Vec<f64>>,
76    /// The floor for the logistic growth.
77    ///
78    /// Will only be `Some` if the model used [`GrowthType::Logistic`](crate::GrowthType::Logistic)
79    /// and the floor was provided in the input data.
80    pub floor: Option<Vec<f64>>,
81
82    /// The combined combination of all _additive_ components.
83    ///
84    /// This includes seasonalities, holidays and regressors if their mode
85    /// was configured to be [`FeatureMode::Additive`](crate::FeatureMode::Additive).
86    pub additive: FeaturePrediction,
87
88    /// The combined combination of all _multiplicative_ components.
89    ///
90    /// This includes seasonalities, holidays and regressors if their mode
91    /// was configured to be [`FeatureMode::Multiplicative`](crate::FeatureMode::Multiplicative).
92    pub multiplicative: FeaturePrediction,
93
94    /// Mapping from holiday name to that holiday's contribution.
95    pub holidays: HashMap<String, FeaturePrediction>,
96
97    /// Mapping from seasonality name to that seasonality's contribution.
98    pub seasonalities: HashMap<String, FeaturePrediction>,
99
100    /// Mapping from regressor name to that regressor's contribution.
101    pub regressors: HashMap<String, FeaturePrediction>,
102}
103
104/// Whether to include the historical dates in the future dataframe for predictions.
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum IncludeHistory {
107    /// Include the historical dates in the future dataframe.
108    Yes,
109    /// Do not include the historical dates in the future data frame.
110    No,
111}
112
113#[derive(Debug)]
114pub(super) struct PosteriorPredictiveSamples {
115    pub(super) yhat: Vec<Vec<f64>>,
116    pub(super) trend: Vec<Vec<f64>>,
117}
118
119impl<O> Prophet<O> {
120    /// Predict trend.
121    pub(super) fn predict_trend(
122        &self,
123        t: &[f64],
124        cap: &Option<Vec<f64>>,
125        floor: &[f64],
126        changepoints_t: &[f64],
127        params: &OptimizedParams,
128        y_scale: f64,
129    ) -> Result<FeaturePrediction, Error> {
130        let point = match (self.opts.growth, cap) {
131            (GrowthType::Linear, _) => {
132                Self::piecewise_linear(t, &params.delta, params.k, params.m, changepoints_t)
133                    .zip(floor)
134                    .map(|(trend, flr)| trend * y_scale + flr)
135                    .collect_vec()
136            }
137            (GrowthType::Logistic, Some(cap)) => {
138                Self::piecewise_logistic(t, cap, &params.delta, params.k, params.m, changepoints_t)
139                    .zip(floor)
140                    .map(|(trend, flr)| trend * y_scale + flr)
141                    .collect_vec()
142            }
143            (GrowthType::Logistic, None) => return Err(Error::MissingCap),
144            (GrowthType::Flat, _) => Self::flat_trend(t, params.m)
145                .zip(floor)
146                .map(|(trend, flr)| trend * y_scale + flr)
147                .collect_vec(),
148        };
149        Ok(FeaturePrediction {
150            point,
151            lower: None,
152            upper: None,
153        })
154    }
155
156    fn piecewise_linear<'a>(
157        t: &'a [f64],
158        deltas: &'a [f64],
159        k: f64,
160        m: f64,
161        changepoints_t: &'a [f64],
162    ) -> impl Iterator<Item = f64> + 'a {
163        // `deltas_t` is a contiguous array with the changepoint delta to apply
164        // delta at each time point; it has a stride of `changepoints_t.len()`,
165        // since it's a 2D array in the numpy version.
166        let cp_zipped = deltas.iter().zip(changepoints_t);
167        let deltas_t = cp_zipped
168            .cartesian_product(t)
169            .map(|((delta, cp_t), t)| if cp_t <= t { *delta } else { 0.0 });
170
171        // Repeat each changepoint effect `n` times so we can zip it up.
172        let changepoints_repeated = changepoints_t
173            .iter()
174            .flat_map(|x| std::iter::repeat_n(*x, t.len()));
175        let indexes = (0..t.len()).cycle();
176        // `k_m_t` is a contiguous array where each element contains the rate and offset to
177        // apply at each time point.
178        let k_m_t = izip!(deltas_t, changepoints_repeated, indexes).fold(
179            vec![(k, m); t.len()],
180            |mut acc, (delta, cp_t, idx)| {
181                // Add the changepoint rate to the initial rate.
182                acc[idx].0 += delta;
183                // Add the changepoint offset to the initial offset where applicable.
184                acc[idx].1 += -cp_t * delta;
185                acc
186            },
187        );
188
189        izip!(t, k_m_t).map(|(t, (k, m))| t * k + m)
190    }
191
192    fn piecewise_logistic<'a>(
193        t: &'a [f64],
194        cap: &'a [f64],
195        deltas: &'a [f64],
196        k: f64,
197        m: f64,
198        changepoints_t: &'a [f64],
199    ) -> impl Iterator<Item = f64> + 'a {
200        // Compute offset changes.
201        let k_cum = std::iter::once(k)
202            .chain(deltas.iter().scan(k, |state, delta| {
203                *state += delta;
204                Some(*state)
205            }))
206            .collect_vec();
207        let mut gammas = vec![0.0; changepoints_t.len()];
208        let mut gammas_sum = 0.0;
209        for (i, t_s) in changepoints_t.iter().enumerate() {
210            gammas[i] = (t_s - m - gammas_sum) * (1.0 - k_cum[i] / k_cum[i + 1]);
211            gammas_sum += gammas[i];
212        }
213
214        // Get cumulative rate and offset at each time point.
215        let mut k_t = vec![k; t.len()];
216        let mut m_t = vec![m; t.len()];
217        for (s, t_s) in changepoints_t.iter().enumerate() {
218            for (i, t_i) in t.iter().enumerate() {
219                if t_i >= t_s {
220                    k_t[i] += deltas[s];
221                    m_t[i] += gammas[s];
222                }
223            }
224        }
225
226        izip!(cap, t, k_t, m_t).map(|(cap, t, k, m)| cap / (1.0 + (-k * (t - m)).exp()))
227    }
228
229    /// Evaluate the flat trend function.
230    fn flat_trend(t: &[f64], m: f64) -> impl Iterator<Item = f64> {
231        std::iter::repeat_n(m, t.len())
232    }
233
234    /// Predict seasonality, holidays and added regressors.
235    pub(super) fn predict_features(
236        &self,
237        features: &Features,
238        params: &OptimizedParams,
239        y_scale: f64,
240    ) -> Result<FeaturePredictions, Error> {
241        let Features {
242            features,
243            component_columns,
244            modes,
245            ..
246        } = features;
247        let predict_feature = |col, f: fn(String) -> ComponentName| {
248            Self::predict_components(col, &features.data, &params.beta, y_scale, modes, f)
249        };
250        Ok(FeaturePredictions {
251            additive: Self::predict_feature(
252                &component_columns.additive,
253                &features.data,
254                &params.beta,
255                y_scale,
256                true,
257            ),
258            multiplicative: Self::predict_feature(
259                &component_columns.multiplicative,
260                &features.data,
261                &params.beta,
262                y_scale,
263                false,
264            ),
265            holidays: predict_feature(&component_columns.holidays, ComponentName::Holiday),
266            seasonalities: predict_feature(
267                &component_columns.seasonalities,
268                ComponentName::Seasonality,
269            ),
270            regressors: predict_feature(&component_columns.regressors, ComponentName::Regressor),
271        })
272    }
273
274    fn predict_components(
275        component_columns: &HashMap<String, Vec<i32>>,
276        #[allow(non_snake_case)] X: &[Vec<f64>],
277        beta: &[f64],
278        y_scale: f64,
279        modes: &Modes,
280        make_mode: impl Fn(String) -> ComponentName,
281    ) -> HashMap<String, FeaturePrediction> {
282        component_columns
283            .iter()
284            .map(|(name, component_col)| {
285                (
286                    name.clone(),
287                    Self::predict_feature(
288                        component_col,
289                        X,
290                        beta,
291                        y_scale,
292                        modes.additive.contains(&make_mode(name.clone())),
293                    ),
294                )
295            })
296            .collect()
297    }
298
299    pub(super) fn predict_feature(
300        component_col: &[i32],
301        #[allow(non_snake_case)] X: &[Vec<f64>],
302        beta: &[f64],
303        y_scale: f64,
304        is_additive: bool,
305    ) -> FeaturePrediction {
306        let beta_c = component_col
307            .iter()
308            .copied()
309            .zip(beta)
310            .map(|(x, b)| x as f64 * b)
311            .collect_vec();
312        // Matrix multiply `beta_c` and `x`.
313        let mut point = vec![0.0; X[0].len()];
314        for (feature, b) in izip!(X, beta_c) {
315            for (p, x) in izip!(point.iter_mut(), feature) {
316                *p += b * x;
317            }
318        }
319        if is_additive {
320            point.iter_mut().for_each(|x| *x *= y_scale);
321        }
322        FeaturePrediction {
323            point,
324            lower: None,
325            upper: None,
326        }
327    }
328
329    #[allow(clippy::too_many_arguments)]
330    pub(super) fn predict_uncertainty(
331        &self,
332        df: &ProcessedData,
333        features: &Features,
334        params: &OptimizedParams,
335        changepoints_t: &[f64],
336        yhat: &mut FeaturePrediction,
337        trend: &mut FeaturePrediction,
338        y_scale: f64,
339    ) -> Result<(), Error> {
340        let mut sim_values =
341            self.sample_posterior_predictive(df, features, params, changepoints_t, y_scale)?;
342        let lower_p = 100.0 * (1.0 - *self.opts.interval_width) / 2.0;
343        let upper_p = 100.0 * (1.0 + *self.opts.interval_width) / 2.0;
344
345        let mut yhat_lower = Vec::with_capacity(df.ds.len());
346        let mut yhat_upper = Vec::with_capacity(df.ds.len());
347        let mut trend_lower = Vec::with_capacity(df.ds.len());
348        let mut trend_upper = Vec::with_capacity(df.ds.len());
349
350        for (yhat_samples, trend_samples) in
351            sim_values.yhat.iter_mut().zip(sim_values.trend.iter_mut())
352        {
353            // Sort, since we need to find multiple percentiles.
354            yhat_samples
355                .sort_unstable_by(|a, b| a.partial_cmp(b).expect("found NaN in yhat sample"));
356            trend_samples
357                .sort_unstable_by(|a, b| a.partial_cmp(b).expect("found NaN in yhat sample"));
358            yhat_lower.push(percentile_of_sorted(yhat_samples, lower_p));
359            yhat_upper.push(percentile_of_sorted(yhat_samples, upper_p));
360            trend_lower.push(percentile_of_sorted(trend_samples, lower_p));
361            trend_upper.push(percentile_of_sorted(trend_samples, upper_p));
362        }
363        yhat.lower = Some(yhat_lower);
364        yhat.upper = Some(yhat_upper);
365        trend.lower = Some(trend_lower);
366        trend.upper = Some(trend_upper);
367        Ok(())
368    }
369
370    /// Sample posterior predictive values from the model.
371    pub(super) fn sample_posterior_predictive(
372        &self,
373        df: &ProcessedData,
374        features: &Features,
375        params: &OptimizedParams,
376        changepoints_t: &[f64],
377        y_scale: f64,
378    ) -> Result<PosteriorPredictiveSamples, Error> {
379        // TODO: handle multiple chains.
380        let n_iterations = 1;
381        let samples_per_iter = usize::max(
382            1,
383            (self.opts.uncertainty_samples as f64 / n_iterations as f64).ceil() as usize,
384        );
385        let Features {
386            features,
387            component_columns,
388            ..
389        } = features;
390        // We're going to generate `samples_per_iter * n_iterations` samples
391        // for each of the `n` timestamps we want to predict.
392        // We'll store these in a nested `Vec<Vec<f64>>`, where the outer
393        // vector is indexed by the timestamps and the inner vector is
394        // indexed by the samples, since we need to calculate the `p` percentile
395        // of the samples for each timestamp.
396        let n_timestamps = df.ds.len();
397        let n_samples = samples_per_iter * n_iterations;
398        let mut sim_values = PosteriorPredictiveSamples {
399            yhat: std::iter::repeat_with(|| Vec::with_capacity(n_samples))
400                .take(n_timestamps)
401                .collect_vec(),
402            trend: std::iter::repeat_with(|| Vec::with_capacity(n_samples))
403                .take(n_timestamps)
404                .collect_vec(),
405        };
406        // Use temporary buffers to avoid allocating a new Vec for each
407        // call to `sample_model`.
408        let (mut yhat, mut trend) = (
409            Vec::with_capacity(n_timestamps),
410            Vec::with_capacity(n_timestamps),
411        );
412        for i in 0..n_iterations {
413            for _ in 0..samples_per_iter {
414                self.sample_model(
415                    df,
416                    features,
417                    params,
418                    changepoints_t,
419                    &component_columns.additive,
420                    &component_columns.multiplicative,
421                    y_scale,
422                    i,
423                    &mut yhat,
424                    &mut trend,
425                )?;
426                // We have to transpose things, unfortunately.
427                for ((i, yhat), trend) in yhat.iter().enumerate().zip(&trend) {
428                    sim_values.yhat[i].push(*yhat);
429                    sim_values.trend[i].push(*trend);
430                }
431            }
432        }
433        debug_assert_eq!(sim_values.yhat.len(), n_timestamps);
434        debug_assert_eq!(sim_values.trend.len(), n_timestamps);
435        Ok(sim_values)
436    }
437
438    /// Simulate observations from the extrapolated model.
439    #[allow(clippy::too_many_arguments)]
440    fn sample_model(
441        &self,
442        df: &ProcessedData,
443        features: &FeaturesFrame,
444        params: &OptimizedParams,
445        changepoints_t: &[f64],
446        additive: &[i32],
447        multiplicative: &[i32],
448        y_scale: f64,
449        iteration: usize,
450        yhat_tmp: &mut Vec<f64>,
451        trend_tmp: &mut Vec<f64>,
452    ) -> Result<(), Error> {
453        yhat_tmp.clear();
454        trend_tmp.clear();
455        let n = df.ds.len();
456        *trend_tmp =
457            self.sample_predictive_trend(df, params, changepoints_t, y_scale, iteration)?;
458        let beta = &params.beta;
459        let mut xb_a = vec![0.0; n];
460        for (feature, b, a) in izip!(&features.data, beta, additive) {
461            for (p, x) in izip!(&mut xb_a, feature) {
462                *p += x * b * *a as f64;
463            }
464        }
465        xb_a.iter_mut().for_each(|x| *x *= y_scale);
466        let mut xb_m = vec![0.0; n];
467        for (feature, b, m) in izip!(&features.data, beta, multiplicative) {
468            for (p, x) in izip!(&mut xb_m, feature) {
469                *p += x * b * *m as f64;
470            }
471        }
472
473        let sigma = params.sigma_obs;
474        let dist = Normal::new(0.0, *sigma).expect("sigma must be non-negative");
475        let mut rng = thread_rng();
476        let noise = (&mut rng).sample_iter(dist).take(n).map(|x| x * y_scale);
477
478        for yhat in izip!(trend_tmp, &xb_a, &xb_m, noise).map(|(t, a, m, n)| *t * (1.0 + m) + a + n)
479        {
480            yhat_tmp.push(yhat);
481        }
482
483        Ok(())
484    }
485
486    fn sample_predictive_trend(
487        &self,
488        df: &ProcessedData,
489        params: &OptimizedParams,
490        changepoints_t: &[f64],
491        y_scale: f64,
492        _iteration: usize, // This will be used when we implement MCMC predictions.
493    ) -> Result<Vec<f64>, Error> {
494        let deltas = &params.delta;
495
496        let t_max = df.t.iter().copied().nanmax(true);
497
498        let mut rng = thread_rng();
499
500        let n_changes = if t_max > 1.0 {
501            // Sample new changepoints from a Poisson process with rate n_cp on [1, T].
502            let n_cp = changepoints_t.len() as i32;
503            let lambda = n_cp as f64 * (t_max - 1.0);
504            // Lambda should always be positive, so this should never fail.
505            let dist = Poisson::new(lambda).expect("Valid Poisson distribution");
506            rng.sample::<f64, _>(dist).round() as usize
507        } else {
508            0
509        };
510        let changepoints_t_new = if n_changes > 0 {
511            let mut cp_t_new = (&mut rng)
512                .sample_iter(Uniform::new(0.0, t_max - 1.0))
513                .take(n_changes)
514                .map(|x| x + 1.0)
515                .collect_vec();
516            cp_t_new.sort_unstable_by(|a, b| {
517                a.partial_cmp(b)
518                    .expect("uniform distribution should not sample NaNs")
519            });
520            cp_t_new
521        } else {
522            vec![]
523        };
524
525        // Get the empirical scale of the deltas, plus epsilon to avoid NaNs.
526        let mut lambda = deltas.iter().map(|x| x.abs()).nanmean(false) + 1e-8;
527        if lambda.is_nan() {
528            lambda = 1e-8;
529        }
530        // Sample deltas from a Laplace distribution with location 0 and scale lambda.
531        // Lambda should always be positive and non-NaN, checked above.
532        let dist = Laplace::new(0.0, lambda).expect("Valid Laplace distribution");
533        let deltas_new = rng.sample_iter(dist).take(n_changes);
534
535        // Prepend the times and deltas from the history.
536        let all_changepoints_t = changepoints_t
537            .iter()
538            .copied()
539            .chain(changepoints_t_new)
540            .collect_vec();
541        let all_deltas = deltas.iter().copied().chain(deltas_new).collect_vec();
542
543        // Predict the trend.
544        let new_params = OptimizedParams {
545            delta: all_deltas,
546            ..params.clone()
547        };
548        let trend = self.predict_trend(
549            &df.t,
550            &df.cap_scaled,
551            &df.floor,
552            &all_changepoints_t,
553            &new_params,
554            y_scale,
555        )?;
556        Ok(trend.point)
557    }
558}
559
560// Taken from the Rust compiler's test suite:
561// https://github.com/rust-lang/rust/blob/917b0b6c70f078cb08bbb0080c9379e4487353c3/library/test/src/stats.rs#L258-L280.
562fn percentile_of_sorted(sorted_samples: &[f64], pct: f64) -> f64 {
563    assert!(!sorted_samples.is_empty());
564    if sorted_samples.len() == 1 {
565        return sorted_samples[0];
566    }
567    let zero: f64 = 0.0;
568    assert!(zero <= pct);
569    let hundred = 100_f64;
570    assert!(pct <= hundred);
571    if pct == hundred {
572        return sorted_samples[sorted_samples.len() - 1];
573    }
574    let length = (sorted_samples.len() - 1) as f64;
575    let rank = (pct / hundred) * length;
576    let lrank = rank.floor();
577    let d = rank - lrank;
578    let n = lrank as usize;
579    let lo = sorted_samples[n];
580    let hi = sorted_samples[n + 1];
581    lo + (hi - lo) * d
582}
583
584#[cfg(test)]
585mod test {
586    use augurs_testing::{assert_all_close, assert_approx_eq};
587    use itertools::Itertools;
588
589    use crate::{
590        optimizer::{mock_optimizer::MockOptimizer, OptimizedParams},
591        testdata::{daily_univariate_ts, train_test_splitn},
592        IncludeHistory, Prophet, ProphetOptions,
593    };
594
595    #[test]
596    fn piecewise_linear() {
597        let t = (0..11).map(f64::from).collect_vec();
598        let m = 0.0;
599        let k = 1.0;
600        let deltas = vec![0.5];
601        let changepoints_t = vec![5.0];
602        let y = Prophet::<()>::piecewise_linear(&t, &deltas, k, m, &changepoints_t).collect_vec();
603        let y_true = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.5, 8.0, 9.5, 11.0, 12.5];
604        assert_eq!(y, y_true);
605
606        let y =
607            Prophet::<()>::piecewise_linear(&t[8..], &deltas, k, m, &changepoints_t).collect_vec();
608        assert_eq!(y, y_true[8..]);
609
610        // This test isn't in the Python version but it's worth having one with multiple
611        // changepoints.
612        let deltas = vec![0.4, 0.5];
613        let changepoints_t = vec![4.0, 8.0];
614        let y = Prophet::<()>::piecewise_linear(&t, &deltas, k, m, &changepoints_t).collect_vec();
615        let y_true = &[0.0, 1.0, 2.0, 3.0, 4.0, 5.4, 6.8, 8.2, 9.6, 11.5, 13.4];
616        for (a, b) in y.iter().zip(y_true) {
617            assert_approx_eq!(a, b);
618        }
619    }
620
621    #[test]
622    fn piecewise_logistic() {
623        let t = (0..11).map(f64::from).collect_vec();
624        let cap = vec![10.0; 11];
625        let m = 0.0;
626        let k = 1.0;
627        let deltas = vec![0.5];
628        let changepoints_t = vec![5.0];
629        let y = Prophet::<()>::piecewise_logistic(&t, &cap, &deltas, k, m, &changepoints_t)
630            .collect_vec();
631        let y_true = &[
632            5.000000, 7.310586, 8.807971, 9.525741, 9.820138, 9.933071, 9.984988, 9.996646,
633            9.999252, 9.999833, 9.999963,
634        ];
635        for (a, b) in y.iter().zip(y_true) {
636            assert_approx_eq!(a, b);
637        }
638
639        let y =
640            Prophet::<()>::piecewise_logistic(&t[8..], &cap[8..], &deltas, k, m, &changepoints_t)
641                .collect_vec();
642        for (a, b) in y.iter().zip(&y_true[8..]) {
643            assert_approx_eq!(a, b);
644        }
645
646        // This test isn't in the Python version but it's worth having one with multiple
647        // changepoints.
648        let deltas = vec![0.4, 0.5];
649        let changepoints_t = vec![4.0, 8.0];
650        let y = Prophet::<()>::piecewise_logistic(&t, &cap, &deltas, k, m, &changepoints_t)
651            .collect_vec();
652        let y_true = &[
653            5., 7.31058579, 8.80797078, 9.52574127, 9.8201379, 9.95503727, 9.98887464, 9.99725422,
654            9.99932276, 9.9998987, 9.99998485,
655        ];
656        for (a, b) in y.iter().zip(y_true) {
657            assert_approx_eq!(a, b);
658        }
659    }
660
661    #[test]
662    fn flat_trend() {
663        let t = (0..11).map(f64::from).collect_vec();
664        let m = 0.5;
665        let y = Prophet::<()>::flat_trend(&t, m).collect_vec();
666        assert_all_close(&y, &[0.5; 11]);
667
668        let y = Prophet::<()>::flat_trend(&t[8..], m).collect_vec();
669        assert_all_close(&y, &[0.5; 3]);
670    }
671
672    /// This test is extracted from the `fit_predict` test of the Python Prophet
673    /// library. Since we don't want to depend on an optimizer, this test patches the
674    /// optimized parameters on the Prophet object and runs `predict`, ensuring we
675    /// get sensible results.
676    ///
677    /// There is a similar test in `fit.rs` which ensures the data being sent to
678    /// Stan is correct .
679    #[test]
680    fn predict_absmax() {
681        let test_days = 30;
682        let (train, test) = train_test_splitn(daily_univariate_ts(), test_days);
683        let opts = ProphetOptions {
684            scaling: crate::Scaling::AbsMax,
685            ..Default::default()
686        };
687        let opt = MockOptimizer::new();
688        let mut prophet = Prophet::new(opts, opt);
689        prophet.fit(train.clone(), Default::default()).unwrap();
690
691        // Override optimized params since we don't have a real optimizer.
692        // These were obtained from the Python version.
693        prophet.optimized = Some(OptimizedParams {
694            k: -1.01136,
695            m: 0.460947,
696            sigma_obs: 0.0451108.try_into().unwrap(),
697            beta: vec![
698                0.0205064,
699                -0.0129451,
700                -0.0164735,
701                -0.00275837,
702                0.00333371,
703                0.00599414,
704            ],
705            delta: vec![
706                3.51708e-08,
707                1.17925e-09,
708                -2.91421e-09,
709                2.06189e-01,
710                9.06870e-01,
711                4.49113e-01,
712                1.94664e-03,
713                -1.16088e-09,
714                -5.75394e-08,
715                -7.90284e-06,
716                -6.74530e-01,
717                -5.70814e-02,
718                -4.91360e-08,
719                -3.53111e-09,
720                1.42645e-08,
721                4.50809e-05,
722                8.86286e-01,
723                1.14535e+00,
724                4.40539e-02,
725                8.17306e-09,
726                -1.57715e-07,
727                -5.15430e-01,
728                -3.15001e-01,
729                1.14429e-08,
730                -2.56863e-09,
731            ],
732            trend: vec![
733                0.460947, 0.4566, 0.455151, 0.453703, 0.452254, 0.450805, 0.445009, 0.44356,
734                0.442111, 0.440662, 0.436315, 0.434866, 0.433417, 0.431968, 0.430519, 0.426173,
735                0.424724, 0.423275, 0.421826, 0.420377, 0.41603, 0.414581, 0.413132, 0.411683,
736                0.410234, 0.405887, 0.404438, 0.402989, 0.40154, 0.400092, 0.395745, 0.394296,
737                0.391398, 0.389949, 0.385602, 0.384153, 0.382704, 0.381255, 0.379806, 0.375459,
738                0.374011, 0.372562, 0.371113, 0.369664, 0.365317, 0.363868, 0.362419, 0.36097,
739                0.359521, 0.355174, 0.353725, 0.352276, 0.350827, 0.349378, 0.345032, 0.343583,
740                0.342134, 0.340685, 0.339236, 0.334889, 0.33344, 0.331991, 0.330838, 0.329684,
741                0.326223, 0.32507, 0.323916, 0.322763, 0.321609, 0.318149, 0.316995, 0.315841,
742                0.314688, 0.313534, 0.30892, 0.307767, 0.306613, 0.30546, 0.305897, 0.306042,
743                0.306188, 0.306334, 0.306479, 0.306916, 0.307062, 0.307208, 0.307354, 0.307499,
744                0.307936, 0.308082, 0.308228, 0.308373, 0.308519, 0.310886, 0.311676, 0.312465,
745                0.313254, 0.314043, 0.31641, 0.317199, 0.317989, 0.318778, 0.319567, 0.321934,
746                0.322723, 0.323512, 0.324302, 0.325091, 0.327466, 0.328258, 0.32905, 0.329842,
747                0.330634, 0.334594, 0.335386, 0.336177, 0.338553, 0.339345, 0.340137, 0.340929,
748                0.341721, 0.344097, 0.344888, 0.34568, 0.346472, 0.347264, 0.34964, 0.350432,
749                0.351224, 0.352808, 0.355183, 0.355975, 0.356767, 0.357559, 0.358351, 0.360727,
750                0.361519, 0.362311, 0.363102, 0.363894, 0.36627, 0.367062, 0.367854, 0.368646,
751                0.369438, 0.371813, 0.372605, 0.373397, 0.374189, 0.374981, 0.377357, 0.378941,
752                0.379733, 0.380524, 0.3829, 0.384484, 0.385276, 0.386068, 0.388443, 0.389235,
753                0.390027, 0.390819, 0.391611, 0.393987, 0.394779, 0.395571, 0.396362, 0.397154,
754                0.400322, 0.401114, 0.400939, 0.400765, 0.400242, 0.400067, 0.399893, 0.399718,
755                0.399544, 0.39902, 0.398846, 0.398671, 0.398497, 0.398322, 0.397799, 0.397624,
756                0.39745, 0.397194, 0.396937, 0.395912, 0.395656, 0.3954, 0.395144, 0.394375,
757                0.394119, 0.393862, 0.393606, 0.39335, 0.392581, 0.392325, 0.392069, 0.391812,
758                0.391556, 0.390787, 0.390531, 0.390275, 0.390019, 0.389762, 0.388994, 0.388737,
759                0.388481, 0.388225, 0.387968, 0.3872, 0.386943, 0.386687, 0.386431, 0.385406,
760                0.38515, 0.384893, 0.384637, 0.384381, 0.383612, 0.383356, 0.3831, 0.382843,
761                0.382587, 0.381818, 0.381562, 0.381306, 0.38105, 0.380793, 0.380025, 0.379768,
762                0.379512, 0.379256, 0.379, 0.378231, 0.377975, 0.377718, 0.377462, 0.377206,
763                0.376437, 0.376181, 0.375925, 0.375668, 0.375412, 0.374643, 0.374387, 0.374131,
764                0.373875, 0.373619, 0.37285, 0.372594, 0.372338, 0.372081, 0.371825, 0.3708,
765                0.370544, 0.370288, 0.370032, 0.369263, 0.369007, 0.370021, 0.371034, 0.372048,
766                0.375088, 0.376102, 0.377116, 0.378129, 0.379143, 0.382183, 0.383197, 0.384211,
767                0.385224, 0.386238, 0.389278, 0.390292, 0.391305, 0.39396, 0.396614, 0.404578,
768                0.407232, 0.409887, 0.415196, 0.423159, 0.425813, 0.428468, 0.431122, 0.433777,
769                0.44174, 0.444395, 0.447049, 0.449704, 0.452421, 0.460574, 0.463291, 0.466009,
770                0.468727, 0.471444, 0.479597, 0.482314, 0.485032, 0.48775, 0.490467, 0.49862,
771                0.501337, 0.504055, 0.506773, 0.50949, 0.517643, 0.520361, 0.523078, 0.525796,
772                0.528513, 0.536666, 0.539384, 0.542101, 0.544819, 0.547536, 0.555689, 0.558407,
773                0.561124, 0.563842, 0.566559, 0.57743, 0.580147, 0.582865, 0.585582, 0.593735,
774                0.596453, 0.59917, 0.601888, 0.604605, 0.612758, 0.615476, 0.618193, 0.620911,
775                0.623628, 0.631781, 0.63376, 0.635739, 0.637719, 0.639698, 0.645635, 0.647614,
776                0.649593, 0.651572, 0.653552, 0.659489, 0.661468, 0.663447, 0.665426, 0.667406,
777                0.673343, 0.674871, 0.676399, 0.677926, 0.679454, 0.684038, 0.685566, 0.687094,
778                0.688621, 0.690149, 0.694733, 0.696261, 0.697788, 0.699316, 0.700844, 0.705428,
779                0.706956, 0.708483, 0.710011, 0.711539, 0.716123, 0.71765, 0.719178, 0.720706,
780                0.722234, 0.726818, 0.728345, 0.729873, 0.731401, 0.732929, 0.737512, 0.73904,
781                0.740568, 0.743624, 0.748207, 0.749735, 0.751263, 0.752791, 0.754319, 0.758902,
782                0.76043, 0.761958, 0.763486, 0.765014, 0.769597, 0.771125, 0.772653, 0.774181,
783                0.775709, 0.780292, 0.78182, 0.784876, 0.786404, 0.790987, 0.792515, 0.795571,
784                0.797098, 0.801682, 0.80321, 0.804738, 0.806265, 0.807793, 0.812377, 0.813905,
785                0.815433, 0.81696, 0.818488, 0.8246, 0.826127, 0.827655, 0.829183, 0.833767,
786                0.835295, 0.836822, 0.83835, 0.839878, 0.844462, 0.845989, 0.847517, 0.849045,
787                0.850573, 0.855157, 0.856684, 0.858212, 0.85974, 0.861268, 0.867379, 0.868907,
788                0.870435, 0.871963, 0.876546, 0.878074, 0.879602, 0.88113, 0.882658, 0.887241,
789                0.888769, 0.890297, 0.891825, 0.893353, 0.897936, 0.899464, 0.900992, 0.90252,
790                0.904048, 0.908631, 0.910159, 0.911687, 0.913215, 0.914743, 0.919326, 0.920854,
791                0.922382, 0.92391, 0.925437, 0.930021, 0.931549, 0.933077, 0.934604, 0.936132,
792                0.940716, 0.942244, 0.943772, 0.945299, 0.946827, 0.951411, 0.952939, 0.954466,
793            ],
794        });
795        let future = prophet
796            .make_future_dataframe((test_days as u32).try_into().unwrap(), IncludeHistory::No)
797            .unwrap();
798        let predictions = prophet.predict(future).unwrap();
799        assert_eq!(predictions.yhat.point.len(), test_days);
800        let rmse = (predictions
801            .yhat
802            .point
803            .iter()
804            .zip(&test.y)
805            .map(|(a, b)| (a - b).powi(2))
806            .sum::<f64>()
807            / test.y.len() as f64)
808            .sqrt();
809        assert_approx_eq!(rmse, 10.64, 1e-1);
810
811        let lower = predictions.yhat.lower.as_ref().unwrap();
812        let upper = predictions.yhat.upper.as_ref().unwrap();
813        assert_eq!(lower.len(), predictions.yhat.point.len());
814        for (lower_bound, point_estimate) in lower.iter().zip(&predictions.yhat.point) {
815            assert!(
816                lower_bound <= point_estimate,
817                "Lower bound should be less than the point estimate"
818            );
819        }
820        for (upper_bound, point_estimate) in upper.iter().zip(&predictions.yhat.point) {
821            assert!(
822                upper_bound >= point_estimate,
823                "Upper bound should be greater than the point estimate"
824            );
825        }
826    }
827}