use crate::matrix::FdMatrix;
use super::compute_mean_curve;
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct SsaResult {
pub trend: Vec<f64>,
pub seasonal: Vec<f64>,
pub noise: Vec<f64>,
pub singular_values: Vec<f64>,
pub contributions: Vec<f64>,
pub window_length: usize,
pub n_components: usize,
pub detected_period: f64,
pub confidence: f64,
}
pub fn ssa(
values: &[f64],
window_length: Option<usize>,
n_components: Option<usize>,
trend_components: Option<&[usize]>,
seasonal_components: Option<&[usize]>,
) -> SsaResult {
let n = values.len();
let l = window_length.unwrap_or_else(|| (n / 2).clamp(2, 50));
if n < 4 || l < 2 || l > n / 2 {
return SsaResult {
trend: values.to_vec(),
seasonal: vec![0.0; n],
noise: vec![0.0; n],
singular_values: vec![],
contributions: vec![],
window_length: l,
n_components: 0,
detected_period: 0.0,
confidence: 0.0,
};
}
let k = n - l + 1;
let trajectory = embed_trajectory(values, l, k);
let (u, sigma, vt) = svd_decompose(&trajectory, l, k);
let max_components = sigma.len();
let n_comp = n_components.unwrap_or(10).min(max_components);
let total_var: f64 = sigma.iter().map(|&s| s * s).sum();
let contributions: Vec<f64> = sigma
.iter()
.take(n_comp)
.map(|&s| s * s / total_var.max(1e-15))
.collect();
let (trend_idx, seasonal_idx, detected_period, confidence) =
if trend_components.is_some() || seasonal_components.is_some() {
let t_idx: Vec<usize> = trend_components.map(<[usize]>::to_vec).unwrap_or_default();
let s_idx: Vec<usize> = seasonal_components
.map(<[usize]>::to_vec)
.unwrap_or_default();
(t_idx, s_idx, 0.0, 0.0)
} else {
auto_group_ssa_components(&u, &sigma, l, k, n_comp)
};
let trend = reconstruct_grouped(&u, &sigma, &vt, l, k, n, &trend_idx);
let seasonal = reconstruct_grouped(&u, &sigma, &vt, l, k, n, &seasonal_idx);
let noise: Vec<f64> = values
.iter()
.zip(trend.iter())
.zip(seasonal.iter())
.map(|((&y, &t), &s)| y - t - s)
.collect();
SsaResult {
trend,
seasonal,
noise,
singular_values: sigma.into_iter().take(n_comp).collect(),
contributions,
window_length: l,
n_components: n_comp,
detected_period,
confidence,
}
}
pub(super) fn embed_trajectory(values: &[f64], l: usize, k: usize) -> Vec<f64> {
let mut trajectory = vec![0.0; l * k];
for j in 0..k {
for i in 0..l {
trajectory[i + j * l] = values[i + j];
}
}
trajectory
}
pub(super) fn svd_decompose(
trajectory: &[f64],
l: usize,
k: usize,
) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
use nalgebra::{DMatrix, SVD};
let mat = DMatrix::from_column_slice(l, k, trajectory);
let svd = SVD::new(mat, true, true);
let Some(u_mat) = svd.u else {
return (vec![], vec![], vec![]);
};
let Some(vt_mat) = svd.v_t else {
return (vec![], vec![], vec![]);
};
let sigma = svd.singular_values;
let u: Vec<f64> = u_mat.iter().copied().collect();
let sigma_vec: Vec<f64> = sigma.iter().copied().collect();
let vt: Vec<f64> = vt_mat.iter().copied().collect();
(u, sigma_vec, vt)
}
pub(super) enum SsaComponentKind {
Trend,
Seasonal(f64),
Noise,
}
pub(super) fn classify_ssa_component(u_col: &[f64], trend_count: usize) -> SsaComponentKind {
if is_trend_component(u_col) && trend_count < 2 {
SsaComponentKind::Trend
} else {
let (is_periodic, period) = is_periodic_component(u_col);
if is_periodic {
SsaComponentKind::Seasonal(period)
} else {
SsaComponentKind::Noise
}
}
}
pub(super) fn apply_ssa_grouping_defaults(
trend_idx: &mut Vec<usize>,
seasonal_idx: &mut Vec<usize>,
n_comp: usize,
) {
if trend_idx.is_empty() && n_comp > 0 {
trend_idx.push(0);
}
if seasonal_idx.is_empty() && n_comp >= 3 {
seasonal_idx.push(1);
if n_comp > 2 {
seasonal_idx.push(2);
}
}
}
pub(super) fn auto_group_ssa_components(
u: &[f64],
sigma: &[f64],
l: usize,
_k: usize,
n_comp: usize,
) -> (Vec<usize>, Vec<usize>, f64, f64) {
let mut trend_idx = Vec::new();
let mut seasonal_idx = Vec::new();
let mut detected_period = 0.0;
let mut confidence = 0.0;
for i in 0..n_comp.min(sigma.len()) {
let u_col: Vec<f64> = (0..l).map(|j| u[j + i * l]).collect();
match classify_ssa_component(&u_col, trend_idx.len()) {
SsaComponentKind::Trend => trend_idx.push(i),
SsaComponentKind::Seasonal(period) => {
seasonal_idx.push(i);
if detected_period == 0.0 && period > 0.0 {
detected_period = period;
confidence = sigma[i] / sigma[0].max(1e-15);
}
}
SsaComponentKind::Noise => {}
}
}
apply_ssa_grouping_defaults(&mut trend_idx, &mut seasonal_idx, n_comp);
(trend_idx, seasonal_idx, detected_period, confidence)
}
pub(super) fn is_trend_component(u_col: &[f64]) -> bool {
let n = u_col.len();
if n < 3 {
return false;
}
let mut sign_changes = 0;
for i in 1..n {
if u_col[i] * u_col[i - 1] < 0.0 {
sign_changes += 1;
}
}
sign_changes <= n / 10
}
pub(super) fn is_periodic_component(u_col: &[f64]) -> (bool, f64) {
let n = u_col.len();
if n < 4 {
return (false, 0.0);
}
let mean: f64 = u_col.iter().sum::<f64>() / n as f64;
let centered: Vec<f64> = u_col.iter().map(|&x| x - mean).collect();
let var: f64 = centered.iter().map(|&x| x * x).sum();
if var < 1e-15 {
return (false, 0.0);
}
let mut best_period = 0.0;
let mut best_acf = 0.0;
for lag in 2..n / 2 {
let mut acf = 0.0;
for i in 0..(n - lag) {
acf += centered[i] * centered[i + lag];
}
acf /= var;
if acf > best_acf && acf > 0.3 {
best_acf = acf;
best_period = lag as f64;
}
}
let is_periodic = best_acf > 0.3 && best_period > 0.0;
(is_periodic, best_period)
}
pub(super) fn reconstruct_grouped(
u: &[f64],
sigma: &[f64],
vt: &[f64],
l: usize,
k: usize,
n: usize,
group_idx: &[usize],
) -> Vec<f64> {
if group_idx.is_empty() {
return vec![0.0; n];
}
let mut grouped_matrix = vec![0.0; l * k];
for &idx in group_idx {
if idx >= sigma.len() {
continue;
}
let s = sigma[idx];
for j in 0..k {
for i in 0..l {
let u_val = u[i + idx * l];
let v_val = vt[idx + j * sigma.len().min(l)]; grouped_matrix[i + j * l] += s * u_val * v_val;
}
}
}
diagonal_average(&grouped_matrix, l, k, n)
}
pub(super) fn diagonal_average(matrix: &[f64], l: usize, k: usize, n: usize) -> Vec<f64> {
let mut result = vec![0.0; n];
let mut counts = vec![0.0; n];
for j in 0..k {
for i in 0..l {
let idx = i + j; if idx < n {
result[idx] += matrix[i + j * l];
counts[idx] += 1.0;
}
}
}
for i in 0..n {
if counts[i] > 0.0 {
result[i] /= counts[i];
}
}
result
}
pub fn ssa_fdata(
data: &FdMatrix,
window_length: Option<usize>,
n_components: Option<usize>,
) -> SsaResult {
let mean_curve = compute_mean_curve(data);
ssa(&mean_curve, window_length, n_components, None, None)
}
pub fn ssa_seasonality(
values: &[f64],
window_length: Option<usize>,
confidence_threshold: Option<f64>,
) -> (bool, f64, f64) {
let result = ssa(values, window_length, None, None, None);
let threshold = confidence_threshold.unwrap_or(0.1);
let is_seasonal = result.confidence >= threshold && result.detected_period > 0.0;
(is_seasonal, result.detected_period, result.confidence)
}