1#![doc = include_str!("../README.md")]
2
3use stlrs::MstlResult;
4use tracing::instrument;
5
6use augurs_core::{Forecast, ForecastIntervals, ModelError, Predict};
7
8mod trend;
9
10pub use crate::trend::{FittedTrendModel, NaiveTrend, TrendModel};
11pub use stlrs;
12
13#[derive(Debug, thiserror::Error)]
15pub enum Error {
16 #[error("fitting MSTL: {0}")]
18 MSTL(String),
19 #[error("running STL: {0}")]
21 STL(#[from] stlrs::Error),
22 #[error("trend model error: {0}")]
24 TrendModel(Box<dyn std::error::Error + Send + Sync + 'static>),
25}
26
27type Result<T> = std::result::Result<T, Error>;
28
29#[derive(Debug)]
35pub struct MSTLModel<T> {
36 periods: Vec<usize>,
38 mstl_params: stlrs::MstlParams,
39
40 trend_model: T,
41
42 impute: bool,
43}
44
45impl MSTLModel<NaiveTrend> {
46 pub fn naive(periods: Vec<usize>) -> Self {
52 Self::new(periods, NaiveTrend::new())
53 }
54}
55
56impl<T: TrendModel> MSTLModel<T> {
57 pub fn trend_model(&self) -> &T {
59 &self.trend_model
60 }
61}
62
63impl<T: TrendModel> MSTLModel<T> {
64 pub fn new(periods: Vec<usize>, trend_model: T) -> Self {
66 Self {
67 periods,
68 mstl_params: stlrs::MstlParams::new(),
69 trend_model,
70 impute: false,
71 }
72 }
73
74 pub fn impute(mut self, impute: bool) -> Self {
79 self.impute = impute;
80 self
81 }
82
83 pub fn mstl_params(mut self, params: stlrs::MstlParams) -> Self {
88 self.mstl_params = params;
89 self
90 }
91
92 #[instrument(skip_all)]
102 fn fit_impl(&self, y: &[f64]) -> Result<FittedMSTLModel> {
103 let y: Vec<f32> = y.iter().copied().map(|y| y as f32).collect::<Vec<_>>();
104 let fit = self.mstl_params.fit(&y, &self.periods)?;
105 let trend = fit.trend();
107 let residual = fit.remainder();
108 let deseasonalised = trend
109 .iter()
110 .zip(residual)
111 .map(|(t, r)| (t + r) as f64)
112 .collect::<Vec<_>>();
113 let fitted_trend_model = self
114 .trend_model
115 .fit(&deseasonalised)
116 .map_err(Error::TrendModel)?;
117 tracing::trace!(
118 trend_model = ?self.trend_model,
119 "found best trend model",
120 );
121 Ok(FittedMSTLModel {
122 periods: self.periods.clone(),
123 fit,
124 fitted_trend_model,
125 })
126 }
127}
128
129#[derive(Debug)]
135pub struct FittedMSTLModel {
136 periods: Vec<usize>,
138 fit: MstlResult,
139 fitted_trend_model: Box<dyn FittedTrendModel + Sync + Send>,
140}
141
142impl FittedMSTLModel {
143 pub fn fit(&self) -> &MstlResult {
145 &self.fit
146 }
147}
148
149impl FittedMSTLModel {
150 fn predict_impl(
151 &self,
152 horizon: usize,
153 level: Option<f64>,
154 forecast: &mut Forecast,
155 ) -> Result<()> {
156 if horizon == 0 {
157 return Ok(());
158 }
159 self.fitted_trend_model
160 .predict_inplace(horizon, level, forecast)
161 .map_err(Error::TrendModel)?;
162 self.add_seasonal_out_of_sample(forecast);
163 Ok(())
164 }
165
166 fn predict_in_sample_impl(&self, level: Option<f64>, forecast: &mut Forecast) -> Result<()> {
167 self.fitted_trend_model
168 .predict_in_sample_inplace(level, forecast)
169 .map_err(Error::TrendModel)?;
170 self.add_seasonal_in_sample(forecast);
171 Ok(())
172 }
173
174 fn add_seasonal_in_sample(&self, trend: &mut Forecast) {
175 self.fit().seasonal().iter().for_each(|component| {
176 let period_contributions = component.iter().zip(trend.point.iter_mut());
177 match &mut trend.intervals {
178 None => period_contributions.for_each(|(c, p)| *p += *c as f64),
179 Some(ForecastIntervals {
180 ref mut lower,
181 ref mut upper,
182 ..
183 }) => {
184 period_contributions
185 .zip(lower.iter_mut())
186 .zip(upper.iter_mut())
187 .for_each(|(((c, p), l), u)| {
188 *p += *c as f64;
189 *l += *c as f64;
190 *u += *c as f64;
191 });
192 }
193 }
194 });
195 }
196
197 fn add_seasonal_out_of_sample(&self, trend: &mut Forecast) {
198 self.periods
199 .iter()
200 .zip(self.fit().seasonal())
201 .for_each(|(period, component)| {
202 let period_contributions = component
208 .iter()
209 .copied()
210 .skip(component.len() - period)
211 .cycle()
212 .zip(trend.point.iter_mut());
213 match &mut trend.intervals {
214 None => period_contributions.for_each(|(c, p)| *p += c as f64),
215 Some(ForecastIntervals {
216 ref mut lower,
217 ref mut upper,
218 ..
219 }) => {
220 period_contributions
221 .zip(lower.iter_mut())
222 .zip(upper.iter_mut())
223 .for_each(|(((c, p), l), u)| {
224 *p += c as f64;
225 *l += c as f64;
226 *u += c as f64;
227 });
228 }
229 }
230 });
231 }
232}
233
234impl ModelError for Error {}
235
236impl<T: TrendModel> augurs_core::Fit for MSTLModel<T> {
237 type Fitted = FittedMSTLModel;
238 type Error = Error;
239 fn fit(&self, y: &[f64]) -> Result<Self::Fitted> {
240 self.fit_impl(y)
241 }
242}
243
244impl Predict for FittedMSTLModel {
245 type Error = Error;
246
247 fn predict_inplace(
248 &self,
249 horizon: usize,
250 level: Option<f64>,
251 forecast: &mut Forecast,
252 ) -> Result<()> {
253 self.predict_impl(horizon, level, forecast)
254 }
255
256 fn predict_in_sample_inplace(&self, level: Option<f64>, forecast: &mut Forecast) -> Result<()> {
257 self.predict_in_sample_impl(level, forecast)
258 }
259
260 fn training_data_size(&self) -> usize {
261 self.fit().trend().len()
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use augurs_core::prelude::*;
268 use augurs_testing::{assert_all_close, data::VIC_ELEC};
269
270 use crate::{trend::NaiveTrend, ForecastIntervals, MSTLModel};
271
272 #[test]
273 fn results_match_r() {
274 let y = VIC_ELEC.clone();
275
276 let mut stl_params = stlrs::params();
277 stl_params
278 .seasonal_degree(0)
279 .seasonal_jump(1)
280 .trend_degree(1)
281 .trend_jump(1)
282 .low_pass_degree(1)
283 .inner_loops(2)
284 .outer_loops(0);
285 let mut mstl_params = stlrs::MstlParams::new();
286 mstl_params.stl_params(stl_params);
287 let periods = vec![24, 24 * 7];
288 let trend_model = NaiveTrend::new();
289 let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
290 let fit = mstl.fit(&y).unwrap();
291
292 let in_sample = fit.predict_in_sample(0.95).unwrap();
293 let expected_in_sample = vec![
295 f64::NAN,
296 7952.216,
297 7269.439,
298 6878.110,
299 6606.999,
300 6402.581,
301 6659.523,
302 7457.488,
303 8111.359,
304 8693.762,
305 9255.807,
306 9870.213,
307 ];
308 assert_eq!(in_sample.point.len(), y.len());
309 assert_all_close(&in_sample.point[..12], &expected_in_sample);
310
311 let out_of_sample = fit.predict(10, 0.95).unwrap();
312 let expected_out_of_sample: Vec<f64> = vec![
313 8920.670, 8874.234, 8215.508, 7782.726, 7697.259, 8216.241, 9664.907, 10914.452,
314 11536.929, 11664.737,
315 ];
316 let expected_out_of_sample_lower = vec![
317 8700.984, 8563.551, 7835.001, 7343.354, 7206.026, 7678.122, 9083.672, 10293.087,
318 10877.871, 10970.029,
319 ];
320 let expected_out_of_sample_upper = vec![
321 9140.356, 9184.917, 8596.016, 8222.098, 8188.491, 8754.359, 10246.141, 11535.818,
322 12195.987, 12359.445,
323 ];
324 assert_eq!(out_of_sample.point.len(), 10);
325 assert_all_close(&out_of_sample.point, &expected_out_of_sample);
326 let ForecastIntervals { lower, upper, .. } = out_of_sample.intervals.unwrap();
327 assert_eq!(lower.len(), 10);
328 assert_eq!(upper.len(), 10);
329 assert_all_close(&lower, &expected_out_of_sample_lower);
330 assert_all_close(&upper, &expected_out_of_sample_upper);
331 }
332
333 #[test]
334 fn predict_zero_horizon() {
335 let y = VIC_ELEC.clone();
336
337 let mut stl_params = stlrs::params();
338 stl_params
339 .seasonal_degree(0)
340 .seasonal_jump(1)
341 .trend_degree(1)
342 .trend_jump(1)
343 .low_pass_degree(1)
344 .inner_loops(2)
345 .outer_loops(0);
346 let mut mstl_params = stlrs::MstlParams::new();
347 mstl_params.stl_params(stl_params);
348 let periods = vec![24, 24 * 7];
349 let trend_model = NaiveTrend::new();
350 let mstl = MSTLModel::new(periods, trend_model).mstl_params(mstl_params);
351 let fit = mstl.fit(&y).unwrap();
352 let forecast = fit.predict(0, 0.95).unwrap();
353 assert!(forecast.point.is_empty());
354 let ForecastIntervals { lower, upper, .. } = forecast.intervals.unwrap();
355 assert!(lower.is_empty());
356 assert!(upper.is_empty());
357 }
358}