oxidiviner_moving_average/
model.rs

1use oxidiviner_core::{Forecaster, ModelEvaluation, ModelOutput, OxiError, Result, TimeSeriesData};
2use oxidiviner_math::metrics::{mae, mse, rmse, mape, smape};
3use oxidiviner_math::transforms::moving_average;
4use crate::error::MAError;
5
6/// Moving Average model for time series forecasting.
7///
8/// This is a simple model that predicts future values
9/// as the average of the last `window_size` observations.
10pub struct MAModel {
11    /// Model name
12    name: String,
13    /// Window size (number of past observations to average)
14    window_size: usize,
15    /// Last `window_size` values from the training data
16    last_values: Option<Vec<f64>>,
17    /// Fitted values over the training period
18    fitted_values: Option<Vec<f64>>,
19}
20
21impl MAModel {
22    /// Creates a new Moving Average model.
23    ///
24    /// # Arguments
25    /// * `window_size` - Number of past observations to average
26    ///
27    /// # Returns
28    /// * `Result<Self>` - A new model instance if parameters are valid
29    pub fn new(window_size: usize) -> std::result::Result<Self, MAError> {
30        // Validate parameters
31        if window_size == 0 {
32            return Err(MAError::InvalidWindowSize(window_size));
33        }
34        
35        let name = format!("MA({})", window_size);
36        
37        Ok(MAModel {
38            name,
39            window_size,
40            last_values: None,
41            fitted_values: None,
42        })
43    }
44    
45    /// Fit the model to the provided time series data.
46    /// This is a convenience method that calls the trait method directly.
47    pub fn fit(&mut self, data: &TimeSeriesData) -> Result<()> {
48        <Self as Forecaster>::fit(self, data)
49    }
50    
51    /// Forecast future values.
52    /// This is a convenience method that calls the trait method directly.
53    pub fn forecast(&self, horizon: usize) -> Result<Vec<f64>> {
54        <Self as Forecaster>::forecast(self, horizon)
55    }
56    
57    /// Evaluate the model on test data.
58    /// This is a convenience method that calls the trait method directly.
59    pub fn evaluate(&self, test_data: &TimeSeriesData) -> Result<ModelEvaluation> {
60        <Self as Forecaster>::evaluate(self, test_data)
61    }
62    
63    /// Generate forecasts and evaluation in a standardized format.
64    /// This is a convenience method that calls the trait method directly.
65    pub fn predict(&self, horizon: usize, test_data: Option<&TimeSeriesData>) -> Result<ModelOutput> {
66        <Self as Forecaster>::predict(self, horizon, test_data)
67    }
68    
69    /// Get the fitted values if available.
70    pub fn fitted_values(&self) -> Option<&Vec<f64>> {
71        self.fitted_values.as_ref()
72    }
73}
74
75impl Forecaster for MAModel {
76    fn name(&self) -> &str {
77        &self.name
78    }
79    
80    fn fit(&mut self, data: &TimeSeriesData) -> Result<()> {
81        if data.is_empty() {
82            return Err(OxiError::from(MAError::EmptyData));
83        }
84        
85        let n = data.values.len();
86        
87        if n < self.window_size {
88            return Err(OxiError::from(MAError::TimeSeriesTooShort {
89                actual: n,
90                expected: self.window_size,
91            }));
92        }
93        
94        // Calculate fitted values
95        let ma_values = moving_average(&data.values, self.window_size);
96        let mut fitted_values = vec![f64::NAN; self.window_size - 1];
97        fitted_values.extend(ma_values);
98        
99        // Store the last window_size values for forecasting
100        self.last_values = Some(data.values[n - self.window_size..].to_vec());
101        self.fitted_values = Some(fitted_values);
102        
103        Ok(())
104    }
105    
106    fn forecast(&self, horizon: usize) -> Result<Vec<f64>> {
107        if let Some(last_values) = &self.last_values {
108            if horizon == 0 {
109                return Err(OxiError::from(MAError::InvalidHorizon(horizon)));
110            }
111            
112            // For MA model, all forecasts are the same value: the average of the last window_size values
113            let forecast_value = last_values.iter().sum::<f64>() / last_values.len() as f64;
114            
115            // Return the same value for all horizons
116            Ok(vec![forecast_value; horizon])
117        } else {
118            Err(OxiError::from(MAError::NotFitted))
119        }
120    }
121    
122    fn evaluate(&self, test_data: &TimeSeriesData) -> Result<ModelEvaluation> {
123        if self.last_values.is_none() {
124            return Err(OxiError::from(MAError::NotFitted));
125        }
126        
127        let forecast = self.forecast(test_data.values.len())?;
128        
129        // Calculate error metrics
130        let mae = mae(&test_data.values, &forecast);
131        let mse = mse(&test_data.values, &forecast);
132        let rmse = rmse(&test_data.values, &forecast);
133        let mape = mape(&test_data.values, &forecast);
134        let smape = smape(&test_data.values, &forecast);
135        
136        Ok(ModelEvaluation {
137            model_name: self.name.clone(),
138            mae,
139            mse,
140            rmse,
141            mape,
142            smape,
143        })
144    }
145    
146    // Use the default predict implementation from the trait
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use chrono::{DateTime, Utc, TimeZone};
153
154    #[test]
155    fn test_ma_forecast() {
156        // Create a simple time series
157        let now = Utc::now();
158        let timestamps: Vec<DateTime<Utc>> = (0..10)
159            .map(|i| Utc.timestamp_opt(now.timestamp() + i * 86400, 0).unwrap())
160            .collect();
161        
162        // Linear trend data: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
163        let values: Vec<f64> = (1..=10).map(|i| i as f64).collect();
164        
165        let time_series = TimeSeriesData::new(timestamps, values, "test_series").unwrap();
166        
167        // Create and fit the MA model with window size 3
168        let mut model = MAModel::new(3).unwrap();
169        model.fit(&time_series).unwrap();
170        
171        // Forecast 5 periods ahead
172        let forecast_horizon = 5;
173        let forecasts = model.forecast(forecast_horizon).unwrap();
174        
175        // Check that the number of forecasts matches the horizon
176        assert_eq!(forecasts.len(), forecast_horizon);
177        
178        // All forecasts should be the average of the last 3 values: (8+9+10)/3 = 9
179        let expected_forecast = 9.0;
180        for forecast in forecasts.iter() {
181            assert!((forecast - expected_forecast).abs() < 1e-6);
182        }
183        
184        // Test the standardized ModelOutput from predict()
185        let output = model.predict(forecast_horizon, None).unwrap();
186        
187        // Check basic properties of the ModelOutput
188        assert_eq!(output.model_name, model.name());
189        assert_eq!(output.forecasts.len(), forecast_horizon);
190        
191        // Test with evaluation
192        let output_with_eval = model.predict(forecast_horizon, Some(&time_series)).unwrap();
193        
194        // Should have evaluation metrics
195        assert!(output_with_eval.evaluation.is_some());
196        let eval = output_with_eval.evaluation.unwrap();
197        assert_eq!(eval.model_name, model.name());
198    }
199
200    #[test]
201    fn test_ma_fitted_values() {
202        let now = Utc::now();
203        let timestamps: Vec<DateTime<Utc>> = (0..10)
204            .map(|i| Utc.timestamp_opt(now.timestamp() + i * 86400, 0).unwrap())
205            .collect();
206        
207        let values = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
208        
209        let time_series = TimeSeriesData::new(timestamps, values, "test_series").unwrap();
210        
211        // MA model with window size 3
212        let mut model = MAModel::new(3).unwrap();
213        model.fit(&time_series).unwrap();
214        
215        // First 2 fitted values should be NaN (window_size - 1)
216        let fitted_values = model.fitted_values().unwrap();
217        assert!(fitted_values[0].is_nan());
218        assert!(fitted_values[1].is_nan());
219        
220        // Check subsequent values
221        // MA(3) at index 2 = (1+2+3)/3 = 2
222        assert!((fitted_values[2] - 2.0).abs() < 1e-6);
223        
224        // MA(3) at index 3 = (2+3+4)/3 = 3
225        assert!((fitted_values[3] - 3.0).abs() < 1e-6);
226        
227        // MA(3) at index 9 = (8+9+10)/3 = 9
228        assert!((fitted_values[9] - 9.0).abs() < 1e-6);
229    }
230}