use scirs2_core::ndarray::Array1;
use crate::error::{Result, TimeSeriesError};
const MAX_ARMA_ORDER: usize = 3;
const ARMA_ITER: usize = 100;
#[derive(Debug, Clone)]
pub struct BatsConfig {
pub use_box_cox: Option<bool>,
pub use_trend: Option<bool>,
pub use_damped_trend: Option<bool>,
pub seasonal_periods: Vec<usize>,
pub ar_order: Option<usize>,
pub ma_order: Option<usize>,
}
impl Default for BatsConfig {
fn default() -> Self {
Self {
use_box_cox: None,
use_trend: None,
use_damped_trend: None,
seasonal_periods: Vec::new(),
ar_order: None,
ma_order: None,
}
}
}
#[derive(Debug, Clone)]
struct SeasonalState {
period: usize,
states: Vec<f64>,
pos: usize,
gamma: f64,
}
impl SeasonalState {
fn new(period: usize, initial_values: Vec<f64>, gamma: f64) -> Self {
debug_assert_eq!(initial_values.len(), period);
Self {
period,
states: initial_values,
pos: 0,
gamma,
}
}
fn contribution(&self) -> f64 {
self.states[self.pos]
}
fn update(&mut self, error: f64) {
let old = self.states[self.pos];
self.states[self.pos] = old + self.gamma * error;
self.pos = (self.pos + 1) % self.period;
}
fn forecast_ahead(&self, h: usize) -> f64 {
let idx = (self.pos + h - 1) % self.period;
self.states[idx]
}
}
#[derive(Debug, Clone)]
struct ArmaState {
ar: Vec<f64>,
ma: Vec<f64>,
d_buf: Vec<f64>,
eps_buf: Vec<f64>,
d_pos: usize,
eps_pos: usize,
}
impl ArmaState {
fn new(ar: Vec<f64>, ma: Vec<f64>) -> Self {
let p = ar.len().max(1);
let q = ma.len().max(1);
Self {
ar,
ma,
d_buf: vec![0.0; p],
eps_buf: vec![0.0; q],
d_pos: 0,
eps_pos: 0,
}
}
fn contribution(&self) -> f64 {
let p = self.ar.len();
let q = self.ma.len();
let buf_p = self.d_buf.len();
let buf_q = self.eps_buf.len();
let mut d = 0.0_f64;
for i in 0..p {
let idx = (self.d_pos + buf_p - 1 - i) % buf_p;
d += self.ar[i] * self.d_buf[idx];
}
for i in 0..q {
let idx = (self.eps_pos + buf_q - 1 - i) % buf_q;
d += self.ma[i] * self.eps_buf[idx];
}
d
}
fn push(&mut self, eps: f64, d_current: f64) {
self.d_buf[self.d_pos] = d_current;
self.d_pos = (self.d_pos + 1) % self.d_buf.len().max(1);
self.eps_buf[self.eps_pos] = eps;
self.eps_pos = (self.eps_pos + 1) % self.eps_buf.len().max(1);
}
fn forecast_ahead(&self, h: usize) -> Vec<f64> {
let p = self.ar.len();
let q = self.ma.len();
let hist = p.max(q) + h + 1;
let mut d_hist: Vec<f64> = Vec::with_capacity(hist);
let buf_p = self.d_buf.len();
for i in 0..buf_p {
let idx = (self.d_pos + i) % buf_p;
d_hist.push(self.d_buf[idx]);
}
let mut eps_hist: Vec<f64> = Vec::with_capacity(hist);
let buf_q = self.eps_buf.len();
for i in 0..buf_q {
let idx = (self.eps_pos + i) % buf_q;
eps_hist.push(self.eps_buf[idx]);
}
while d_hist.len() < hist {
d_hist.push(0.0);
}
while eps_hist.len() < hist {
eps_hist.push(0.0);
}
let offset = buf_p.max(buf_q);
let mut result = Vec::with_capacity(h);
for step in 0..h {
let idx = offset + step;
let mut val = 0.0;
for i in 0..p {
if idx > i {
val += self.ar[i] * d_hist[idx - 1 - i];
}
}
for i in 0..q {
if step == 0 && idx > i {
val += self.ma[i] * eps_hist[idx - 1 - i];
}
}
while d_hist.len() <= idx {
d_hist.push(0.0);
}
d_hist[idx] = val;
result.push(val);
}
result
}
}
#[derive(Debug, Clone)]
pub struct BatsModel {
lambda: Option<f64>,
alpha: f64,
beta: f64,
phi: f64,
use_trend: bool,
seasonal_states: Vec<SeasonalState>,
arma: ArmaState,
level: f64,
trend_state: f64,
fitted_vals: Vec<f64>,
residuals: Vec<f64>,
sigma: f64,
aic: f64,
n_obs: usize,
}
impl BatsModel {
pub fn fit(data: &[f64], config: BatsConfig) -> Result<Self> {
let min_required: usize = {
let max_p = config.seasonal_periods.iter().copied().max().unwrap_or(0);
(max_p * 2).max(10)
};
if data.len() < min_required {
return Err(TimeSeriesError::InsufficientData {
message: format!(
"BATS requires at least {} observations for the given configuration",
min_required
),
required: min_required,
actual: data.len(),
});
}
for &p in &config.seasonal_periods {
if p < 2 {
return Err(TimeSeriesError::InvalidParameter {
name: "seasonal_periods".to_string(),
message: "All seasonal periods must be >= 2".to_string(),
});
}
}
let n = data.len();
let lambda = determine_lambda(data, &config)?;
let working: Vec<f64> = if let Some(lam) = lambda {
data.iter().map(|&y| box_cox(y, lam)).collect()
} else {
data.to_vec()
};
let use_trend = config.use_trend.unwrap_or_else(|| {
let slope = ols_slope(&working);
let max_val = working.iter().cloned().fold(0.0_f64, |a, b| a.abs().max(b.abs()));
slope.abs() > 1e-3 * max_val.max(1e-12)
});
let phi = if use_trend {
match config.use_damped_trend {
Some(true) => 0.98,
Some(false) => 1.0,
None => 0.99, }
} else {
1.0
};
let alpha = 0.15_f64;
let beta = if use_trend { 0.05_f64 } else { 0.0 };
let seasonal_states = init_seasonal_states(&working, &config.seasonal_periods);
let (p_order, q_order) = if config.ar_order.is_some() || config.ma_order.is_some() {
(
config.ar_order.unwrap_or(0).min(MAX_ARMA_ORDER),
config.ma_order.unwrap_or(0).min(MAX_ARMA_ORDER),
)
} else {
(0usize, 0usize)
};
let (mut seasonal_states, arma, level, trend_state, fitted_tf, residuals_tf) =
forward_pass(
&working,
alpha,
beta,
phi,
use_trend,
seasonal_states,
p_order,
q_order,
);
let n_f = n as f64;
let resid_var = residuals_tf.iter().map(|&r| r * r).sum::<f64>() / n_f;
let sigma = resid_var.sqrt().max(1e-12);
let n_free_params = 1
+ if use_trend { 2 } else { 0 }
+ config.seasonal_periods.len()
+ p_order + q_order
+ if lambda.is_some() { 1 } else { 0 };
let log_lik = -0.5 * n_f * (1.0 + (2.0 * std::f64::consts::PI * resid_var).ln());
let aic = -2.0 * log_lik + 2.0 * n_free_params as f64;
let fitted_vals: Vec<f64> = if let Some(lam) = lambda {
fitted_tf.iter().map(|&w| inv_box_cox(w, lam)).collect()
} else {
fitted_tf
};
for sc in &mut seasonal_states {
let _ = sc;
}
Ok(Self {
lambda,
alpha,
beta,
phi,
use_trend,
seasonal_states,
arma,
level,
trend_state,
fitted_vals,
residuals: residuals_tf,
sigma,
aic,
n_obs: n,
})
}
pub fn forecast(&self, h: usize) -> Result<Array1<f64>> {
let arma_fcast = self.arma.forecast_ahead(h);
let mut phi_acc = 0.0_f64;
let mut forecasts = Vec::with_capacity(h);
for step in 1..=h {
phi_acc = if self.use_trend {
phi_acc * self.phi + self.phi
} else {
0.0
};
let trend_contrib = if self.use_trend {
phi_acc * self.trend_state
} else {
0.0
};
let seas_contrib: f64 = self
.seasonal_states
.iter()
.map(|sc| sc.forecast_ahead(step))
.sum();
let arma_contrib = arma_fcast.get(step - 1).copied().unwrap_or(0.0);
let yhat_tf = self.level + trend_contrib + seas_contrib + arma_contrib;
let yhat = if let Some(lam) = self.lambda {
inv_box_cox(yhat_tf, lam)
} else {
yhat_tf
};
forecasts.push(yhat);
}
Ok(Array1::from_vec(forecasts))
}
pub fn predict(&self, h: usize, alpha: f64) -> Result<(Array1<f64>, Array1<f64>, Array1<f64>)> {
if !(0.0 < alpha && alpha < 1.0) {
return Err(TimeSeriesError::InvalidParameter {
name: "alpha".to_string(),
message: "alpha must be in the open interval (0, 1)".to_string(),
});
}
let point = self.forecast(h)?;
let z = normal_quantile(1.0 - alpha / 2.0);
let mut lower = Vec::with_capacity(h);
let mut upper = Vec::with_capacity(h);
for (k, &f) in point.iter().enumerate() {
let h_var = self.sigma * self.sigma * (1.0 + (k + 1) as f64 * self.alpha * self.alpha);
let std_h = h_var.sqrt();
let (lo, hi) = if let Some(lam) = self.lambda {
let fpos: f64 = f.max(1e-10);
let center_tf = box_cox(fpos, lam);
let lo_t = center_tf - z * std_h;
let hi_t = center_tf + z * std_h;
(inv_box_cox(lo_t, lam), inv_box_cox(hi_t, lam))
} else {
(f - z * std_h, f + z * std_h)
};
lower.push(lo);
upper.push(hi);
}
Ok((point, Array1::from_vec(lower), Array1::from_vec(upper)))
}
pub fn aic(&self) -> f64 {
self.aic
}
pub fn fitted_values(&self) -> &[f64] {
&self.fitted_vals
}
pub fn lambda(&self) -> Option<f64> {
self.lambda
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn phi(&self) -> f64 {
self.phi
}
pub fn use_trend(&self) -> bool {
self.use_trend
}
pub fn sigma(&self) -> f64 {
self.sigma
}
pub fn n_obs(&self) -> usize {
self.n_obs
}
}
fn box_cox(y: f64, lambda: f64) -> f64 {
if lambda.abs() < 1e-10 {
y.max(1e-10).ln()
} else {
(y.max(1e-10).powf(lambda) - 1.0) / lambda
}
}
fn inv_box_cox(w: f64, lambda: f64) -> f64 {
if lambda.abs() < 1e-10 {
w.exp()
} else {
let base = lambda * w + 1.0;
if base <= 0.0 { 0.0 } else { base.powf(1.0 / lambda) }
}
}
fn determine_lambda(data: &[f64], config: &BatsConfig) -> Result<Option<f64>> {
match config.use_box_cox {
Some(false) => Ok(None),
Some(true) => Ok(Some(estimate_box_cox_lambda(data))),
None => {
if data.iter().all(|&v| v > 0.0) {
let lam = estimate_box_cox_lambda(data);
if (lam - 1.0).abs() > 0.1 {
Ok(Some(lam))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
}
}
fn estimate_box_cox_lambda(data: &[f64]) -> f64 {
if data.iter().any(|&v| v <= 0.0) {
return 1.0;
}
let n = data.len() as f64;
let log_y_sum: f64 = data.iter().map(|&y| y.max(1e-10).ln()).sum();
let candidates: Vec<f64> = (-20..=20).map(|i| i as f64 * 0.1).collect();
let mut best_lam = 1.0_f64;
let mut best_ll = f64::NEG_INFINITY;
for &lam in &candidates {
let transformed: Vec<f64> = data.iter().map(|&y| box_cox(y, lam)).collect();
let mean = transformed.iter().sum::<f64>() / n;
let var = transformed.iter().map(|&w| (w - mean).powi(2)).sum::<f64>() / n;
if var <= 0.0 {
continue;
}
let ll = -0.5 * n * var.ln() + (lam - 1.0) * log_y_sum;
if ll > best_ll {
best_ll = ll;
best_lam = lam;
}
}
best_lam
}
fn init_seasonal_states(data: &[f64], periods: &[usize]) -> Vec<SeasonalState> {
let n = data.len();
let global_mean = data.iter().sum::<f64>() / n as f64;
let gamma_default = 0.001;
periods
.iter()
.map(|&m| {
let mut init = vec![0.0_f64; m];
let cycles = n / m;
if cycles == 0 {
return SeasonalState::new(m, init, gamma_default);
}
let mut counts = vec![0usize; m];
for (i, &v) in data.iter().enumerate() {
let idx = i % m;
init[idx] += v - global_mean;
counts[idx] += 1;
}
for i in 0..m {
if counts[i] > 0 {
init[i] /= counts[i] as f64;
}
}
let seas_mean = init.iter().sum::<f64>() / m as f64;
for v in &mut init {
*v -= seas_mean;
}
SeasonalState::new(m, init, gamma_default)
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn forward_pass(
working: &[f64],
alpha: f64,
beta: f64,
phi: f64,
use_trend: bool,
mut seasonal_states: Vec<SeasonalState>,
p_order: usize,
q_order: usize,
) -> (Vec<SeasonalState>, ArmaState, f64, f64, Vec<f64>, Vec<f64>) {
let n = working.len();
let mut level = working.iter().take(3.min(n)).sum::<f64>() / 3.0_f64.min(n as f64);
let mut trend_state = if use_trend && n >= 2 {
(working[1] - working[0]).abs() * 0.01 } else {
0.0
};
let ar_init = vec![0.0_f64; p_order];
let ma_init = vec![0.0_f64; q_order];
let mut arma = ArmaState::new(ar_init, ma_init);
let mut fitted_tf = Vec::with_capacity(n);
let mut residuals_tf = Vec::with_capacity(n);
let mut raw_errors = Vec::with_capacity(n);
for t in 0..n {
let trend_contrib = if use_trend { phi * trend_state } else { 0.0 };
let seas_contrib: f64 = seasonal_states.iter().map(|s| s.contribution()).sum();
let arma_contrib = arma.contribution();
let yhat = level + trend_contrib + seas_contrib + arma_contrib;
fitted_tf.push(yhat);
let error = working[t] - yhat;
raw_errors.push(error);
residuals_tf.push(error);
level += trend_contrib + alpha * error;
if use_trend {
trend_state = (1.0 - phi) * 0.0 + phi * trend_state + beta * error;
}
for sc in &mut seasonal_states {
sc.update(error);
}
let d_current = arma.contribution();
arma.push(error, d_current);
}
if p_order > 0 || q_order > 0 {
let ar_coeffs = if p_order > 0 {
yule_walker(&raw_errors, p_order)
} else {
Vec::new()
};
let ma_coeffs = if q_order > 0 {
estimate_ma_coeffs(&raw_errors, q_order)
} else {
Vec::new()
};
let mut arma2 = ArmaState::new(ar_coeffs, ma_coeffs);
let mut level2 = working.iter().take(3.min(n)).sum::<f64>() / 3.0_f64.min(n as f64);
let mut trend2 = if use_trend && n >= 2 {
(working[1] - working[0]).abs() * 0.01
} else {
0.0
};
let mut seas2 = seasonal_states.clone();
let mut fitted2 = Vec::with_capacity(n);
let mut resid2 = Vec::with_capacity(n);
for t in 0..n {
let tc = if use_trend { phi * trend2 } else { 0.0 };
let sc: f64 = seas2.iter().map(|s| s.contribution()).sum();
let ac = arma2.contribution();
let yhat2 = level2 + tc + sc + ac;
fitted2.push(yhat2);
let err = working[t] - yhat2;
resid2.push(err);
level2 += tc + alpha * err;
if use_trend {
trend2 = phi * trend2 + beta * err;
}
for sc_state in &mut seas2 {
sc_state.update(err);
}
let d2 = arma2.contribution();
arma2.push(err, d2);
}
return (seas2, arma2, level2, trend2, fitted2, resid2);
}
(seasonal_states, arma, level, trend_state, fitted_tf, residuals_tf)
}
fn yule_walker(data: &[f64], p: usize) -> Vec<f64> {
let n = data.len();
if n < p + 1 || p == 0 {
return vec![0.0; p];
}
let mean = data.iter().sum::<f64>() / n as f64;
let centered: Vec<f64> = data.iter().map(|&v| v - mean).collect();
let mut r = vec![0.0_f64; p + 1];
for lag in 0..=p {
let mut s = 0.0_f64;
for t in lag..n {
s += centered[t] * centered[t - lag];
}
r[lag] = s / n as f64;
}
if r[0].abs() < 1e-14 {
return vec![0.0; p];
}
let mut mat = vec![vec![0.0_f64; p + 1]; p];
for i in 0..p {
for j in 0..p {
let lag = (i as isize - j as isize).unsigned_abs();
mat[i][j] = r[lag] / r[0];
}
mat[i][p] = r[i + 1] / r[0];
}
gaussian_elimination(&mut mat).unwrap_or_else(|_| vec![0.0; p])
}
fn estimate_ma_coeffs(residuals: &[f64], q: usize) -> Vec<f64> {
let n = residuals.len();
if n < q + 1 || q == 0 {
return vec![0.0; q];
}
let mean = residuals.iter().sum::<f64>() / n as f64;
let centered: Vec<f64> = residuals.iter().map(|&v| v - mean).collect();
let mut gamma = vec![0.0_f64; q + 1];
for lag in 0..=q {
for t in lag..n {
gamma[lag] += centered[t] * centered[t - lag];
}
gamma[lag] /= n as f64;
}
if gamma[0].abs() < 1e-14 {
return vec![0.0; q];
}
let mut theta = vec![0.0_f64; q];
for i in 0..q {
theta[i] = (gamma[i + 1] / gamma[0]).clamp(-0.99, 0.99);
}
for _ in 0..ARMA_ITER {
let mut updated = theta.clone();
for i in 0..q {
let mut model_acf = 0.0_f64;
for j in 0..q.saturating_sub(i) {
let tj = if j == 0 { 1.0 } else { theta[j - 1] };
let tjh = if j + i + 1 <= q { theta[j + i] } else { 0.0 };
model_acf += tj * tjh;
}
let target = gamma[i + 1] / gamma[0];
let grad = model_acf - target;
updated[i] -= 0.1 * grad;
updated[i] = updated[i].clamp(-0.99, 0.99);
}
theta = updated;
}
theta
}
fn gaussian_elimination(mat: &mut Vec<Vec<f64>>) -> Result<Vec<f64>> {
let m = mat.len();
if m == 0 {
return Ok(Vec::new());
}
for col in 0..m {
let mut max_row = col;
let mut max_val = mat[col][col].abs();
for row in col + 1..m {
let v = mat[row][col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_row != col {
mat.swap(col, max_row);
}
let pivot = mat[col][col];
if pivot.abs() < 1e-14 {
return Err(TimeSeriesError::NumericalInstability(
"near-singular matrix in Yule-Walker solve".to_string(),
));
}
let n_cols = mat[col].len();
let pivot_inv = 1.0 / pivot;
for j in col..n_cols {
let v = mat[col][j];
mat[col][j] = v * pivot_inv;
}
for row in 0..m {
if row != col {
let factor = mat[row][col];
let n_cols2 = mat[row].len();
for j in col..n_cols2 {
let sub = factor * mat[col][j];
mat[row][j] -= sub;
}
}
}
}
Ok(mat.iter().map(|row| *row.last().unwrap_or(&0.0)).collect())
}
fn ols_slope(data: &[f64]) -> f64 {
let n = data.len() as f64;
if n < 2.0 {
return 0.0;
}
let t_mean = (n + 1.0) / 2.0;
let y_mean = data.iter().sum::<f64>() / n;
let mut sxy = 0.0_f64;
let mut sxx = 0.0_f64;
for (i, &y) in data.iter().enumerate() {
let t = (i + 1) as f64;
sxy += (t - t_mean) * (y - y_mean);
sxx += (t - t_mean).powi(2);
}
if sxx.abs() < 1e-14 { 0.0 } else { sxy / sxx }
}
fn normal_quantile(p: f64) -> f64 {
if p <= 0.0 { return f64::NEG_INFINITY; }
if p >= 1.0 { return f64::INFINITY; }
if (p - 0.5).abs() < 1e-15 { return 0.0; }
let a = [
-3.969683028665376e+01_f64,
2.209460984245205e+02,
-2.759285104469687e+02,
1.383577518672690e+02,
-3.066479806614716e+01,
2.506628277459239e+00,
];
let b = [
-5.447609879822406e+01_f64,
1.615858368580409e+02,
-1.556989798598866e+02,
6.680131188771972e+01,
-1.328068155288572e+01,
];
let c = [
-7.784894002430293e-03_f64,
-3.223964580411365e-01,
-2.400758277161838e+00,
-2.549732539343734e+00,
4.374664141464968e+00,
2.938163982698783e+00,
];
let d = [
7.784695709041462e-03_f64,
3.224671290700398e-01,
2.445134137142996e+00,
3.754408661907416e+00,
];
let p_low = 0.02425_f64;
let p_high = 1.0 - p_low;
if p < p_low {
let q = (-2.0 * p.ln()).sqrt();
(((((c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
/ ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
} else if p <= p_high {
let q = p - 0.5;
let r = q * q;
(((((a[0] * r + a[1]) * r + a[2]) * r + a[3]) * r + a[4]) * r + a[5]) * q
/ (((((b[0] * r + b[1]) * r + b[2]) * r + b[3]) * r + b[4]) * r + 1.0)
} else {
let q = (-2.0 * (1.0 - p).ln()).sqrt();
-((((( c[0] * q + c[1]) * q + c[2]) * q + c[3]) * q + c[4]) * q + c[5])
/ ((((d[0] * q + d[1]) * q + d[2]) * q + d[3]) * q + 1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_seasonal(n_cycles: usize, period: usize) -> Vec<f64> {
let mut v = Vec::new();
for c in 0..n_cycles {
for i in 0..period {
let angle = 2.0 * std::f64::consts::PI * i as f64 / period as f64;
v.push(10.0 + c as f64 * 0.3 + 3.0 * angle.sin());
}
}
v
}
fn make_trend_data(n: usize) -> Vec<f64> {
(0..n).map(|i| 1.0 + 0.05 * i as f64).collect()
}
fn make_exponential_data() -> Vec<f64> {
(1..=30).map(|i| (i as f64 * 0.1).exp() + 1.0).collect()
}
#[test]
fn test_bats_fit_no_seasonality() {
let data = make_trend_data(20);
let config = BatsConfig {
use_box_cox: Some(false),
use_trend: Some(true),
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit should succeed");
assert_eq!(model.fitted_values().len(), data.len());
assert!(model.use_trend());
}
#[test]
fn test_bats_fit_single_season() {
let data = make_seasonal(4, 7);
let config = BatsConfig {
use_box_cox: Some(false),
use_trend: Some(false),
seasonal_periods: vec![7],
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit should succeed");
assert_eq!(model.fitted_values().len(), data.len());
}
#[test]
fn test_bats_fit_multiple_seasons() {
let data: Vec<f64> = (0..60)
.map(|i| {
let w = 2.0 * std::f64::consts::PI * i as f64;
5.0 + 2.0 * (w / 7.0).sin() + 1.0 * (w / 30.0).sin()
})
.collect();
let config = BatsConfig {
use_box_cox: Some(false),
seasonal_periods: vec![7, 30],
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit should succeed");
assert_eq!(model.fitted_values().len(), data.len());
assert!(model.aic().is_finite());
}
#[test]
fn test_bats_forecast_length() {
let data = make_seasonal(5, 7);
let config = BatsConfig {
use_box_cox: Some(false),
seasonal_periods: vec![7],
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit");
let fc = model.forecast(14).expect("forecast");
assert_eq!(fc.len(), 14);
}
#[test]
fn test_bats_forecast_finite() {
let data = make_seasonal(5, 7);
let config = BatsConfig {
use_box_cox: Some(false),
seasonal_periods: vec![7],
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit");
let fc = model.forecast(21).expect("forecast");
for (i, &f) in fc.iter().enumerate() {
assert!(f.is_finite(), "forecast[{i}] is not finite: {f}");
}
}
#[test]
fn test_bats_predict_ci_ordering() {
let data = make_seasonal(5, 7);
let config = BatsConfig {
use_box_cox: Some(false),
seasonal_periods: vec![7],
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit");
let (fc, lower, upper) = model.predict(14, 0.05).expect("predict");
assert_eq!(fc.len(), 14);
for i in 0..14 {
assert!(
lower[i] <= upper[i],
"lower must be <= upper at step {i}: {lower_val} > {upper_val}",
lower_val = lower[i],
upper_val = upper[i]
);
}
}
#[test]
fn test_bats_predict_invalid_alpha() {
let data = make_seasonal(4, 7);
let config = BatsConfig {
use_box_cox: Some(false),
seasonal_periods: vec![7],
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit");
assert!(model.predict(5, 0.0).is_err());
assert!(model.predict(5, 1.0).is_err());
}
#[test]
fn test_bats_box_cox_roundtrip() {
let vals = [1.0_f64, 2.5, 10.0, 100.0];
for &v in &vals {
for &lam in &[0.0_f64, 0.5, 1.0, -0.5, 2.0] {
let w = box_cox(v, lam);
let recovered = inv_box_cox(w, lam);
assert!(
(recovered - v).abs() < 1e-8,
"roundtrip failed: v={v}, λ={lam}, recovered={recovered}"
);
}
}
}
#[test]
fn test_bats_auto_box_cox() {
let data = make_exponential_data();
let config = BatsConfig {
use_box_cox: None, ..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit");
let fc = model.forecast(5).expect("forecast");
for &f in fc.iter() {
assert!(f.is_finite());
assert!(f > 0.0, "exponential forecast should be positive");
}
}
#[test]
fn test_bats_arma_errors() {
let data = make_seasonal(5, 7);
let config = BatsConfig {
use_box_cox: Some(false),
seasonal_periods: vec![7],
ar_order: Some(1),
ma_order: Some(1),
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit");
let fc = model.forecast(7).expect("forecast");
assert_eq!(fc.len(), 7);
for &f in fc.iter() {
assert!(f.is_finite());
}
}
#[test]
fn test_bats_insufficient_data() {
let config = BatsConfig {
seasonal_periods: vec![12],
..Default::default()
};
assert!(BatsModel::fit(&[1.0, 2.0, 3.0], config).is_err());
}
#[test]
fn test_bats_invalid_period() {
let data: Vec<f64> = (0..20).map(|i| i as f64).collect();
let config = BatsConfig {
seasonal_periods: vec![1], ..Default::default()
};
assert!(BatsModel::fit(&data, config).is_err());
}
#[test]
fn test_bats_aic_finite() {
let data = make_seasonal(4, 7);
let config = BatsConfig {
use_box_cox: Some(false),
seasonal_periods: vec![7],
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit");
assert!(model.aic().is_finite(), "AIC must be finite");
}
#[test]
fn test_bats_accessors() {
let data = make_trend_data(20);
let config = BatsConfig {
use_box_cox: Some(false),
use_trend: Some(true),
..Default::default()
};
let model = BatsModel::fit(&data, config).expect("fit");
assert!(model.alpha() > 0.0 && model.alpha() < 1.0);
assert!(model.phi() > 0.0 && model.phi() <= 1.0);
assert_eq!(model.n_obs(), data.len());
assert_eq!(model.lambda(), None);
assert!(model.sigma() > 0.0);
}
}