augurs_mstl/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use stlrs::MstlResult;
4use tracing::instrument;
5
6use augurs_core::{Forecast, ForecastIntervals, ModelError, Predict};
7
8mod trend;
9
10pub use crate::trend::{FittedTrendModel, NaiveTrend, TrendModel};
11pub use stlrs;
12
13/// Errors that can occur when using this crate.
14#[derive(Debug, thiserror::Error)]
15pub enum Error {
16    /// An error occurred while running the MSTL algorithm.
17    #[error("fitting MSTL: {0}")]
18    MSTL(String),
19    /// An error occurred while running the STL algorithm.
20    #[error("running STL: {0}")]
21    STL(#[from] stlrs::Error),
22    /// An error occurred while fitting or predicting using the trend model.
23    #[error("trend model error: {0}")]
24    TrendModel(Box<dyn std::error::Error + Send + Sync + 'static>),
25}
26
27type Result<T> = std::result::Result<T, Error>;
28
29/// A model that uses the [MSTL] to decompose a time series into trend,
30/// seasonal and remainder components, and then uses a trend model to
31/// forecast the trend component.
32///
33/// [MSTL]: https://arxiv.org/abs/2107.13462
34#[derive(Debug)]
35pub struct MSTLModel<T> {
36    /// Periodicity of the seasonal components.
37    periods: Vec<usize>,
38    mstl_params: stlrs::MstlParams,
39
40    trend_model: T,
41
42    impute: bool,
43}
44
45impl MSTLModel<NaiveTrend> {
46    /// Create a new MSTL model with a naive trend model.
47    ///
48    /// The naive trend model predicts the last value in the training set
49    /// and so is unlikely to be useful for real applications, but it can
50    /// be useful for testing, benchmarking and pedagogy.
51    pub fn naive(periods: Vec<usize>) -> Self {
52        Self::new(periods, NaiveTrend::new())
53    }
54}
55
56impl<T: TrendModel> MSTLModel<T> {
57    /// Return a reference to the trend model.
58    pub fn trend_model(&self) -> &T {
59        &self.trend_model
60    }
61}
62
63impl<T: TrendModel> MSTLModel<T> {
64    /// Create a new MSTL model with the given trend model.
65    pub fn new(periods: Vec<usize>, trend_model: T) -> Self {
66        Self {
67            periods,
68            mstl_params: stlrs::MstlParams::new(),
69            trend_model,
70            impute: false,
71        }
72    }
73
74    /// Set whether to impute missing values in the time series.
75    ///
76    /// If `true`, then missing values will be imputed using
77    /// linear interpolation before fitting the model.
78    pub fn impute(mut self, impute: bool) -> Self {
79        self.impute = impute;
80        self
81    }
82
83    /// Set the parameters for the MSTL algorithm.
84    ///
85    /// This can be used to control the parameters for the inner STL algorithm
86    /// by using [`stlrs::MstlParams`].
87    pub fn mstl_params(mut self, params: stlrs::MstlParams) -> Self {
88        self.mstl_params = params;
89        self
90    }
91
92    /// Fit the model to the given time series.
93    ///
94    /// # Errors
95    ///
96    /// If no periods are specified, or if all periods are greater than
97    /// half the length of the time series, then an error is returned.
98    ///
99    /// Any errors returned by the STL algorithm or trend model
100    /// are also propagated.
101    #[instrument(skip_all)]
102    fn fit_impl(&self, y: &[f64]) -> Result<FittedMSTLModel> {
103        let y: Vec<f32> = y.iter().copied().map(|y| y as f32).collect::<Vec<_>>();
104        let fit = self.mstl_params.fit(&y, &self.periods)?;
105        // Determine the differencing term for the trend component.
106        let trend = fit.trend();
107        let residual = fit.remainder();
108        let deseasonalised = trend
109            .iter()
110            .zip(residual)
111            .map(|(t, r)| (t + r) as f64)
112            .collect::<Vec<_>>();
113        let fitted_trend_model = self
114            .trend_model
115            .fit(&deseasonalised)
116            .map_err(Error::TrendModel)?;
117        tracing::trace!(
118            trend_model = ?self.trend_model,
119            "found best trend model",
120        );
121        Ok(FittedMSTLModel {
122            periods: self.periods.clone(),
123            fit,
124            fitted_trend_model,
125        })
126    }
127}
128
129/// A model that uses the [MSTL] to decompose a time series into trend,
130/// seasonal and remainder components, and then uses a trend model to
131/// forecast the trend component.
132///
133/// [MSTL]: https://arxiv.org/abs/2107.13462
134#[derive(Debug)]
135pub struct FittedMSTLModel {
136    /// Periodicity of the seasonal components.
137    periods: Vec<usize>,
138    fit: MstlResult,
139    fitted_trend_model: Box<dyn FittedTrendModel + Sync + Send>,
140}
141
142impl FittedMSTLModel {
143    /// Return the MSTL fit of the training data.
144    pub fn fit(&self) -> &MstlResult {
145        &self.fit
146    }
147}
148
149impl FittedMSTLModel {
150    fn predict_impl(
151        &self,
152        horizon: usize,
153        level: Option<f64>,
154        forecast: &mut Forecast,
155    ) -> Result<()> {
156        if horizon == 0 {
157            return Ok(());
158        }
159        self.fitted_trend_model
160            .predict_inplace(horizon, level, forecast)
161            .map_err(Error::TrendModel)?;
162        self.add_seasonal_out_of_sample(forecast);
163        Ok(())
164    }
165
166    fn predict_in_sample_impl(&self, level: Option<f64>, forecast: &mut Forecast) -> Result<()> {
167        self.fitted_trend_model
168            .predict_in_sample_inplace(level, forecast)
169            .map_err(Error::TrendModel)?;
170        self.add_seasonal_in_sample(forecast);
171        Ok(())
172    }
173
174    fn add_seasonal_in_sample(&self, trend: &mut Forecast) {
175        self.fit().seasonal().iter().for_each(|component| {
176            let period_contributions = component.iter().zip(trend.point.iter_mut());
177            match &mut trend.intervals {
178                None => period_contributions.for_each(|(c, p)| *p += *c as f64),
179                Some(ForecastIntervals {
180                    ref mut lower,
181                    ref mut upper,
182                    ..
183                }) => {
184                    period_contributions
185                        .zip(lower.iter_mut())
186                        .zip(upper.iter_mut())
187                        .for_each(|(((c, p), l), u)| {
188                            *p += *c as f64;
189                            *l += *c as f64;
190                            *u += *c as f64;
191                        });
192                }
193            }
194        });
195    }
196
197    fn add_seasonal_out_of_sample(&self, trend: &mut Forecast) {
198        self.periods
199            .iter()
200            .zip(self.fit().seasonal())
201            .for_each(|(period, component)| {
202                // For each seasonal period we're going to create a cycle iterator
203                // which will repeat the seasonal component every `period` steps.
204                // We'll zip it up with the trend point estimates and add the
205                // contribution of the seasonal component to the trend.
206                // If there are intervals, we'll also add the contribution to those.
207                let period_contributions = component
208                    .iter()
209                    .copied()
210                    .skip(component.len() - period)
211                    .cycle()
212                    .zip(trend.point.iter_mut());
213                match &mut trend.intervals {
214                    None => period_contributions.for_each(|(c, p)| *p += c as f64),
215                    Some(ForecastIntervals {
216                        ref mut lower,
217                        ref mut upper,
218                        ..
219                    }) => {
220                        period_contributions
221                            .zip(lower.iter_mut())
222                            .zip(upper.iter_mut())
223                            .for_each(|(((c, p), l), u)| {
224                                *p += c as f64;
225                                *l += c as f64;
226                                *u += c as f64;
227                            });
228                    }
229                }
230            });
231    }
232}
233
234impl ModelError for Error {}
235
236impl<T: TrendModel> augurs_core::Fit for MSTLModel<T> {
237    type Fitted = FittedMSTLModel;
238    type Error = Error;
239    fn fit(&self, y: &[f64]) -> Result<Self::Fitted> {
240        self.fit_impl(y)
241    }
242}
243
244impl Predict for FittedMSTLModel {
245    type Error = Error;
246
247    fn predict_inplace(
248        &self,
249        horizon: usize,
250        level: Option<f64>,
251        forecast: &mut Forecast,
252    ) -> Result<()> {
253        self.predict_impl(horizon, level, forecast)
254    }
255
256    fn predict_in_sample_inplace(&self, level: Option<f64>, forecast: &mut Forecast) -> Result<()> {
257        self.predict_in_sample_impl(level, forecast)
258    }
259
260    fn training_data_size(&self) -> usize {
261        self.fit().trend().len()
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use augurs_core::prelude::*;
268    use augurs_testing::{assert_all_close, data::VIC_ELEC};
269
270    use crate::{trend::NaiveTrend, ForecastIntervals, MSTLModel};
271
272    #[test]
273    fn results_match_r() {
274        let y = VIC_ELEC.clone();
275
276        let mut stl_params = stlrs::params();
277        stl_params
278            .seasonal_degree(0)
279            .seasonal_jump(1)
280            .trend_degree(1)
281            .trend_jump(1)
282            .low_pass_degree(1)
283            .inner_loops(2)
284            .outer_loops(0);
285        let mut mstl_params = stlrs::MstlParams::new();
286        mstl_params.stl_params(stl_params);
287        let periods = vec![24, 24 * 7];
288        let trend_model = NaiveTrend::new();
289        let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
290        let fit = mstl.fit(&y).unwrap();
291
292        let in_sample = fit.predict_in_sample(0.95).unwrap();
293        // The first 12 values from R.
294        let expected_in_sample = vec![
295            f64::NAN,
296            7952.216,
297            7269.439,
298            6878.110,
299            6606.999,
300            6402.581,
301            6659.523,
302            7457.488,
303            8111.359,
304            8693.762,
305            9255.807,
306            9870.213,
307        ];
308        assert_eq!(in_sample.point.len(), y.len());
309        assert_all_close(&in_sample.point[..12], &expected_in_sample);
310
311        let out_of_sample = fit.predict(10, 0.95).unwrap();
312        let expected_out_of_sample: Vec<f64> = vec![
313            8920.670, 8874.234, 8215.508, 7782.726, 7697.259, 8216.241, 9664.907, 10914.452,
314            11536.929, 11664.737,
315        ];
316        let expected_out_of_sample_lower = vec![
317            8700.984, 8563.551, 7835.001, 7343.354, 7206.026, 7678.122, 9083.672, 10293.087,
318            10877.871, 10970.029,
319        ];
320        let expected_out_of_sample_upper = vec![
321            9140.356, 9184.917, 8596.016, 8222.098, 8188.491, 8754.359, 10246.141, 11535.818,
322            12195.987, 12359.445,
323        ];
324        assert_eq!(out_of_sample.point.len(), 10);
325        assert_all_close(&out_of_sample.point, &expected_out_of_sample);
326        let ForecastIntervals { lower, upper, .. } = out_of_sample.intervals.unwrap();
327        assert_eq!(lower.len(), 10);
328        assert_eq!(upper.len(), 10);
329        assert_all_close(&lower, &expected_out_of_sample_lower);
330        assert_all_close(&upper, &expected_out_of_sample_upper);
331    }
332
333    #[test]
334    fn predict_zero_horizon() {
335        let y = VIC_ELEC.clone();
336
337        let mut stl_params = stlrs::params();
338        stl_params
339            .seasonal_degree(0)
340            .seasonal_jump(1)
341            .trend_degree(1)
342            .trend_jump(1)
343            .low_pass_degree(1)
344            .inner_loops(2)
345            .outer_loops(0);
346        let mut mstl_params = stlrs::MstlParams::new();
347        mstl_params.stl_params(stl_params);
348        let periods = vec![24, 24 * 7];
349        let trend_model = NaiveTrend::new();
350        let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
351        let fit = mstl.fit(&y).unwrap();
352        let forecast = fit.predict(0, 0.95).unwrap();
353        assert!(forecast.point.is_empty());
354        let ForecastIntervals { lower, upper, .. } = forecast.intervals.unwrap();
355        assert!(lower.is_empty());
356        assert!(upper.is_empty());
357    }
358}