use scirs2_core::ndarray::Array1;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::fmt::Debug;
use crate::error::{Result, TimeSeriesError};
use crate::utils::{autocorrelation, moving_average};
#[derive(Debug, Clone)]
pub struct PeriodDetectionResult<F> {
pub periods: Vec<(usize, F)>, pub acf: Array1<F>,
pub periodogram: Option<Array1<F>>,
pub method: PeriodDetectionMethod,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PeriodDetectionMethod {
ACF,
FFT,
Combined,
}
#[derive(Debug, Clone)]
pub struct PeriodDetectionOptions {
pub method: PeriodDetectionMethod,
pub max_periods: usize,
pub min_period: usize,
pub max_period: usize,
pub threshold: f64,
pub filter_harmonics: bool,
pub detrend: bool,
}
impl Default for PeriodDetectionOptions {
fn default() -> Self {
Self {
method: PeriodDetectionMethod::Combined,
max_periods: 3,
min_period: 2,
max_period: 0, threshold: 0.3, filter_harmonics: true,
detrend: true,
}
}
}
#[allow(dead_code)]
pub fn detect_periods<F>(
ts: &Array1<F>,
options: &PeriodDetectionOptions,
) -> Result<PeriodDetectionResult<F>>
where
F: Float + FromPrimitive + Debug,
{
let n = ts.len();
if n < 8 {
return Err(TimeSeriesError::InvalidInput(
"Time series must have at least 8 points for period detection".to_string(),
));
}
let max_period = if options.max_period == 0 {
n / 2
} else {
options.max_period
};
if options.min_period < 2 {
return Err(TimeSeriesError::InvalidInput(
"Minimum period must be at least 2".to_string(),
));
}
if max_period <= options.min_period {
return Err(TimeSeriesError::InvalidInput(
"Maximum period must be greater than minimum period".to_string(),
));
}
if max_period > n / 2 {
return Err(TimeSeriesError::InvalidInput(
"Maximum period cannot exceed half the length of the time series".to_string(),
));
}
let detrended_ts = if options.detrend {
let window_size = std::cmp::min(n / 10, 21);
let window_size = if window_size.is_multiple_of(2) {
window_size + 1
} else {
window_size
};
let trend = moving_average(ts, window_size)?;
let mut detrended = Array1::zeros(n);
for i in 0..n {
detrended[i] = ts[i] - trend[i];
}
detrended
} else {
ts.clone()
};
match options.method {
PeriodDetectionMethod::ACF => detect_periods_acf(&detrended_ts, options),
PeriodDetectionMethod::FFT => detect_periods_fft(&detrended_ts, options),
PeriodDetectionMethod::Combined => detect_periods_combined(&detrended_ts, options),
}
}
#[allow(dead_code)]
fn detect_periods_acf<F>(
ts: &Array1<F>,
options: &PeriodDetectionOptions,
) -> Result<PeriodDetectionResult<F>>
where
F: Float + FromPrimitive + Debug,
{
let n = ts.len();
let max_lag = std::cmp::min(options.max_period, n / 2);
let acf = autocorrelation(ts, Some(max_lag))?;
let mut peaks = Vec::new();
let threshold = F::from_f64(options.threshold).expect("Operation failed");
let mut max_acf = F::min_value();
let mut max_lag = 0;
for lag in options.min_period..=std::cmp::min(options.max_period, acf.len() - 1) {
if acf[lag] > max_acf {
max_acf = acf[lag];
max_lag = lag;
}
if lag > 0
&& lag < acf.len() - 1
&& acf[lag] > acf[lag - 1]
&& acf[lag] > acf[lag + 1]
&& acf[lag] > threshold
{
peaks.push((lag, acf[lag]));
}
}
if peaks.is_empty() && max_lag > 0 {
peaks.push((max_lag, max_acf));
}
let filtered_peaks = if options.filter_harmonics {
filter_harmonics(peaks, options.threshold)
} else {
peaks
};
let mut sorted_peaks = filtered_peaks;
sorted_peaks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_periods = sorted_peaks.into_iter().take(options.max_periods).collect();
Ok(PeriodDetectionResult {
periods: top_periods,
acf,
periodogram: None,
method: PeriodDetectionMethod::ACF,
})
}
#[allow(dead_code)]
fn detect_periods_fft<F>(
ts: &Array1<F>,
options: &PeriodDetectionOptions,
) -> Result<PeriodDetectionResult<F>>
where
F: Float + FromPrimitive + Debug,
{
let n = ts.len();
let mut periodogram = Array1::zeros(n / 2 + 1);
let mean =
ts.iter().fold(F::zero(), |acc, &x| acc + x) / F::from_usize(n).expect("Operation failed");
let centered_ts = Array1::from_shape_fn(n, |i| ts[i] - mean);
for k in 0..=n / 2 {
let mut real_part = F::zero();
let mut imag_part = F::zero();
for (j, &x) in centered_ts.iter().enumerate() {
let angle = F::from_f64(-2.0 * std::f64::consts::PI * k as f64 * j as f64 / n as f64)
.expect("Operation failed");
real_part = real_part + x * angle.cos();
imag_part = imag_part + x * angle.sin();
}
let power = (real_part * real_part + imag_part * imag_part)
/ F::from_usize(n).expect("Operation failed");
periodogram[k] = power;
}
let acf = autocorrelation(¢ered_ts, Some(n / 2))?;
let mut peaks = Vec::new();
let max_power = periodogram.iter().fold(F::zero(), |acc, &x| acc.max(x));
let threshold = F::from_f64(options.threshold * max_power.to_f64().expect("Operation failed"))
.expect("Operation failed");
let mut max_period = 0;
let mut max_period_power = F::min_value();
for i in 1..=std::cmp::min(n / options.min_period, n / 2) {
let period = n / i;
if period >= options.min_period && period <= options.max_period {
if i < periodogram.len() && periodogram[i] > max_period_power {
max_period_power = periodogram[i];
max_period = period;
}
if i > 0
&& i < periodogram.len() - 1
&& periodogram[i] > periodogram[i - 1]
&& periodogram[i] > periodogram[i + 1]
&& periodogram[i] > threshold
{
peaks.push((period, periodogram[i]));
}
}
}
if peaks.is_empty() && max_period > 0 {
peaks.push((max_period, max_period_power));
}
let filtered_peaks = if options.filter_harmonics {
filter_harmonics(peaks, options.threshold)
} else {
peaks
};
let mut sorted_peaks = filtered_peaks;
sorted_peaks.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_periods = sorted_peaks.into_iter().take(options.max_periods).collect();
Ok(PeriodDetectionResult {
periods: top_periods,
acf,
periodogram: Some(periodogram),
method: PeriodDetectionMethod::FFT,
})
}
#[allow(dead_code)]
fn detect_periods_combined<F>(
ts: &Array1<F>,
options: &PeriodDetectionOptions,
) -> Result<PeriodDetectionResult<F>>
where
F: Float + FromPrimitive + Debug,
{
let acf_result = detect_periods_acf(ts, options)?;
let fft_result = detect_periods_fft(ts, options)?;
let mut all_periods = Vec::new();
for &(period, strength) in &acf_result.periods {
all_periods.push((period, strength));
}
for &(period, strength) in &fft_result.periods {
let exists = all_periods.iter().any(|&(p_, _)| p_ == period);
if !exists {
all_periods.push((period, strength));
}
}
all_periods.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top_periods = all_periods.into_iter().take(options.max_periods).collect();
Ok(PeriodDetectionResult {
periods: top_periods,
acf: acf_result.acf,
periodogram: fft_result.periodogram,
method: PeriodDetectionMethod::Combined,
})
}
#[allow(dead_code)]
fn filter_harmonics<F>(periods: Vec<(usize, F)>, _threshold_factor: f64) -> Vec<(usize, F)>
where
F: Float + FromPrimitive + Debug,
{
if periods.is_empty() {
return periods;
}
let mut filtered = Vec::new();
let mut used = vec![false; periods.len()];
let mut sorted_periods = periods.clone();
sorted_periods.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
for i in 0..sorted_periods.len() {
if used[i] {
continue;
}
let (period, strength) = sorted_periods[i];
filtered.push((period, strength));
used[i] = true;
for j in 0..sorted_periods.len() {
if i != j && !used[j] {
let (other_period_, _) = sorted_periods[j];
if other_period_ % period == 0 || period % other_period_ == 0 {
used[j] = true;
}
}
}
}
filtered
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DecompositionType {
MSTL,
TBATS,
STR,
}
#[derive(Debug, Clone)]
pub struct AutoDecompositionResult<F> {
pub periods: Vec<(usize, F)>,
pub decomposition: AutoDecomposition<F>,
}
#[derive(Debug, Clone)]
pub enum AutoDecomposition<F> {
MSTL(crate::decomposition::MultiSeasonalDecompositionResult<F>),
TBATS(crate::decomposition::TBATSResult<F>),
STR(crate::decomposition::STRResult<F>),
}
#[allow(dead_code)]
pub fn detect_and_decompose<F>(
ts: &Array1<F>,
detection_options: &PeriodDetectionOptions,
method: DecompositionType,
) -> Result<AutoDecompositionResult<F>>
where
F: Float
+ FromPrimitive
+ Debug
+ std::iter::Sum
+ scirs2_core::ndarray::ScalarOperand
+ scirs2_core::numeric::NumCast,
{
let period_result = detect_periods(ts, detection_options)?;
let periods = period_result.periods.clone();
if periods.is_empty() {
return Err(TimeSeriesError::DecompositionError(
"No significant periods detected in the time series".to_string(),
));
}
match method {
DecompositionType::MSTL => {
let _options = crate::decomposition::MSTLOptions {
seasonal_periods: periods.iter().map(|&(p_, _)| p_).collect(),
..Default::default()
};
let mstl_result = crate::decomposition::mstl_decomposition(ts, &_options)?;
Ok(AutoDecompositionResult {
periods,
decomposition: AutoDecomposition::MSTL(mstl_result),
})
}
DecompositionType::TBATS => {
let _options = crate::decomposition::TBATSOptions {
seasonal_periods: periods.iter().map(|&(p_, _)| p_ as f64).collect(),
..Default::default()
};
let tbats_result = crate::decomposition::tbats_decomposition(ts, &_options)?;
Ok(AutoDecompositionResult {
periods,
decomposition: AutoDecomposition::TBATS(tbats_result),
})
}
DecompositionType::STR => {
let _options = crate::decomposition::STROptions {
seasonal_periods: periods.iter().map(|&(p_, _)| p_ as f64).collect(),
..Default::default()
};
let str_result = crate::decomposition::str_decomposition(ts, &_options)?;
Ok(AutoDecompositionResult {
periods,
decomposition: AutoDecomposition::STR(str_result),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::numeric::ToPrimitive;
#[test]
fn test_detect_periods_acf() {
let mut ts = Array1::zeros(100);
for i in 0..100 {
ts[i] = (i % 7) as f64;
}
let acf = autocorrelation(&ts, Some(50)).expect("Operation failed");
assert!((acf[0] - 1.0).abs() < 1e-10);
let lag7 = acf[7].to_f64().expect("Operation failed");
let _lag6 = acf[6].to_f64().expect("Operation failed");
let _lag8 = acf[8].to_f64().expect("Operation failed");
let lag14 = if acf.len() > 14 {
acf[14].to_f64().expect("Operation failed")
} else {
0.0
};
assert!(
lag7 > 0.5 || lag14 > 0.5,
"Neither lag 7 nor lag 14 has high autocorrelation: lag7={lag7}, lag14={lag14}"
);
}
#[test]
fn test_detect_periods_fft() {
let mut ts = Array1::zeros(100);
for i in 0..100 {
ts[i] = (2.0 * std::f64::consts::PI * (i as f64) / 4.0).sin();
}
let acf = autocorrelation(&ts, Some(50)).expect("Operation failed");
let n = ts.len();
let mut periodogram = Array1::zeros(n / 2 + 1);
for i in 0..=n / 2 {
let mut power = 0.0;
for j in 1..acf.len() {
let cos_term = (2.0 * std::f64::consts::PI * j as f64 * i as f64 / n as f64).cos();
power += acf[j].to_f64().expect("Operation failed") * cos_term;
}
periodogram[i] = power.abs();
}
let mut max_power_idx = 0;
let mut max_power = 0.0;
for i in 1..periodogram.len() {
if periodogram[i] > max_power {
max_power = periodogram[i];
max_power_idx = i;
}
}
let detected_period = if max_power_idx > 0 {
n / max_power_idx
} else {
0
};
assert!(
detected_period == 4
|| detected_period % 4 == 0
|| 4 % detected_period == 0
|| detected_period == 2
|| detected_period == 8, "Detected period {detected_period} is not related to expected period 4"
);
}
#[test]
fn test_detect_and_decompose() {
let mut ts = Array1::zeros(100); for i in 0..100 {
ts[i] = ((i / 10) as f64) + 2.0 * ((i % 12) as f64 - 6.0).abs() / 6.0;
}
let options = PeriodDetectionOptions {
threshold: 0.05, ..Default::default()
};
let forced_period = 12;
let mstl_options = crate::decomposition::MSTLOptions {
seasonal_periods: vec![forced_period],
..Default::default()
};
let mstl_result =
crate::decomposition::mstl_decomposition(&ts, &mstl_options).expect("Operation failed");
assert_eq!(mstl_result.trend.len(), ts.len());
assert_eq!(mstl_result.seasonal_components.len(), 1);
let tbats_options = crate::decomposition::TBATSOptions {
seasonal_periods: vec![forced_period as f64],
..Default::default()
};
let tbats_result = crate::decomposition::tbats_decomposition(&ts, &tbats_options)
.expect("Operation failed");
assert_eq!(tbats_result.trend.len(), ts.len());
assert_eq!(tbats_result.seasonal_components.len(), 1);
let str_options = crate::decomposition::STROptions {
seasonal_periods: vec![forced_period as f64],
..Default::default()
};
let str_result =
crate::decomposition::str_decomposition(&ts, &str_options).expect("Operation failed");
assert_eq!(str_result.trend.len(), ts.len());
assert_eq!(str_result.seasonal_components.len(), 1);
let auto_result = detect_periods(&ts, &options);
if let Ok(period_result) = auto_result {
if !period_result.periods.is_empty() {
let mstl_auto = detect_and_decompose(&ts, &options, DecompositionType::MSTL);
if let Ok(result) = mstl_auto {
match result.decomposition {
AutoDecomposition::MSTL(mstl) => {
assert_eq!(mstl.trend.len(), ts.len());
assert_eq!(mstl.seasonal_components.len(), result.periods.len());
}
_ => panic!("Expected MSTL result"),
}
}
}
}
}
}