augurs_forecaster/
forecaster.rs

1use augurs_core::{Fit, Forecast, Predict};
2
3use crate::{Data, Error, Pipeline, Result, Transformer};
4
5/// A high-level API to fit and predict time series forecasting models.
6///
7/// The `Forecaster` type allows you to combine a model with a set of
8/// transformations and fit it to a time series, then use the fitted model to
9/// make predictions. The predictions are back-transformed using the inverse of
10/// the transformations applied to the input data.
11#[derive(Debug)]
12pub struct Forecaster<M: Fit> {
13    model: M,
14    fitted: Option<M::Fitted>,
15
16    pipeline: Pipeline,
17}
18
19impl<M> Forecaster<M>
20where
21    M: Fit,
22    M::Fitted: Predict,
23{
24    /// Create a new `Forecaster` with the given model.
25    pub fn new(model: M) -> Self {
26        Self {
27            model,
28            fitted: None,
29            pipeline: Pipeline::default(),
30        }
31    }
32
33    /// Set the transformations to be applied to the input data.
34    pub fn with_transformers(mut self, transformers: Vec<Box<dyn Transformer>>) -> Self {
35        self.pipeline = Pipeline::new(transformers);
36        self
37    }
38
39    /// Fit the model to the given time series.
40    pub fn fit<D: Data + Clone>(&mut self, y: D) -> Result<()> {
41        let mut y = y.as_slice().to_vec();
42        self.pipeline.fit_transform(&mut y)?;
43        self.fitted = Some(self.model.fit(&y).map_err(|e| Error::Fit {
44            source: Box::new(e) as _,
45        })?);
46        Ok(())
47    }
48
49    fn fitted(&self) -> Result<&M::Fitted> {
50        self.fitted.as_ref().ok_or(Error::ModelNotYetFit)
51    }
52
53    /// Predict the next `horizon` values, optionally including prediction
54    /// intervals at the given level.
55    pub fn predict(&self, horizon: usize, level: impl Into<Option<f64>>) -> Result<Forecast> {
56        let mut untransformed =
57            self.fitted()?
58                .predict(horizon, level.into())
59                .map_err(|e| Error::Predict {
60                    source: Box::new(e) as _,
61                })?;
62        self.pipeline
63            .inverse_transform_forecast(&mut untransformed)?;
64        Ok(untransformed)
65    }
66
67    /// Produce in-sample forecasts, optionally including prediction intervals
68    /// at the given level.
69    pub fn predict_in_sample(&self, level: impl Into<Option<f64>>) -> Result<Forecast> {
70        let mut untransformed = self
71            .fitted()?
72            .predict_in_sample(level.into())
73            .map_err(|e| Error::Predict {
74                source: Box::new(e) as _,
75            })?;
76        self.pipeline
77            .inverse_transform_forecast(&mut untransformed)?;
78        Ok(untransformed)
79    }
80}
81
82#[cfg(test)]
83mod test {
84
85    use augurs::mstl::{MSTLModel, NaiveTrend};
86    use augurs_testing::assert_all_close;
87
88    use crate::transforms::{BoxCox, LinearInterpolator, Logit, MinMaxScaler, YeoJohnson};
89
90    use super::*;
91
92    #[test]
93    fn test_forecaster() {
94        let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
95        let transformers = vec![
96            LinearInterpolator::new().boxed(),
97            MinMaxScaler::new().boxed(),
98            Logit::new().boxed(),
99        ];
100        let model = MSTLModel::new(vec![2], NaiveTrend::new());
101        let mut forecaster = Forecaster::new(model).with_transformers(transformers);
102        forecaster.fit(data).unwrap();
103        let forecasts = forecaster.predict(4, None).unwrap();
104        assert_all_close(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]);
105    }
106
107    #[test]
108    fn test_forecaster_power_positive() {
109        let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
110        let transformers = vec![BoxCox::new().boxed()];
111        let model = MSTLModel::new(vec![2], NaiveTrend::new());
112        let mut forecaster = Forecaster::new(model).with_transformers(transformers);
113        forecaster.fit(data).unwrap();
114        let forecasts = forecaster.predict(4, None).unwrap();
115        assert_all_close(
116            &forecasts.point,
117            &[
118                5.084499064884572,
119                5.000000030329821,
120                5.084499064884572,
121                5.000000030329821,
122            ],
123        );
124    }
125
126    #[test]
127    fn test_forecaster_power_non_positive() {
128        let data = &[0.0, 2.0, 3.0, 4.0, 5.0];
129        let transformers = vec![YeoJohnson::new().boxed()];
130        let model = MSTLModel::new(vec![2], NaiveTrend::new());
131        let mut forecaster = Forecaster::new(model).with_transformers(transformers);
132        forecaster.fit(data).unwrap();
133        let forecasts = forecaster.predict(4, None).unwrap();
134        assert_all_close(
135            &forecasts.point,
136            &[
137                5.205557727170964,
138                5.000000132803496,
139                5.205557727170964,
140                5.000000132803496,
141            ],
142        );
143    }
144}