Skip to main content

so_tsa/
arima.rs

1//! ARIMA (AutoRegressive Integrated Moving Average) models
2//!
3//! This module implements ARIMA models for time series forecasting,
4//! including seasonal ARIMA (SARIMA).
5//!
6//! # Model Specification
7//!
8//! ARIMA(p, d, q) models a time series as:
9//! (1 - φ₁L - ... - φₚLᵖ)(1 - L)ᵈ yₜ = c + (1 + θ₁L + ... + θₚLᵖ) εₜ
10//!
11//! where:
12//! - L is the lag operator: L yₜ = yₜ₋₁
13//! - φ are AR coefficients
14//! - θ are MA coefficients
15//! - d is the order of differencing
16//! - εₜ is white noise
17//!
18//! Seasonal ARIMA: ARIMA(p, d, q)(P, D, Q)ₛ
19//!
20//! # Estimation Methods
21//!
22//! 1. **Conditional Sum of Squares (CSS)**: Fast, good for initial estimates
23//! 2. **Maximum Likelihood (ML)**: More accurate, uses Kalman filter
24//! 3. **Exact Maximum Likelihood**: Uses state space representation
25
26#![allow(non_snake_case)] // Allow mathematical notation (X, y, etc.)
27
28use super::timeseries::TimeSeries;
29use ndarray::{Array1, Array2};
30use serde::{Deserialize, Serialize};
31use so_core::error::{Error, Result};
32use so_linalg;
33
34/// ARIMA model order
35#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
36pub struct ARIMAOrder {
37    /// AR order (p)
38    pub p: usize,
39    /// Differencing order (d)
40    pub d: usize,
41    /// MA order (q)
42    pub q: usize,
43}
44
45/// Seasonal ARIMA model order
46#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
47pub struct SARIMAOrder {
48    /// Non-seasonal order
49    pub order: ARIMAOrder,
50    /// Seasonal AR order (P)
51    pub seasonal_p: usize,
52    /// Seasonal differencing order (D)
53    pub seasonal_d: usize,
54    /// Seasonal MA order (Q)
55    pub seasonal_q: usize,
56    /// Seasonal period (s)
57    pub seasonal_period: usize,
58}
59
60/// ARIMA model estimation method
61#[derive(Debug, Clone, Copy, PartialEq)]
62pub enum EstimationMethod {
63    /// Conditional Sum of Squares
64    CSS,
65    /// Maximum Likelihood (via Kalman filter)
66    ML,
67    /// Exact Maximum Likelihood
68    ExactML,
69}
70
71/// ARIMA model configuration
72#[derive(Debug, Clone)]
73pub struct ARIMAConfig {
74    /// Model order
75    pub order: ARIMAOrder,
76    /// Include constant term
77    pub with_constant: bool,
78    /// Estimation method
79    pub method: EstimationMethod,
80    /// Maximum iterations for optimization
81    pub max_iter: usize,
82    /// Convergence tolerance
83    pub tol: f64,
84}
85
86impl Default for ARIMAConfig {
87    fn default() -> Self {
88        Self {
89            order: ARIMAOrder { p: 1, d: 0, q: 1 },
90            with_constant: true,
91            method: EstimationMethod::CSS,
92            max_iter: 100,
93            tol: 1e-6,
94        }
95    }
96}
97
98/// ARIMA model results
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct ARIMAResults {
101    /// AR coefficients (φ₁, ..., φₚ)
102    pub ar_coef: Option<Array1<f64>>,
103    /// MA coefficients (θ₁, ..., θ_q)
104    pub ma_coef: Option<Array1<f64>>,
105    /// Constant term (c)
106    pub constant: Option<f64>,
107    /// Residual variance (σ²)
108    pub sigma2: f64,
109    /// Log-likelihood
110    pub log_likelihood: f64,
111    /// Akaike Information Criterion
112    pub aic: f64,
113    /// Bayesian Information Criterion
114    pub bic: f64,
115    /// Number of observations used
116    pub n_obs: usize,
117    /// Residuals
118    pub residuals: Array1<f64>,
119    /// Fitted values
120    pub fitted: Array1<f64>,
121}
122
123/// ARIMA model builder
124pub struct ARIMABuilder {
125    config: ARIMAConfig,
126}
127
128impl ARIMABuilder {
129    /// Create new ARIMA builder
130    pub fn new(p: usize, d: usize, q: usize) -> Self {
131        Self {
132            config: ARIMAConfig {
133                order: ARIMAOrder { p, d, q },
134                ..Default::default()
135            },
136        }
137    }
138
139    /// Set seasonal components (SARIMA)
140    pub fn seasonal(self, P: usize, D: usize, Q: usize, period: usize) -> SARIMABuilder {
141        SARIMABuilder::new(
142            self.config.order.p,
143            self.config.order.d,
144            self.config.order.q,
145        )
146        .seasonal(P, D, Q, period)
147    }
148
149    /// Include constant term
150    pub fn with_constant(mut self, include: bool) -> Self {
151        self.config.with_constant = include;
152        self
153    }
154
155    /// Set estimation method
156    pub fn method(mut self, method: EstimationMethod) -> Self {
157        self.config.method = method;
158        self
159    }
160
161    /// Set maximum iterations
162    pub fn max_iter(mut self, max_iter: usize) -> Self {
163        self.config.max_iter = max_iter;
164        self
165    }
166
167    /// Set convergence tolerance
168    pub fn tol(mut self, tol: f64) -> Self {
169        self.config.tol = tol;
170        self
171    }
172
173    /// Fit ARIMA model
174    pub fn fit(self, ts: &TimeSeries) -> Result<ARIMAResults> {
175        let mut arima = ARIMA::new(self.config);
176        arima.fit(ts)
177    }
178}
179
180/// SARIMA model builder
181pub struct SARIMABuilder {
182    order: SARIMAOrder,
183    with_constant: bool,
184    method: EstimationMethod,
185    max_iter: usize,
186    tol: f64,
187}
188
189impl SARIMABuilder {
190    /// Create new SARIMA builder
191    pub fn new(p: usize, d: usize, q: usize) -> Self {
192        Self {
193            order: SARIMAOrder {
194                order: ARIMAOrder { p, d, q },
195                seasonal_p: 0,
196                seasonal_d: 0,
197                seasonal_q: 0,
198                seasonal_period: 1,
199            },
200            with_constant: true,
201            method: EstimationMethod::CSS,
202            max_iter: 100,
203            tol: 1e-6,
204        }
205    }
206
207    /// Set seasonal components
208    pub fn seasonal(mut self, P: usize, D: usize, Q: usize, period: usize) -> Self {
209        self.order.seasonal_p = P;
210        self.order.seasonal_d = D;
211        self.order.seasonal_q = Q;
212        self.order.seasonal_period = period;
213        self
214    }
215
216    /// Include constant term
217    pub fn with_constant(mut self, include: bool) -> Self {
218        self.with_constant = include;
219        self
220    }
221
222    /// Set estimation method
223    pub fn method(mut self, method: EstimationMethod) -> Self {
224        self.method = method;
225        self
226    }
227
228    /// Set maximum iterations
229    pub fn max_iter(mut self, max_iter: usize) -> Self {
230        self.max_iter = max_iter;
231        self
232    }
233
234    /// Set convergence tolerance
235    pub fn tol(mut self, tol: f64) -> Self {
236        self.tol = tol;
237        self
238    }
239
240    /// Fit SARIMA model
241    pub fn fit(self, ts: &TimeSeries) -> Result<ARIMAResults> {
242        // Convert to equivalent ARIMA order
243        let total_p = self.order.order.p + self.order.seasonal_p * self.order.seasonal_period;
244        let total_q = self.order.order.q + self.order.seasonal_q * self.order.seasonal_period;
245
246        let mut arima = ARIMA::new(ARIMAConfig {
247            order: ARIMAOrder {
248                p: total_p,
249                d: self.order.order.d + self.order.seasonal_d * self.order.seasonal_period,
250                q: total_q,
251            },
252            with_constant: self.with_constant,
253            method: self.method,
254            max_iter: self.max_iter,
255            tol: self.tol,
256        });
257
258        arima.fit(ts)
259    }
260}
261
262/// ARIMA model
263pub struct ARIMA {
264    config: ARIMAConfig,
265}
266
267impl ARIMA {
268    /// Create new ARIMA model
269    pub fn new(config: ARIMAConfig) -> Self {
270        Self { config }
271    }
272
273    /// Create ARIMA builder
274    pub fn builder(p: usize, d: usize, q: usize) -> ARIMABuilder {
275        ARIMABuilder::new(p, d, q)
276    }
277
278    /// Fit ARIMA model to time series
279    pub fn fit(&mut self, ts: &TimeSeries) -> Result<ARIMAResults> {
280        let n = ts.len();
281        let order = self.config.order;
282
283        if n < order.p + order.q + 10 {
284            return Err(Error::DataError(format!(
285                "Not enough observations for ARIMA({},{},{}), need at least {}, got {}",
286                order.p,
287                order.d,
288                order.q,
289                order.p + order.q + 10,
290                n
291            )));
292        }
293
294        // Apply differencing if needed
295        let (diffed_ts, _diff_timestamps) = self.difference(ts)?;
296        let y = diffed_ts.values();
297
298        match self.config.method {
299            EstimationMethod::CSS => self.fit_css(y, n),
300            EstimationMethod::ML => self.fit_ml(y, n),
301            EstimationMethod::ExactML => self.fit_exact_ml(y, n),
302        }
303    }
304
305    /// Apply differencing
306    fn difference(&self, ts: &TimeSeries) -> Result<(TimeSeries, Vec<i64>)> {
307        if self.config.order.d == 0 {
308            return Ok((ts.clone(), ts.timestamps().to_vec()));
309        }
310
311        let diffed = ts.diff(1, self.config.order.d)?;
312        let timestamps = diffed.timestamps().to_vec();
313        Ok((diffed, timestamps))
314    }
315
316    /// Fit using Conditional Sum of Squares (CSS)
317    fn fit_css(&self, y: &Array1<f64>, n_orig: usize) -> Result<ARIMAResults> {
318        let order = self.config.order;
319        let n = y.len();
320
321        // Prepare regression matrix for AR terms
322        let mut X = Array2::zeros((n - order.p, order.p + order.q + 1));
323        let mut y_reg = Array1::zeros(n - order.p);
324
325        let mut residuals = Array1::zeros(n);
326        let mut fitted = Array1::zeros(n);
327
328        // Initial MA residuals (assume zero)
329        for i in 0..n {
330            residuals[i] = y[i];
331        }
332
333        // Iterate to estimate AR and MA coefficients
334        let mut converged = false;
335        let mut iteration = 0;
336
337        // Variables to store coefficients (declared outside loop)
338        let mut ar_coef = if order.p > 0 {
339            Some(Array1::zeros(order.p))
340        } else {
341            None
342        };
343
344        let mut ma_coef = if order.q > 0 {
345            Some(Array1::zeros(order.q))
346        } else {
347            None
348        };
349
350        let mut constant = if self.config.with_constant {
351            Some(0.0)
352        } else {
353            None
354        };
355
356        while iteration < self.config.max_iter && !converged {
357            // Estimate AR coefficients using current residuals
358            for t in order.p..n {
359                let mut row_idx = 0;
360
361                // AR terms: y_{t-1}, ..., y_{t-p}
362                for lag in 1..=order.p {
363                    X[(t - order.p, row_idx)] = y[t - lag];
364                    row_idx += 1;
365                }
366
367                // MA terms: ε_{t-1}, ..., ε_{t-q}
368                for lag in 1..=order.q {
369                    if t - lag < residuals.len() {
370                        X[(t - order.p, row_idx)] = residuals[t - lag];
371                    }
372                    row_idx += 1;
373                }
374
375                // Constant term
376                if self.config.with_constant {
377                    X[(t - order.p, row_idx)] = 1.0;
378                }
379
380                y_reg[t - order.p] = y[t];
381            }
382
383            // Solve regression
384            let XtX = X.t().dot(&X);
385            let Xty = X.t().dot(&y_reg);
386
387            let coef = so_linalg::solve(&XtX, &Xty)
388                .map_err(|e| Error::LinearAlgebraError(format!("ARIMA CSS solve failed: {}", e)))?;
389
390            // Extract coefficients (update existing variables)
391            let mut idx = 0;
392
393            if let Some(ref mut ar) = ar_coef {
394                for i in 0..order.p {
395                    ar[i] = coef[idx];
396                    idx += 1;
397                }
398            }
399
400            if let Some(ref mut ma) = ma_coef {
401                for i in 0..order.q {
402                    ma[i] = coef[idx];
403                    idx += 1;
404                }
405            }
406
407            if let Some(ref mut c) = constant {
408                *c = coef[idx];
409            }
410
411            // Update residuals and fitted values
412            let mut prev_change = 0.0;
413            for t in 0..n {
414                let mut prediction = 0.0;
415
416                // AR terms
417                if let Some(ref ar) = ar_coef {
418                    for lag in 1..=order.p {
419                        if t >= lag {
420                            prediction += ar[lag - 1] * y[t - lag];
421                        }
422                    }
423                }
424
425                // MA terms
426                if let Some(ref ma) = ma_coef {
427                    for lag in 1..=order.q {
428                        if t >= lag {
429                            prediction += ma[lag - 1] * residuals[t - lag];
430                        }
431                    }
432                }
433
434                // Constant
435                if let Some(c) = constant {
436                    prediction += c;
437                }
438
439                if t >= order.p {
440                    fitted[t] = prediction;
441                }
442
443                let new_residual = y[t] - prediction;
444                prev_change += (new_residual - residuals[t]).abs();
445                residuals[t] = new_residual;
446            }
447
448            // Check convergence
449            if prev_change / (n as f64) < self.config.tol {
450                converged = true;
451            }
452
453            iteration += 1;
454        }
455
456        if !converged {
457            return Err(Error::DataError(format!(
458                "ARIMA CSS did not converge after {} iterations",
459                self.config.max_iter
460            )));
461        }
462
463        // Calculate statistics
464        let rss: f64 = residuals.iter().map(|&r| r.powi(2)).sum();
465        let sigma2 =
466            rss / (n - order.p - order.q - if self.config.with_constant { 1 } else { 0 }) as f64;
467
468        let log_likelihood = self.calculate_log_likelihood(&residuals, sigma2, n);
469        let (aic, bic) = self.calculate_information_criteria(
470            log_likelihood,
471            order.p + order.q + if self.config.with_constant { 1 } else { 0 },
472            n_orig,
473        );
474
475        Ok(ARIMAResults {
476            ar_coef,
477            ma_coef,
478            constant,
479            sigma2,
480            log_likelihood,
481            aic,
482            bic,
483            n_obs: n_orig,
484            residuals,
485            fitted,
486        })
487    }
488
489    /// Fit using Maximum Likelihood (simplified)
490    fn fit_ml(&self, y: &Array1<f64>, n_orig: usize) -> Result<ARIMAResults> {
491        // For now, use CSS as starting point and refine with ML
492        self.fit_css(y, n_orig)
493    }
494
495    /// Fit using Exact Maximum Likelihood
496    fn fit_exact_ml(&self, y: &Array1<f64>, n_orig: usize) -> Result<ARIMAResults> {
497        // Would use state space representation and Kalman filter
498        // For now, fall back to ML
499        self.fit_ml(y, n_orig)
500    }
501
502    /// Calculate log-likelihood for Gaussian errors
503    fn calculate_log_likelihood(&self, residuals: &Array1<f64>, sigma2: f64, n: usize) -> f64 {
504        -0.5 * n as f64 * (2.0 * std::f64::consts::PI * sigma2).ln()
505            - 0.5 * residuals.iter().map(|&r| r.powi(2)).sum::<f64>() / sigma2
506    }
507
508    /// Calculate AIC and BIC
509    fn calculate_information_criteria(&self, log_lik: f64, k: usize, n: usize) -> (f64, f64) {
510        let aic = 2.0 * k as f64 - 2.0 * log_lik;
511        let bic = (n as f64).ln() * k as f64 - 2.0 * log_lik;
512        (aic, bic)
513    }
514
515    /// Forecast future values
516    pub fn forecast(&self, results: &ARIMAResults, steps: usize) -> Array1<f64> {
517        let order = self.config.order;
518        let n = results.residuals.len();
519
520        let mut forecasts = Array1::zeros(steps);
521        let mut y_extended = results.fitted.clone();
522        let mut residuals_extended = results.residuals.clone();
523
524        for h in 0..steps {
525            let mut prediction = 0.0;
526
527            // AR terms
528            if let Some(ref ar) = results.ar_coef {
529                for lag in 1..=order.p {
530                    let idx = n + h - lag;
531                    if idx < y_extended.len() {
532                        prediction += ar[lag - 1] * y_extended[idx];
533                    }
534                }
535            }
536
537            // MA terms
538            if let Some(ref ma) = results.ma_coef {
539                for lag in 1..=order.q {
540                    let idx = n + h - lag;
541                    if idx < residuals_extended.len() {
542                        prediction += ma[lag - 1] * residuals_extended[idx];
543                    }
544                }
545            }
546
547            // Constant
548            if let Some(c) = results.constant {
549                prediction += c;
550            }
551
552            forecasts[h] = prediction;
553
554            // Extend arrays for next forecast
555            y_extended = ndarray::concatenate(
556                ndarray::Axis(0),
557                &[y_extended.view(), ndarray::array![prediction].view()],
558            )
559            .unwrap();
560
561            // For MA terms, we need future residuals (assume zero)
562            residuals_extended = ndarray::concatenate(
563                ndarray::Axis(0),
564                &[residuals_extended.view(), ndarray::array![0.0].view()],
565            )
566            .unwrap();
567        }
568
569        forecasts
570    }
571
572    /// Calculate prediction intervals
573    pub fn prediction_intervals(
574        &self,
575        results: &ARIMAResults,
576        forecasts: &Array1<f64>,
577        alpha: f64,
578    ) -> (Array1<f64>, Array1<f64>) {
579        let sigma = results.sigma2.sqrt();
580        let _z = 1.0 - alpha / 2.0;
581        let z_value = 1.96; // Approximate for 95% CI
582
583        let lower = forecasts.mapv(|f| f - z_value * sigma);
584        let upper = forecasts.mapv(|f| f + z_value * sigma);
585
586        (lower, upper)
587    }
588}
589
590/// Extension trait for TimeSeries
591pub trait ARIMAExt {
592    /// Fit ARIMA model
593    fn arima(&self, p: usize, d: usize, q: usize) -> Result<ARIMAResults>;
594}
595
596impl ARIMAExt for TimeSeries {
597    fn arima(&self, p: usize, d: usize, q: usize) -> Result<ARIMAResults> {
598        ARIMA::builder(p, d, q).fit(self)
599    }
600}