use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::common::{min, DecompositionModel, DecompositionResult};
use crate::error::{Result, TimeSeriesError};
#[allow(dead_code)]
pub fn exponential_decomposition<F>(
ts: &Array1<F>,
period: usize,
alpha: f64,
beta: f64,
gamma: f64,
model: DecompositionModel,
) -> Result<DecompositionResult<F>>
where
F: Float + FromPrimitive + Debug,
{
if ts.len() < period + 1 {
return Err(TimeSeriesError::DecompositionError(format!(
"Time series length ({}) must be greater than the seasonal period ({})",
ts.len(),
period
)));
}
if alpha <= 0.0 || alpha >= 1.0 || beta <= 0.0 || beta >= 1.0 || gamma <= 0.0 || gamma >= 1.0 {
return Err(TimeSeriesError::InvalidInput(
"Smoothing parameters must be between 0 and 1 (exclusive)".to_string(),
));
}
let alpha = F::from_f64(alpha).expect("Operation failed");
let beta = F::from_f64(beta).expect("Operation failed");
let gamma = F::from_f64(gamma).expect("Operation failed");
let n = ts.len();
let mut level = Array1::zeros(n + 1);
let mut trend = Array1::zeros(n + 1);
let mut seasonal = Array1::zeros(n + period);
let mut residual = Array1::zeros(n);
let initial_level = ts[0]; level[0] = initial_level;
if n > 1 {
let mut sum = F::zero();
for i in 1..min(n, 10) {
sum = sum + (ts[i] - ts[i - 1]);
}
trend[0] = sum / F::from_usize(min(n - 1, 9)).expect("Operation failed");
}
for i in 0..min(period, n) {
let pos = i % period;
let expected = level[0] + F::from_usize(i).expect("Operation failed") * trend[0];
match model {
DecompositionModel::Additive => {
seasonal[pos] = ts[i] - expected;
}
DecompositionModel::Multiplicative => {
if expected == F::zero() {
return Err(TimeSeriesError::DecompositionError(
"Division by zero in multiplicative model initialization".to_string(),
));
}
seasonal[pos] = ts[i] / expected;
}
}
}
match model {
DecompositionModel::Additive => {
let mean = seasonal
.iter()
.take(period)
.fold(F::zero(), |acc, &x| acc + x)
/ F::from_usize(period).expect("Operation failed");
for i in 0..period {
seasonal[i] = seasonal[i] - mean;
}
}
DecompositionModel::Multiplicative => {
let mean = seasonal
.iter()
.take(period)
.fold(F::zero(), |acc, &x| acc + x)
/ F::from_usize(period).expect("Operation failed");
if mean == F::zero() {
return Err(TimeSeriesError::DecompositionError(
"Division by zero normalizing multiplicative seasonal component".to_string(),
));
}
for i in 0..period {
seasonal[i] = seasonal[i] / mean;
}
}
}
for i in 0..n {
let s = i % period; let expected = match model {
DecompositionModel::Additive => level[i] + trend[i],
DecompositionModel::Multiplicative => level[i] * trend[i],
};
match model {
DecompositionModel::Additive => {
residual[i] = ts[i] - expected - seasonal[s];
}
DecompositionModel::Multiplicative => {
if expected == F::zero() || seasonal[s] == F::zero() {
residual[i] = F::zero(); } else {
residual[i] = ts[i] / (expected * seasonal[s]);
}
}
}
match model {
DecompositionModel::Additive => {
level[i + 1] =
alpha * (ts[i] - seasonal[s]) + (F::one() - alpha) * (level[i] + trend[i]);
trend[i + 1] = beta * (level[i + 1] - level[i]) + (F::one() - beta) * trend[i];
seasonal[s + period] =
gamma * (ts[i] - level[i + 1]) + (F::one() - gamma) * seasonal[s];
}
DecompositionModel::Multiplicative => {
if seasonal[s] == F::zero() {
return Err(TimeSeriesError::DecompositionError(
"Division by zero in multiplicative model update".to_string(),
));
}
level[i + 1] =
alpha * (ts[i] / seasonal[s]) + (F::one() - alpha) * (level[i] * trend[i]);
if level[i] == F::zero() {
trend[i + 1] = trend[i]; } else {
trend[i + 1] = beta * (level[i + 1] / level[i]) + (F::one() - beta) * trend[i];
}
if level[i + 1] == F::zero() {
seasonal[s + period] = seasonal[s]; } else {
seasonal[s + period] =
gamma * (ts[i] / level[i + 1]) + (F::one() - gamma) * seasonal[s];
}
}
}
}
let trend_component = Array1::from_iter(level.iter().take(n).cloned());
let seasonal_component = Array1::from_iter((0..n).map(|i| seasonal[i % period]));
let original = ts.clone();
Ok(DecompositionResult {
trend: trend_component,
seasonal: seasonal_component,
residual,
original,
})
}