use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use super::common::DecompositionResult;
use crate::error::{Result, TimeSeriesError};
#[derive(Debug, Clone)]
pub struct STLOptions {
pub trend_window: usize,
pub seasonal_window: usize,
pub n_inner: usize,
pub n_outer: usize,
pub robust: bool,
}
impl Default for STLOptions {
fn default() -> Self {
Self {
trend_window: 21,
seasonal_window: 13,
n_inner: 2,
n_outer: 1,
robust: false,
}
}
}
#[derive(Debug, Clone)]
pub struct MSTLOptions {
pub seasonal_periods: Vec<usize>,
pub trend_window: usize,
pub seasonal_windows: Option<Vec<usize>>,
pub n_inner: usize,
pub n_outer: usize,
pub robust: bool,
}
impl Default for MSTLOptions {
fn default() -> Self {
Self {
seasonal_periods: Vec::new(),
trend_window: 21,
seasonal_windows: None,
n_inner: 2,
n_outer: 1,
robust: false,
}
}
}
#[derive(Debug, Clone)]
pub struct MultiSeasonalDecompositionResult<F> {
pub trend: Array1<F>,
pub seasonal_components: Vec<Array1<F>>,
pub residual: Array1<F>,
pub original: Array1<F>,
}
#[allow(dead_code)]
pub fn stl_decomposition<F>(
ts: &Array1<F>,
period: usize,
options: &STLOptions,
) -> Result<DecompositionResult<F>>
where
F: Float + FromPrimitive + Debug,
{
if ts.len() < 2 * period {
return Err(TimeSeriesError::DecompositionError(format!(
"Time series length ({}) must be at least twice the seasonal period ({})",
ts.len(),
period
)));
}
if options.trend_window.is_multiple_of(2) {
return Err(TimeSeriesError::InvalidParameter {
name: "trend_window".to_string(),
message: "Trend window size must be odd".to_string(),
});
}
if options.seasonal_window.is_multiple_of(2) {
return Err(TimeSeriesError::InvalidParameter {
name: "seasonal_window".to_string(),
message: "Seasonal window size must be odd".to_string(),
});
}
let n = ts.len();
let mut seasonal = Array1::zeros(n);
let mut trend = Array1::zeros(n);
let mut weights = Array1::from_elem(n, F::one());
let original = ts.clone();
for _ in 0..options.n_outer {
for _ in 0..options.n_inner {
let detrended = if trend.iter().all(|&x| x == F::zero()) {
original.clone()
} else {
original.clone() - &trend
};
let mut cycle_subseries = vec![Vec::new(); period];
let mut smoothed_seasonal = Array1::zeros(n);
for i in 0..n {
cycle_subseries[i % period].push((i, detrended[i]));
}
for subseries in cycle_subseries.iter() {
if subseries.is_empty() {
continue;
}
let mut indices = Vec::with_capacity(subseries.len());
let mut values = Vec::with_capacity(subseries.len());
let mut subseries_weights = Vec::with_capacity(subseries.len());
for &(idx, val) in subseries {
indices.push(idx);
values.push(val);
subseries_weights.push(weights[idx]);
}
let indices_array = Array1::from_vec(indices);
let values_array = Array1::from_vec(values);
let weights_array = Array1::from_vec(subseries_weights);
let mut smoothed_values = Array1::zeros(indices_array.len());
for i in 0..indices_array.len() {
let mut count = 0;
let mut sum = F::zero();
let window = options.seasonal_window / 2;
for j in 0..indices_array.len() {
if i >= window
&& i < indices_array.len() - window
&& j >= i - window
&& j <= i + window
{
sum = sum + values_array[j] * weights_array[j];
count += 1;
}
}
if count > 0 {
smoothed_values[i] = sum / F::from_usize(count).expect("Operation failed");
} else {
smoothed_values[i] = values_array[i];
}
}
for (idx, val) in indices_array.iter().zip(smoothed_values.iter()) {
smoothed_seasonal[*idx] = *val;
}
}
let filtered_seasonal = smoothed_seasonal.clone();
let deseasonalized = original.clone() - &filtered_seasonal;
let mut new_trend = Array1::zeros(n);
for i in 0..n {
let mut count = 0;
let mut sum = F::zero();
let window = options.trend_window / 2;
for j in 0..n {
if i >= window && i < n - window && j >= i - window && j <= i + window {
sum = sum + deseasonalized[j] * weights[j];
count += 1;
}
}
if count > 0 {
new_trend[i] = sum / F::from_usize(count).expect("Operation failed");
} else {
new_trend[i] = deseasonalized[i];
}
}
trend = new_trend;
seasonal = filtered_seasonal;
}
if options.robust {
let residual = original.clone() - &trend - &seasonal;
let abs_residuals = residual.mapv(|x| x.abs());
let max_residual = abs_residuals.fold(F::zero(), |a, &b| if a > b { a } else { b });
if max_residual > F::zero() {
for i in 0..n {
let r = abs_residuals[i] / max_residual;
if r < F::from_f64(0.5).expect("Operation failed") {
weights[i] = F::one();
} else if r < F::one() {
let tmp = F::one() - r * r;
weights[i] = tmp * tmp;
} else {
weights[i] = F::zero();
}
}
}
}
}
let residual = original.clone() - &trend - &seasonal;
Ok(DecompositionResult {
trend,
seasonal,
residual,
original,
})
}
#[allow(dead_code)]
pub fn mstl_decomposition<F>(
ts: &Array1<F>,
options: &MSTLOptions,
) -> Result<MultiSeasonalDecompositionResult<F>>
where
F: Float + FromPrimitive + Debug,
{
if options.seasonal_periods.is_empty() {
return Err(TimeSeriesError::InvalidParameter {
name: "seasonal_periods".to_string(),
message: "At least one seasonal period must be specified".to_string(),
});
}
let n_seasons = options.seasonal_periods.len();
if let Some(ref windows) = options.seasonal_windows {
if windows.len() != n_seasons {
return Err(TimeSeriesError::InvalidParameter {
name: "seasonal_windows".to_string(),
message: format!(
"Number of seasonal windows ({}) must match number of seasonal periods ({})",
windows.len(),
n_seasons
),
});
}
}
for &period in &options.seasonal_periods {
if ts.len() < 2 * period {
return Err(TimeSeriesError::DecompositionError(format!(
"Time series length ({}) must be at least twice the seasonal period ({})",
ts.len(),
period
)));
}
}
let n = ts.len();
let original = ts.clone();
let mut seasonal_components = Vec::with_capacity(n_seasons);
let _weights = Array1::from_elem(n, F::one());
let mut deseasonal = original.clone();
for (i, &period) in options.seasonal_periods.iter().enumerate() {
let seasonal_window = if let Some(ref windows) = options.seasonal_windows {
windows[i]
} else {
std::cmp::max(7, period / 2) | 1 };
let stl_options = STLOptions {
trend_window: options.trend_window,
seasonal_window,
n_inner: options.n_inner,
n_outer: options.n_outer,
robust: options.robust,
};
let result = stl_decomposition(&deseasonal, period, &stl_options)?;
seasonal_components.push(result.seasonal);
deseasonal = deseasonal - &seasonal_components[i];
}
let trend = deseasonal.clone();
let mut residual = original.clone();
residual = residual - &trend;
for seasonal in &seasonal_components {
residual = residual - seasonal;
}
Ok(MultiSeasonalDecompositionResult {
trend,
seasonal_components,
residual,
original,
})
}