1use std::{
8 borrow::Cow,
9 fmt::{self, Debug},
10};
11
12use crate::{Forecast, ForecastIntervals};
13
14pub trait TrendModel: Debug {
21 fn name(&self) -> Cow<'_, str>;
23
24 fn fit(
30 &self,
31 y: &[f64],
32 ) -> Result<
33 Box<dyn FittedTrendModel + Sync + Send>,
34 Box<dyn std::error::Error + Send + Sync + 'static>,
35 >;
36}
37
38pub trait FittedTrendModel: Debug {
40 fn predict_inplace(
46 &self,
47 horizon: usize,
48 level: Option<f64>,
49 forecast: &mut Forecast,
50 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;
51
52 fn predict_in_sample_inplace(
60 &self,
61 level: Option<f64>,
62 forecast: &mut Forecast,
63 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;
64
65 fn predict(
78 &self,
79 horizon: usize,
80 level: Option<f64>,
81 ) -> Result<Forecast, Box<dyn std::error::Error + Send + Sync + 'static>> {
82 let mut forecast = level
83 .map(|l| Forecast::with_capacity_and_level(horizon, l))
84 .unwrap_or_else(|| Forecast::with_capacity(horizon));
85 self.predict_inplace(horizon, level, &mut forecast)?;
86 Ok(forecast)
87 }
88
89 fn predict_in_sample(
102 &self,
103 level: Option<f64>,
104 ) -> Result<Forecast, Box<dyn std::error::Error + Send + Sync + 'static>> {
105 let mut forecast = level
106 .zip(self.training_data_size())
107 .map(|(l, c)| Forecast::with_capacity_and_level(c, l))
108 .unwrap_or_else(|| Forecast::with_capacity(0));
109 self.predict_in_sample_inplace(level, &mut forecast)?;
110 Ok(forecast)
111 }
112
113 fn training_data_size(&self) -> Option<usize>;
115}
116
117impl<T: TrendModel + ?Sized> TrendModel for Box<T> {
118 fn name(&self) -> Cow<'_, str> {
119 (**self).name()
120 }
121
122 fn fit(
123 &self,
124 y: &[f64],
125 ) -> Result<
126 Box<dyn FittedTrendModel + Sync + Send>,
127 Box<dyn std::error::Error + Send + Sync + 'static>,
128 > {
129 (**self).fit(y)
130 }
131}
132
133impl<T: FittedTrendModel + ?Sized> FittedTrendModel for Box<T> {
134 fn predict_inplace(
135 &self,
136 horizon: usize,
137 level: Option<f64>,
138 forecast: &mut Forecast,
139 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
140 (**self).predict_inplace(horizon, level, forecast)
141 }
142
143 fn predict_in_sample_inplace(
144 &self,
145 level: Option<f64>,
146 forecast: &mut Forecast,
147 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
148 (**self).predict_in_sample_inplace(level, forecast)
149 }
150
151 fn training_data_size(&self) -> Option<usize> {
152 (**self).training_data_size()
153 }
154}
155
156#[derive(Clone, Default)]
159pub struct NaiveTrend {
160 fitted: Option<Vec<f64>>,
161 last_value: Option<f64>,
162 sigma_squared: Option<f64>,
163}
164
165impl fmt::Debug for NaiveTrend {
166 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
167 f.debug_struct("NaiveTrend")
168 .field(
169 "y",
170 &self
171 .fitted
172 .as_ref()
173 .map(|y| format!("<omitted vec, length {}>", y.len())),
174 )
175 .field("last_value", &self.last_value)
176 .field("sigma", &self.sigma_squared)
177 .finish()
178 }
179}
180
181impl NaiveTrend {
182 pub const fn new() -> Self {
184 Self {
185 fitted: None,
186 last_value: None,
187 sigma_squared: None,
188 }
189 }
190}
191
192impl TrendModel for NaiveTrend {
193 fn name(&self) -> Cow<'_, str> {
194 Cow::Borrowed("Naive")
195 }
196
197 fn fit(
198 &self,
199 y: &[f64],
200 ) -> Result<
201 Box<dyn FittedTrendModel + Sync + Send>,
202 Box<dyn std::error::Error + Send + Sync + 'static>,
203 > {
204 let last_value = y[y.len() - 1];
205 let fitted: Vec<f64> = std::iter::once(f64::NAN)
206 .chain(y.iter().copied())
207 .take(y.len())
208 .collect();
209 let sigma_squared = y
210 .iter()
211 .zip(&fitted)
212 .filter_map(|(y, f)| {
213 if f.is_nan() {
214 None
215 } else {
216 Some((y - f).powi(2))
217 }
218 })
219 .sum::<f64>()
220 / (y.len() - 1) as f64;
221
222 Ok(Box::new(NaiveTrendFitted {
223 last_value,
224 fitted,
225 sigma_squared,
226 }))
227 }
228}
229
230#[derive(Debug, Clone)]
231struct NaiveTrendFitted {
232 last_value: f64,
233 sigma_squared: f64,
234 fitted: Vec<f64>,
235}
236
237impl NaiveTrendFitted {
238 fn prediction_intervals(
239 &self,
240 preds: impl Iterator<Item = f64>,
241 level: f64,
242 sigma: impl Iterator<Item = f64>,
243 intervals: &mut ForecastIntervals,
244 ) {
245 intervals.level = level;
246 let z = distrs::Normal::ppf(0.5 + level / 2.0, 0.0, 1.0);
247 (intervals.lower, intervals.upper) = preds
248 .zip(sigma)
249 .map(|(p, s)| (p - z * s, p + z * s))
250 .unzip();
251 }
252}
253
254impl FittedTrendModel for NaiveTrendFitted {
255 fn predict_inplace(
256 &self,
257 horizon: usize,
258 level: Option<f64>,
259 forecast: &mut Forecast,
260 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
261 forecast.point = vec![self.last_value; horizon];
262 if let Some(level) = level {
263 let sigmas = (1..horizon + 1).map(|step| ((step as f64) * self.sigma_squared).sqrt());
264 let intervals = forecast
265 .intervals
266 .get_or_insert_with(|| ForecastIntervals::with_capacity(level, horizon));
267 self.prediction_intervals(std::iter::repeat(self.last_value), level, sigmas, intervals);
268 }
269 Ok(())
270 }
271
272 fn predict_in_sample_inplace(
273 &self,
274 level: Option<f64>,
275 forecast: &mut Forecast,
276 ) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
277 forecast.point.clone_from(&self.fitted);
278 if let Some(level) = level {
279 let intervals = forecast
280 .intervals
281 .get_or_insert_with(|| ForecastIntervals::with_capacity(level, self.fitted.len()));
282 self.prediction_intervals(
283 self.fitted.iter().copied(),
284 level,
285 std::iter::repeat(self.sigma_squared.sqrt()),
286 intervals,
287 );
288 }
289 Ok(())
290 }
291
292 fn training_data_size(&self) -> Option<usize> {
293 Some(self.fitted.len())
294 }
295}