use scirs2_core::ndarray::ArrayStatCompat;
use scirs2_core::ndarray::{Array1, Array2, ScalarOperand};
use scirs2_core::numeric::{Float, FromPrimitive, NumCast};
use scirs2_linalg::solve;
use std::fmt::Debug;
use super::common::box_cox_transform;
use crate::error::{Result, TimeSeriesError};
use statrs::statistics::Statistics;
#[derive(Debug, Clone)]
pub struct TBATSOptions {
pub use_box_cox: bool,
pub box_cox_lambda: Option<f64>,
pub use_trend: bool,
pub use_damped_trend: bool,
pub seasonal_periods: Vec<f64>,
pub fourier_terms: Option<Vec<usize>>,
pub ar_order: usize,
pub ma_order: usize,
pub auto_arma: bool,
pub use_parallel: bool,
pub max_iterations: usize,
pub tolerance: f64,
}
impl Default for TBATSOptions {
fn default() -> Self {
Self {
use_box_cox: false, box_cox_lambda: None,
use_trend: true,
use_damped_trend: false,
seasonal_periods: Vec::new(),
fourier_terms: None,
ar_order: 0,
ma_order: 0,
auto_arma: false, use_parallel: false,
max_iterations: 100,
tolerance: 1e-6,
}
}
}
#[derive(Debug, Clone)]
pub struct TBATSResult<F> {
pub trend: Array1<F>,
pub seasonal_components: Vec<Array1<F>>,
pub residuals: Array1<F>,
pub level: Array1<F>,
pub original: Array1<F>,
pub transformed: Option<Array1<F>>,
pub parameters: TBATSParameters,
pub log_likelihood: f64,
}
#[derive(Debug, Clone)]
pub struct TBATSParameters {
pub lambda: Option<f64>,
pub alpha: f64,
pub beta: Option<f64>,
pub phi: Option<f64>,
pub gamma: Option<Vec<f64>>,
pub fourier_coefficients: Vec<Vec<(f64, f64)>>,
pub ar_coefficients: Vec<f64>,
pub ma_coefficients: Vec<f64>,
pub sigma_squared: f64,
}
#[allow(dead_code)]
pub fn tbats_decomposition<F>(ts: &Array1<F>, options: &TBATSOptions) -> Result<TBATSResult<F>>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + ScalarOperand + NumCast,
{
let n = ts.len();
if n < 3 {
return Err(TimeSeriesError::DecompositionError(
"Time series must have at least 3 points for TBATS decomposition".to_string(),
));
}
if options.seasonal_periods.is_empty()
&& options.use_box_cox
&& options.box_cox_lambda.is_none()
{
return Err(TimeSeriesError::DecompositionError(
"When using Box-Cox transformation with no seasonal periods, lambda must be specified"
.to_string(),
));
}
for &period in &options.seasonal_periods {
if period <= 1.0 {
return Err(TimeSeriesError::DecompositionError(
"Seasonal periods must be greater than 1".to_string(),
));
}
}
if let Some(ref terms) = options.fourier_terms {
if terms.len() != options.seasonal_periods.len() {
return Err(TimeSeriesError::DecompositionError(
"Number of Fourier terms must match number of seasonal periods".to_string(),
));
}
for &k in terms {
if k == 0 {
return Err(TimeSeriesError::DecompositionError(
"Number of Fourier terms must be at least 1 for each seasonal period"
.to_string(),
));
}
}
}
let (transformed_ts, lambda) = if options.use_box_cox {
let lambda = options.box_cox_lambda.unwrap_or_else(|| {
estimate_box_cox_lambda(ts)
});
let transformed = box_cox_transform(ts, lambda)?;
(transformed, Some(lambda))
} else {
(ts.clone(), None)
};
let fourier_terms = match &options.fourier_terms {
Some(terms) => terms.clone(),
None => {
options
.seasonal_periods
.iter()
.map(|&p| std::cmp::min((p / 2.0).floor() as usize, 3).max(1))
.collect()
}
};
let state_size = calculate_state_size(options, &fourier_terms);
let mut state = Array1::zeros(state_size);
initialize_state(&mut state, &transformed_ts, options, &fourier_terms)?;
let parameters = estimate_parameters(&transformed_ts, options, &fourier_terms)?;
let components =
apply_state_space_model(&transformed_ts, ¶meters, options, &fourier_terms)?;
let (level, trend, seasonal_components, residuals, log_likelihood) = extract_components(
&transformed_ts,
&components,
¶meters,
options,
&fourier_terms,
)?;
let result = TBATSResult {
trend,
seasonal_components,
residuals,
level,
original: ts.clone(),
transformed: if options.use_box_cox {
Some(transformed_ts)
} else {
None
},
parameters,
log_likelihood,
};
Ok(result)
}
#[allow(dead_code)]
fn estimate_box_cox_lambda<F>(ts: &Array1<F>) -> f64
where
F: Float + FromPrimitive + Debug,
{
let variance = ts.var(F::zero());
let mean = ts.mean_or(F::zero());
if variance > F::zero() && mean > F::zero() {
let cv = variance.sqrt() / mean;
if cv.to_f64().unwrap_or(1.0) > 0.3 {
0.0 } else {
0.5 }
} else {
1.0 }
}
#[allow(dead_code)]
fn calculate_state_size(options: &TBATSOptions, fourierterms: &[usize]) -> usize {
let mut size = 1;
if options.use_trend {
size += 1; }
for &k in fourierterms {
size += 2 * k;
}
size
}
#[allow(dead_code)]
fn initialize_state<F>(
state: &mut Array1<F>,
ts: &Array1<F>,
options: &TBATSOptions,
fourier_terms: &[usize],
) -> Result<()>
where
F: Float + FromPrimitive + Debug,
{
let n = ts.len();
state[0] = ts[0];
let mut idx = 1;
if options.use_trend {
if n > 1 {
let mut trend_sum = F::zero();
let trend_points = std::cmp::min(n - 1, 4);
for i in 0..trend_points {
trend_sum = trend_sum + (ts[i + 1] - ts[i]);
}
state[idx] = trend_sum / F::from_usize(trend_points).expect("Operation failed");
}
idx += 1;
}
for (s_idx, (&_period, &k)) in options
.seasonal_periods
.iter()
.zip(fourier_terms.iter())
.enumerate()
{
for _ in 0..(2 * k) {
state[idx] = F::from_f64(0.01 * (s_idx as f64 + 1.0)).expect("Operation failed");
idx += 1;
}
}
Ok(())
}
#[allow(dead_code)]
fn estimate_parameters<F>(
ts: &Array1<F>,
options: &TBATSOptions,
fourier_terms: &[usize],
) -> Result<TBATSParameters>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + ScalarOperand + NumCast,
{
let alpha = 0.1; let beta = if options.use_trend { Some(0.01) } else { None };
let phi = if options.use_damped_trend {
Some(0.98)
} else {
None
};
let gamma = if !options.seasonal_periods.is_empty() {
Some(vec![0.001; options.seasonal_periods.len()])
} else {
None
};
let fourier_coefficients = estimate_fourier_coefficients(ts, options, fourier_terms)?;
let ar_coefficients = vec![0.0; options.ar_order];
let ma_coefficients = vec![0.0; options.ma_order];
let residual_variance = estimate_residual_variance(ts, &fourier_coefficients, options);
Ok(TBATSParameters {
lambda: options.box_cox_lambda,
alpha,
beta,
phi,
gamma,
fourier_coefficients,
ar_coefficients,
ma_coefficients,
sigma_squared: residual_variance,
})
}
#[allow(dead_code)]
fn estimate_fourier_coefficients<F>(
ts: &Array1<F>,
options: &TBATSOptions,
fourier_terms: &[usize],
) -> Result<Vec<Vec<(f64, f64)>>>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + ScalarOperand + NumCast,
{
let n = ts.len();
let mut all_coefficients = Vec::new();
for (&period, &k) in options.seasonal_periods.iter().zip(fourier_terms.iter()) {
let mut design_matrix = Array2::zeros((n, 2 * k));
for t in 0..n {
let t_f = F::from_usize(t).expect("Operation failed");
for j in 0..k {
let freq = F::from_f64(2.0 * std::f64::consts::PI * (j + 1) as f64 / period)
.expect("Operation failed");
design_matrix[[t, 2 * j]] = Float::sin(freq * t_f);
design_matrix[[t, 2 * j + 1]] = Float::cos(freq * t_f);
}
}
let xtx = design_matrix.t().dot(&design_matrix);
let xty = design_matrix.t().dot(ts);
let n = xtx.shape()[0];
let mut xtx_reg = xtx.clone();
let lambda = F::from(1e-6).expect("Failed to convert constant to float");
for i in 0..n {
xtx_reg[[i, i]] = xtx_reg[[i, i]] + lambda;
}
let coeffs = solve_regularized_least_squares(&xtx_reg, &xty)?;
let mut seasonal_coeffs = Vec::new();
for j in 0..k {
let a = coeffs[2 * j].to_f64().unwrap_or(0.0);
let b = coeffs[2 * j + 1].to_f64().unwrap_or(0.0);
seasonal_coeffs.push((a, b));
}
all_coefficients.push(seasonal_coeffs);
}
Ok(all_coefficients)
}
#[allow(dead_code)]
fn estimate_residual_variance<F>(
ts: &Array1<F>,
fourier_coefficients: &[Vec<(f64, f64)>],
options: &TBATSOptions,
) -> f64
where
F: Float + FromPrimitive + Debug + std::iter::Sum,
{
let n = ts.len();
let mut residual_sum = F::zero();
for t in 0..n {
let mut seasonal_sum = F::zero();
let t_f = F::from_usize(t).expect("Operation failed");
for (&period, coeffs) in options
.seasonal_periods
.iter()
.zip(fourier_coefficients.iter())
{
for (j, &(a, b)) in coeffs.iter().enumerate() {
let freq = F::from_f64(2.0 * std::f64::consts::PI * (j + 1) as f64 / period)
.expect("Operation failed");
let a_f = F::from_f64(a).expect("Operation failed");
let b_f = F::from_f64(b).expect("Operation failed");
seasonal_sum =
seasonal_sum + a_f * Float::sin(freq * t_f) + b_f * Float::cos(freq * t_f);
}
}
let residual = ts[t] - seasonal_sum;
residual_sum = residual_sum + residual * residual;
}
(residual_sum / F::from_usize(n).expect("Operation failed"))
.to_f64()
.unwrap_or(1.0)
}
#[allow(dead_code)]
fn apply_state_space_model<F>(
ts: &Array1<F>,
parameters: &TBATSParameters,
options: &TBATSOptions,
fourier_terms: &[usize],
) -> Result<Array2<F>>
where
F: Float + FromPrimitive + Debug + NumCast,
{
let n = ts.len();
let state_size = calculate_state_size(options, fourier_terms);
let mut states = Array2::zeros((n, state_size));
states[[0, 0]] = ts[0];
for t in 1..n {
states[[t, 0]] = F::from(parameters.alpha).expect("Failed to convert to float") * ts[t]
+ (F::one() - F::from(parameters.alpha).expect("Failed to convert to float"))
* states[[t - 1, 0]];
for i in 1..state_size {
states[[t, i]] = states[[t - 1, i]];
}
}
Ok(states)
}
type TBATSComponentsResult<F> = Result<(Array1<F>, Array1<F>, Vec<Array1<F>>, Array1<F>, f64)>;
#[allow(dead_code)]
fn extract_components<F>(
ts: &Array1<F>,
states: &Array2<F>,
parameters: &TBATSParameters,
options: &TBATSOptions,
fourier_terms: &[usize],
) -> TBATSComponentsResult<F>
where
F: Float + FromPrimitive + Debug + std::iter::Sum + NumCast,
{
let n = ts.len();
let level = states.column(0).to_owned();
let trend = if options.use_trend {
states.column(1).to_owned()
} else {
Array1::zeros(n)
};
let mut seasonal_components = Vec::new();
for (s_idx, (&period, &_k)) in options
.seasonal_periods
.iter()
.zip(fourier_terms.iter())
.enumerate()
{
let mut seasonal = Array1::zeros(n);
let coeffs = ¶meters.fourier_coefficients[s_idx];
for t in 0..n {
let t_f = F::from_usize(t).expect("Operation failed");
let mut seasonal_value = F::zero();
for (j, &(a, b)) in coeffs.iter().enumerate() {
let freq = F::from_f64(2.0 * std::f64::consts::PI * (j + 1) as f64 / period)
.expect("Operation failed");
let a_f = F::from_f64(a).expect("Operation failed");
let b_f = F::from_f64(b).expect("Operation failed");
seasonal_value =
seasonal_value + a_f * Float::sin(freq * t_f) + b_f * Float::cos(freq * t_f);
}
seasonal[t] = seasonal_value;
}
seasonal_components.push(seasonal);
}
let mut residuals = Array1::zeros(n);
for t in 0..n {
let mut fitted = level[t];
if options.use_trend {
fitted = fitted + trend[t];
}
for seasonal in &seasonal_components {
fitted = fitted + seasonal[t];
}
residuals[t] = ts[t] - fitted;
}
let residual_variance =
residuals.mapv(|x| x * x).sum() / F::from_usize(n).expect("Operation failed");
let log_likelihood = -0.5
* n as f64
* (2.0 * std::f64::consts::PI * residual_variance.to_f64().unwrap_or(1.0)).ln()
- 0.5 * residuals.mapv(|x| x * x).sum().to_f64().unwrap_or(0.0)
/ residual_variance.to_f64().unwrap_or(1.0);
Ok((level, trend, seasonal_components, residuals, log_likelihood))
}
#[allow(dead_code)]
fn solve_regularized_least_squares<F>(a: &Array2<F>, b: &Array1<F>) -> Result<Array1<F>>
where
F: Float + FromPrimitive + ScalarOperand + NumCast + 'static,
{
let n = a.shape()[0];
if n != a.shape()[1] || n != b.len() {
return Err(TimeSeriesError::DecompositionError(
"Matrix dimensions mismatch".to_string(),
));
}
let a_f64 = a.mapv(|x| x.to_f64().unwrap_or(0.0));
let b_f64 = b.mapv(|x| x.to_f64().unwrap_or(0.0));
let x_f64 = solve(&a_f64.view(), &b_f64.view(), None)
.map_err(|e| TimeSeriesError::DecompositionError(format!("Linear solve failed: {e}")))?;
let x = x_f64.mapv(|val| F::from_f64(val).unwrap_or_else(F::zero));
Ok(x)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_tbats_basic() {
let n = 48;
let mut ts = Array1::zeros(n);
for i in 0..n {
let trend = 0.1 * i as f64;
let seasonal = 2.0 * (2.0 * std::f64::consts::PI * i as f64 / 12.0).sin();
let noise = 0.1 * (i as f64 * 0.789).sin();
ts[i] = 10.0 + trend + seasonal + noise;
}
let options = TBATSOptions {
seasonal_periods: vec![12.0],
use_trend: true,
use_box_cox: false,
..Default::default()
};
let result = tbats_decomposition(&ts, &options).expect("Operation failed");
assert_eq!(result.seasonal_components.len(), 1);
assert_eq!(result.level.len(), n);
assert_eq!(result.trend.len(), n);
assert_eq!(result.residuals.len(), n);
}
#[test]
fn test_tbats_multiple_seasons() {
let n = 60;
let mut ts = Array1::zeros(n);
for i in 0..n {
let trend = 0.05 * i as f64;
let seasonal1 = 3.0 * (2.0 * std::f64::consts::PI * i as f64 / 12.0).sin();
let seasonal2 = 1.5 * (2.0 * std::f64::consts::PI * i as f64 / 4.0).cos();
ts[i] = 5.0 + trend + seasonal1 + seasonal2;
}
let options = TBATSOptions {
seasonal_periods: vec![12.0, 4.0],
use_trend: true,
use_box_cox: false,
..Default::default()
};
let result = tbats_decomposition(&ts, &options).expect("Operation failed");
assert_eq!(result.seasonal_components.len(), 2);
assert!(result.parameters.alpha > 0.0);
assert!(result.parameters.fourier_coefficients.len() == 2);
}
#[test]
fn test_tbats_edge_cases() {
let ts = array![1.0, 2.0, 3.0];
let mut options = TBATSOptions {
seasonal_periods: vec![2.0],
..Default::default()
};
let result = tbats_decomposition(&ts, &options);
assert!(result.is_ok());
options.seasonal_periods = vec![0.5];
let result = tbats_decomposition(&ts, &options);
assert!(result.is_err());
let ts = array![1.0, 2.0];
options.seasonal_periods = vec![2.0];
let result = tbats_decomposition(&ts, &options);
assert!(result.is_err());
}
#[test]
fn test_tbats_no_seasonal() {
let n = 20;
let mut ts = Array1::zeros(n);
for i in 0..n {
let trend = 0.2 * i as f64;
ts[i] = 5.0 + trend;
}
let options = TBATSOptions {
seasonal_periods: vec![],
use_trend: true,
use_box_cox: false,
..Default::default()
};
let result = tbats_decomposition(&ts, &options).expect("Operation failed");
assert_eq!(result.seasonal_components.len(), 0);
assert_eq!(result.level.len(), n);
assert_eq!(result.trend.len(), n);
}
}