use std::f64::consts::PI;
use scirs2_core::numeric::Complex64;
use crate::error::{FFTError, FFTResult};
#[derive(Debug, Clone)]
pub struct FilterBankConfig {
pub j_max: usize,
pub quality_factors: Vec<usize>,
pub signal_length: usize,
pub xi0: f64,
pub sigma: Option<f64>,
}
impl FilterBankConfig {
pub fn new(j_max: usize, quality_factors: Vec<usize>, signal_length: usize) -> Self {
Self {
j_max,
quality_factors,
signal_length,
xi0: PI,
sigma: None,
}
}
#[must_use]
pub fn with_xi0(mut self, xi0: f64) -> Self {
self.xi0 = xi0;
self
}
#[must_use]
pub fn with_sigma(mut self, sigma: f64) -> Self {
self.sigma = Some(sigma);
self
}
}
#[derive(Debug, Clone)]
pub struct MorletWavelet {
pub xi: f64,
pub sigma: f64,
pub j: usize,
pub q_index: usize,
pub linear_index: usize,
pub freq_response: Vec<Complex64>,
}
#[derive(Debug, Clone)]
pub struct FilterBank {
pub config: FilterBankConfig,
pub fft_size: usize,
pub wavelets: Vec<Vec<MorletWavelet>>,
pub phi: Vec<Complex64>,
}
impl FilterBank {
pub fn new(config: FilterBankConfig) -> FFTResult<Self> {
if config.j_max == 0 {
return Err(FFTError::ValueError("j_max must be at least 1".to_string()));
}
if config.quality_factors.is_empty() {
return Err(FFTError::ValueError(
"quality_factors must have at least one entry".to_string(),
));
}
for (i, &q) in config.quality_factors.iter().enumerate() {
if q == 0 {
return Err(FFTError::ValueError(format!(
"quality_factors[{i}] must be at least 1"
)));
}
}
if config.signal_length == 0 {
return Err(FFTError::ValueError(
"signal_length must be positive".to_string(),
));
}
let fft_size = config.signal_length.next_power_of_two();
let mut all_wavelets = Vec::new();
for (order, &q) in config.quality_factors.iter().enumerate() {
let sigma_base = compute_sigma_from_q(q, config.xi0, config.sigma);
let wavelets =
build_morlet_wavelets(config.j_max, q, config.xi0, sigma_base, fft_size, order)?;
all_wavelets.push(wavelets);
}
let sigma_phi = compute_sigma_from_q(config.quality_factors[0], config.xi0, config.sigma);
let phi = build_scaling_function(config.j_max, sigma_phi, fft_size)?;
Ok(Self {
config,
fft_size,
wavelets: all_wavelets,
phi,
})
}
pub fn num_first_order(&self) -> usize {
self.wavelets.first().map_or(0, |w| w.len())
}
pub fn num_second_order(&self) -> usize {
self.wavelets.get(1).map_or(0, |w| w.len())
}
pub fn total_wavelets(&self) -> usize {
self.wavelets.iter().map(|w| w.len()).sum()
}
}
fn compute_sigma_from_q(q: usize, xi0: f64, custom_sigma: Option<f64>) -> f64 {
if let Some(s) = custom_sigma {
return s;
}
let ln2_sqrt = (2.0_f64 * 2.0_f64.ln()).sqrt();
xi0 / (q as f64 * ln2_sqrt)
}
fn build_morlet_wavelets(
j_max: usize,
q: usize,
xi0: f64,
sigma_base: f64,
fft_size: usize,
_order: usize,
) -> FFTResult<Vec<MorletWavelet>> {
let total = j_max * q;
let mut wavelets = Vec::with_capacity(total);
let n = fft_size;
for idx in 0..total {
let j = idx / q;
let q_index = idx % q;
let scale = 2.0_f64.powf(idx as f64 / q as f64);
let xi = xi0 / scale;
let sigma = sigma_base * scale;
let mut freq_response = vec![Complex64::new(0.0, 0.0); n];
let n_f64 = n as f64;
for k in 0..n {
let omega = 2.0 * PI * k as f64 / n_f64;
let diff_pos = omega - xi;
let gauss_pos = (-0.5 * diff_pos * diff_pos * sigma * sigma).exp();
let gauss_correction = (-0.5 * xi * xi * sigma * sigma).exp();
let gauss_zero = (-0.5 * omega * omega * sigma * sigma).exp();
let value = gauss_pos - gauss_correction * gauss_zero;
freq_response[k] = Complex64::new(value, 0.0);
}
let energy: f64 = freq_response.iter().map(|c| c.norm_sqr()).sum();
if energy > 1e-15 {
let norm_factor = 1.0 / energy.sqrt();
for c in &mut freq_response {
*c = Complex64::new(c.re * norm_factor, c.im * norm_factor);
}
}
wavelets.push(MorletWavelet {
xi,
sigma,
j,
q_index,
linear_index: idx,
freq_response,
});
}
Ok(wavelets)
}
fn build_scaling_function(
j_max: usize,
sigma_base: f64,
fft_size: usize,
) -> FFTResult<Vec<Complex64>> {
let n = fft_size;
let n_f64 = n as f64;
let sigma_j = sigma_base * 2.0_f64.powi(j_max as i32);
let mut phi = vec![Complex64::new(0.0, 0.0); n];
for k in 0..n {
let omega = 2.0 * PI * k as f64 / n_f64;
let omega_wrapped = if omega > PI { omega - 2.0 * PI } else { omega };
let value = (-0.5 * omega_wrapped * omega_wrapped * sigma_j * sigma_j).exp();
phi[k] = Complex64::new(value, 0.0);
}
let energy: f64 = phi.iter().map(|c| c.norm_sqr()).sum();
if energy > 1e-15 {
let norm_factor = 1.0 / energy.sqrt();
for c in &mut phi {
*c = Complex64::new(c.re * norm_factor, c.im * norm_factor);
}
}
Ok(phi)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filter_bank_creation() {
let config = FilterBankConfig::new(4, vec![8, 1], 1024);
let fb = FilterBank::new(config).expect("filter bank creation should succeed");
assert_eq!(fb.num_first_order(), 32); assert_eq!(fb.num_second_order(), 4); assert_eq!(fb.fft_size, 1024);
assert_eq!(fb.phi.len(), 1024);
}
#[test]
fn test_wavelet_frequency_peaks() {
let config = FilterBankConfig::new(3, vec![4], 512);
let fb = FilterBank::new(config).expect("filter bank creation should succeed");
let first_order = &fb.wavelets[0];
for w in first_order {
let peak_bin = w
.freq_response
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| {
a.norm_sqr()
.partial_cmp(&b.norm_sqr())
.unwrap_or(std::cmp::Ordering::Equal)
})
.map(|(idx, _)| idx)
.expect("should find peak");
let peak_omega = 2.0 * PI * peak_bin as f64 / fb.fft_size as f64;
let rel_error = if w.xi > 1e-6 {
(peak_omega - w.xi).abs() / w.xi
} else {
peak_omega.abs()
};
assert!(
rel_error < 0.5,
"wavelet j={} q={}: peak_omega={:.4} vs xi={:.4}, rel_error={:.4}",
w.j,
w.q_index,
peak_omega,
w.xi,
rel_error
);
}
}
#[test]
fn test_dyadic_scaling() {
let config = FilterBankConfig::new(4, vec![1], 1024);
let fb = FilterBank::new(config).expect("filter bank creation should succeed");
let wavelets = &fb.wavelets[0];
for i in 0..wavelets.len() - 1 {
let ratio = wavelets[i].xi / wavelets[i + 1].xi;
assert!(
(ratio - 2.0).abs() < 0.1,
"octave {i} to {}: ratio={:.4}, expected ~2.0",
i + 1,
ratio
);
}
}
#[test]
fn test_filter_bank_invalid_config() {
let config = FilterBankConfig::new(0, vec![8], 1024);
assert!(FilterBank::new(config).is_err());
let config = FilterBankConfig::new(4, vec![], 1024);
assert!(FilterBank::new(config).is_err());
let config = FilterBankConfig::new(4, vec![0], 1024);
assert!(FilterBank::new(config).is_err());
let config = FilterBankConfig::new(4, vec![8], 0);
assert!(FilterBank::new(config).is_err());
}
#[test]
fn test_wavelet_l2_normalization() {
let config = FilterBankConfig::new(3, vec![4], 256);
let fb = FilterBank::new(config).expect("filter bank creation should succeed");
for w in &fb.wavelets[0] {
let energy: f64 = w.freq_response.iter().map(|c| c.norm_sqr()).sum();
assert!(
(energy - 1.0).abs() < 1e-10,
"wavelet j={} q={} has energy {:.6}, expected 1.0",
w.j,
w.q_index,
energy
);
}
}
#[test]
fn test_scaling_function_is_lowpass() {
let config = FilterBankConfig::new(3, vec![4], 512);
let fb = FilterBank::new(config).expect("filter bank creation should succeed");
let dc_mag = fb.phi[0].norm_sqr();
let nyquist_bin = fb.fft_size / 2;
let nyquist_mag = fb.phi[nyquist_bin].norm_sqr();
assert!(
dc_mag > nyquist_mag,
"scaling function should peak at DC: dc={:.6} vs nyquist={:.6}",
dc_mag,
nyquist_mag
);
}
}