1pub(crate) mod options;
2pub(crate) mod predict;
3pub(crate) mod prep;
4
5use std::{
6 collections::{HashMap, HashSet},
7 num::NonZeroU32,
8 sync::Arc,
9};
10
11use itertools::{izip, Itertools};
12use options::ProphetOptions;
13use prep::{ComponentColumns, Modes, Preprocessed, Scales};
14
15use crate::{
16 forecaster::ProphetForecaster,
17 optimizer::{InitialParams, OptimizeOpts, OptimizedParams, Optimizer},
18 Error, EstimationMode, FeaturePrediction, IncludeHistory, IntervalWidth, PredictionData,
19 Predictions, Regressor, Seasonality, TimestampSeconds, TrainingData,
20};
21
22#[derive(Debug, Clone)]
24pub struct Prophet<O> {
25 opts: ProphetOptions,
27
28 regressors: HashMap<String, Regressor>,
30
31 seasonalities: HashMap<String, Seasonality>,
33
34 scales: Option<Scales>,
43
44 changepoints: Option<Vec<TimestampSeconds>>,
46
47 changepoints_t: Option<Vec<f64>>,
49
50 component_modes: Option<Modes>,
52
53 train_component_columns: Option<ComponentColumns>,
55
56 train_holiday_names: Option<HashSet<String>>,
58
59 optimizer: O,
61
62 processed: Option<Preprocessed>,
64
65 init: Option<InitialParams>,
67
68 optimized: Option<OptimizedParams>,
70}
71
72impl<O> Prophet<O> {
76 pub fn new(opts: ProphetOptions, optimizer: O) -> Self {
78 Self {
79 opts,
80 regressors: HashMap::new(),
81 seasonalities: HashMap::new(),
82 scales: None,
83 changepoints: None,
84 changepoints_t: None,
85 component_modes: None,
86 train_component_columns: None,
87 train_holiday_names: None,
88 optimizer,
89 processed: None,
90 init: None,
91 optimized: None,
92 }
93 }
94
95 pub fn add_seasonality(
97 &mut self,
98 name: String,
99 seasonality: Seasonality,
100 ) -> Result<&mut Self, Error> {
101 if self.seasonalities.contains_key(&name) {
102 return Err(Error::DuplicateSeasonality(name));
103 }
104 self.seasonalities.insert(name, seasonality);
105 Ok(self)
106 }
107
108 pub fn add_regressor(&mut self, name: String, regressor: Regressor) -> &mut Self {
110 self.regressors.insert(name, regressor);
111 self
112 }
113
114 pub fn is_fitted(&self) -> bool {
116 self.optimized.is_some()
117 }
118
119 pub fn predict(&self, data: impl Into<Option<PredictionData>>) -> Result<Predictions, Error> {
125 let Self {
126 processed: Some(processed),
127 optimized: Some(params),
128 changepoints_t: Some(changepoints_t),
129 scales: Some(scales),
130 ..
131 } = self
132 else {
133 return Err(Error::ModelNotFit);
134 };
135 let data = data.into();
136 let df = data
137 .map(|data| {
138 let training_data = TrainingData {
139 n: data.n,
140 ds: data.ds.clone(),
141 y: vec![],
142 cap: data.cap.clone(),
143 floor: data.floor.clone(),
144 seasonality_conditions: data.seasonality_conditions.clone(),
145 x: data.x.clone(),
146 };
147 self.setup_dataframe(training_data, Some(scales.clone()))
148 .map(|(df, _)| df)
149 })
150 .transpose()?
151 .unwrap_or_else(|| processed.history.clone());
152
153 let mut trend = self.predict_trend(
154 &df.t,
155 &df.cap_scaled,
156 &df.floor,
157 changepoints_t,
158 params,
159 scales.y_scale,
160 )?;
161 let features = self.make_all_features(&df)?;
162 let seasonal_components = self.predict_features(&features, params, scales.y_scale)?;
163
164 let yhat_point = izip!(
165 &trend.point,
166 &seasonal_components.additive.point,
167 &seasonal_components.multiplicative.point
168 )
169 .map(|(t, a, m)| t * (1.0 + m) + a)
170 .collect();
171 let mut yhat = FeaturePrediction {
172 point: yhat_point,
173 lower: None,
174 upper: None,
175 };
176
177 if self.opts.uncertainty_samples > 0 {
178 self.predict_uncertainty(
179 &df,
180 &features,
181 params,
182 changepoints_t,
183 &mut yhat,
184 &mut trend,
185 scales.y_scale,
186 )?;
187 }
188
189 Ok(Predictions {
190 ds: df.ds,
191 yhat,
192 trend,
193 cap: df.cap,
194 floor: scales.logistic_floor.then_some(df.floor),
195 additive: seasonal_components.additive,
196 multiplicative: seasonal_components.multiplicative,
197 holidays: seasonal_components.holidays,
198 seasonalities: seasonal_components.seasonalities,
199 regressors: seasonal_components.regressors,
200 })
201 }
202
203 pub fn make_future_dataframe(
215 &self,
216 horizon: NonZeroU32,
217 include_history: IncludeHistory,
218 ) -> Result<PredictionData, Error> {
219 let Some(Preprocessed { history_dates, .. }) = &self.processed else {
220 return Err(Error::ModelNotFit);
221 };
222 let freq = Self::infer_freq(history_dates)?;
223 let last_date = *history_dates.last().ok_or(Error::NotEnoughData)?;
224 let n = (horizon.get() as u64 + 1) as TimestampSeconds;
225 let dates = (last_date..last_date + n * freq)
226 .step_by(freq as usize)
227 .filter(|ds| *ds > last_date)
228 .take(horizon.get() as usize);
229
230 let ds = if include_history == IncludeHistory::Yes {
231 history_dates.iter().copied().chain(dates).collect()
232 } else {
233 dates.collect()
234 };
235 Ok(PredictionData::new(ds))
236 }
237
238 pub fn opts(&self) -> &ProphetOptions {
240 &self.opts
241 }
242
243 pub fn opts_mut(&mut self) -> &mut ProphetOptions {
245 &mut self.opts
246 }
247
248 pub fn set_interval_width(&mut self, interval_width: IntervalWidth) {
254 self.opts.interval_width = interval_width;
255 }
256
257 fn infer_freq(history_dates: &[TimestampSeconds]) -> Result<TimestampSeconds, Error> {
258 const INFER_N: usize = 5;
259 let get_tried = || {
260 history_dates
261 .iter()
262 .rev()
263 .take(INFER_N)
264 .copied()
265 .collect_vec()
266 };
267 let diff_counts = history_dates
270 .iter()
271 .rev()
272 .take(INFER_N)
273 .tuple_windows()
274 .map(|(a, b)| a - b)
275 .counts();
276 let max = diff_counts
279 .values()
280 .copied()
281 .max()
282 .ok_or_else(|| Error::UnableToInferFrequency(get_tried()))?;
283 diff_counts
284 .into_iter()
285 .filter(|(_, v)| *v == max)
286 .map(|(k, _)| k)
287 .exactly_one()
288 .map_err(|_| Error::UnableToInferFrequency(get_tried()))
289 }
290}
291
292impl<O: Optimizer + 'static> Prophet<O> {
293 pub(crate) fn into_dyn_optimizer(self) -> Prophet<Arc<dyn Optimizer + 'static>> {
294 Prophet {
295 optimizer: Arc::new(self.optimizer),
296 opts: self.opts,
297 regressors: self.regressors,
298 optimized: self.optimized,
299 changepoints: self.changepoints,
300 changepoints_t: self.changepoints_t,
301 init: self.init,
302 scales: self.scales,
303 processed: self.processed,
304 seasonalities: self.seasonalities,
305 component_modes: self.component_modes,
306 train_holiday_names: self.train_holiday_names,
307 train_component_columns: self.train_component_columns,
308 }
309 }
310
311 pub fn into_forecaster(
316 self,
317 data: TrainingData,
318 optimize_opts: OptimizeOpts,
319 ) -> ProphetForecaster {
320 ProphetForecaster::new(self, data, optimize_opts)
321 }
322}
323
324impl<O: Optimizer> Prophet<O> {
325 pub fn fit(&mut self, data: TrainingData, mut opts: OptimizeOpts) -> Result<(), Error> {
327 let preprocessed = self.preprocess(data)?;
328 let init = preprocessed.calculate_initial_params(&self.opts)?;
329 if opts.jacobian.is_none() {
334 let use_jacobian = self.opts.estimation == EstimationMode::Map;
335 opts.jacobian = Some(use_jacobian);
336 }
337 self.optimized = Some(
338 self.optimizer
339 .optimize(&init, &preprocessed.data, &opts)
340 .map_err(|e| Error::OptimizationFailed(e.to_string()))?,
341 );
342 self.processed = Some(preprocessed);
343 self.init = Some(init);
344 Ok(())
345 }
346}
347
348#[cfg(test)]
349mod test_trend {
350 use std::f64::consts::PI;
351
352 use augurs_core::FloatIterExt;
353 use augurs_testing::assert_approx_eq;
354 use chrono::{NaiveDate, TimeDelta};
355 use itertools::Itertools;
356
357 use super::*;
358 use crate::{
359 optimizer::mock_optimizer::MockOptimizer,
360 testdata::{daily_univariate_ts, train_test_split},
361 GrowthType, IncludeHistory, Scaling, TrainingData,
362 };
363
364 #[test]
365 fn growth_init() {
366 let mut data = daily_univariate_ts().head(468);
367 let max = data.y.iter().copied().nanmax(true);
368 data = data.with_cap(vec![max; 468]).unwrap();
369
370 let mut opts = ProphetOptions::default();
371 let mut prophet = Prophet::new(opts.clone(), MockOptimizer::new());
372 let preprocessed = prophet.preprocess(data.clone()).unwrap();
373 let init = preprocessed.calculate_initial_params(&opts).unwrap();
374 assert_approx_eq!(init.k, 0.3055671);
375 assert_approx_eq!(init.m, 0.5307511);
376
377 opts.growth = GrowthType::Logistic;
378 let mut prophet = Prophet::new(opts.clone(), MockOptimizer::new());
379 let preprocessed = prophet.preprocess(data).unwrap();
380 let init = preprocessed.calculate_initial_params(&opts).unwrap();
381 assert_approx_eq!(init.k, 1.507925);
382 assert_approx_eq!(init.m, -0.08167497);
383
384 opts.growth = GrowthType::Flat;
385 let init = preprocessed.calculate_initial_params(&opts).unwrap();
386 assert_approx_eq!(init.k, 0.0);
387 assert_approx_eq!(init.m, 0.49335657);
388 }
389
390 #[test]
391 fn growth_init_minmax() {
392 let mut data = daily_univariate_ts().head(468);
393 let max = data.y.iter().copied().nanmax(true);
394 data = data.with_cap(vec![max; 468]).unwrap();
395
396 let mut opts = ProphetOptions {
397 scaling: Scaling::MinMax,
398 ..ProphetOptions::default()
399 };
400 let mut prophet = Prophet::new(opts.clone(), MockOptimizer::new());
401 let preprocessed = prophet.preprocess(data.clone()).unwrap();
402 let init = preprocessed.calculate_initial_params(&opts).unwrap();
403 assert_approx_eq!(init.k, 0.4053406);
404 assert_approx_eq!(init.m, 0.3775322);
405
406 opts.growth = GrowthType::Logistic;
407 let mut prophet = Prophet::new(opts.clone(), MockOptimizer::new());
408 let preprocessed = prophet.preprocess(data).unwrap();
409 let init = preprocessed.calculate_initial_params(&opts).unwrap();
410 assert_approx_eq!(init.k, 1.782523);
411 assert_approx_eq!(init.m, 0.280521);
412
413 opts.growth = GrowthType::Flat;
414 let init = preprocessed.calculate_initial_params(&opts).unwrap();
415 assert_approx_eq!(init.k, 0.0);
416 assert_approx_eq!(init.m, 0.32792770);
417 }
418
419 #[test]
420 fn flat_growth_absmax() {
421 let opts = ProphetOptions {
422 growth: GrowthType::Flat,
423 scaling: Scaling::AbsMax,
424 ..ProphetOptions::default()
425 };
426 let mut prophet = Prophet::new(opts, MockOptimizer::new());
427 let x = (0..50).map(|x| x as f64 * PI * 2.0 / 50.0);
428 let y = x.map(|x| 30.0 + (x * 8.0).sin()).collect_vec();
429 let ds = (0..50)
430 .map(|x| {
431 (NaiveDate::from_ymd_opt(2020, 1, 1).unwrap() + TimeDelta::days(x))
432 .and_hms_opt(0, 0, 0)
433 .unwrap()
434 .and_utc()
435 .timestamp() as TimestampSeconds
436 })
437 .collect_vec();
438 let data = TrainingData::new(ds, y).unwrap();
439 prophet.fit(data, Default::default()).unwrap();
440 let future = prophet
441 .make_future_dataframe(10.try_into().unwrap(), IncludeHistory::Yes)
442 .unwrap();
443 let _predictions = prophet.predict(future).unwrap();
444 }
445
446 #[test]
447 fn get_changepoints() {
448 let (data, _) = train_test_split(daily_univariate_ts(), 0.5);
449 let optimizer = MockOptimizer::new();
450 let mut prophet = Prophet::new(ProphetOptions::default(), optimizer);
451 let preprocessed = prophet.preprocess(data).unwrap();
452 let history = preprocessed.history;
453 let changepoints_t = prophet.changepoints_t.as_ref().unwrap();
454 assert_eq!(changepoints_t.len() as u32, prophet.opts.n_changepoints,);
455 assert!(changepoints_t.iter().copied().nanmin(true) > 0.0);
457 let cp_idx = (history.ds.len() as f64 * 0.8).ceil() as usize;
459 assert!(changepoints_t.iter().copied().nanmax(true) <= history.t[cp_idx]);
460 let expected = &[
461 0.03504043, 0.06738544, 0.09433962, 0.12938005, 0.16442049, 0.1967655, 0.22371968,
462 0.25606469, 0.28301887, 0.3180593, 0.35040431, 0.37735849, 0.41239892, 0.45013477,
463 0.48247978, 0.51752022, 0.54447439, 0.57681941, 0.61185984, 0.64150943, 0.67924528,
464 0.7115903, 0.74663073, 0.77358491, 0.80592992,
465 ];
466 for (a, b) in changepoints_t.iter().zip(expected) {
467 assert_approx_eq!(a, b);
468 }
469 }
470
471 #[test]
472 fn get_changepoints_range() {
473 let (data, _) = train_test_split(daily_univariate_ts(), 0.5);
474 let opts = ProphetOptions {
475 changepoint_range: 0.4.try_into().unwrap(),
476 ..ProphetOptions::default()
477 };
478 let mut prophet = Prophet::new(opts, MockOptimizer::new());
479 let preprocessed = prophet.preprocess(data).unwrap();
480 let history = preprocessed.history;
481 let changepoints_t = prophet.changepoints_t.as_ref().unwrap();
482 assert_eq!(changepoints_t.len() as u32, prophet.opts.n_changepoints,);
483 assert!(changepoints_t.iter().copied().nanmin(true) > 0.0);
485 let cp_idx = (history.ds.len() as f64 * 0.4).ceil() as usize;
487 assert!(changepoints_t.iter().copied().nanmax(true) <= history.t[cp_idx]);
488 let expected = &[
489 0.01617251, 0.03504043, 0.05121294, 0.06738544, 0.08355795, 0.09433962, 0.11051213,
490 0.12938005, 0.14555256, 0.16172507, 0.17789757, 0.18867925, 0.20754717, 0.22371968,
491 0.23989218, 0.25606469, 0.2722372, 0.28301887, 0.30188679, 0.3180593, 0.33423181,
492 0.35040431, 0.36657682, 0.37735849, 0.393531,
493 ];
494 for (a, b) in changepoints_t.iter().zip(expected) {
495 assert_approx_eq!(a, b);
496 }
497 }
498
499 #[test]
500 fn get_zero_changepoints() {
501 let (data, _) = train_test_split(daily_univariate_ts(), 0.5);
502 let opts = ProphetOptions {
503 n_changepoints: 0,
504 ..ProphetOptions::default()
505 };
506 let mut prophet = Prophet::new(opts, MockOptimizer::new());
507 prophet.preprocess(data).unwrap();
508 let changepoints_t = prophet.changepoints_t.as_ref().unwrap();
509 assert_eq!(changepoints_t.len() as u32, 1);
510 assert_eq!(changepoints_t[0], 0.0);
511 }
512
513 #[test]
514 fn get_n_changepoints() {
515 let data = daily_univariate_ts().head(20);
516 let opts = ProphetOptions {
517 n_changepoints: 15,
518 ..ProphetOptions::default()
519 };
520 let mut prophet = Prophet::new(opts, MockOptimizer::new());
521 prophet.preprocess(data).unwrap();
522 let changepoints_t = prophet.changepoints_t.as_ref().unwrap();
523 assert_eq!(prophet.opts.n_changepoints, 15);
524 assert_eq!(changepoints_t.len() as u32, 15);
525 }
526}
527
528#[cfg(test)]
529mod test_seasonal {
530 use augurs_testing::assert_approx_eq;
531
532 use super::*;
533 use crate::testdata::daily_univariate_ts;
534
535 #[test]
536 fn fourier_series_weekly() {
537 let data = daily_univariate_ts();
538 let mat =
539 Prophet::<()>::fourier_series(&data.ds, 7.0.try_into().unwrap(), 3.try_into().unwrap());
540 let expected = &[
541 0.7818315, 0.6234898, 0.9749279, -0.2225209, 0.4338837, -0.9009689,
542 ];
543 assert_eq!(mat.len(), expected.len());
544 let first = mat.iter().map(|row| row[0]);
545 for (a, b) in first.zip(expected) {
546 assert_approx_eq!(a, b);
547 }
548 }
549
550 #[test]
551 fn fourier_series_yearly() {
552 let data = daily_univariate_ts();
553 let mat = Prophet::<()>::fourier_series(
554 &data.ds,
555 365.25.try_into().unwrap(),
556 3.try_into().unwrap(),
557 );
558 let expected = &[
559 0.7006152, -0.7135393, -0.9998330, 0.01827656, 0.7262249, 0.6874572,
560 ];
561 assert_eq!(mat.len(), expected.len());
562 let first = mat.iter().map(|row| row[0]);
563 for (a, b) in first.zip(expected) {
564 assert_approx_eq!(a, b);
565 }
566 }
567}
568
569#[cfg(test)]
570mod test_custom_seasonal {
571 use std::collections::HashMap;
572
573 use chrono::NaiveDate;
574 use itertools::Itertools;
575
576 use crate::{
577 optimizer::mock_optimizer::MockOptimizer,
578 prophet::prep::{FeatureName, Features},
579 testdata::daily_univariate_ts,
580 FeatureMode, Holiday, HolidayOccurrence, ProphetOptions, Seasonality, SeasonalityOption,
581 };
582
583 use super::Prophet;
584
585 #[test]
586 fn custom_prior() {
587 let holiday_dates = ["2017-01-02"]
588 .iter()
589 .map(|s| {
590 HolidayOccurrence::for_day(
591 s.parse::<NaiveDate>()
592 .unwrap()
593 .and_hms_opt(0, 0, 0)
594 .unwrap()
595 .and_utc()
596 .timestamp(),
597 )
598 })
599 .collect();
600
601 let opts = ProphetOptions {
602 holidays: [(
603 "special day".to_string(),
604 Holiday::new(holiday_dates).with_prior_scale(4.0.try_into().unwrap()),
605 )]
606 .into(),
607 seasonality_mode: FeatureMode::Multiplicative,
608 yearly_seasonality: SeasonalityOption::Manual(false),
609 ..Default::default()
610 };
611
612 let data = daily_univariate_ts();
613 let mut prophet = Prophet::new(opts, MockOptimizer::new());
614 prophet
615 .add_seasonality(
616 "monthly".to_string(),
617 Seasonality::new(30.0.try_into().unwrap(), 5.try_into().unwrap())
618 .with_prior_scale(2.0.try_into().unwrap())
619 .with_mode(FeatureMode::Additive),
620 )
621 .unwrap();
622 prophet.fit(data, Default::default()).unwrap();
623 prophet.predict(None).unwrap();
624
625 assert_eq!(prophet.seasonalities["weekly"].mode, None);
626 assert_eq!(
627 prophet.seasonalities["monthly"].mode,
628 Some(FeatureMode::Additive)
629 );
630 let Features {
631 features,
632 prior_scales,
633 component_columns,
634 ..
635 } = prophet
636 .make_all_features(&prophet.processed.as_ref().unwrap().history)
637 .unwrap();
638
639 assert_eq!(
640 component_columns.seasonalities["monthly"]
641 .iter()
642 .sum::<i32>(),
643 10
644 );
645 assert_eq!(
646 component_columns.holidays["special day"]
647 .iter()
648 .sum::<i32>(),
649 1
650 );
651 assert_eq!(
652 component_columns.seasonalities["weekly"]
653 .iter()
654 .sum::<i32>(),
655 6
656 );
657 assert_eq!(component_columns.additive.iter().sum::<i32>(), 10);
658 assert_eq!(component_columns.multiplicative.iter().sum::<i32>(), 7);
659
660 if features.names[0]
661 == (FeatureName::Seasonality {
662 name: "monthly".to_string(),
663 _id: 1,
664 })
665 {
666 assert_eq!(
667 component_columns.seasonalities["monthly"],
668 &[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
669 );
670 assert_eq!(
671 component_columns.seasonalities["weekly"],
672 &[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0],
673 );
674 let expected_prior_scales = [
675 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 10.0, 10.0, 10.0, 10.0, 10.0,
676 10.0, 4.0,
677 ]
678 .map(|x| x.try_into().unwrap());
679 assert_eq!(&prior_scales, &expected_prior_scales);
680 } else {
681 assert_eq!(
682 component_columns.seasonalities["monthly"],
683 &[0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
684 );
685 assert_eq!(
686 component_columns.seasonalities["weekly"],
687 &[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
688 );
689 let expected_prior_scales = [
690 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
691 2.0, 4.0,
692 ]
693 .map(|x| x.try_into().unwrap());
694 assert_eq!(&prior_scales, &expected_prior_scales);
695 }
696 }
697
698 #[test]
699 fn conditional_custom_seasonality() {
700 let mut data = daily_univariate_ts();
702 let condition_col = [[false; 255], [true; 255]].concat();
703 let conditions =
704 HashMap::from([("is_conditional_week".to_string(), condition_col.clone())]);
705 data = data.with_seasonality_conditions(conditions).unwrap();
706
707 let opts = ProphetOptions {
709 yearly_seasonality: SeasonalityOption::Manual(false),
710 weekly_seasonality: SeasonalityOption::Manual(false),
711 ..Default::default()
712 };
713 let mut prophet = Prophet::new(opts, MockOptimizer::new());
714 prophet
715 .add_seasonality(
716 "conditional_weekly".to_string(),
717 Seasonality::new(7.0.try_into().unwrap(), 3.try_into().unwrap())
718 .with_prior_scale(2.0.try_into().unwrap())
719 .with_condition("is_conditional_week".to_string())
720 .with_mode(FeatureMode::Additive),
721 )
722 .unwrap()
723 .add_seasonality(
724 "normal_monthly".to_string(),
725 Seasonality::new(30.5.try_into().unwrap(), 5.try_into().unwrap())
726 .with_prior_scale(2.0.try_into().unwrap())
727 .with_mode(FeatureMode::Additive),
728 )
729 .unwrap();
730
731 prophet.fit(data, Default::default()).unwrap();
732 prophet.predict(None).unwrap();
733
734 let Features { features, .. } = prophet
735 .make_all_features(&prophet.processed.as_ref().unwrap().history)
736 .unwrap();
737 let condition_features = features
738 .names
739 .iter()
740 .zip(&features.data)
741 .filter(|(name, _)| {
742 matches!(name, FeatureName::Seasonality { name, .. } if name == "conditional_weekly")
743 })
744 .collect_vec();
745 for (_, condition_feature) in condition_features {
748 assert_eq!(condition_col.len(), condition_feature.len());
749 for (cond, f) in condition_col.iter().zip(condition_feature) {
750 assert_eq!(*f != 0.0, *cond);
751 }
752 }
753 }
754}
755
756#[cfg(test)]
757mod test_holidays {
758 use chrono::NaiveDate;
759
760 use crate::{
761 optimizer::mock_optimizer::MockOptimizer, testdata::daily_univariate_ts, Holiday,
762 HolidayOccurrence, Prophet, ProphetOptions,
763 };
764
765 #[test]
766 fn fit_predict_holiday() {
767 let holiday_dates = ["2012-10-09", "2013-10-09"]
768 .iter()
769 .map(|s| {
770 HolidayOccurrence::for_day(
771 s.parse::<NaiveDate>()
772 .unwrap()
773 .and_hms_opt(0, 0, 0)
774 .unwrap()
775 .and_utc()
776 .timestamp(),
777 )
778 })
779 .collect();
780 let opts = ProphetOptions {
781 holidays: [("bens-bday".to_string(), Holiday::new(holiday_dates))].into(),
782 ..Default::default()
783 };
784 let data = daily_univariate_ts();
785 let mut prophet = Prophet::new(opts, MockOptimizer::new());
786 prophet.fit(data, Default::default()).unwrap();
787 prophet.predict(None).unwrap();
788 }
789}
790
791#[cfg(test)]
792mod test_fit {
793 use augurs_core::FloatIterExt;
794 use augurs_testing::assert_all_close;
795 use itertools::Itertools;
796
797 use crate::{
798 optimizer::{mock_optimizer::MockOptimizer, InitialParams},
799 testdata::{daily_univariate_ts, train_test_splitn},
800 Prophet, ProphetOptions, TrendIndicator,
801 };
802
803 #[test]
811 fn fit_absmax() {
812 let test_days = 30;
813 let (train, _) = train_test_splitn(daily_univariate_ts(), test_days);
814 let opts = ProphetOptions {
815 scaling: crate::Scaling::AbsMax,
816 ..Default::default()
817 };
818 let opt = MockOptimizer::new();
819 let mut prophet = Prophet::new(opts, opt);
820 prophet.fit(train.clone(), Default::default()).unwrap();
821 let opt: &MockOptimizer = &prophet.optimizer;
823 let call = opt.take_call().unwrap();
824 assert_eq!(
825 call.init,
826 InitialParams {
827 beta: vec![0.0; 6],
828 delta: vec![0.0; 25],
829 k: 0.29834791059280863,
830 m: 0.5307510759405802,
831 sigma_obs: 1.0.try_into().unwrap(),
832 }
833 );
834 assert_eq!(call.data.T, 480);
835 assert_eq!(call.data.S, 25);
836 assert_eq!(call.data.K, 6);
837 assert_eq!(*call.data.tau, 0.05);
838 assert_eq!(call.data.trend_indicator, TrendIndicator::Linear);
839 assert_eq!(call.data.y.iter().copied().nanmax(true), 1.0);
840 assert_all_close(
841 &call.data.y[0..5],
842 &[0.530751, 0.472442, 0.430376, 0.444259, 0.458559],
843 );
844 assert_eq!(call.data.t.len(), train.y.len());
845 assert_all_close(
846 &call.data.t[0..5],
847 &[0.0, 0.004298, 0.005731, 0.007163, 0.008596],
848 );
849
850 assert_eq!(call.data.cap.len(), train.y.len());
851 assert_eq!(&call.data.cap, &[0.0; 480]);
852
853 assert_eq!(
854 &call.data.sigmas.iter().map(|x| **x).collect_vec(),
855 &[10.0; 6]
856 );
857 assert_eq!(&call.data.s_a, &[1; 6]);
858 assert_eq!(&call.data.s_m, &[0; 6]);
859 assert_eq!(call.data.X.len(), 6 * 480);
860 let first = &call.data.X[..6];
861 assert_all_close(
862 first,
863 &[0.781831, 0.623490, 0.974928, -0.222521, 0.433884, -0.900969],
864 );
865 }
866
867 #[test]
869 fn fit_with_nans() {
870 let test_days = 30;
871 let (mut train, _) = train_test_splitn(daily_univariate_ts(), test_days);
872 train.y[10] = f64::NAN;
873 let opt = MockOptimizer::new();
874 let mut prophet = Prophet::new(Default::default(), opt);
875 prophet.fit(train.clone(), Default::default()).unwrap();
877 }
878}