Skip to main content

augurs_prophet/
prophet.rs

1pub(crate) mod options;
2pub(crate) mod predict;
3pub(crate) mod prep;
4
5use std::{
6    collections::{HashMap, HashSet},
7    num::NonZeroU32,
8    sync::Arc,
9};
10
11use itertools::{izip, Itertools};
12use options::ProphetOptions;
13use prep::{ComponentColumns, Modes, Preprocessed, Scales};
14
15use crate::{
16    forecaster::ProphetForecaster,
17    optimizer::{InitialParams, OptimizeOpts, OptimizedParams, Optimizer},
18    Error, EstimationMode, FeaturePrediction, IncludeHistory, IntervalWidth, PredictionData,
19    Predictions, Regressor, Seasonality, TimestampSeconds, TrainingData,
20};
21
22/// The Prophet time series forecasting model.
23#[derive(Debug, Clone)]
24pub struct Prophet<O> {
25    /// Options to be used for fitting.
26    opts: ProphetOptions,
27
28    /// Extra regressors.
29    regressors: HashMap<String, Regressor>,
30
31    /// Custom seasonalities.
32    seasonalities: HashMap<String, Seasonality>,
33
34    // TODO: move all of the below into a separate struct.
35    // That way we minimize the number of fields in this struct
36    // and the number of permutations of optional fields,
37    // so it's harder to accidentally get into an invalid state.
38    /// Scaling factors for the data.
39    ///
40    /// This is calculated during fitting, and is used to scale the data
41    /// before fitting.
42    scales: Option<Scales>,
43
44    /// The changepoints for the model.
45    changepoints: Option<Vec<TimestampSeconds>>,
46
47    /// The time of the changepoints.
48    changepoints_t: Option<Vec<f64>>,
49
50    /// The modes of the components.
51    component_modes: Option<Modes>,
52
53    /// The component columns used for training.
54    train_component_columns: Option<ComponentColumns>,
55
56    /// The names of the holidays that were seen in the training data.
57    train_holiday_names: Option<HashSet<String>>,
58
59    /// The optimizer to use.
60    optimizer: O,
61
62    /// The processed data used for fitting.
63    processed: Option<Preprocessed>,
64
65    /// The initial parameters passed to optimization.
66    init: Option<InitialParams>,
67
68    /// The optimized model, if it has been fit.
69    optimized: Option<OptimizedParams>,
70}
71
72// All public methods should live in this `impl` even if they call
73// lots of functions in private modules, so that Rustdoc shows them
74// all in a single block.
75impl<O> Prophet<O> {
76    /// Create a new Prophet model with the given options and optimizer.
77    pub fn new(opts: ProphetOptions, optimizer: O) -> Self {
78        Self {
79            opts,
80            regressors: HashMap::new(),
81            seasonalities: HashMap::new(),
82            scales: None,
83            changepoints: None,
84            changepoints_t: None,
85            component_modes: None,
86            train_component_columns: None,
87            train_holiday_names: None,
88            optimizer,
89            processed: None,
90            init: None,
91            optimized: None,
92        }
93    }
94
95    /// Add a custom seasonality to the model.
96    pub fn add_seasonality(
97        &mut self,
98        name: String,
99        seasonality: Seasonality,
100    ) -> Result<&mut Self, Error> {
101        if self.seasonalities.contains_key(&name) {
102            return Err(Error::DuplicateSeasonality(name));
103        }
104        self.seasonalities.insert(name, seasonality);
105        Ok(self)
106    }
107
108    /// Add a regressor to the model.
109    pub fn add_regressor(&mut self, name: String, regressor: Regressor) -> &mut Self {
110        self.regressors.insert(name, regressor);
111        self
112    }
113
114    /// Return `true` if the model has been fit, or `false` if not.
115    pub fn is_fitted(&self) -> bool {
116        self.optimized.is_some()
117    }
118
119    /// Predict using the Prophet model.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if the model has not been fit.
124    pub fn predict(&self, data: impl Into<Option<PredictionData>>) -> Result<Predictions, Error> {
125        let Self {
126            processed: Some(processed),
127            optimized: Some(params),
128            changepoints_t: Some(changepoints_t),
129            scales: Some(scales),
130            ..
131        } = self
132        else {
133            return Err(Error::ModelNotFit);
134        };
135        let data = data.into();
136        let df = data
137            .map(|data| {
138                let training_data = TrainingData {
139                    n: data.n,
140                    ds: data.ds.clone(),
141                    y: vec![],
142                    cap: data.cap.clone(),
143                    floor: data.floor.clone(),
144                    seasonality_conditions: data.seasonality_conditions.clone(),
145                    x: data.x.clone(),
146                };
147                self.setup_dataframe(training_data, Some(scales.clone()))
148                    .map(|(df, _)| df)
149            })
150            .transpose()?
151            .unwrap_or_else(|| processed.history.clone());
152
153        let mut trend = self.predict_trend(
154            &df.t,
155            &df.cap_scaled,
156            &df.floor,
157            changepoints_t,
158            params,
159            scales.y_scale,
160        )?;
161        let features = self.make_all_features(&df)?;
162        let seasonal_components = self.predict_features(&features, params, scales.y_scale)?;
163
164        let yhat_point = izip!(
165            &trend.point,
166            &seasonal_components.additive.point,
167            &seasonal_components.multiplicative.point
168        )
169        .map(|(t, a, m)| t * (1.0 + m) + a)
170        .collect();
171        let mut yhat = FeaturePrediction {
172            point: yhat_point,
173            lower: None,
174            upper: None,
175        };
176
177        if self.opts.uncertainty_samples > 0 {
178            self.predict_uncertainty(
179                &df,
180                &features,
181                params,
182                changepoints_t,
183                &mut yhat,
184                &mut trend,
185                scales.y_scale,
186            )?;
187        }
188
189        Ok(Predictions {
190            ds: df.ds,
191            yhat,
192            trend,
193            cap: df.cap,
194            floor: scales.logistic_floor.then_some(df.floor),
195            additive: seasonal_components.additive,
196            multiplicative: seasonal_components.multiplicative,
197            holidays: seasonal_components.holidays,
198            seasonalities: seasonal_components.seasonalities,
199            regressors: seasonal_components.regressors,
200        })
201    }
202
203    /// Create dates to use for predictions.
204    ///
205    /// # Parameters
206    ///
207    /// - `horizon`: The number of days to predict forward.
208    /// - `include_history`: Whether to include the historical dates in the
209    ///   future dataframe.
210    ///
211    /// # Errors
212    ///
213    /// Returns an error if the model has not been fit.
214    pub fn make_future_dataframe(
215        &self,
216        horizon: NonZeroU32,
217        include_history: IncludeHistory,
218    ) -> Result<PredictionData, Error> {
219        let Some(Preprocessed { history_dates, .. }) = &self.processed else {
220            return Err(Error::ModelNotFit);
221        };
222        let freq = Self::infer_freq(history_dates)?;
223        let last_date = *history_dates.last().ok_or(Error::NotEnoughData)?;
224        let n = (horizon.get() as u64 + 1) as TimestampSeconds;
225        let dates = (last_date..last_date + n * freq)
226            .step_by(freq as usize)
227            .filter(|ds| *ds > last_date)
228            .take(horizon.get() as usize);
229
230        let ds = if include_history == IncludeHistory::Yes {
231            history_dates.iter().copied().chain(dates).collect()
232        } else {
233            dates.collect()
234        };
235        Ok(PredictionData::new(ds))
236    }
237
238    /// Get a reference to the Prophet options.
239    pub fn opts(&self) -> &ProphetOptions {
240        &self.opts
241    }
242
243    /// Get a mutable reference to the Prophet options.
244    pub fn opts_mut(&mut self) -> &mut ProphetOptions {
245        &mut self.opts
246    }
247
248    /// Set the width of the uncertainty intervals.
249    ///
250    /// The interval width does not affect training, only predictions,
251    /// so this can be called after fitting the model to obtain predictions
252    /// with different levels of uncertainty.
253    pub fn set_interval_width(&mut self, interval_width: IntervalWidth) {
254        self.opts.interval_width = interval_width;
255    }
256
257    fn infer_freq(history_dates: &[TimestampSeconds]) -> Result<TimestampSeconds, Error> {
258        const INFER_N: usize = 5;
259        let get_tried = || {
260            history_dates
261                .iter()
262                .rev()
263                .take(INFER_N)
264                .copied()
265                .collect_vec()
266        };
267        // Calculate diffs between the last 5 dates in the history, and
268        // create a map from diffs to counts.
269        let diff_counts = history_dates
270            .iter()
271            .rev()
272            .take(INFER_N)
273            .tuple_windows()
274            .map(|(a, b)| a - b)
275            .counts();
276        // Find the max count, and return the corresponding diff, provided there
277        // is exactly one diff with that count.
278        let max = diff_counts
279            .values()
280            .copied()
281            .max()
282            .ok_or_else(|| Error::UnableToInferFrequency(get_tried()))?;
283        diff_counts
284            .into_iter()
285            .filter(|(_, v)| *v == max)
286            .map(|(k, _)| k)
287            .exactly_one()
288            .map_err(|_| Error::UnableToInferFrequency(get_tried()))
289    }
290}
291
292impl<O: Optimizer + 'static> Prophet<O> {
293    pub(crate) fn into_dyn_optimizer(self) -> Prophet<Arc<dyn Optimizer + 'static>> {
294        Prophet {
295            optimizer: Arc::new(self.optimizer),
296            opts: self.opts,
297            regressors: self.regressors,
298            optimized: self.optimized,
299            changepoints: self.changepoints,
300            changepoints_t: self.changepoints_t,
301            init: self.init,
302            scales: self.scales,
303            processed: self.processed,
304            seasonalities: self.seasonalities,
305            component_modes: self.component_modes,
306            train_holiday_names: self.train_holiday_names,
307            train_component_columns: self.train_component_columns,
308        }
309    }
310
311    /// Create a new `ProphetForecaster` from this Prophet model.
312    ///
313    /// This requires the data and optimize options to be provided and sets up
314    /// a `ProphetForecaster` ready to be used with the `augurs_forecaster` crate.
315    pub fn into_forecaster(
316        self,
317        data: TrainingData,
318        optimize_opts: OptimizeOpts,
319    ) -> ProphetForecaster {
320        ProphetForecaster::new(self, data, optimize_opts)
321    }
322}
323
324impl<O: Optimizer> Prophet<O> {
325    /// Fit the Prophet model to some training data.
326    pub fn fit(&mut self, data: TrainingData, mut opts: OptimizeOpts) -> Result<(), Error> {
327        let preprocessed = self.preprocess(data)?;
328        let init = preprocessed.calculate_initial_params(&self.opts)?;
329        // TODO: call `sample` if `self.opts.estimation == EstimationMode::Mcmc`.
330        // We'll first need to add a `sample` method to the Optimizer trait, then accept
331        // a different set of options in `opts`, and also update `OptimizedParams` to
332        // include every MCMC sample.
333        if opts.jacobian.is_none() {
334            let use_jacobian = self.opts.estimation == EstimationMode::Map;
335            opts.jacobian = Some(use_jacobian);
336        }
337        self.optimized = Some(
338            self.optimizer
339                .optimize(&init, &preprocessed.data, &opts)
340                .map_err(|e| Error::OptimizationFailed(e.to_string()))?,
341        );
342        self.processed = Some(preprocessed);
343        self.init = Some(init);
344        Ok(())
345    }
346}
347
348#[cfg(test)]
349mod test_trend {
350    use std::f64::consts::PI;
351
352    use augurs_core::FloatIterExt;
353    use augurs_testing::assert_approx_eq;
354    use chrono::{NaiveDate, TimeDelta};
355    use itertools::Itertools;
356
357    use super::*;
358    use crate::{
359        optimizer::mock_optimizer::MockOptimizer,
360        testdata::{daily_univariate_ts, train_test_split},
361        GrowthType, IncludeHistory, Scaling, TrainingData,
362    };
363
364    #[test]
365    fn growth_init() {
366        let mut data = daily_univariate_ts().head(468);
367        let max = data.y.iter().copied().nanmax(true);
368        data = data.with_cap(vec![max; 468]).unwrap();
369
370        let mut opts = ProphetOptions::default();
371        let mut prophet = Prophet::new(opts.clone(), MockOptimizer::new());
372        let preprocessed = prophet.preprocess(data.clone()).unwrap();
373        let init = preprocessed.calculate_initial_params(&opts).unwrap();
374        assert_approx_eq!(init.k, 0.3055671);
375        assert_approx_eq!(init.m, 0.5307511);
376
377        opts.growth = GrowthType::Logistic;
378        let mut prophet = Prophet::new(opts.clone(), MockOptimizer::new());
379        let preprocessed = prophet.preprocess(data).unwrap();
380        let init = preprocessed.calculate_initial_params(&opts).unwrap();
381        assert_approx_eq!(init.k, 1.507925);
382        assert_approx_eq!(init.m, -0.08167497);
383
384        opts.growth = GrowthType::Flat;
385        let init = preprocessed.calculate_initial_params(&opts).unwrap();
386        assert_approx_eq!(init.k, 0.0);
387        assert_approx_eq!(init.m, 0.49335657);
388    }
389
390    #[test]
391    fn growth_init_minmax() {
392        let mut data = daily_univariate_ts().head(468);
393        let max = data.y.iter().copied().nanmax(true);
394        data = data.with_cap(vec![max; 468]).unwrap();
395
396        let mut opts = ProphetOptions {
397            scaling: Scaling::MinMax,
398            ..ProphetOptions::default()
399        };
400        let mut prophet = Prophet::new(opts.clone(), MockOptimizer::new());
401        let preprocessed = prophet.preprocess(data.clone()).unwrap();
402        let init = preprocessed.calculate_initial_params(&opts).unwrap();
403        assert_approx_eq!(init.k, 0.4053406);
404        assert_approx_eq!(init.m, 0.3775322);
405
406        opts.growth = GrowthType::Logistic;
407        let mut prophet = Prophet::new(opts.clone(), MockOptimizer::new());
408        let preprocessed = prophet.preprocess(data).unwrap();
409        let init = preprocessed.calculate_initial_params(&opts).unwrap();
410        assert_approx_eq!(init.k, 1.782523);
411        assert_approx_eq!(init.m, 0.280521);
412
413        opts.growth = GrowthType::Flat;
414        let init = preprocessed.calculate_initial_params(&opts).unwrap();
415        assert_approx_eq!(init.k, 0.0);
416        assert_approx_eq!(init.m, 0.32792770);
417    }
418
419    #[test]
420    fn flat_growth_absmax() {
421        let opts = ProphetOptions {
422            growth: GrowthType::Flat,
423            scaling: Scaling::AbsMax,
424            ..ProphetOptions::default()
425        };
426        let mut prophet = Prophet::new(opts, MockOptimizer::new());
427        let x = (0..50).map(|x| x as f64 * PI * 2.0 / 50.0);
428        let y = x.map(|x| 30.0 + (x * 8.0).sin()).collect_vec();
429        let ds = (0..50)
430            .map(|x| {
431                (NaiveDate::from_ymd_opt(2020, 1, 1).unwrap() + TimeDelta::days(x))
432                    .and_hms_opt(0, 0, 0)
433                    .unwrap()
434                    .and_utc()
435                    .timestamp() as TimestampSeconds
436            })
437            .collect_vec();
438        let data = TrainingData::new(ds, y).unwrap();
439        prophet.fit(data, Default::default()).unwrap();
440        let future = prophet
441            .make_future_dataframe(10.try_into().unwrap(), IncludeHistory::Yes)
442            .unwrap();
443        let _predictions = prophet.predict(future).unwrap();
444    }
445
446    #[test]
447    fn get_changepoints() {
448        let (data, _) = train_test_split(daily_univariate_ts(), 0.5);
449        let optimizer = MockOptimizer::new();
450        let mut prophet = Prophet::new(ProphetOptions::default(), optimizer);
451        let preprocessed = prophet.preprocess(data).unwrap();
452        let history = preprocessed.history;
453        let changepoints_t = prophet.changepoints_t.as_ref().unwrap();
454        assert_eq!(changepoints_t.len() as u32, prophet.opts.n_changepoints,);
455        // Assert that the earliest changepoint is after the first point.
456        assert!(changepoints_t.iter().copied().nanmin(true) > 0.0);
457        // Assert that the changepoints are less than the 80th percentile of `t`.
458        let cp_idx = (history.ds.len() as f64 * 0.8).ceil() as usize;
459        assert!(changepoints_t.iter().copied().nanmax(true) <= history.t[cp_idx]);
460        let expected = &[
461            0.03504043, 0.06738544, 0.09433962, 0.12938005, 0.16442049, 0.1967655, 0.22371968,
462            0.25606469, 0.28301887, 0.3180593, 0.35040431, 0.37735849, 0.41239892, 0.45013477,
463            0.48247978, 0.51752022, 0.54447439, 0.57681941, 0.61185984, 0.64150943, 0.67924528,
464            0.7115903, 0.74663073, 0.77358491, 0.80592992,
465        ];
466        for (a, b) in changepoints_t.iter().zip(expected) {
467            assert_approx_eq!(a, b);
468        }
469    }
470
471    #[test]
472    fn get_changepoints_range() {
473        let (data, _) = train_test_split(daily_univariate_ts(), 0.5);
474        let opts = ProphetOptions {
475            changepoint_range: 0.4.try_into().unwrap(),
476            ..ProphetOptions::default()
477        };
478        let mut prophet = Prophet::new(opts, MockOptimizer::new());
479        let preprocessed = prophet.preprocess(data).unwrap();
480        let history = preprocessed.history;
481        let changepoints_t = prophet.changepoints_t.as_ref().unwrap();
482        assert_eq!(changepoints_t.len() as u32, prophet.opts.n_changepoints,);
483        // Assert that the earliest changepoint is after the first point.
484        assert!(changepoints_t.iter().copied().nanmin(true) > 0.0);
485        // Assert that the changepoints are less than the 80th percentile of `t`.
486        let cp_idx = (history.ds.len() as f64 * 0.4).ceil() as usize;
487        assert!(changepoints_t.iter().copied().nanmax(true) <= history.t[cp_idx]);
488        let expected = &[
489            0.01617251, 0.03504043, 0.05121294, 0.06738544, 0.08355795, 0.09433962, 0.11051213,
490            0.12938005, 0.14555256, 0.16172507, 0.17789757, 0.18867925, 0.20754717, 0.22371968,
491            0.23989218, 0.25606469, 0.2722372, 0.28301887, 0.30188679, 0.3180593, 0.33423181,
492            0.35040431, 0.36657682, 0.37735849, 0.393531,
493        ];
494        for (a, b) in changepoints_t.iter().zip(expected) {
495            assert_approx_eq!(a, b);
496        }
497    }
498
499    #[test]
500    fn get_zero_changepoints() {
501        let (data, _) = train_test_split(daily_univariate_ts(), 0.5);
502        let opts = ProphetOptions {
503            n_changepoints: 0,
504            ..ProphetOptions::default()
505        };
506        let mut prophet = Prophet::new(opts, MockOptimizer::new());
507        prophet.preprocess(data).unwrap();
508        let changepoints_t = prophet.changepoints_t.as_ref().unwrap();
509        assert_eq!(changepoints_t.len() as u32, 1);
510        assert_eq!(changepoints_t[0], 0.0);
511    }
512
513    #[test]
514    fn get_n_changepoints() {
515        let data = daily_univariate_ts().head(20);
516        let opts = ProphetOptions {
517            n_changepoints: 15,
518            ..ProphetOptions::default()
519        };
520        let mut prophet = Prophet::new(opts, MockOptimizer::new());
521        prophet.preprocess(data).unwrap();
522        let changepoints_t = prophet.changepoints_t.as_ref().unwrap();
523        assert_eq!(prophet.opts.n_changepoints, 15);
524        assert_eq!(changepoints_t.len() as u32, 15);
525    }
526}
527
528#[cfg(test)]
529mod test_seasonal {
530    use augurs_testing::assert_approx_eq;
531
532    use super::*;
533    use crate::testdata::daily_univariate_ts;
534
535    #[test]
536    fn fourier_series_weekly() {
537        let data = daily_univariate_ts();
538        let mat =
539            Prophet::<()>::fourier_series(&data.ds, 7.0.try_into().unwrap(), 3.try_into().unwrap());
540        let expected = &[
541            0.7818315, 0.6234898, 0.9749279, -0.2225209, 0.4338837, -0.9009689,
542        ];
543        assert_eq!(mat.len(), expected.len());
544        let first = mat.iter().map(|row| row[0]);
545        for (a, b) in first.zip(expected) {
546            assert_approx_eq!(a, b);
547        }
548    }
549
550    #[test]
551    fn fourier_series_yearly() {
552        let data = daily_univariate_ts();
553        let mat = Prophet::<()>::fourier_series(
554            &data.ds,
555            365.25.try_into().unwrap(),
556            3.try_into().unwrap(),
557        );
558        let expected = &[
559            0.7006152, -0.7135393, -0.9998330, 0.01827656, 0.7262249, 0.6874572,
560        ];
561        assert_eq!(mat.len(), expected.len());
562        let first = mat.iter().map(|row| row[0]);
563        for (a, b) in first.zip(expected) {
564            assert_approx_eq!(a, b);
565        }
566    }
567}
568
569#[cfg(test)]
570mod test_custom_seasonal {
571    use std::collections::HashMap;
572
573    use chrono::NaiveDate;
574    use itertools::Itertools;
575
576    use crate::{
577        optimizer::mock_optimizer::MockOptimizer,
578        prophet::prep::{FeatureName, Features},
579        testdata::daily_univariate_ts,
580        FeatureMode, Holiday, HolidayOccurrence, ProphetOptions, Seasonality, SeasonalityOption,
581    };
582
583    use super::Prophet;
584
585    #[test]
586    fn custom_prior() {
587        let holiday_dates = ["2017-01-02"]
588            .iter()
589            .map(|s| {
590                HolidayOccurrence::for_day(
591                    s.parse::<NaiveDate>()
592                        .unwrap()
593                        .and_hms_opt(0, 0, 0)
594                        .unwrap()
595                        .and_utc()
596                        .timestamp(),
597                )
598            })
599            .collect();
600
601        let opts = ProphetOptions {
602            holidays: [(
603                "special day".to_string(),
604                Holiday::new(holiday_dates).with_prior_scale(4.0.try_into().unwrap()),
605            )]
606            .into(),
607            seasonality_mode: FeatureMode::Multiplicative,
608            yearly_seasonality: SeasonalityOption::Manual(false),
609            ..Default::default()
610        };
611
612        let data = daily_univariate_ts();
613        let mut prophet = Prophet::new(opts, MockOptimizer::new());
614        prophet
615            .add_seasonality(
616                "monthly".to_string(),
617                Seasonality::new(30.0.try_into().unwrap(), 5.try_into().unwrap())
618                    .with_prior_scale(2.0.try_into().unwrap())
619                    .with_mode(FeatureMode::Additive),
620            )
621            .unwrap();
622        prophet.fit(data, Default::default()).unwrap();
623        prophet.predict(None).unwrap();
624
625        assert_eq!(prophet.seasonalities["weekly"].mode, None);
626        assert_eq!(
627            prophet.seasonalities["monthly"].mode,
628            Some(FeatureMode::Additive)
629        );
630        let Features {
631            features,
632            prior_scales,
633            component_columns,
634            ..
635        } = prophet
636            .make_all_features(&prophet.processed.as_ref().unwrap().history)
637            .unwrap();
638
639        assert_eq!(
640            component_columns.seasonalities["monthly"]
641                .iter()
642                .sum::<i32>(),
643            10
644        );
645        assert_eq!(
646            component_columns.holidays["special day"]
647                .iter()
648                .sum::<i32>(),
649            1
650        );
651        assert_eq!(
652            component_columns.seasonalities["weekly"]
653                .iter()
654                .sum::<i32>(),
655            6
656        );
657        assert_eq!(component_columns.additive.iter().sum::<i32>(), 10);
658        assert_eq!(component_columns.multiplicative.iter().sum::<i32>(), 7);
659
660        if features.names[0]
661            == (FeatureName::Seasonality {
662                name: "monthly".to_string(),
663                _id: 1,
664            })
665        {
666            assert_eq!(
667                component_columns.seasonalities["monthly"],
668                &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
669            );
670            assert_eq!(
671                component_columns.seasonalities["weekly"],
672                &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0],
673            );
674            let expected_prior_scales = [
675                2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 10.0, 10.0, 10.0, 10.0, 10.0,
676                10.0, 4.0,
677            ]
678            .map(|x| x.try_into().unwrap());
679            assert_eq!(&prior_scales, &expected_prior_scales);
680        } else {
681            assert_eq!(
682                component_columns.seasonalities["monthly"],
683                &[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
684            );
685            assert_eq!(
686                component_columns.seasonalities["weekly"],
687                &[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
688            );
689            let expected_prior_scales = [
690                10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
691                2.0, 4.0,
692            ]
693            .map(|x| x.try_into().unwrap());
694            assert_eq!(&prior_scales, &expected_prior_scales);
695        }
696    }
697
698    #[test]
699    fn conditional_custom_seasonality() {
700        // Set up data.
701        let mut data = daily_univariate_ts();
702        let condition_col = [[false; 255], [true; 255]].concat();
703        let conditions =
704            HashMap::from([("is_conditional_week".to_string(), condition_col.clone())]);
705        data = data.with_seasonality_conditions(conditions).unwrap();
706
707        // Set up Prophet model.
708        let opts = ProphetOptions {
709            yearly_seasonality: SeasonalityOption::Manual(false),
710            weekly_seasonality: SeasonalityOption::Manual(false),
711            ..Default::default()
712        };
713        let mut prophet = Prophet::new(opts, MockOptimizer::new());
714        prophet
715            .add_seasonality(
716                "conditional_weekly".to_string(),
717                Seasonality::new(7.0.try_into().unwrap(), 3.try_into().unwrap())
718                    .with_prior_scale(2.0.try_into().unwrap())
719                    .with_condition("is_conditional_week".to_string())
720                    .with_mode(FeatureMode::Additive),
721            )
722            .unwrap()
723            .add_seasonality(
724                "normal_monthly".to_string(),
725                Seasonality::new(30.5.try_into().unwrap(), 5.try_into().unwrap())
726                    .with_prior_scale(2.0.try_into().unwrap())
727                    .with_mode(FeatureMode::Additive),
728            )
729            .unwrap();
730
731        prophet.fit(data, Default::default()).unwrap();
732        prophet.predict(None).unwrap();
733
734        let Features { features, .. } = prophet
735            .make_all_features(&prophet.processed.as_ref().unwrap().history)
736            .unwrap();
737        let condition_features = features
738            .names
739            .iter()
740            .zip(&features.data)
741            .filter(|(name, _)| {
742                matches!(name, FeatureName::Seasonality { name, .. } if name == "conditional_weekly")
743            })
744            .collect_vec();
745        // Check that each of the condition features is zero everywhere except
746        // where the condition column is true.
747        for (_, condition_feature) in condition_features {
748            assert_eq!(condition_col.len(), condition_feature.len());
749            for (cond, f) in condition_col.iter().zip(condition_feature) {
750                assert_eq!(*f != 0.0, *cond);
751            }
752        }
753    }
754}
755
756#[cfg(test)]
757mod test_holidays {
758    use chrono::NaiveDate;
759
760    use crate::{
761        optimizer::mock_optimizer::MockOptimizer, testdata::daily_univariate_ts, Holiday,
762        HolidayOccurrence, Prophet, ProphetOptions,
763    };
764
765    #[test]
766    fn fit_predict_holiday() {
767        let holiday_dates = ["2012-10-09", "2013-10-09"]
768            .iter()
769            .map(|s| {
770                HolidayOccurrence::for_day(
771                    s.parse::<NaiveDate>()
772                        .unwrap()
773                        .and_hms_opt(0, 0, 0)
774                        .unwrap()
775                        .and_utc()
776                        .timestamp(),
777                )
778            })
779            .collect();
780        let opts = ProphetOptions {
781            holidays: [("bens-bday".to_string(), Holiday::new(holiday_dates))].into(),
782            ..Default::default()
783        };
784        let data = daily_univariate_ts();
785        let mut prophet = Prophet::new(opts, MockOptimizer::new());
786        prophet.fit(data, Default::default()).unwrap();
787        prophet.predict(None).unwrap();
788    }
789}
790
791#[cfg(test)]
792mod test_fit {
793    use augurs_core::FloatIterExt;
794    use augurs_testing::assert_all_close;
795    use itertools::Itertools;
796
797    use crate::{
798        optimizer::{mock_optimizer::MockOptimizer, InitialParams},
799        testdata::{daily_univariate_ts, train_test_splitn},
800        Prophet, ProphetOptions, TrendIndicator,
801    };
802
803    /// This test is extracted from the `fit_predict` test of the Python Prophet
804    /// library. Since we don't want to depend on an optimizer, this just ensures
805    /// that we're correctly getting the data ready for Stan, by recording the data
806    /// that's sent to the configured optimizer.
807    ///
808    /// There is a similar test in `predict.rs` which patches the returned
809    /// optimized parameters and ensures predictions look sensible.
810    #[test]
811    fn fit_absmax() {
812        let test_days = 30;
813        let (train, _) = train_test_splitn(daily_univariate_ts(), test_days);
814        let opts = ProphetOptions {
815            scaling: crate::Scaling::AbsMax,
816            ..Default::default()
817        };
818        let opt = MockOptimizer::new();
819        let mut prophet = Prophet::new(opts, opt);
820        prophet.fit(train.clone(), Default::default()).unwrap();
821        // Make sure our optimizer was called correctly.
822        let opt: &MockOptimizer = &prophet.optimizer;
823        let call = opt.take_call().unwrap();
824        assert_eq!(
825            call.init,
826            InitialParams {
827                beta: vec![0.0; 6],
828                delta: vec![0.0; 25],
829                k: 0.29834791059280863,
830                m: 0.5307510759405802,
831                sigma_obs: 1.0.try_into().unwrap(),
832            }
833        );
834        assert_eq!(call.data.T, 480);
835        assert_eq!(call.data.S, 25);
836        assert_eq!(call.data.K, 6);
837        assert_eq!(*call.data.tau, 0.05);
838        assert_eq!(call.data.trend_indicator, TrendIndicator::Linear);
839        assert_eq!(call.data.y.iter().copied().nanmax(true), 1.0);
840        assert_all_close(
841            &call.data.y[0..5],
842            &[0.530751, 0.472442, 0.430376, 0.444259, 0.458559],
843        );
844        assert_eq!(call.data.t.len(), train.y.len());
845        assert_all_close(
846            &call.data.t[0..5],
847            &[0.0, 0.004298, 0.005731, 0.007163, 0.008596],
848        );
849
850        assert_eq!(call.data.cap.len(), train.y.len());
851        assert_eq!(&call.data.cap, &[0.0; 480]);
852
853        assert_eq!(
854            &call.data.sigmas.iter().map(|x| **x).collect_vec(),
855            &[10.0; 6]
856        );
857        assert_eq!(&call.data.s_a, &[1; 6]);
858        assert_eq!(&call.data.s_m, &[0; 6]);
859        assert_eq!(call.data.X.len(), 6 * 480);
860        let first = &call.data.X[..6];
861        assert_all_close(
862            first,
863            &[0.781831, 0.623490, 0.974928, -0.222521, 0.433884, -0.900969],
864        );
865    }
866
867    // Regression test for https://github.com/grafana/augurs/issues/209.
868    #[test]
869    fn fit_with_nans() {
870        let test_days = 30;
871        let (mut train, _) = train_test_splitn(daily_univariate_ts(), test_days);
872        train.y[10] = f64::NAN;
873        let opt = MockOptimizer::new();
874        let mut prophet = Prophet::new(Default::default(), opt);
875        // Should not panic.
876        prophet.fit(train.clone(), Default::default()).unwrap();
877    }
878}