augurs_prophet/
forecaster.rs

1//! [`Fit`] and [`Predict`] implementations for the Prophet algorithm.
2use std::{cell::RefCell, num::NonZeroU32, sync::Arc};
3
4use augurs_core::{Fit, ModelError, Predict};
5
6use crate::{optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, TrainingData};
7
8impl ModelError for Error {}
9
10/// A forecaster that uses the Prophet algorithm.
11///
12/// This is a wrapper around the [`Prophet`] struct that provides
13/// a simpler API for fitting and predicting. Notably it implements
14/// the [`Fit`] trait from `augurs_core`, so it can be
15/// used with the `augurs` framework (e.g. with the `Forecaster` struct
16/// in the `augurs::forecaster` module).
17#[derive(Debug)]
18pub struct ProphetForecaster {
19    data: TrainingData,
20    model: Prophet<Arc<dyn Optimizer>>,
21    optimize_opts: OptimizeOpts,
22}
23
24impl ProphetForecaster {
25    /// Create a new Prophet forecaster.
26    ///
27    /// # Parameters
28    ///
29    /// - `opts`: The options to use for fitting the model.
30    ///   Note that `uncertainty_samples` will be set to 1000 if it is 0,
31    ///   to facilitate generating prediction intervals.
32    /// - `optimizer`: The optimizer to use for fitting the model.
33    /// - `optimize_opts`: The options to use for optimizing the model.
34    pub fn new<T: Optimizer + 'static>(
35        mut model: Prophet<T>,
36        data: TrainingData,
37        optimize_opts: OptimizeOpts,
38    ) -> Self {
39        let opts = model.opts_mut();
40        if opts.uncertainty_samples == 0 {
41            opts.uncertainty_samples = 1000;
42        }
43        Self {
44            data,
45            model: model.into_dyn_optimizer(),
46            optimize_opts,
47        }
48    }
49}
50
51impl Fit for ProphetForecaster {
52    type Fitted = FittedProphetForecaster;
53    type Error = Error;
54
55    fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error> {
56        // Use the training data from `self`...
57        let mut training_data = self.data.clone();
58        // ...but replace the `y` column with whatever we're passed
59        // (which may be a transformed version of `y`, if the user is
60        // using `augurs_forecaster`).
61        training_data.y = y.to_vec();
62        let mut fitted_model = self.model.clone();
63        fitted_model.fit(training_data, self.optimize_opts.clone())?;
64        Ok(FittedProphetForecaster {
65            model: RefCell::new(fitted_model),
66            training_n: y.len(),
67        })
68    }
69}
70
71/// A fitted Prophet forecaster.
72#[derive(Debug)]
73pub struct FittedProphetForecaster {
74    model: RefCell<Prophet<Arc<dyn Optimizer>>>,
75    training_n: usize,
76}
77
78impl Predict for FittedProphetForecaster {
79    type Error = Error;
80
81    fn predict_in_sample_inplace(
82        &self,
83        level: Option<f64>,
84        forecast: &mut augurs_core::Forecast,
85    ) -> Result<(), Self::Error> {
86        if let Some(level) = level {
87            self.model
88                .borrow_mut()
89                .set_interval_width(level.try_into()?);
90        }
91        let predictions = self.model.borrow().predict(None)?;
92        forecast.point = predictions.yhat.point;
93        if let Some(intervals) = forecast.intervals.as_mut() {
94            intervals.lower = predictions
95                .yhat
96                .lower
97                // This `expect` is OK because we've set uncertainty_samples > 0 in the
98                // `ProphetForecaster` constructor.
99                .expect("uncertainty_samples should be > 0, this is a bug");
100            intervals.upper = predictions
101                .yhat
102                .upper
103                // This `expect` is OK because we've set uncertainty_samples > 0 in the
104                // `ProphetForecaster` constructor.
105                .expect("uncertainty_samples should be > 0, this is a bug");
106        }
107        Ok(())
108    }
109
110    fn predict_inplace(
111        &self,
112        horizon: usize,
113        level: Option<f64>,
114        forecast: &mut augurs_core::Forecast,
115    ) -> Result<(), Self::Error> {
116        let horizon = match NonZeroU32::try_from(horizon as u32) {
117            Ok(h) => h,
118            // If horizon is 0, short circuit without even trying to predict.
119            Err(_) => return Ok(()),
120        };
121        if let Some(level) = level {
122            self.model
123                .borrow_mut()
124                .set_interval_width(level.try_into()?);
125        }
126        let predictions = {
127            let model = self.model.borrow();
128            let prediction_data = model.make_future_dataframe(horizon, IncludeHistory::No)?;
129            model.predict(prediction_data)?
130        };
131        forecast.point = predictions.yhat.point;
132        if let Some(intervals) = forecast.intervals.as_mut() {
133            intervals.lower = predictions
134                .yhat
135                .lower
136                // This `expect` is OK because we've set uncertainty_samples > 0 in the
137                // `ProphetForecaster` constructor.
138                .expect("uncertainty_samples should be > 0");
139            intervals.upper = predictions
140                .yhat
141                .upper
142                // This `expect` is OK because we've set uncertainty_samples > 0 in the
143                // `ProphetForecaster` constructor.
144                .expect("uncertainty_samples should be > 0");
145        }
146        Ok(())
147    }
148
149    fn training_data_size(&self) -> usize {
150        self.training_n
151    }
152}
153
154#[cfg(all(test, feature = "wasmstan"))]
155mod test {
156
157    use augurs_core::{Fit, Predict};
158    use augurs_testing::assert_all_close;
159
160    use crate::{
161        testdata::{daily_univariate_ts, train_test_splitn},
162        wasmstan::WasmstanOptimizer,
163        IncludeHistory, Prophet,
164    };
165
166    use super::ProphetForecaster;
167
168    #[test]
169    fn forecaster() {
170        let test_days = 30;
171        let (train, _) = train_test_splitn(daily_univariate_ts(), test_days);
172
173        let model = Prophet::new(Default::default(), WasmstanOptimizer::new());
174        let forecaster = ProphetForecaster::new(model, train.clone(), Default::default());
175        let fitted = forecaster.fit(&train.y).unwrap();
176        let forecast_predictions = fitted.predict(30, 0.95).unwrap();
177
178        let mut prophet = Prophet::new(Default::default(), WasmstanOptimizer::new());
179        prophet.fit(train, Default::default()).unwrap();
180        let prediction_data = prophet
181            .make_future_dataframe(30.try_into().unwrap(), IncludeHistory::No)
182            .unwrap();
183        let predictions = prophet.predict(prediction_data).unwrap();
184
185        // We should get the same results back when using the Forecaster impl.
186        assert_eq!(
187            predictions.yhat.point.len(),
188            forecast_predictions.point.len()
189        );
190        assert_all_close(&predictions.yhat.point, &forecast_predictions.point);
191    }
192}