oxidiviner_moving_average/
model.rs1use oxidiviner_core::{Forecaster, ModelEvaluation, ModelOutput, OxiError, Result, TimeSeriesData};
2use oxidiviner_math::metrics::{mae, mse, rmse, mape, smape};
3use oxidiviner_math::transforms::moving_average;
4use crate::error::MAError;
5
6pub struct MAModel {
11 name: String,
13 window_size: usize,
15 last_values: Option<Vec<f64>>,
17 fitted_values: Option<Vec<f64>>,
19}
20
21impl MAModel {
22 pub fn new(window_size: usize) -> std::result::Result<Self, MAError> {
30 if window_size == 0 {
32 return Err(MAError::InvalidWindowSize(window_size));
33 }
34
35 let name = format!("MA({})", window_size);
36
37 Ok(MAModel {
38 name,
39 window_size,
40 last_values: None,
41 fitted_values: None,
42 })
43 }
44
45 pub fn fit(&mut self, data: &TimeSeriesData) -> Result<()> {
48 <Self as Forecaster>::fit(self, data)
49 }
50
51 pub fn forecast(&self, horizon: usize) -> Result<Vec<f64>> {
54 <Self as Forecaster>::forecast(self, horizon)
55 }
56
57 pub fn evaluate(&self, test_data: &TimeSeriesData) -> Result<ModelEvaluation> {
60 <Self as Forecaster>::evaluate(self, test_data)
61 }
62
63 pub fn predict(&self, horizon: usize, test_data: Option<&TimeSeriesData>) -> Result<ModelOutput> {
66 <Self as Forecaster>::predict(self, horizon, test_data)
67 }
68
69 pub fn fitted_values(&self) -> Option<&Vec<f64>> {
71 self.fitted_values.as_ref()
72 }
73}
74
75impl Forecaster for MAModel {
76 fn name(&self) -> &str {
77 &self.name
78 }
79
80 fn fit(&mut self, data: &TimeSeriesData) -> Result<()> {
81 if data.is_empty() {
82 return Err(OxiError::from(MAError::EmptyData));
83 }
84
85 let n = data.values.len();
86
87 if n < self.window_size {
88 return Err(OxiError::from(MAError::TimeSeriesTooShort {
89 actual: n,
90 expected: self.window_size,
91 }));
92 }
93
94 let ma_values = moving_average(&data.values, self.window_size);
96 let mut fitted_values = vec![f64::NAN; self.window_size - 1];
97 fitted_values.extend(ma_values);
98
99 self.last_values = Some(data.values[n - self.window_size..].to_vec());
101 self.fitted_values = Some(fitted_values);
102
103 Ok(())
104 }
105
106 fn forecast(&self, horizon: usize) -> Result<Vec<f64>> {
107 if let Some(last_values) = &self.last_values {
108 if horizon == 0 {
109 return Err(OxiError::from(MAError::InvalidHorizon(horizon)));
110 }
111
112 let forecast_value = last_values.iter().sum::<f64>() / last_values.len() as f64;
114
115 Ok(vec![forecast_value; horizon])
117 } else {
118 Err(OxiError::from(MAError::NotFitted))
119 }
120 }
121
122 fn evaluate(&self, test_data: &TimeSeriesData) -> Result<ModelEvaluation> {
123 if self.last_values.is_none() {
124 return Err(OxiError::from(MAError::NotFitted));
125 }
126
127 let forecast = self.forecast(test_data.values.len())?;
128
129 let mae = mae(&test_data.values, &forecast);
131 let mse = mse(&test_data.values, &forecast);
132 let rmse = rmse(&test_data.values, &forecast);
133 let mape = mape(&test_data.values, &forecast);
134 let smape = smape(&test_data.values, &forecast);
135
136 Ok(ModelEvaluation {
137 model_name: self.name.clone(),
138 mae,
139 mse,
140 rmse,
141 mape,
142 smape,
143 })
144 }
145
146 }
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152 use chrono::{DateTime, Utc, TimeZone};
153
154 #[test]
155 fn test_ma_forecast() {
156 let now = Utc::now();
158 let timestamps: Vec<DateTime<Utc>> = (0..10)
159 .map(|i| Utc.timestamp_opt(now.timestamp() + i * 86400, 0).unwrap())
160 .collect();
161
162 let values: Vec<f64> = (1..=10).map(|i| i as f64).collect();
164
165 let time_series = TimeSeriesData::new(timestamps, values, "test_series").unwrap();
166
167 let mut model = MAModel::new(3).unwrap();
169 model.fit(&time_series).unwrap();
170
171 let forecast_horizon = 5;
173 let forecasts = model.forecast(forecast_horizon).unwrap();
174
175 assert_eq!(forecasts.len(), forecast_horizon);
177
178 let expected_forecast = 9.0;
180 for forecast in forecasts.iter() {
181 assert!((forecast - expected_forecast).abs() < 1e-6);
182 }
183
184 let output = model.predict(forecast_horizon, None).unwrap();
186
187 assert_eq!(output.model_name, model.name());
189 assert_eq!(output.forecasts.len(), forecast_horizon);
190
191 let output_with_eval = model.predict(forecast_horizon, Some(&time_series)).unwrap();
193
194 assert!(output_with_eval.evaluation.is_some());
196 let eval = output_with_eval.evaluation.unwrap();
197 assert_eq!(eval.model_name, model.name());
198 }
199
200 #[test]
201 fn test_ma_fitted_values() {
202 let now = Utc::now();
203 let timestamps: Vec<DateTime<Utc>> = (0..10)
204 .map(|i| Utc.timestamp_opt(now.timestamp() + i * 86400, 0).unwrap())
205 .collect();
206
207 let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
208
209 let time_series = TimeSeriesData::new(timestamps, values, "test_series").unwrap();
210
211 let mut model = MAModel::new(3).unwrap();
213 model.fit(&time_series).unwrap();
214
215 let fitted_values = model.fitted_values().unwrap();
217 assert!(fitted_values[0].is_nan());
218 assert!(fitted_values[1].is_nan());
219
220 assert!((fitted_values[2] - 2.0).abs() < 1e-6);
223
224 assert!((fitted_values[3] - 3.0).abs() < 1e-6);
226
227 assert!((fitted_values[9] - 9.0).abs() < 1e-6);
229 }
230}