augurs_prophet/
forecaster.rs1use std::{cell::RefCell, num::NonZeroU32, sync::Arc};
3
4use augurs_core::{Fit, ModelError, Predict};
5
6use crate::{optimizer::OptimizeOpts, Error, IncludeHistory, Optimizer, Prophet, TrainingData};
7
8impl ModelError for Error {}
9
10#[derive(Debug)]
18pub struct ProphetForecaster {
19 data: TrainingData,
20 model: Prophet<Arc<dyn Optimizer>>,
21 optimize_opts: OptimizeOpts,
22}
23
24impl ProphetForecaster {
25 pub fn new<T: Optimizer + 'static>(
35 mut model: Prophet<T>,
36 data: TrainingData,
37 optimize_opts: OptimizeOpts,
38 ) -> Self {
39 let opts = model.opts_mut();
40 if opts.uncertainty_samples == 0 {
41 opts.uncertainty_samples = 1000;
42 }
43 Self {
44 data,
45 model: model.into_dyn_optimizer(),
46 optimize_opts,
47 }
48 }
49}
50
51impl Fit for ProphetForecaster {
52 type Fitted = FittedProphetForecaster;
53 type Error = Error;
54
55 fn fit(&self, y: &[f64]) -> Result<Self::Fitted, Self::Error> {
56 let mut training_data = self.data.clone();
58 training_data.y = y.to_vec();
62 let mut fitted_model = self.model.clone();
63 fitted_model.fit(training_data, self.optimize_opts.clone())?;
64 Ok(FittedProphetForecaster {
65 model: RefCell::new(fitted_model),
66 training_n: y.len(),
67 })
68 }
69}
70
71#[derive(Debug)]
73pub struct FittedProphetForecaster {
74 model: RefCell<Prophet<Arc<dyn Optimizer>>>,
75 training_n: usize,
76}
77
78impl Predict for FittedProphetForecaster {
79 type Error = Error;
80
81 fn predict_in_sample_inplace(
82 &self,
83 level: Option<f64>,
84 forecast: &mut augurs_core::Forecast,
85 ) -> Result<(), Self::Error> {
86 if let Some(level) = level {
87 self.model
88 .borrow_mut()
89 .set_interval_width(level.try_into()?);
90 }
91 let predictions = self.model.borrow().predict(None)?;
92 forecast.point = predictions.yhat.point;
93 if let Some(intervals) = forecast.intervals.as_mut() {
94 intervals.lower = predictions
95 .yhat
96 .lower
97 .expect("uncertainty_samples should be > 0, this is a bug");
100 intervals.upper = predictions
101 .yhat
102 .upper
103 .expect("uncertainty_samples should be > 0, this is a bug");
106 }
107 Ok(())
108 }
109
110 fn predict_inplace(
111 &self,
112 horizon: usize,
113 level: Option<f64>,
114 forecast: &mut augurs_core::Forecast,
115 ) -> Result<(), Self::Error> {
116 let horizon = match NonZeroU32::try_from(horizon as u32) {
117 Ok(h) => h,
118 Err(_) => return Ok(()),
120 };
121 if let Some(level) = level {
122 self.model
123 .borrow_mut()
124 .set_interval_width(level.try_into()?);
125 }
126 let predictions = {
127 let model = self.model.borrow();
128 let prediction_data = model.make_future_dataframe(horizon, IncludeHistory::No)?;
129 model.predict(prediction_data)?
130 };
131 forecast.point = predictions.yhat.point;
132 if let Some(intervals) = forecast.intervals.as_mut() {
133 intervals.lower = predictions
134 .yhat
135 .lower
136 .expect("uncertainty_samples should be > 0");
139 intervals.upper = predictions
140 .yhat
141 .upper
142 .expect("uncertainty_samples should be > 0");
145 }
146 Ok(())
147 }
148
149 fn training_data_size(&self) -> usize {
150 self.training_n
151 }
152}
153
154#[cfg(all(test, feature = "wasmstan"))]
155mod test {
156
157 use augurs_core::{Fit, Predict};
158 use augurs_testing::assert_all_close;
159
160 use crate::{
161 testdata::{daily_univariate_ts, train_test_splitn},
162 wasmstan::WasmstanOptimizer,
163 IncludeHistory, Prophet,
164 };
165
166 use super::ProphetForecaster;
167
168 #[test]
169 fn forecaster() {
170 let test_days = 30;
171 let (train, _) = train_test_splitn(daily_univariate_ts(), test_days);
172
173 let model = Prophet::new(Default::default(), WasmstanOptimizer::new());
174 let forecaster = ProphetForecaster::new(model, train.clone(), Default::default());
175 let fitted = forecaster.fit(&train.y).unwrap();
176 let forecast_predictions = fitted.predict(30, 0.95).unwrap();
177
178 let mut prophet = Prophet::new(Default::default(), WasmstanOptimizer::new());
179 prophet.fit(train, Default::default()).unwrap();
180 let prediction_data = prophet
181 .make_future_dataframe(30.try_into().unwrap(), IncludeHistory::No)
182 .unwrap();
183 let predictions = prophet.predict(prediction_data).unwrap();
184
185 assert_eq!(
187 predictions.yhat.point.len(),
188 forecast_predictions.point.len()
189 );
190 assert_all_close(&predictions.yhat.point, &forecast_predictions.point);
191 }
192}