use std::collections::HashMap;
use std::f64::consts::LN_2;
use crate::error::{FFTError, FFTResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Wavelet {
Haar,
Db2,
Db3,
Db4,
Db5,
Sym2,
Sym4,
Coif1,
Bior22,
}
#[derive(Debug, Clone)]
pub struct WaveletFilters {
pub lo_d: Vec<f64>,
pub hi_d: Vec<f64>,
pub lo_r: Vec<f64>,
pub hi_r: Vec<f64>,
}
impl WaveletFilters {
pub fn for_wavelet(w: Wavelet) -> Self {
match w {
Wavelet::Haar => {
let s = 1.0_f64 / 2.0_f64.sqrt();
let lo = vec![s, s];
let hi = vec![s, -s];
WaveletFilters {
lo_d: lo.clone(),
hi_d: hi.clone(),
lo_r: lo,
hi_r: hi,
}
}
Wavelet::Db2 => {
let s3 = 3.0_f64.sqrt();
let norm = 4.0 * 2.0_f64.sqrt(); let lo = vec![
(1.0 + s3) / norm,
(3.0 + s3) / norm,
(3.0 - s3) / norm,
(1.0 - s3) / norm,
];
let hi = qmf_hi(&lo);
let lo_r = lo.iter().rev().cloned().collect();
let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
WaveletFilters {
lo_d: lo,
hi_d: hi,
lo_r,
hi_r,
}
}
Wavelet::Db3 => {
let lo = vec![
0.035226291882100656,
-0.08544127388202666,
-0.13501102001039084,
0.4598775021193313,
0.8068915093133388,
0.3326705529509569,
];
let hi = qmf_hi(&lo);
let lo_r = lo.iter().rev().cloned().collect();
let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
WaveletFilters {
lo_d: lo,
hi_d: hi,
lo_r,
hi_r,
}
}
Wavelet::Db4 => {
let lo = vec![
-0.010597401784997278,
0.032883011666982945,
0.030841381835986965,
-0.18703481171888114,
-0.027983769416983849,
0.6308807679295904,
0.7148465705525415,
0.23037781330885523,
];
let hi = qmf_hi(&lo);
let lo_r = lo.iter().rev().cloned().collect();
let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
WaveletFilters {
lo_d: lo,
hi_d: hi,
lo_r,
hi_r,
}
}
Wavelet::Db5 => {
let lo = vec![
0.003335725285001549,
-0.012580751999015526,
-0.006241490213011705,
0.07757149384006515,
-0.03224486958502952,
-0.24229488706619015,
0.13842814590110342,
0.7243085284377729,
0.6038292697974729,
0.160102397974125,
];
let hi = qmf_hi(&lo);
let lo_r = lo.iter().rev().cloned().collect();
let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
WaveletFilters {
lo_d: lo,
hi_d: hi,
lo_r,
hi_r,
}
}
Wavelet::Sym2 => {
let s3 = 3.0_f64.sqrt();
let lo = vec![
(1.0 - s3) / 8.0_f64.sqrt(),
(3.0 - s3) / 8.0_f64.sqrt(),
(3.0 + s3) / 8.0_f64.sqrt(),
(1.0 + s3) / 8.0_f64.sqrt(),
];
let hi = qmf_hi(&lo);
let lo_r = lo.iter().rev().cloned().collect();
let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
WaveletFilters {
lo_d: lo,
hi_d: hi,
lo_r,
hi_r,
}
}
Wavelet::Sym4 => {
let lo = vec![
-0.07576571478927333,
-0.02963552764599851,
0.49761866763201545,
0.8037387518059161,
0.29785779560527736,
-0.09921954357684722,
-0.012603967262037833,
0.032223100604042702,
];
let hi = qmf_hi(&lo);
let lo_r = lo.iter().rev().cloned().collect();
let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
WaveletFilters {
lo_d: lo,
hi_d: hi,
lo_r,
hi_r,
}
}
Wavelet::Coif1 => {
let lo = vec![
-0.015655728135960927,
-0.07273261951285047,
0.3848648565381134,
0.8525720202122554,
0.3378976624578092,
-0.07273261951285047,
];
let hi = qmf_hi(&lo);
let lo_r = lo.iter().rev().cloned().collect();
let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
WaveletFilters {
lo_d: lo,
hi_d: hi,
lo_r,
hi_r,
}
}
Wavelet::Bior22 => {
let lo = vec![-0.125, 0.25, 0.75, 0.25, -0.125];
let hi = vec![-0.25, 0.5, -0.25];
let lo_r: Vec<f64> = lo.iter().rev().cloned().collect();
let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
WaveletFilters {
lo_d: lo,
hi_d: hi,
lo_r,
hi_r,
}
}
}
}
}
fn qmf_hi(lo: &[f64]) -> Vec<f64> {
let n = lo.len();
lo.iter()
.rev()
.enumerate()
.map(|(k, &v)| if (n - 1 - k) % 2 == 0 { v } else { -v })
.collect()
}
fn conv_downsample(signal: &[f64], filter: &[f64]) -> Vec<f64> {
let n = signal.len();
let flen = filter.len();
let out_len = (n + flen - 1) / 2; let mut out = vec![0.0_f64; out_len];
for k in 0..out_len {
let t = 2 * k;
let mut acc = 0.0_f64;
for (j, &h) in filter.iter().enumerate() {
let idx = ((t as isize - j as isize).rem_euclid(n as isize)) as usize;
acc += signal[idx] * h;
}
out[k] = acc;
}
out
}
fn upsample_conv(input: &[f64], filter: &[f64], target_len: usize) -> Vec<f64> {
let n_up = input.len() * 2;
let flen = filter.len();
let mut out = vec![0.0_f64; target_len];
for k in 0..target_len {
let mut acc = 0.0_f64;
for (j, &h) in filter.iter().enumerate() {
let t = (k as isize - j as isize).rem_euclid(n_up as isize) as usize;
if t % 2 != 0 {
continue; }
let src = t / 2;
acc += input[src] * h;
}
out[k] = acc;
}
out
}
#[derive(Debug, Clone)]
pub struct WaveletPacketNode {
pub coeffs: Vec<f64>,
pub level: usize,
pub index: usize,
}
impl WaveletPacketNode {
pub fn new(coeffs: Vec<f64>, level: usize, index: usize) -> Self {
WaveletPacketNode {
coeffs,
level,
index,
}
}
pub fn is_root(&self) -> bool {
self.level == 0
}
fn key(level: usize, index: usize) -> u64 {
(level as u64) << 32 | (index as u64)
}
}
#[derive(Debug, Clone)]
pub struct WaveletPacketTree {
nodes: HashMap<u64, WaveletPacketNode>,
pub max_level: usize,
pub wavelet: Wavelet,
pub signal_len: usize,
}
impl WaveletPacketTree {
pub fn new(wavelet: Wavelet, max_level: usize, signal_len: usize) -> Self {
WaveletPacketTree {
nodes: HashMap::new(),
max_level,
wavelet,
signal_len,
}
}
pub fn insert(&mut self, node: WaveletPacketNode) {
let key = WaveletPacketNode::key(node.level, node.index);
self.nodes.insert(key, node);
}
pub fn get(&self, level: usize, index: usize) -> Option<&WaveletPacketNode> {
self.nodes.get(&WaveletPacketNode::key(level, index))
}
pub fn nodes_at_level(&self, level: usize) -> impl Iterator<Item = &WaveletPacketNode> {
self.nodes
.values()
.filter(move |n| n.level == level)
}
pub fn all_nodes(&self) -> impl Iterator<Item = &WaveletPacketNode> {
self.nodes.values()
}
}
pub fn wpd(signal: &[f64], wavelet: Wavelet, max_level: usize) -> FFTResult<WaveletPacketTree> {
if signal.is_empty() {
return Err(FFTError::ValueError("signal must be non-empty".to_string()));
}
if max_level == 0 {
return Err(FFTError::ValueError(
"max_level must be >= 1".to_string(),
));
}
let filters = WaveletFilters::for_wavelet(wavelet);
let signal_len = signal.len();
let mut tree = WaveletPacketTree::new(wavelet, max_level, signal_len);
tree.insert(WaveletPacketNode::new(signal.to_vec(), 0, 0));
for level in 0..max_level {
let num_nodes = 1_usize << level;
for index in 0..num_nodes {
let coeffs = match tree.get(level, index) {
Some(n) => n.coeffs.clone(),
None => {
return Err(FFTError::InternalError(format!(
"missing node ({level}, {index})"
)))
}
};
let lo = conv_downsample(&coeffs, &filters.lo_d);
tree.insert(WaveletPacketNode::new(lo, level + 1, 2 * index));
let hi = conv_downsample(&coeffs, &filters.hi_d);
tree.insert(WaveletPacketNode::new(hi, level + 1, 2 * index + 1));
}
}
Ok(tree)
}
pub fn shannon_entropy(coeffs: &[f64]) -> f64 {
coeffs
.iter()
.filter_map(|&c| {
let p = c * c;
if p > 0.0 {
Some(-p * p.log2())
} else {
None
}
})
.sum()
}
pub fn log_energy_entropy(coeffs: &[f64]) -> f64 {
coeffs
.iter()
.filter_map(|&c| {
let p = c * c;
if p > 0.0 {
Some(p.ln() / LN_2)
} else {
None
}
})
.sum()
}
pub fn lp_norm_cost(coeffs: &[f64], p: f64) -> f64 {
coeffs.iter().map(|&c| c.abs().powf(p)).sum()
}
pub fn best_basis<F>(
tree: &WaveletPacketTree,
cost_fn: F,
) -> FFTResult<Vec<WaveletPacketNode>>
where
F: Fn(&[f64]) -> f64,
{
if tree.max_level == 0 {
return Err(FFTError::ValueError("tree is empty".to_string()));
}
let mut costs: HashMap<u64, f64> = HashMap::new();
for node in tree.all_nodes() {
let key = WaveletPacketNode::key(node.level, node.index);
costs.insert(key, cost_fn(&node.coeffs));
}
let mut best_flag: HashMap<u64, bool> = HashMap::new();
for level in (0..tree.max_level).rev() {
let num_nodes = 1_usize << level;
for index in 0..num_nodes {
let parent_key = WaveletPacketNode::key(level, index);
let left_key = WaveletPacketNode::key(level + 1, 2 * index);
let right_key = WaveletPacketNode::key(level + 1, 2 * index + 1);
let parent_cost = match costs.get(&parent_key) {
Some(&c) => c,
None => continue,
};
let left_cost = effective_cost(&costs, &best_flag, level + 1, 2 * index);
let right_cost = effective_cost(&costs, &best_flag, level + 1, 2 * index + 1);
let children_cost = left_cost + right_cost;
if parent_cost <= children_cost {
best_flag.insert(parent_key, false); costs.insert(parent_key, parent_cost);
} else {
best_flag.insert(parent_key, true);
costs.insert(parent_key, children_cost);
}
best_flag.entry(left_key).or_insert(false);
best_flag.entry(right_key).or_insert(false);
}
}
let mut basis: Vec<WaveletPacketNode> = Vec::new();
collect_basis(tree, &best_flag, 0, 0, &mut basis)?;
Ok(basis)
}
fn collect_basis(
tree: &WaveletPacketTree,
best_flag: &HashMap<u64, bool>,
level: usize,
index: usize,
out: &mut Vec<WaveletPacketNode>,
) -> FFTResult<()> {
let key = WaveletPacketNode::key(level, index);
let is_split = best_flag.get(&key).copied().unwrap_or(false);
if !is_split || level == tree.max_level {
if let Some(node) = tree.get(level, index) {
out.push(node.clone());
}
} else {
collect_basis(tree, best_flag, level + 1, 2 * index, out)?;
collect_basis(tree, best_flag, level + 1, 2 * index + 1, out)?;
}
Ok(())
}
fn effective_cost(
costs: &HashMap<u64, f64>,
best_flag: &HashMap<u64, bool>,
level: usize,
index: usize,
) -> f64 {
let key = WaveletPacketNode::key(level, index);
costs.get(&key).copied().unwrap_or(f64::INFINITY)
}
pub fn wp_reconstruct(
tree: &WaveletPacketTree,
basis_nodes: &[WaveletPacketNode],
) -> FFTResult<Vec<f64>> {
if basis_nodes.is_empty() {
return Err(FFTError::ValueError(
"basis_nodes must be non-empty".to_string(),
));
}
let filters = WaveletFilters::for_wavelet(tree.wavelet);
let mut node_map: HashMap<u64, Vec<f64>> = HashMap::new();
for node in basis_nodes {
let key = WaveletPacketNode::key(node.level, node.index);
node_map.insert(key, node.coeffs.clone());
}
for level in (1..=tree.max_level).rev() {
let num_nodes = 1_usize << level;
let parent_level = level - 1;
let num_parents = 1_usize << parent_level;
for p_idx in 0..num_parents {
let left_key = WaveletPacketNode::key(level, 2 * p_idx);
let right_key = WaveletPacketNode::key(level, 2 * p_idx + 1);
let parent_key = WaveletPacketNode::key(parent_level, p_idx);
if node_map.contains_key(&parent_key) {
continue;
}
let left_coeffs = match node_map.get(&left_key) {
Some(c) => c.clone(),
None => continue,
};
let right_coeffs = match node_map.get(&right_key) {
Some(c) => c.clone(),
None => continue,
};
let target_len = tree
.get(parent_level, p_idx)
.map(|n| n.coeffs.len())
.unwrap_or_else(|| {
left_coeffs.len() * 2
});
let lo_rec = upsample_conv(&left_coeffs, &filters.lo_r, target_len);
let hi_rec = upsample_conv(&right_coeffs, &filters.hi_r, target_len);
let parent_coeffs: Vec<f64> = lo_rec
.iter()
.zip(hi_rec.iter())
.map(|(a, b)| a + b)
.collect();
node_map.insert(parent_key, parent_coeffs);
}
for idx in 0..num_nodes {
let left_key = WaveletPacketNode::key(level, idx);
let _ = left_key;
}
}
let root_key = WaveletPacketNode::key(0, 0);
node_map
.remove(&root_key)
.ok_or_else(|| FFTError::InternalError("reconstruction failed: root not reached".to_string()))
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ThresholdMethod {
Hard,
Soft,
Garrote,
Firm { t2: f64 },
}
fn threshold_coeffs(coeffs: &[f64], tau: f64, method: ThresholdMethod) -> Vec<f64> {
coeffs
.iter()
.map(|&c| apply_threshold(c, tau, method))
.collect()
}
fn apply_threshold(c: f64, tau: f64, method: ThresholdMethod) -> f64 {
match method {
ThresholdMethod::Hard => {
if c.abs() >= tau {
c
} else {
0.0
}
}
ThresholdMethod::Soft => {
if c > tau {
c - tau
} else if c < -tau {
c + tau
} else {
0.0
}
}
ThresholdMethod::Garrote => {
if c.abs() <= tau {
0.0
} else {
c - tau * tau / c
}
}
ThresholdMethod::Firm { t2 } => {
let t1 = tau;
let abs_c = c.abs();
if abs_c <= t1 {
0.0
} else if abs_c >= t2 {
c
} else {
c.signum() * t1 * (abs_c - t1) / (t2 - t1)
}
}
}
}
pub fn wp_denoising(
signal: &[f64],
wavelet: Wavelet,
max_level: usize,
threshold: f64,
method: ThresholdMethod,
) -> FFTResult<Vec<f64>> {
let tree = wpd(signal, wavelet, max_level)?;
let basis = best_basis(&tree, shannon_entropy)?;
let thresholded: Vec<WaveletPacketNode> = basis
.into_iter()
.map(|mut node| {
if node.level > 0 {
node.coeffs = threshold_coeffs(&node.coeffs, threshold, method);
}
node
})
.collect();
let mut recon = wp_reconstruct(&tree, &thresholded)?;
recon.truncate(signal.len());
while recon.len() < signal.len() {
recon.push(0.0);
}
Ok(recon)
}
#[cfg(test)]
mod tests {
use super::*;
fn test_signal(n: usize) -> Vec<f64> {
(0..n)
.map(|i| {
let t = i as f64 / n as f64;
(2.0 * std::f64::consts::PI * 5.0 * t).sin()
+ 0.5 * (2.0 * std::f64::consts::PI * 13.0 * t).sin()
})
.collect()
}
#[test]
fn test_haar_decomp_shape() {
let sig = test_signal(64);
let tree = wpd(&sig, Wavelet::Haar, 3).expect("wpd failed");
for idx in 0..8 {
assert!(
tree.get(3, idx).is_some(),
"missing node (3, {idx})"
);
}
}
#[test]
fn test_qmf_energy_preservation() {
let filters = WaveletFilters::for_wavelet(Wavelet::Db2);
let e_lo: f64 = filters.lo_d.iter().map(|&c| c * c).sum();
let e_hi: f64 = filters.hi_d.iter().map(|&c| c * c).sum();
assert!((e_lo - 1.0).abs() < 1e-10, "lo energy {e_lo}");
assert!((e_hi - 1.0).abs() < 1e-10, "hi energy {e_hi}");
}
#[test]
fn test_shannon_entropy_uniform() {
let coeffs = vec![0.5_f64; 8];
let e = shannon_entropy(&coeffs);
assert!(e > 0.0, "expected positive entropy, got {e}");
}
#[test]
fn test_shannon_entropy_sparse() {
let mut coeffs = vec![0.0_f64; 64];
coeffs[0] = 1.0;
let e = shannon_entropy(&coeffs);
assert!((e - 0.0).abs() < 1e-12, "sparse signal entropy {e}");
}
#[test]
fn test_best_basis_returns_valid_partition() {
let sig = test_signal(64);
let tree = wpd(&sig, Wavelet::Db2, 3).expect("wpd");
let basis = best_basis(&tree, shannon_entropy).expect("best_basis");
assert!(!basis.is_empty(), "basis is empty");
for node in &basis {
assert!(
tree.get(node.level, node.index).is_some(),
"basis node ({}, {}) not in tree",
node.level,
node.index
);
}
}
#[test]
fn test_haar_perfect_reconstruction() {
let sig = test_signal(64);
let tree = wpd(&sig, Wavelet::Haar, 2).expect("wpd");
let basis: Vec<WaveletPacketNode> = (0..4_usize)
.filter_map(|idx| tree.get(2, idx).cloned())
.collect();
let recon = wp_reconstruct(&tree, &basis).expect("recon");
for (i, (&s, &r)) in sig.iter().zip(recon.iter()).enumerate() {
assert!(
(s - r).abs() < 1e-10,
"mismatch at {i}: orig={s}, recon={r}"
);
}
}
#[test]
fn test_denoising_length_preserved() {
let sig = test_signal(64);
let denoised =
wp_denoising(&sig, Wavelet::Db4, 3, 0.1, ThresholdMethod::Soft).expect("denoise");
assert_eq!(denoised.len(), sig.len());
}
#[test]
fn test_threshold_hard() {
let coeffs = vec![1.0, -0.5, 0.3, -0.1, 2.0];
let out = threshold_coeffs(&coeffs, 0.4, ThresholdMethod::Hard);
assert_eq!(out, vec![1.0, -0.5, 0.0, 0.0, 2.0]);
}
#[test]
fn test_threshold_soft() {
let out = threshold_coeffs(&[1.0, -1.5, 0.2], 0.5, ThresholdMethod::Soft);
assert!((out[0] - 0.5).abs() < 1e-12);
assert!((out[1] - (-1.0)).abs() < 1e-12);
assert!((out[2] - 0.0).abs() < 1e-12);
}
#[test]
fn test_wpd_error_on_empty() {
let result = wpd(&[], Wavelet::Haar, 2);
assert!(result.is_err());
}
#[test]
fn test_wpd_error_on_zero_level() {
let result = wpd(&[1.0, 2.0, 3.0], Wavelet::Haar, 0);
assert!(result.is_err());
}
}