augurs_mstl/
trend.rs

1//! Trend models.
2//!
3//! Contains the [`TrendModel`] trait and an implementation of a basic
4//! naive trend model.
5// TODO: decide where this should live. Perhaps it's more general than just MSTL?
6
7use std::{
8    borrow::Cow,
9    fmt::{self, Debug},
10};
11
12use crate::{Forecast, ForecastIntervals};
13
14/// A trend model.
15///
16/// Trend models are used to model the trend component of a time series.
17/// Examples implemented in other languages include ARIMA, Theta and ETS.
18///
19/// You can implement this trait for your own trend models.
20pub trait TrendModel: Debug {
21    /// Return the name of the trend model.
22    fn name(&self) -> Cow<'_, str>;
23
24    /// Fit the model to the given time series.
25    ///
26    /// This method is called once before any calls to `predict` or `predict_in_sample`.
27    ///
28    /// Implementations should store any state required for prediction in the struct itself.
29    fn fit(
30        &self,
31        y: &[f64],
32    ) -> Result<
33        Box<dyn FittedTrendModel + Sync + Send>,
34        Box<dyn std::error::Error + Send + Sync + 'static>,
35    >;
36}
37
38/// A fitted trend model.
39pub trait FittedTrendModel: Debug {
40    /// Produce a forecast for the next `horizon` time points.
41    ///
42    /// The `level` parameter specifies the confidence level for the prediction intervals.
43    /// Where possible, implementations should provide prediction intervals
44    /// alongside the point forecasts if `level` is not `None`.
45    fn predict_inplace(
46        &self,
47        horizon: usize,
48        level: Option<f64>,
49        forecast: &mut Forecast,
50    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;
51
52    /// Produce in-sample predictions.
53    ///
54    /// In-sample predictions are used to assess the fit of the model to the training data.
55    ///
56    /// The `level` parameter specifies the confidence level for the prediction intervals.
57    /// Where possible, implementations should provide prediction intervals
58    /// alongside the point forecasts if `level` is not `None`.
59    fn predict_in_sample_inplace(
60        &self,
61        level: Option<f64>,
62        forecast: &mut Forecast,
63    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;
64
65    /// Return the n-ahead predictions for the given horizon.
66    ///
67    /// The predictions are point forecasts and optionally include
68    /// prediction intervals at the specified `level`.
69    ///
70    /// `level` should be a float between 0 and 1 representing the
71    /// confidence level of the prediction intervals. If `None` then
72    /// no prediction intervals are returned.
73    ///
74    /// # Errors
75    ///
76    /// Any errors returned by the trend model are propagated.
77    fn predict(
78        &self,
79        horizon: usize,
80        level: Option<f64>,
81    ) -> Result<Forecast, Box<dyn std::error::Error + Send + Sync + 'static>> {
82        let mut forecast = level
83            .map(|l| Forecast::with_capacity_and_level(horizon, l))
84            .unwrap_or_else(|| Forecast::with_capacity(horizon));
85        self.predict_inplace(horizon, level, &mut forecast)?;
86        Ok(forecast)
87    }
88
89    /// Return the in-sample predictions.
90    ///
91    /// The predictions are point forecasts and optionally include
92    /// prediction intervals at the specified `level`.
93    ///
94    /// `level` should be a float between 0 and 1 representing the
95    /// confidence level of the prediction intervals. If `None` then
96    /// no prediction intervals are returned.
97    ///
98    /// # Errors
99    ///
100    /// Any errors returned by the trend model are propagated.
101    fn predict_in_sample(
102        &self,
103        level: Option<f64>,
104    ) -> Result<Forecast, Box<dyn std::error::Error + Send + Sync + 'static>> {
105        let mut forecast = level
106            .zip(self.training_data_size())
107            .map(|(l, c)| Forecast::with_capacity_and_level(c, l))
108            .unwrap_or_else(|| Forecast::with_capacity(0));
109        self.predict_in_sample_inplace(level, &mut forecast)?;
110        Ok(forecast)
111    }
112
113    /// Return the number of training data points used to fit the model.
114    fn training_data_size(&self) -> Option<usize>;
115}
116
117impl<T: TrendModel + ?Sized> TrendModel for Box<T> {
118    fn name(&self) -> Cow<'_, str> {
119        (**self).name()
120    }
121
122    fn fit(
123        &self,
124        y: &[f64],
125    ) -> Result<
126        Box<dyn FittedTrendModel + Sync + Send>,
127        Box<dyn std::error::Error + Send + Sync + 'static>,
128    > {
129        (**self).fit(y)
130    }
131}
132
133impl<T: FittedTrendModel + ?Sized> FittedTrendModel for Box<T> {
134    fn predict_inplace(
135        &self,
136        horizon: usize,
137        level: Option<f64>,
138        forecast: &mut Forecast,
139    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
140        (**self).predict_inplace(horizon, level, forecast)
141    }
142
143    fn predict_in_sample_inplace(
144        &self,
145        level: Option<f64>,
146        forecast: &mut Forecast,
147    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
148        (**self).predict_in_sample_inplace(level, forecast)
149    }
150
151    fn training_data_size(&self) -> Option<usize> {
152        (**self).training_data_size()
153    }
154}
155
156/// A naive trend model that predicts the last value in the training set
157/// for all future time points.
158#[derive(Clone, Default)]
159pub struct NaiveTrend {
160    fitted: Option<Vec<f64>>,
161    last_value: Option<f64>,
162    sigma_squared: Option<f64>,
163}
164
165impl fmt::Debug for NaiveTrend {
166    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167        f.debug_struct("NaiveTrend")
168            .field(
169                "y",
170                &self
171                    .fitted
172                    .as_ref()
173                    .map(|y| format!("<omitted vec, length {}>", y.len())),
174            )
175            .field("last_value", &self.last_value)
176            .field("sigma", &self.sigma_squared)
177            .finish()
178    }
179}
180
181impl NaiveTrend {
182    /// Create a new naive trend model.
183    pub const fn new() -> Self {
184        Self {
185            fitted: None,
186            last_value: None,
187            sigma_squared: None,
188        }
189    }
190}
191
192impl TrendModel for NaiveTrend {
193    fn name(&self) -> Cow<'_, str> {
194        Cow::Borrowed("Naive")
195    }
196
197    fn fit(
198        &self,
199        y: &[f64],
200    ) -> Result<
201        Box<dyn FittedTrendModel + Sync + Send>,
202        Box<dyn std::error::Error + Send + Sync + 'static>,
203    > {
204        let last_value = y[y.len() - 1];
205        let fitted: Vec<f64> = std::iter::once(f64::NAN)
206            .chain(y.iter().copied())
207            .take(y.len())
208            .collect();
209        let sigma_squared = y
210            .iter()
211            .zip(&fitted)
212            .filter_map(|(y, f)| {
213                if f.is_nan() {
214                    None
215                } else {
216                    Some((y - f).powi(2))
217                }
218            })
219            .sum::<f64>()
220            / (y.len() - 1) as f64;
221
222        Ok(Box::new(NaiveTrendFitted {
223            last_value,
224            fitted,
225            sigma_squared,
226        }))
227    }
228}
229
230#[derive(Debug, Clone)]
231struct NaiveTrendFitted {
232    last_value: f64,
233    sigma_squared: f64,
234    fitted: Vec<f64>,
235}
236
237impl NaiveTrendFitted {
238    fn prediction_intervals(
239        &self,
240        preds: impl Iterator<Item = f64>,
241        level: f64,
242        sigma: impl Iterator<Item = f64>,
243        intervals: &mut ForecastIntervals,
244    ) {
245        intervals.level = level;
246        let z = distrs::Normal::ppf(0.5 + level / 2.0, 0.0, 1.0);
247        (intervals.lower, intervals.upper) = preds
248            .zip(sigma)
249            .map(|(p, s)| (p - z * s, p + z * s))
250            .unzip();
251    }
252}
253
254impl FittedTrendModel for NaiveTrendFitted {
255    fn predict_inplace(
256        &self,
257        horizon: usize,
258        level: Option<f64>,
259        forecast: &mut Forecast,
260    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
261        forecast.point = vec![self.last_value; horizon];
262        if let Some(level) = level {
263            let sigmas = (1..horizon + 1).map(|step| ((step as f64) * self.sigma_squared).sqrt());
264            let intervals = forecast
265                .intervals
266                .get_or_insert_with(|| ForecastIntervals::with_capacity(level, horizon));
267            self.prediction_intervals(std::iter::repeat(self.last_value), level, sigmas, intervals);
268        }
269        Ok(())
270    }
271
272    fn predict_in_sample_inplace(
273        &self,
274        level: Option<f64>,
275        forecast: &mut Forecast,
276    ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
277        forecast.point.clone_from(&self.fitted);
278        if let Some(level) = level {
279            let intervals = forecast
280                .intervals
281                .get_or_insert_with(|| ForecastIntervals::with_capacity(level, self.fitted.len()));
282            self.prediction_intervals(
283                self.fitted.iter().copied(),
284                level,
285                std::iter::repeat(self.sigma_squared.sqrt()),
286                intervals,
287            );
288        }
289        Ok(())
290    }
291
292    fn training_data_size(&self) -> Option<usize> {
293        Some(self.fitted.len())
294    }
295}