use crate::learner::StreamingLearner;
use irithyll_core::error::ConfigError;
#[derive(Debug, Clone)]
pub struct SNARIMAXConfig {
pub p: usize,
pub q: usize,
pub seasonal_period: usize,
pub sp: usize,
pub sq: usize,
pub n_exogenous: usize,
pub learning_rate: f64,
}
impl SNARIMAXConfig {
pub fn builder() -> SNARIMAXConfigBuilder {
SNARIMAXConfigBuilder {
p: 1,
q: 0,
seasonal_period: 0,
sp: 0,
sq: 0,
n_exogenous: 0,
learning_rate: 0.01,
}
}
}
#[derive(Debug, Clone)]
pub struct SNARIMAXConfigBuilder {
p: usize,
q: usize,
seasonal_period: usize,
sp: usize,
sq: usize,
n_exogenous: usize,
learning_rate: f64,
}
impl SNARIMAXConfigBuilder {
pub fn p(mut self, p: usize) -> Self {
self.p = p;
self
}
pub fn q(mut self, q: usize) -> Self {
self.q = q;
self
}
pub fn seasonal_period(mut self, s: usize) -> Self {
self.seasonal_period = s;
self
}
pub fn sp(mut self, sp: usize) -> Self {
self.sp = sp;
self
}
pub fn sq(mut self, sq: usize) -> Self {
self.sq = sq;
self
}
pub fn n_exogenous(mut self, n: usize) -> Self {
self.n_exogenous = n;
self
}
pub fn learning_rate(mut self, lr: f64) -> Self {
self.learning_rate = lr;
self
}
pub fn build(self) -> Result<SNARIMAXConfig, ConfigError> {
if self.learning_rate <= 0.0 {
return Err(ConfigError::out_of_range(
"learning_rate",
"must be > 0",
self.learning_rate,
));
}
if (self.sp > 0 || self.sq > 0) && self.seasonal_period == 0 {
return Err(ConfigError::invalid(
"seasonal_period",
"must be > 0 when sp or sq is non-zero",
));
}
if self.p + self.q + self.sp + self.sq + self.n_exogenous == 0 {
return Err(ConfigError::invalid(
"p+q+sp+sq+n_exogenous",
"at least one lag order must be > 0",
));
}
Ok(SNARIMAXConfig {
p: self.p,
q: self.q,
seasonal_period: self.seasonal_period,
sp: self.sp,
sq: self.sq,
n_exogenous: self.n_exogenous,
learning_rate: self.learning_rate,
})
}
}
#[derive(Debug, Clone)]
pub struct SNARIMAXCoefficients {
pub intercept: f64,
pub ar: Vec<f64>,
pub ma: Vec<f64>,
pub seasonal_ar: Vec<f64>,
pub seasonal_ma: Vec<f64>,
pub exogenous: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct SNARIMAX {
config: SNARIMAXConfig,
intercept: f64,
ar_coeffs: Vec<f64>,
ma_coeffs: Vec<f64>,
sar_coeffs: Vec<f64>,
sma_coeffs: Vec<f64>,
exo_coeffs: Vec<f64>,
y_buffer: Vec<f64>,
e_buffer: Vec<f64>,
buffer_pos: usize,
n_samples: u64,
}
impl SNARIMAX {
pub fn new(config: SNARIMAXConfig) -> Self {
let buf_cap = Self::compute_buffer_capacity(&config);
Self {
ar_coeffs: vec![0.0; config.p],
ma_coeffs: vec![0.0; config.q],
sar_coeffs: vec![0.0; config.sp],
sma_coeffs: vec![0.0; config.sq],
exo_coeffs: vec![0.0; config.n_exogenous],
intercept: 0.0,
y_buffer: vec![0.0; buf_cap],
e_buffer: vec![0.0; buf_cap],
buffer_pos: 0,
n_samples: 0,
config,
}
}
pub fn train_one(&mut self, y: f64, exogenous: &[f64]) {
self.train_impl(y, exogenous);
}
pub fn predict_one(&self, exogenous: &[f64]) -> f64 {
self.predict_from_buffers(exogenous)
}
pub fn forecast(&self, horizon: usize) -> Vec<f64> {
let mut results = Vec::with_capacity(horizon);
let mut y_buf = self.y_buffer.clone();
let mut e_buf = self.e_buffer.clone();
let mut pos = self.buffer_pos;
let cap = y_buf.len();
if cap == 0 {
let pred = self.intercept;
return vec![pred; horizon];
}
for _ in 0..horizon {
let mut y_hat = self.intercept;
for i in 0..self.config.p {
let lag = i + 1;
let idx = (pos + cap - lag) % cap;
y_hat += self.ar_coeffs[i] * y_buf[idx];
}
for j in 0..self.config.q {
let lag = j + 1;
let idx = (pos + cap - lag) % cap;
y_hat += self.ma_coeffs[j] * e_buf[idx];
}
if self.config.seasonal_period > 0 {
for k in 0..self.config.sp {
let lag = (k + 1) * self.config.seasonal_period;
if lag <= cap {
let idx = (pos + cap - lag) % cap;
y_hat += self.sar_coeffs[k] * y_buf[idx];
}
}
for l in 0..self.config.sq {
let lag = (l + 1) * self.config.seasonal_period;
if lag <= cap {
let idx = (pos + cap - lag) % cap;
y_hat += self.sma_coeffs[l] * e_buf[idx];
}
}
}
y_buf[pos] = y_hat;
e_buf[pos] = 0.0;
pos = (pos + 1) % cap;
results.push(y_hat);
}
results
}
#[inline]
pub fn n_samples_seen(&self) -> u64 {
self.n_samples
}
pub fn coefficients(&self) -> SNARIMAXCoefficients {
SNARIMAXCoefficients {
intercept: self.intercept,
ar: self.ar_coeffs.clone(),
ma: self.ma_coeffs.clone(),
seasonal_ar: self.sar_coeffs.clone(),
seasonal_ma: self.sma_coeffs.clone(),
exogenous: self.exo_coeffs.clone(),
}
}
pub fn reset(&mut self) {
self.reset_impl();
}
#[inline]
pub fn config(&self) -> &SNARIMAXConfig {
&self.config
}
fn compute_buffer_capacity(config: &SNARIMAXConfig) -> usize {
let s = config.seasonal_period;
let mut cap = config.p;
if config.q > cap {
cap = config.q;
}
if s > 0 {
let sar_max = config.sp * s;
let sma_max = config.sq * s;
if sar_max > cap {
cap = sar_max;
}
if sma_max > cap {
cap = sma_max;
}
}
if cap == 0 {
1
} else {
cap
}
}
fn predict_from_buffers(&self, exogenous: &[f64]) -> f64 {
let mut y_hat = self.intercept;
for i in 0..self.config.p {
let lag = i + 1;
y_hat += self.ar_coeffs[i] * self.get_y_lag(lag);
}
for j in 0..self.config.q {
let lag = j + 1;
y_hat += self.ma_coeffs[j] * self.get_e_lag(lag);
}
if self.config.seasonal_period > 0 {
let cap = self.y_buffer.len();
for k in 0..self.config.sp {
let lag = (k + 1) * self.config.seasonal_period;
if lag <= cap {
y_hat += self.sar_coeffs[k] * self.get_y_lag(lag);
}
}
for l in 0..self.config.sq {
let lag = (l + 1) * self.config.seasonal_period;
if lag <= cap {
y_hat += self.sma_coeffs[l] * self.get_e_lag(lag);
}
}
}
for m in 0..self.config.n_exogenous {
let x = if m < exogenous.len() {
exogenous[m]
} else {
0.0
};
y_hat += self.exo_coeffs[m] * x;
}
y_hat
}
#[inline]
fn get_y_lag(&self, lag: usize) -> f64 {
let cap = self.y_buffer.len();
let idx = (self.buffer_pos + cap - lag) % cap;
self.y_buffer[idx]
}
#[inline]
fn get_e_lag(&self, lag: usize) -> f64 {
let cap = self.e_buffer.len();
let idx = (self.buffer_pos + cap - lag) % cap;
self.e_buffer[idx]
}
#[inline]
fn push_to_buffers(&mut self, y: f64, error: f64) {
let cap = self.y_buffer.len();
let pos = self.buffer_pos;
self.y_buffer[pos] = y;
self.e_buffer[pos] = error;
self.buffer_pos = (pos + 1) % cap;
}
fn train_impl(&mut self, y: f64, exogenous: &[f64]) {
let y_hat = self.predict_from_buffers(exogenous);
let raw_error = y - y_hat;
let error = raw_error.clamp(-1e6, 1e6);
let lr = self.config.learning_rate;
self.intercept += lr * error;
for i in 0..self.config.p {
let lag = i + 1;
let y_lag = self.get_y_lag(lag);
self.ar_coeffs[i] += lr * error * y_lag;
}
for j in 0..self.config.q {
let lag = j + 1;
let e_lag = self.get_e_lag(lag);
self.ma_coeffs[j] += lr * error * e_lag;
}
if self.config.seasonal_period > 0 {
for k in 0..self.config.sp {
let lag = (k + 1) * self.config.seasonal_period;
let y_lag = self.get_y_lag(lag);
self.sar_coeffs[k] += lr * error * y_lag;
}
for l in 0..self.config.sq {
let lag = (l + 1) * self.config.seasonal_period;
let e_lag = self.get_e_lag(lag);
self.sma_coeffs[l] += lr * error * e_lag;
}
}
for m in 0..self.config.n_exogenous {
let x = if m < exogenous.len() {
exogenous[m]
} else {
0.0
};
self.exo_coeffs[m] += lr * error * x;
}
self.push_to_buffers(y, error);
self.n_samples += 1;
}
fn reset_impl(&mut self) {
self.intercept = 0.0;
self.ar_coeffs.iter_mut().for_each(|c| *c = 0.0);
self.ma_coeffs.iter_mut().for_each(|c| *c = 0.0);
self.sar_coeffs.iter_mut().for_each(|c| *c = 0.0);
self.sma_coeffs.iter_mut().for_each(|c| *c = 0.0);
self.exo_coeffs.iter_mut().for_each(|c| *c = 0.0);
self.y_buffer.iter_mut().for_each(|v| *v = 0.0);
self.e_buffer.iter_mut().for_each(|v| *v = 0.0);
self.buffer_pos = 0;
self.n_samples = 0;
}
}
impl StreamingLearner for SNARIMAX {
fn train_one(&mut self, features: &[f64], target: f64, _weight: f64) {
self.train_impl(target, features);
}
#[inline]
fn predict(&self, features: &[f64]) -> f64 {
self.predict_one(features)
}
#[inline]
fn n_samples_seen(&self) -> u64 {
self.n_samples
}
fn reset(&mut self) {
self.reset_impl();
}
}
impl crate::automl::DiagnosticSource for SNARIMAX {
fn config_diagnostics(&self) -> Option<crate::automl::ConfigDiagnostics> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 { 1 } else { seed },
}
}
fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn next_f64(&mut self, amplitude: f64) -> f64 {
let raw = self.next_u64();
let unit = (raw as f64) / (u64::MAX as f64);
(unit * 2.0 - 1.0) * amplitude
}
}
#[test]
fn config_builder_defaults() {
let config = SNARIMAXConfig::builder().build().unwrap();
assert_eq!(config.p, 1, "default p should be 1");
assert_eq!(config.q, 0, "default q should be 0");
assert_eq!(
config.seasonal_period, 0,
"default seasonal_period should be 0"
);
assert_eq!(config.sp, 0, "default sp should be 0");
assert_eq!(config.sq, 0, "default sq should be 0");
assert_eq!(config.n_exogenous, 0, "default n_exogenous should be 0");
assert!(
(config.learning_rate - 0.01).abs() < 1e-12,
"default learning_rate should be 0.01, got {}",
config.learning_rate,
);
}
#[test]
fn simple_ar1_converges() {
let true_phi = 0.8;
let mut rng = Xorshift64::new(42);
let config = SNARIMAXConfig::builder()
.p(1)
.q(0)
.learning_rate(0.05)
.build()
.expect("valid config");
let mut model = SNARIMAX::new(config);
let mut y_prev = 0.0;
for _ in 0..10_000 {
let noise = rng.next_f64(0.1);
let y = true_phi * y_prev + noise;
model.train_one(y, &[]);
y_prev = y;
}
let coeffs = model.coefficients();
let learned_phi = coeffs.ar[0];
assert!(
learned_phi > 0.3 && learned_phi.is_finite(),
"AR(1) coefficient should converge toward {}, got {}",
true_phi,
learned_phi,
);
}
#[test]
fn predict_one_uses_lags() {
let config = SNARIMAXConfig::builder()
.p(2)
.q(0)
.learning_rate(0.01)
.build()
.expect("valid config");
let mut model = SNARIMAX::new(config);
for i in 0..100 {
let y = (i as f64) * 0.5;
model.train_one(y, &[]);
}
let pred = model.predict_one(&[]);
assert!(
pred.is_finite(),
"prediction should be finite, got {}",
pred
);
assert!(
pred.abs() > 1e-6,
"prediction should be non-zero after training, got {}",
pred,
);
let pred2 = model.predict_one(&[]);
assert!(
(pred - pred2).abs() < 1e-12,
"consecutive predict_one calls should return the same value: {} vs {}",
pred,
pred2,
);
}
#[test]
fn forecast_multi_step() {
let config = SNARIMAXConfig::builder()
.p(2)
.q(1)
.learning_rate(0.001)
.build()
.expect("valid config");
let mut model = SNARIMAX::new(config);
for i in 0..200 {
model.train_one(i as f64, &[]);
}
let horizon = 10;
let forecast = model.forecast(horizon);
assert_eq!(
forecast.len(),
horizon,
"forecast should return exactly {} values, got {}",
horizon,
forecast.len(),
);
for (i, &val) in forecast.iter().enumerate() {
assert!(
val.is_finite(),
"forecast[{}] should be finite, got {}",
i,
val,
);
}
let n_before = model.n_samples_seen();
let _ = model.forecast(5);
assert_eq!(
model.n_samples_seen(),
n_before,
"forecast should not change n_samples_seen",
);
}
#[test]
fn seasonal_component_works() {
let seasonal_pattern = [10.0, 20.0, 30.0, 40.0];
let period = seasonal_pattern.len();
let mut rng = Xorshift64::new(123);
let config = SNARIMAXConfig::builder()
.p(0)
.q(0)
.seasonal_period(period)
.sp(1)
.sq(0)
.learning_rate(0.001)
.build()
.expect("valid config");
let mut model = SNARIMAX::new(config);
for _cycle in 0..2000 {
for sp_val in seasonal_pattern.iter().take(period) {
let noise = rng.next_f64(0.5);
let y = sp_val + noise;
model.train_one(y, &[]);
}
}
let coeffs = model.coefficients();
assert!(
coeffs.seasonal_ar[0].abs() > 0.01,
"seasonal AR coefficient should be non-zero, got {}",
coeffs.seasonal_ar[0],
);
}
#[test]
fn exogenous_input() {
let mut rng = Xorshift64::new(999);
let config = SNARIMAXConfig::builder()
.p(0)
.q(0)
.n_exogenous(1)
.learning_rate(0.001)
.build()
.expect("valid config");
let mut model = SNARIMAX::new(config);
for _ in 0..5000 {
let x = rng.next_f64(5.0);
let noise = rng.next_f64(0.1);
let y = 3.0 * x + noise;
model.train_one(y, &[x]);
}
let coeffs = model.coefficients();
assert!(
coeffs.exogenous[0].abs() > 0.1,
"exogenous coefficient should be non-zero, got {}",
coeffs.exogenous[0],
);
assert!(
(coeffs.exogenous[0] - 3.0).abs() < 1.0,
"exogenous coefficient should converge toward 3.0, got {}",
coeffs.exogenous[0],
);
}
#[test]
fn streaming_learner_trait() {
let config = SNARIMAXConfig::builder()
.p(1)
.n_exogenous(2)
.learning_rate(0.01)
.build()
.expect("valid config");
let model = SNARIMAX::new(config);
let mut boxed: Box<dyn StreamingLearner> = Box::new(model);
boxed.train(&[1.0, 2.0], 5.0);
assert_eq!(boxed.n_samples_seen(), 1);
boxed.train(&[3.0, 4.0], 10.0);
assert_eq!(boxed.n_samples_seen(), 2);
let pred = boxed.predict(&[1.0, 2.0]);
assert!(
pred.is_finite(),
"trait prediction should be finite, got {}",
pred
);
boxed.reset();
assert_eq!(boxed.n_samples_seen(), 0);
}
#[test]
fn reset_clears_state() {
let config = SNARIMAXConfig::builder()
.p(2)
.q(1)
.n_exogenous(1)
.learning_rate(0.01)
.build()
.expect("valid config");
let mut model = SNARIMAX::new(config);
for i in 0..100 {
model.train_one(i as f64, &[i as f64 * 0.5]);
}
assert_eq!(model.n_samples_seen(), 100);
let coeffs_before = model.coefficients();
let has_nonzero = coeffs_before.ar.iter().any(|c| c.abs() > 1e-15)
|| coeffs_before.intercept.abs() > 1e-15
|| coeffs_before.exogenous.iter().any(|c| c.abs() > 1e-15);
assert!(
has_nonzero,
"coefficients should be non-zero after training"
);
model.reset();
assert_eq!(
model.n_samples_seen(),
0,
"n_samples should be 0 after reset"
);
let coeffs_after = model.coefficients();
assert!(
coeffs_after.intercept.abs() < 1e-12,
"intercept should be zero after reset, got {}",
coeffs_after.intercept,
);
for (i, c) in coeffs_after.ar.iter().enumerate() {
assert!(
c.abs() < 1e-12,
"ar[{}] should be zero after reset, got {}",
i,
c,
);
}
for (j, c) in coeffs_after.ma.iter().enumerate() {
assert!(
c.abs() < 1e-12,
"ma[{}] should be zero after reset, got {}",
j,
c,
);
}
for (m, c) in coeffs_after.exogenous.iter().enumerate() {
assert!(
c.abs() < 1e-12,
"exo[{}] should be zero after reset, got {}",
m,
c,
);
}
let pred = model.predict_one(&[0.0]);
assert!(
pred.abs() < 1e-12,
"prediction after reset should be zero, got {}",
pred,
);
}
#[test]
fn predict_reads_current_input() {
let config = SNARIMAXConfig::builder()
.p(2)
.q(1)
.learning_rate(0.01)
.build()
.unwrap();
let mut model = SNARIMAX::new(config);
for i in 0..50 {
model.train_one(i as f64 * 0.5, &[]);
}
let pred_before = model.predict_one(&[]);
model.train_one(30.0, &[]); let pred_after = model.predict_one(&[]);
assert!(
pred_before.is_finite(),
"pre-train predict should be finite, got {pred_before}"
);
assert!(
pred_after.is_finite(),
"post-train predict should be finite, got {pred_after}"
);
assert_ne!(
pred_before.to_bits(),
pred_after.to_bits(),
"SNARIMAX predict must reflect current buffer state: {pred_before} == {pred_after}"
);
}
}