use crate::error::{Result, TimeSeriesError};
use scirs2_fft::{
fft, ifft, rfft, rfftfreq,
wavelet_packets::{wp_reconstruct, wpd, Wavelet},
};
use std::f64::consts::PI;
fn check_bandpass_params(n: usize, low: f64, high: f64, fs: f64) -> Result<()> {
if n < 4 {
return Err(TimeSeriesError::InsufficientData {
message: "bandpass filter requires at least 4 samples".to_string(),
required: 4,
actual: n,
});
}
if fs <= 0.0 {
return Err(TimeSeriesError::InvalidInput(
"sampling frequency must be positive".to_string(),
));
}
if low <= 0.0 || high <= 0.0 {
return Err(TimeSeriesError::InvalidInput(
"cutoff frequencies must be positive".to_string(),
));
}
if low >= high {
return Err(TimeSeriesError::InvalidInput(
"low_freq must be strictly less than high_freq".to_string(),
));
}
if high > fs / 2.0 {
return Err(TimeSeriesError::InvalidInput(format!(
"high_freq ({high}) must not exceed Nyquist ({nyq})",
nyq = fs / 2.0
)));
}
Ok(())
}
use scirs2_core::numeric::Complex64;
pub fn bandpass_filter_series(
ts: &[f64],
low_freq: f64,
high_freq: f64,
fs: f64,
order: usize,
) -> Result<Vec<f64>> {
let n = ts.len();
check_bandpass_params(n, low_freq, high_freq, fs)?;
let spectrum = fft(ts, None).map_err(|e| TimeSeriesError::ComputationError(e.to_string()))?;
let order_actual = order.max(1);
let freq_resolution = fs / n as f64;
let mut masked: Vec<Complex64> = spectrum
.iter()
.enumerate()
.map(|(k, c)| {
let freq = if k <= n / 2 {
k as f64 * freq_resolution
} else {
(k as i64 - n as i64) as f64 * freq_resolution
};
let freq_abs = freq.abs();
let gain = butterworth_bandpass_gain(freq_abs, low_freq, high_freq, order_actual);
Complex64::new(c.re * gain, c.im * gain)
})
.collect();
let _ = order_actual;
let recovered = ifft(&masked, None)
.map_err(|e| TimeSeriesError::ComputationError(e.to_string()))?;
let out: Vec<f64> = recovered.iter().map(|c| c.re).collect();
Ok(out)
}
fn butterworth_bandpass_gain(freq: f64, low: f64, high: f64, order: usize) -> f64 {
if freq < f64::EPSILON {
return 0.0;
}
let g_hp = 1.0 / (1.0 + (low / freq).powi(2 * order as i32)).sqrt();
let g_lp = 1.0 / (1.0 + (freq / high).powi(2 * order as i32)).sqrt();
g_hp * g_lp
}
pub fn hp_filter(ts: &[f64], lambda: f64) -> Result<(Vec<f64>, Vec<f64>)> {
let n = ts.len();
if n < 4 {
return Err(TimeSeriesError::InsufficientData {
message: "hp_filter requires at least 4 samples".to_string(),
required: 4,
actual: n,
});
}
if lambda < 0.0 {
return Err(TimeSeriesError::InvalidInput(
"lambda must be non-negative".to_string(),
));
}
let lam = lambda;
let mut diag = vec![0.0_f64; n];
let mut off1 = vec![0.0_f64; n - 1];
let mut off2 = vec![0.0_f64; n - 2];
for i in 0..n {
diag[i] = 1.0;
let dd = match i {
0 => 1.0,
1 => 5.0,
_ if i == n - 2 => 5.0,
_ if i == n - 1 => 1.0,
_ => 6.0,
};
diag[i] += lam * dd;
}
for i in 0..n - 1 {
let dd = match i {
0 => -2.0,
_ if i == n - 2 => -2.0,
_ => -4.0,
};
off1[i] = lam * dd;
}
for i in 0..n - 2 {
off2[i] = lam * 1.0;
}
let trend = band_ldl_solve(&diag, &off1, &off2, ts)?;
let cycle: Vec<f64> = ts.iter().zip(trend.iter()).map(|(y, t)| y - t).collect();
Ok((trend, cycle))
}
fn band_ldl_solve(d: &[f64], e1: &[f64], e2: &[f64], rhs: &[f64]) -> Result<Vec<f64>> {
let n = d.len();
let mut dd = d.to_vec();
let mut l1 = vec![0.0_f64; n]; let mut l2 = vec![0.0_f64; n];
let mut ee1 = vec![0.0_f64; n];
for (i, val) in e1.iter().enumerate() {
ee1[i] = *val;
}
for i in 0..n {
if dd[i].abs() < f64::EPSILON {
return Err(TimeSeriesError::NumericalInstability(
"near-zero pivot in band_ldl_solve".to_string(),
));
}
if i + 1 < n {
l1[i + 1] = ee1[i] / dd[i];
dd[i + 1] -= l1[i + 1] * ee1[i];
if i + 2 < n {
ee1[i + 1] -= l1[i + 1] * e2[i];
}
}
if i + 2 < n {
l2[i + 2] = e2[i] / dd[i];
dd[i + 2] -= l2[i + 2] * e2[i];
}
}
let mut y = rhs.to_vec();
for i in 1..n {
y[i] -= l1[i] * y[i - 1];
if i >= 2 {
y[i] -= l2[i] * y[i - 2];
}
}
for i in 0..n {
if dd[i].abs() < f64::EPSILON {
return Err(TimeSeriesError::NumericalInstability(
"near-zero diagonal in D during band_ldl_solve".to_string(),
));
}
y[i] /= dd[i];
}
let mut x = y;
for i in (0..n - 1).rev() {
x[i] -= l1[i + 1] * x[i + 1];
if i + 2 < n {
x[i] -= l2[i + 2] * x[i + 2];
}
}
Ok(x)
}
pub fn bandpass_filter_bk(ts: &[f64], low: f64, high: f64, k: usize) -> Result<Vec<f64>> {
let n = ts.len();
if n < 2 * k + 1 {
return Err(TimeSeriesError::InsufficientData {
message: format!("BK filter with K={k} requires at least {} samples", 2 * k + 1),
required: 2 * k + 1,
actual: n,
});
}
if low <= 1.0 || high <= low {
return Err(TimeSeriesError::InvalidInput(
"BK filter: need 1 < low < high".to_string(),
));
}
let omega_l = 2.0 * PI / high; let omega_h = 2.0 * PI / low;
let mut weights = vec![0.0_f64; 2 * k + 1]; let b0 = (omega_h - omega_l) / PI;
weights[k] = b0;
for j in 1..=k {
let bj = (omega_h * j as f64).sin() / (PI * j as f64)
- (omega_l * j as f64).sin() / (PI * j as f64);
weights[k + j] = bj;
weights[k - j] = bj; }
let sum: f64 = weights.iter().sum();
let n_w = weights.len() as f64;
for w in weights.iter_mut() {
*w -= sum / n_w;
}
let out_len = n - 2 * k;
let mut out = vec![0.0_f64; out_len];
for t in 0..out_len {
let center = t + k; let mut val = 0.0;
for j in 0..=2 * k {
let lag = j as i64 - k as i64; let idx = center as i64 + lag;
if idx >= 0 && (idx as usize) < n {
val += weights[j] * ts[idx as usize];
}
}
out[t] = val;
}
Ok(out)
}
pub fn christiano_fitzgerald(
ts: &[f64],
low: f64,
high: f64,
) -> Result<(Vec<f64>, Vec<f64>)> {
let n = ts.len();
if n < 4 {
return Err(TimeSeriesError::InsufficientData {
message: "christiano_fitzgerald requires at least 4 samples".to_string(),
required: 4,
actual: n,
});
}
if low <= 1.0 || high <= low {
return Err(TimeSeriesError::InvalidInput(
"CF filter: need 1 < low < high".to_string(),
));
}
let omega_l = 2.0 * PI / high;
let omega_h = 2.0 * PI / low;
let b = |j: i64| -> f64 {
if j == 0 {
(omega_h - omega_l) / PI
} else {
((omega_h * j as f64).sin() - (omega_l * j as f64).sin()) / (PI * j as f64)
}
};
let mut cycle = vec![0.0_f64; n];
for t in 0..n {
let lag_max = t as i64;
let lead_max = (n - 1 - t) as i64;
let sum_b: f64 = (-lag_max..=lead_max).map(|j| b(j)).sum();
let n_terms = (lag_max + lead_max + 1) as f64;
let correction = if n_terms > 0.0 { sum_b / n_terms } else { 0.0 };
let mut val = 0.0;
for j in -lag_max..=lead_max {
let idx = t as i64 + j;
if idx >= 0 && (idx as usize) < n {
val += (b(j) - correction) * ts[idx as usize];
}
}
cycle[t] = val;
}
let trend: Vec<f64> = ts.iter().zip(cycle.iter()).map(|(y, c)| y - c).collect();
Ok((cycle, trend))
}
pub type WaveletType = Wavelet;
pub fn wavelet_decompose_ts(
ts: &[f64],
wavelet: WaveletType,
n_levels: usize,
) -> Result<Vec<Vec<f64>>> {
let n = ts.len();
if n < 4 {
return Err(TimeSeriesError::InsufficientData {
message: "wavelet_decompose_ts requires at least 4 samples".to_string(),
required: 4,
actual: n,
});
}
if n_levels == 0 {
return Err(TimeSeriesError::InvalidInput(
"n_levels must be >= 1".to_string(),
));
}
let tree = wpd(ts, wavelet, n_levels)
.map_err(|e| TimeSeriesError::ComputationError(e.to_string()))?;
let mut components: Vec<Vec<f64>> = Vec::with_capacity(n_levels + 1);
for lev in 1..=n_levels {
let node_opt = tree.get(lev, 1);
match node_opt {
Some(node) => {
let basis = vec![node.clone()];
let reconstructed = wp_reconstruct(&tree, &basis)
.map_err(|e| TimeSeriesError::ComputationError(e.to_string()))?;
components.push(reconstructed);
}
None => {
components.push(vec![0.0; n]);
}
}
}
let approx_opt = tree.get(n_levels, 0);
match approx_opt {
Some(node) => {
let basis = vec![node.clone()];
let reconstructed = wp_reconstruct(&tree, &basis)
.map_err(|e| TimeSeriesError::ComputationError(e.to_string()))?;
components.push(reconstructed);
}
None => {
components.push(vec![0.0; n]);
}
}
Ok(components)
}
pub fn reconstruct_wavelet(components: &[Vec<f64>], _wavelet: WaveletType) -> Result<Vec<f64>> {
if components.is_empty() {
return Err(TimeSeriesError::InvalidInput(
"components must be non-empty".to_string(),
));
}
let n = components[0].len();
for (k, comp) in components.iter().enumerate() {
if comp.len() != n {
return Err(TimeSeriesError::DimensionMismatch {
expected: n,
actual: comp.len(),
});
}
let _ = k;
}
let mut out = vec![0.0_f64; n];
for comp in components {
for (o, c) in out.iter_mut().zip(comp.iter()) {
*o += c;
}
}
Ok(out)
}
pub use scirs2_fft::wavelet_packets::{WaveletPacketNode, WaveletPacketTree};
pub fn rfft_freq_axis(n: usize, fs: f64) -> Result<Vec<f64>> {
if fs <= 0.0 {
return Err(TimeSeriesError::InvalidInput(
"sampling frequency must be positive".to_string(),
));
}
rfftfreq(n, 1.0 / fs).map_err(|e| TimeSeriesError::ComputationError(e.to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
fn make_sine(n: usize, freq: f64, fs: f64) -> Vec<f64> {
(0..n)
.map(|i| (2.0 * PI * freq * i as f64 / fs).sin())
.collect()
}
#[test]
fn test_hp_filter_trend_smooth() {
let n = 100;
let ts: Vec<f64> = (0..n)
.map(|i| i as f64 / 10.0 + (2.0 * PI * 0.1 * i as f64).sin())
.collect();
let (trend, cycle) = hp_filter(&ts, 1600.0).expect("hp_filter failed");
assert_eq!(trend.len(), n);
assert_eq!(cycle.len(), n);
for i in 0..n {
assert!(
(trend[i] + cycle[i] - ts[i]).abs() < 1e-8,
"trend+cycle != original at i={i}"
);
}
let cycle_mean: f64 = cycle.iter().sum::<f64>() / n as f64;
assert!(cycle_mean.abs() < 0.5, "cycle mean too large: {cycle_mean}");
}
#[test]
fn test_hp_filter_errors() {
let ts = vec![1.0, 2.0];
assert!(hp_filter(&ts, 1600.0).is_err());
let ts4 = vec![1.0, 2.0, 3.0, 4.0];
assert!(hp_filter(&ts4, -1.0).is_err());
}
#[test]
fn test_bk_filter_output_length() {
let n = 100;
let k = 12;
let ts: Vec<f64> = (0..n).map(|i| i as f64).collect();
let out = bandpass_filter_bk(&ts, 6.0, 32.0, k).expect("BK filter failed");
assert_eq!(out.len(), n - 2 * k);
}
#[test]
fn test_bk_filter_removes_trend() {
let n = 200;
let k = 12;
let ts: Vec<f64> = (0..n).map(|i| i as f64).collect();
let out = bandpass_filter_bk(&ts, 6.0, 32.0, k).expect("BK filter failed");
let max_val = out.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
assert!(max_val < 1e-8, "BK should remove linear trend, max_val={max_val}");
}
#[test]
fn test_bk_filter_insufficient_data() {
let ts = vec![1.0; 10];
assert!(bandpass_filter_bk(&ts, 6.0, 32.0, 12).is_err());
}
#[test]
fn test_cf_filter_output_length() {
let n = 100;
let ts: Vec<f64> = (0..n).map(|i| i as f64 % 10.0).collect();
let (cycle, trend) = christiano_fitzgerald(&ts, 6.0, 32.0)
.expect("CF filter failed");
assert_eq!(cycle.len(), n);
assert_eq!(trend.len(), n);
}
#[test]
fn test_cf_filter_reconstruction() {
let n = 50;
let ts: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * 0.05 * i as f64).sin() + (i as f64) * 0.01)
.collect();
let (cycle, trend) = christiano_fitzgerald(&ts, 6.0, 32.0)
.expect("CF filter failed");
for i in 0..n {
assert!(
(cycle[i] + trend[i] - ts[i]).abs() < 1e-10,
"cycle+trend != ts at i={i}"
);
}
}
#[test]
fn test_bandpass_filter_removes_out_of_band() {
let fs = 100.0;
let n = 512;
let ts: Vec<f64> = (0..n)
.map(|i| {
(2.0 * PI * 5.0 * i as f64 / fs).sin()
+ (2.0 * PI * 40.0 * i as f64 / fs).sin()
})
.collect();
let filtered =
bandpass_filter_series(&ts, 3.0, 10.0, fs, 4).expect("bandpass_filter failed");
assert_eq!(filtered.len(), n);
let in_energy: f64 = ts.iter().map(|v| v * v).sum();
let out_energy: f64 = filtered.iter().map(|v| v * v).sum();
assert!(
out_energy < in_energy * 0.7,
"output energy {out_energy:.3} should be less than input energy {in_energy:.3}"
);
}
#[test]
fn test_wavelet_mra_component_count() {
let n = 64;
let ts: Vec<f64> = (0..n).map(|i| (i as f64 * 0.1).sin()).collect();
let levels = 3;
let components = wavelet_decompose_ts(&ts, WaveletType::Db4, levels)
.expect("wavelet decompose failed");
assert_eq!(components.len(), levels + 1);
for comp in &components {
assert_eq!(comp.len(), n, "component length mismatch");
}
}
#[test]
fn test_wavelet_reconstruct_matches_sum() {
let n = 64;
let ts: Vec<f64> = (0..n).map(|i| (i as f64 * 0.1).sin()).collect();
let components =
wavelet_decompose_ts(&ts, WaveletType::Haar, 2).expect("decompose failed");
let recon =
reconstruct_wavelet(&components, WaveletType::Haar).expect("reconstruct failed");
assert_eq!(recon.len(), n);
let sum: Vec<f64> = (0..n)
.map(|i| components.iter().map(|c| c[i]).sum::<f64>())
.collect();
for i in 0..n {
assert!(
(recon[i] - sum[i]).abs() < 1e-10,
"reconstruct != sum at i={i}: {} vs {}",
recon[i],
sum[i]
);
}
}
}