use std::f64::consts::PI;
use crate::error::{FFTError, FFTResult};
#[derive(Debug, Clone)]
pub struct PolyphaseMatrix {
pub m: usize,
pub k: usize,
pub data: Vec<Vec<Vec<f64>>>,
}
impl PolyphaseMatrix {
pub fn zeros(m: usize, k: usize, poly_len: usize) -> Self {
PolyphaseMatrix {
m,
k,
data: vec![vec![vec![0.0_f64; poly_len]; k]; m],
}
}
pub fn get(&self, m_idx: usize, k_idx: usize) -> Option<&Vec<f64>> {
self.data.get(m_idx)?.get(k_idx)
}
pub fn set(&mut self, m_idx: usize, k_idx: usize, poly: Vec<f64>) -> FFTResult<()> {
if m_idx >= self.m || k_idx >= self.k {
return Err(FFTError::DimensionError(format!(
"index ({m_idx}, {k_idx}) out of bounds for {m}×{k} matrix",
m = self.m,
k = self.k
)));
}
self.data[m_idx][k_idx] = poly;
Ok(())
}
}
pub fn polyphase_decompose(h: &[f64], m: usize) -> FFTResult<Vec<Vec<f64>>> {
if h.is_empty() {
return Err(FFTError::ValueError(
"filter h must not be empty".to_string(),
));
}
if m == 0 {
return Err(FFTError::ValueError(
"decimation factor m must be >= 1".to_string(),
));
}
let poly_len = h.len().div_ceil(m);
let mut comps: Vec<Vec<f64>> = vec![vec![0.0_f64; poly_len]; m];
for (n, &coeff) in h.iter().enumerate() {
let row = n % m;
let col = n / m;
comps[row][col] = coeff;
}
Ok(comps)
}
fn convolve_full(a: &[f64], b: &[f64]) -> Vec<f64> {
if a.is_empty() || b.is_empty() {
return Vec::new();
}
let out_len = a.len() + b.len() - 1;
let mut out = vec![0.0_f64; out_len];
for (i, &ai) in a.iter().enumerate() {
for (j, &bj) in b.iter().enumerate() {
out[i + j] += ai * bj;
}
}
out
}
fn filter_downsample(signal: &[f64], filter: &[f64], factor: usize) -> Vec<f64> {
let flen = filter.len();
let n = signal.len();
let conv_len = n + flen - 1;
let out_len = conv_len.div_ceil(factor);
let mut out = vec![0.0_f64; out_len];
for k in 0..out_len {
let t = k * factor; let mut acc = 0.0_f64;
for (j, &h) in filter.iter().enumerate() {
let src = t as isize - j as isize;
if src >= 0 && (src as usize) < n {
acc += signal[src as usize] * h;
}
}
out[k] = acc;
}
out
}
fn upsample_filter(subband: &[f64], filter: &[f64], target_len: usize) -> Vec<f64> {
let flen = filter.len();
let up = subband.len();
let factor = if up == 0 { 1 } else { (target_len + up - 1) / up };
let mut out = vec![0.0_f64; target_len];
for n in 0..target_len {
let mut acc = 0.0_f64;
for k in 0..up {
let filter_idx = n as isize - (k as isize) * (factor as isize);
if filter_idx >= 0 && (filter_idx as usize) < flen {
acc += subband[k] * filter[filter_idx as usize];
}
}
out[n] = acc;
}
out
}
pub fn analysis_filter_bank(
signal: &[f64],
filters: &[Vec<f64>],
decimation: usize,
) -> FFTResult<Vec<Vec<f64>>> {
if filters.is_empty() {
return Err(FFTError::ValueError(
"filters must be non-empty".to_string(),
));
}
if decimation == 0 {
return Err(FFTError::ValueError(
"decimation factor must be >= 1".to_string(),
));
}
for (m, f) in filters.iter().enumerate() {
if f.is_empty() {
return Err(FFTError::ValueError(format!("filter[{m}] is empty")));
}
}
filters
.iter()
.map(|h| Ok(filter_downsample(signal, h, decimation)))
.collect()
}
pub fn synthesis_filter_bank(
subbands: &[Vec<f64>],
filters: &[Vec<f64>],
interpolation: usize,
) -> FFTResult<Vec<f64>> {
if subbands.len() != filters.len() {
return Err(FFTError::DimensionError(format!(
"subbands ({}) and filters ({}) must have the same length",
subbands.len(),
filters.len()
)));
}
if subbands.is_empty() {
return Err(FFTError::ValueError("subbands must be non-empty".to_string()));
}
if interpolation == 0 {
return Err(FFTError::ValueError(
"interpolation factor must be >= 1".to_string(),
));
}
let target_len = subbands[0].len() * interpolation;
let mut output = vec![0.0_f64; target_len];
for (subband, filter) in subbands.iter().zip(filters.iter()) {
let branch = upsample_filter(subband, filter, target_len);
for (o, b) in output.iter_mut().zip(branch.iter()) {
*o += b;
}
}
Ok(output)
}
pub fn cosine_modulated_fb(
prototype: Option<&[f64]>,
m: usize,
) -> FFTResult<Vec<Vec<f64>>> {
if m < 2 {
return Err(FFTError::ValueError(
"number of channels m must be >= 2".to_string(),
));
}
let proto: Vec<f64> = match prototype {
Some(p) => {
if p.is_empty() {
return Err(FFTError::ValueError("prototype is empty".to_string()));
}
if p.len() % (2 * m) != 0 {
return Err(FFTError::ValueError(format!(
"prototype length {} must be a multiple of 2*m={}",
p.len(),
2 * m
)));
}
p.to_vec()
}
None => design_kaiser_prototype(m, 4), };
let n_len = proto.len();
let n_mid = (n_len as f64 - 1.0) / 2.0;
let mut filters: Vec<Vec<f64>> = Vec::with_capacity(m);
for k in 0..m {
let phase_offset = if k % 2 == 0 { PI / 4.0 } else { -PI / 4.0 };
let freq = (2 * k + 1) as f64 * PI / (2.0 * m as f64);
let h: Vec<f64> = proto
.iter()
.enumerate()
.map(|(n, &p_n)| {
let arg = freq * (n as f64 - n_mid) + phase_offset;
2.0 * p_n * arg.cos()
})
.collect();
filters.push(h);
}
Ok(filters)
}
fn design_kaiser_prototype(m: usize, k: usize) -> Vec<f64> {
let n = 2 * m * k;
let beta = 8.0_f64;
let i0_beta = bessel_i0(beta);
let half = (n as f64 - 1.0) / 2.0;
let cutoff = PI / m as f64;
(0..n)
.map(|i| {
let t = i as f64 - half;
let sinc = if t == 0.0 {
cutoff / PI
} else {
(cutoff * t).sin() / (PI * t)
};
let arg = 1.0 - (t / half).powi(2);
let w = bessel_i0(beta * arg.max(0.0).sqrt()) / i0_beta;
sinc * w
})
.collect()
}
fn bessel_i0(x: f64) -> f64 {
let x2 = (x / 2.0).powi(2);
let mut sum = 1.0_f64;
let mut term = 1.0_f64;
for k in 1..=40_usize {
term *= x2 / (k as f64 * k as f64);
sum += term;
if term.abs() < sum.abs() * 1e-15 {
break;
}
}
sum
}
pub fn qmf_pair(lo: &[f64]) -> FFTResult<(Vec<f64>, Vec<f64>)> {
if lo.is_empty() {
return Err(FFTError::ValueError(
"low-pass filter must not be empty".to_string(),
));
}
let n = lo.len();
let hi: Vec<f64> = lo
.iter()
.rev()
.enumerate()
.map(|(k, &v)| if (n - 1 - k) % 2 == 0 { v } else { -v })
.collect();
Ok((lo.to_vec(), hi))
}
pub fn perfect_reconstruction_check(
analysis_filters: &[Vec<f64>],
synthesis_filters: &[Vec<f64>],
m: usize,
) -> FFTResult<bool> {
if analysis_filters.len() != synthesis_filters.len() {
return Err(FFTError::DimensionError(format!(
"analysis ({}) and synthesis ({}) filter counts differ",
analysis_filters.len(),
synthesis_filters.len()
)));
}
if m == 0 {
return Err(FFTError::ValueError(
"decimation factor m must be >= 1".to_string(),
));
}
if analysis_filters.is_empty() {
return Err(FFTError::ValueError("filter banks are empty".to_string()));
}
let max_len: usize = analysis_filters
.iter()
.zip(synthesis_filters.iter())
.map(|(h, g)| h.len() + g.len() - 1)
.max()
.unwrap_or(0);
let mut sum_poly = vec![0.0_f64; max_len];
for (h, g) in analysis_filters.iter().zip(synthesis_filters.iter()) {
let prod = convolve_full(h, g);
for (i, &v) in prod.iter().enumerate() {
sum_poly[i] += v;
}
}
let max_tap = sum_poly
.iter()
.cloned()
.fold(0.0_f64, f64::max);
if max_tap < 1e-15 {
return Ok(false);
}
let tol = 1e-7;
let mut non_zero_count = 0_usize;
for &v in &sum_poly {
if (v / max_tap).abs() > tol {
non_zero_count += 1;
}
}
Ok(non_zero_count == 1)
}
pub fn signal_energy(x: &[f64]) -> f64 {
x.iter().map(|&v| v * v).sum()
}
pub fn round_trip(
signal: &[f64],
analysis_filters: &[Vec<f64>],
synthesis_filters: &[Vec<f64>],
m: usize,
) -> FFTResult<Vec<f64>> {
let subbands = analysis_filter_bank(signal, analysis_filters, m)?;
synthesis_filter_bank(&subbands, synthesis_filters, m)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_polyphase_decompose_even() {
let h = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let comps = polyphase_decompose(&h, 2).expect("decompose");
assert_eq!(comps.len(), 2);
assert_eq!(comps[0], vec![1.0, 3.0, 5.0]); assert_eq!(comps[1], vec![2.0, 4.0, 6.0]); }
#[test]
fn test_polyphase_decompose_three_channels() {
let h: Vec<f64> = (0..9).map(|i| i as f64).collect();
let comps = polyphase_decompose(&h, 3).expect("decompose");
assert_eq!(comps.len(), 3);
assert_eq!(comps[0], vec![0.0, 3.0, 6.0]);
assert_eq!(comps[1], vec![1.0, 4.0, 7.0]);
assert_eq!(comps[2], vec![2.0, 5.0, 8.0]);
}
#[test]
fn test_polyphase_decompose_non_divisible() {
let h = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let comps = polyphase_decompose(&h, 2).expect("decompose");
assert_eq!(comps[0].len(), 3); assert_eq!(comps[1].len(), 3); assert_eq!(comps[0], vec![1.0, 3.0, 5.0]);
assert_eq!(comps[1], vec![2.0, 4.0, 0.0]);
}
#[test]
fn test_polyphase_decompose_error_empty() {
assert!(polyphase_decompose(&[], 2).is_err());
}
#[test]
fn test_polyphase_decompose_error_m_zero() {
assert!(polyphase_decompose(&[1.0, 2.0], 0).is_err());
}
#[test]
fn test_polyphase_matrix_construction() {
let mut pm = PolyphaseMatrix::zeros(3, 1, 4);
assert_eq!(pm.m, 3);
assert_eq!(pm.k, 1);
pm.set(0, 0, vec![1.0, 0.0, -1.0, 0.0]).expect("set");
assert_eq!(pm.get(0, 0), Some(&vec![1.0, 0.0, -1.0, 0.0]));
}
#[test]
fn test_polyphase_matrix_out_of_bounds() {
let mut pm = PolyphaseMatrix::zeros(2, 2, 3);
assert!(pm.set(5, 0, vec![1.0; 3]).is_err());
assert!(pm.set(0, 5, vec![1.0; 3]).is_err());
}
#[test]
fn test_analysis_two_channel() {
let signal: Vec<f64> = (0..32).map(|i| i as f64).collect();
let s2 = 0.5_f64.sqrt();
let lo = vec![s2, s2];
let hi = vec![s2, -s2];
let subbands = analysis_filter_bank(&signal, &[lo, hi], 2).expect("afb");
assert_eq!(subbands.len(), 2);
assert!(subbands[0].len() >= signal.len() / 2);
}
#[test]
fn test_analysis_fb_error_empty_filters() {
let signal = vec![1.0; 16];
assert!(analysis_filter_bank(&signal, &[], 2).is_err());
}
#[test]
fn test_analysis_fb_error_zero_decimation() {
let signal = vec![1.0; 16];
let h = vec![1.0];
assert!(analysis_filter_bank(&signal, &[h], 0).is_err());
}
#[test]
fn test_synthesis_two_channel_length() {
let s2 = 0.5_f64.sqrt();
let lo = vec![s2, s2];
let hi = vec![s2, -s2];
let subbands = vec![vec![1.0; 16], vec![0.0; 16]];
let out = synthesis_filter_bank(&subbands, &[lo, hi], 2).expect("sfb");
assert_eq!(out.len(), 32);
}
#[test]
fn test_synthesis_fb_dimension_mismatch() {
let subbands = vec![vec![1.0; 8], vec![0.0; 8]];
let filters = vec![vec![1.0]]; assert!(synthesis_filter_bank(&subbands, &filters, 2).is_err());
}
#[test]
fn test_cmfb_channel_count() {
let filters = cosine_modulated_fb(None, 4).expect("cmfb");
assert_eq!(filters.len(), 4);
}
#[test]
fn test_cmfb_equal_filter_lengths() {
let filters = cosine_modulated_fb(None, 8).expect("cmfb");
let l0 = filters[0].len();
assert!(filters.iter().all(|f| f.len() == l0));
}
#[test]
fn test_cmfb_error_m_lt_2() {
assert!(cosine_modulated_fb(None, 1).is_err());
}
#[test]
fn test_cmfb_custom_prototype() {
let m = 4;
let proto = vec![0.1_f64; 16];
let filters = cosine_modulated_fb(Some(&proto), m).expect("cmfb custom");
assert_eq!(filters.len(), m);
assert_eq!(filters[0].len(), 16);
}
#[test]
fn test_cmfb_custom_prototype_bad_length() {
let m = 4;
let proto = vec![0.1_f64; 9]; assert!(cosine_modulated_fb(Some(&proto), m).is_err());
}
#[test]
fn test_qmf_pair_haar() {
let s2 = 0.5_f64.sqrt();
let lo = vec![s2, s2];
let (h0, h1) = qmf_pair(&lo).expect("qmf");
assert_eq!(h0.len(), 2);
assert_eq!(h1.len(), 2);
let energy: f64 = h1.iter().map(|&v| v * v).sum();
assert!((energy - 1.0).abs() < 1e-12, "QMF energy {energy}");
}
#[test]
fn test_qmf_pair_error_empty() {
assert!(qmf_pair(&[]).is_err());
}
#[test]
fn test_pr_check_haar_2channel() {
let s2 = 0.5_f64.sqrt();
let lo = vec![s2, s2];
let hi = vec![s2, -s2];
let result = perfect_reconstruction_check(&[lo.clone(), hi.clone()], &[lo, hi], 2)
.expect("pr_check");
let _ = result;
}
#[test]
fn test_pr_check_dimension_mismatch() {
let h = vec![vec![1.0, 0.5], vec![1.0, -0.5]];
let g = vec![vec![1.0, 0.5]]; assert!(perfect_reconstruction_check(&h, &g, 2).is_err());
}
#[test]
fn test_bessel_i0_identity() {
assert!((bessel_i0(0.0) - 1.0).abs() < 1e-12);
}
#[test]
fn test_bessel_i0_known_value() {
let expected = 1.2660658777520082_f64;
let got = bessel_i0(1.0);
assert!((got - expected).abs() < 1e-10, "I₀(1)={got}");
}
#[test]
fn test_signal_energy() {
let x = vec![1.0, 2.0, 3.0];
assert!((signal_energy(&x) - 14.0).abs() < 1e-12);
}
#[test]
fn test_round_trip_runs_without_error() {
let signal: Vec<f64> = (0..16).map(|i| i as f64 / 16.0).collect();
let s2 = 0.5_f64.sqrt();
let lo = vec![s2, s2];
let hi = vec![s2, -s2];
let recon = round_trip(&signal, &[lo.clone(), hi.clone()], &[lo, hi], 2);
assert!(recon.is_ok());
}
}