augurs_forecaster/
forecaster.rs1use augurs_core::{Fit, Forecast, Predict};
2
3use crate::{Data, Error, Pipeline, Result, Transformer};
4
5#[derive(Debug)]
12pub struct Forecaster<M: Fit> {
13 model: M,
14 fitted: Option<M::Fitted>,
15
16 pipeline: Pipeline,
17}
18
19impl<M> Forecaster<M>
20where
21 M: Fit,
22 M::Fitted: Predict,
23{
24 pub fn new(model: M) -> Self {
26 Self {
27 model,
28 fitted: None,
29 pipeline: Pipeline::default(),
30 }
31 }
32
33 pub fn with_transformers(mut self, transformers: Vec<Box<dyn Transformer>>) -> Self {
35 self.pipeline = Pipeline::new(transformers);
36 self
37 }
38
39 pub fn fit<D: Data + Clone>(&mut self, y: D) -> Result<()> {
41 let mut y = y.as_slice().to_vec();
42 self.pipeline.fit_transform(&mut y)?;
43 self.fitted = Some(self.model.fit(&y).map_err(|e| Error::Fit {
44 source: Box::new(e) as _,
45 })?);
46 Ok(())
47 }
48
49 fn fitted(&self) -> Result<&M::Fitted> {
50 self.fitted.as_ref().ok_or(Error::ModelNotYetFit)
51 }
52
53 pub fn predict(&self, horizon: usize, level: impl Into<Option<f64>>) -> Result<Forecast> {
56 let mut untransformed =
57 self.fitted()?
58 .predict(horizon, level.into())
59 .map_err(|e| Error::Predict {
60 source: Box::new(e) as _,
61 })?;
62 self.pipeline
63 .inverse_transform_forecast(&mut untransformed)?;
64 Ok(untransformed)
65 }
66
67 pub fn predict_in_sample(&self, level: impl Into<Option<f64>>) -> Result<Forecast> {
70 let mut untransformed = self
71 .fitted()?
72 .predict_in_sample(level.into())
73 .map_err(|e| Error::Predict {
74 source: Box::new(e) as _,
75 })?;
76 self.pipeline
77 .inverse_transform_forecast(&mut untransformed)?;
78 Ok(untransformed)
79 }
80}
81
82#[cfg(test)]
83mod test {
84
85 use augurs::mstl::{MSTLModel, NaiveTrend};
86 use augurs_testing::assert_all_close;
87
88 use crate::transforms::{BoxCox, LinearInterpolator, Logit, MinMaxScaler, YeoJohnson};
89
90 use super::*;
91
92 #[test]
93 fn test_forecaster() {
94 let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
95 let transformers = vec![
96 LinearInterpolator::new().boxed(),
97 MinMaxScaler::new().boxed(),
98 Logit::new().boxed(),
99 ];
100 let model = MSTLModel::new(vec![2], NaiveTrend::new());
101 let mut forecaster = Forecaster::new(model).with_transformers(transformers);
102 forecaster.fit(data).unwrap();
103 let forecasts = forecaster.predict(4, None).unwrap();
104 assert_all_close(&forecasts.point, &[5.0, 5.0, 5.0, 5.0]);
105 }
106
107 #[test]
108 fn test_forecaster_power_positive() {
109 let data = &[1.0_f64, 2.0, 3.0, 4.0, 5.0];
110 let transformers = vec![BoxCox::new().boxed()];
111 let model = MSTLModel::new(vec![2], NaiveTrend::new());
112 let mut forecaster = Forecaster::new(model).with_transformers(transformers);
113 forecaster.fit(data).unwrap();
114 let forecasts = forecaster.predict(4, None).unwrap();
115 assert_all_close(
116 &forecasts.point,
117 &[
118 5.084499064884572,
119 5.000000030329821,
120 5.084499064884572,
121 5.000000030329821,
122 ],
123 );
124 }
125
126 #[test]
127 fn test_forecaster_power_non_positive() {
128 let data = &[0.0, 2.0, 3.0, 4.0, 5.0];
129 let transformers = vec![YeoJohnson::new().boxed()];
130 let model = MSTLModel::new(vec![2], NaiveTrend::new());
131 let mut forecaster = Forecaster::new(model).with_transformers(transformers);
132 forecaster.fit(data).unwrap();
133 let forecasts = forecaster.predict(4, None).unwrap();
134 assert_all_close(
135 &forecasts.point,
136 &[
137 5.205557727170964,
138 5.000000132803496,
139 5.205557727170964,
140 5.000000132803496,
141 ],
142 );
143 }
144}