avila_telemetry/models/
arima.rs1use crate::forecasting::{ForecastResult, Forecaster};
4use crate::{Result, TelemetryError, TimeSeries};
5
6#[derive(Debug, Clone)]
8pub struct ARIMAParams {
9 pub p: usize,
11 pub d: usize,
13 pub q: usize,
15}
16
17pub struct ARIMA {
19 params: ARIMAParams,
20 ar_coeffs: Vec<f64>,
21 ma_coeffs: Vec<f64>,
22 fitted: bool,
23 differenced_data: Vec<f64>,
24}
25
26impl ARIMA {
27 pub fn new(p: usize, d: usize, q: usize) -> Self {
29 Self {
30 params: ARIMAParams { p, d, q },
31 ar_coeffs: vec![0.0; p],
32 ma_coeffs: vec![0.0; q],
33 fitted: false,
34 differenced_data: Vec::new(),
35 }
36 }
37
38 fn difference(&self, data: &[f64]) -> Vec<f64> {
40 let mut result = data.to_vec();
41
42 for _ in 0..self.params.d {
43 result = result.windows(2).map(|w| w[1] - w[0]).collect();
44 }
45
46 result
47 }
48
49 fn _integrate(&self, differenced: &[f64], original_data: &[f64]) -> Vec<f64> {
51 let mut result = differenced.to_vec();
52
53 for _ in 0..self.params.d {
54 let mut integrated = Vec::with_capacity(result.len() + 1);
55 let start_value = if self.params.d == 1 {
56 *original_data.last().unwrap()
57 } else {
58 0.0
59 };
60 integrated.push(start_value);
61
62 for &val in &result {
63 integrated.push(integrated.last().unwrap() + val);
64 }
65
66 result = integrated;
67 }
68
69 result
70 }
71
72 fn estimate_ar_coeffs(&mut self, data: &[f64]) -> Result<()> {
74 if data.len() <= self.params.p {
75 return Err(TelemetryError::InsufficientData(
76 "Not enough data for AR estimation".to_string(),
77 ));
78 }
79
80 for i in 0..self.params.p {
82 let lag = i + 1;
83 let mut sum_xy = 0.0;
84 let mut sum_x2 = 0.0;
85
86 for j in lag..data.len() {
87 sum_xy += data[j] * data[j - lag];
88 sum_x2 += data[j - lag] * data[j - lag];
89 }
90
91 self.ar_coeffs[i] = if sum_x2 != 0.0 { sum_xy / sum_x2 } else { 0.0 };
92 }
93
94 Ok(())
95 }
96
97 fn estimate_ma_coeffs(&mut self, _residuals: &[f64]) -> Result<()> {
99 for i in 0..self.params.q {
101 self.ma_coeffs[i] = 0.1 * (i as f64 + 1.0) / (self.params.q as f64);
102 }
103 Ok(())
104 }
105}
106
107impl Forecaster for ARIMA {
108 fn fit(&mut self, ts: &TimeSeries) -> Result<()> {
109 if ts.len() < self.params.p + self.params.d + self.params.q + 1 {
110 return Err(TelemetryError::InsufficientData(
111 "Insufficient data for ARIMA model".to_string(),
112 ));
113 }
114
115 self.differenced_data = self.difference(&ts.values);
117
118 let differenced_clone = self.differenced_data.clone();
120
121 self.estimate_ar_coeffs(&differenced_clone)?;
123
124 let residuals = self.differenced_data.clone();
126
127 self.estimate_ma_coeffs(&residuals)?;
129
130 self.fitted = true;
131 Ok(())
132 }
133
134 fn forecast(&self, steps: usize) -> Result<ForecastResult> {
135 if !self.fitted {
136 return Err(TelemetryError::ModelError(
137 "Model must be fitted before forecasting".to_string(),
138 ));
139 }
140
141 let mut predictions = Vec::with_capacity(steps);
142 let mut history = self.differenced_data.clone();
143
144 for _ in 0..steps {
145 let mut pred = 0.0;
146
147 for (i, &coeff) in self.ar_coeffs.iter().enumerate() {
149 if i < history.len() {
150 pred += coeff * history[history.len() - 1 - i];
151 }
152 }
153
154 predictions.push(pred);
155 history.push(pred);
156 }
157
158 Ok(ForecastResult {
160 predictions,
161 lower_bound: None,
162 upper_bound: None,
163 confidence: 0.75,
164 })
165 }
166
167 fn forecast_with_confidence(
168 &self,
169 steps: usize,
170 confidence_level: f64,
171 ) -> Result<ForecastResult> {
172 let base_forecast = self.forecast(steps)?;
173
174 let std_dev = 1.0; let z_score = if confidence_level >= 0.95 {
177 1.96
178 } else {
179 1.645
180 };
181 let margin = z_score * std_dev;
182
183 let lower_bound = base_forecast
184 .predictions
185 .iter()
186 .map(|&v| v - margin)
187 .collect();
188
189 let upper_bound = base_forecast
190 .predictions
191 .iter()
192 .map(|&v| v + margin)
193 .collect();
194
195 Ok(ForecastResult {
196 predictions: base_forecast.predictions,
197 lower_bound: Some(lower_bound),
198 upper_bound: Some(upper_bound),
199 confidence: confidence_level,
200 })
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 #[test]
209 fn test_arima_creation() {
210 let model = ARIMA::new(1, 1, 1);
211 assert_eq!(model.params.p, 1);
212 assert_eq!(model.params.d, 1);
213 assert_eq!(model.params.q, 1);
214 }
215
216 #[test]
217 fn test_arima_fit_and_forecast() {
218 let data: Vec<f64> = (1..=50).map(|x| x as f64).collect();
219 let ts = TimeSeries::new(data);
220
221 let mut model = ARIMA::new(2, 1, 1);
222 model.fit(&ts).unwrap();
223
224 let forecast = model.forecast(5).unwrap();
225 assert_eq!(forecast.predictions.len(), 5);
226 }
227}