scirs2_fft/wavelet_packets.rs
1//! Wavelet Packet Transform (WPT)
2//!
3//! Wavelet packets generalize the discrete wavelet transform by allowing full
4//! decomposition of both approximation and detail subbands at each level.
5//! This produces a complete binary tree of subband coefficients.
6//!
7//! The best-basis algorithm (Coifman–Wickerhauser 1992) selects an optimal
8//! orthonormal basis from the packet tree by minimising an additive cost function
9//! (e.g. Shannon entropy or log-energy).
10//!
11//! # References
12//! - Coifman, R.R. & Wickerhauser, M.V. (1992). Entropy-based algorithms for best
13//! basis selection. IEEE Trans. Inf. Theory, 38(2), 713–718.
14//! - Mallat, S. (1999). A Wavelet Tour of Signal Processing. Academic Press.
15
16use std::collections::HashMap;
17use std::f64::consts::LN_2;
18
19use crate::error::{FFTError, FFTResult};
20
21// ─────────────────────────────────────────────────────────────────────────────
22// Wavelet filter definitions
23// ─────────────────────────────────────────────────────────────────────────────
24
25/// Supported orthonormal wavelet families.
26#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
27pub enum Wavelet {
28 /// Haar wavelet (db1)
29 Haar,
30 /// Daubechies 4-tap (db2)
31 Db2,
32 /// Daubechies 6-tap (db3)
33 Db3,
34 /// Daubechies 8-tap (db4)
35 Db4,
36 /// Daubechies 10-tap (db5)
37 Db5,
38 /// Symlet 4-tap (sym2)
39 Sym2,
40 /// Symlet 8-tap (sym4)
41 Sym4,
42 /// Coiflet 6-tap (coif1)
43 Coif1,
44 /// Biorthogonal 2.2 (bior2.2) – for approximation only; analysis filters
45 Bior22,
46}
47
48/// Low-pass (scaling) and high-pass (wavelet) analysis filters for a wavelet.
49#[derive(Debug, Clone)]
50pub struct WaveletFilters {
51 /// Low-pass decomposition filter h₀
52 pub lo_d: Vec<f64>,
53 /// High-pass decomposition filter h₁
54 pub hi_d: Vec<f64>,
55 /// Low-pass reconstruction filter g₀
56 pub lo_r: Vec<f64>,
57 /// High-pass reconstruction filter g₁
58 pub hi_r: Vec<f64>,
59}
60
61impl WaveletFilters {
62 /// Return filters for the given wavelet.
63 pub fn for_wavelet(w: Wavelet) -> Self {
64 match w {
65 Wavelet::Haar => {
66 let s = 1.0_f64 / 2.0_f64.sqrt();
67 let lo = vec![s, s];
68 let hi = vec![s, -s];
69 // For orthogonal wavelets with the transpose synthesis formula,
70 // the synthesis filters equal the analysis filters (lo_r = lo_d, hi_r = hi_d).
71 let lo_r = lo.clone();
72 let hi_r = hi.clone();
73 WaveletFilters {
74 lo_d: lo,
75 hi_d: hi,
76 lo_r,
77 hi_r,
78 }
79 }
80 Wavelet::Db2 => {
81 let s3 = 3.0_f64.sqrt();
82 let norm = 4.0 * 2.0_f64.sqrt(); // 4*sqrt(2)
83 let lo = vec![
84 (1.0 + s3) / norm,
85 (3.0 + s3) / norm,
86 (3.0 - s3) / norm,
87 (1.0 - s3) / norm,
88 ];
89 let hi = qmf_hi(&lo);
90 // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
91 let lo_r = lo.clone();
92 let hi_r = hi.clone();
93 WaveletFilters {
94 lo_d: lo,
95 hi_d: hi,
96 lo_r,
97 hi_r,
98 }
99 }
100 Wavelet::Db3 => {
101 // Daubechies db3 (6-tap) coefficients
102 let lo = vec![
103 0.035226291882100656,
104 -0.08544127388202666,
105 -0.13501102001039084,
106 0.4598775021193313,
107 0.8068915093133388,
108 0.3326705529509569,
109 ];
110 let hi = qmf_hi(&lo);
111 // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
112 let lo_r = lo.clone();
113 let hi_r = hi.clone();
114 WaveletFilters {
115 lo_d: lo,
116 hi_d: hi,
117 lo_r,
118 hi_r,
119 }
120 }
121 Wavelet::Db4 => {
122 // Daubechies db4 (8-tap) coefficients
123 let lo = vec![
124 -0.010597401784997278,
125 0.032883011666982945,
126 0.030841381835986965,
127 -0.18703481171888114,
128 -0.027_983_769_416_983_85,
129 0.6308807679295904,
130 0.7148465705525415,
131 0.23037781330885523,
132 ];
133 let hi = qmf_hi(&lo);
134 // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
135 let lo_r = lo.clone();
136 let hi_r = hi.clone();
137 WaveletFilters {
138 lo_d: lo,
139 hi_d: hi,
140 lo_r,
141 hi_r,
142 }
143 }
144 Wavelet::Db5 => {
145 // Daubechies db5 (10-tap) coefficients
146 let lo = vec![
147 0.003335725285001549,
148 -0.012580751999015526,
149 -0.006241490213011705,
150 0.07757149384006515,
151 -0.03224486958502952,
152 -0.24229488706619015,
153 0.13842814590110342,
154 0.7243085284377729,
155 0.6038292697974729,
156 0.160102397974125,
157 ];
158 let hi = qmf_hi(&lo);
159 // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
160 let lo_r = lo.clone();
161 let hi_r = hi.clone();
162 WaveletFilters {
163 lo_d: lo,
164 hi_d: hi,
165 lo_r,
166 hi_r,
167 }
168 }
169 Wavelet::Sym2 => {
170 // Symlet sym2 = db2 (same energy, different phase)
171 let s3 = 3.0_f64.sqrt();
172 let lo = vec![
173 (1.0 - s3) / 8.0_f64.sqrt(),
174 (3.0 - s3) / 8.0_f64.sqrt(),
175 (3.0 + s3) / 8.0_f64.sqrt(),
176 (1.0 + s3) / 8.0_f64.sqrt(),
177 ];
178 let hi = qmf_hi(&lo);
179 // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
180 let lo_r = lo.clone();
181 let hi_r = hi.clone();
182 WaveletFilters {
183 lo_d: lo,
184 hi_d: hi,
185 lo_r,
186 hi_r,
187 }
188 }
189 Wavelet::Sym4 => {
190 // Symlet sym4 (8-tap)
191 let lo = vec![
192 -0.07576571478927333,
193 -0.02963552764599851,
194 0.49761866763201545,
195 0.8037387518059161,
196 0.29785779560527736,
197 -0.09921954357684722,
198 -0.012603967262037833,
199 0.032_223_100_604_042_7,
200 ];
201 let hi = qmf_hi(&lo);
202 // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
203 let lo_r = lo.clone();
204 let hi_r = hi.clone();
205 WaveletFilters {
206 lo_d: lo,
207 hi_d: hi,
208 lo_r,
209 hi_r,
210 }
211 }
212 Wavelet::Coif1 => {
213 // Coiflet coif1 (6-tap)
214 let lo = vec![
215 -0.015655728135960927,
216 -0.07273261951285047,
217 0.3848648565381134,
218 0.8525720202122554,
219 0.3378976624578092,
220 -0.07273261951285047,
221 ];
222 let hi = qmf_hi(&lo);
223 // Orthogonal wavelet: synthesis filters equal analysis filters for transpose synthesis.
224 let lo_r = lo.clone();
225 let hi_r = hi.clone();
226 WaveletFilters {
227 lo_d: lo,
228 hi_d: hi,
229 lo_r,
230 hi_r,
231 }
232 }
233 Wavelet::Bior22 => {
234 // Biorthogonal 2.2 analysis filters
235 let lo = vec![-0.125, 0.25, 0.75, 0.25, -0.125];
236 let hi = vec![-0.25, 0.5, -0.25];
237 let lo_r: Vec<f64> = lo.iter().rev().cloned().collect();
238 let hi_r: Vec<f64> = hi.iter().rev().cloned().collect();
239 WaveletFilters {
240 lo_d: lo,
241 hi_d: hi,
242 lo_r,
243 hi_r,
244 }
245 }
246 }
247 }
248}
249
250/// Build the high-pass QMF filter from a low-pass filter.
251///
252/// h₁[n] = (-1)^n · h₀[L-1-n]
253fn qmf_hi(lo: &[f64]) -> Vec<f64> {
254 let n = lo.len();
255 lo.iter()
256 .rev()
257 .enumerate()
258 .map(|(k, &v)| if (n - 1 - k) % 2 == 0 { v } else { -v })
259 .collect()
260}
261
262// ─────────────────────────────────────────────────────────────────────────────
263// Core convolution / subsampling helpers
264// ─────────────────────────────────────────────────────────────────────────────
265
266/// Convolve `signal` with `filter` using periodic (circular) boundary extension
267/// and then down-sample by 2 (keep even-indexed samples).
268///
269/// The output length is `ceil(signal.len() / 2)` = `(signal.len() + 1) / 2`.
270/// Using circular (periodic) boundary means the output length depends only on
271/// the signal length, not the filter length.
272fn conv_downsample(signal: &[f64], filter: &[f64]) -> Vec<f64> {
273 let n = signal.len();
274 let out_len = n.div_ceil(2); // ceil(n/2), independent of filter length
275 let mut out = vec![0.0_f64; out_len];
276 for k in 0..out_len {
277 let t = 2 * k;
278 let mut acc = 0.0_f64;
279 for (j, &h) in filter.iter().enumerate() {
280 // periodic (circular) boundary
281 let idx = ((t as isize - j as isize).rem_euclid(n as isize)) as usize;
282 acc += signal[idx] * h;
283 }
284 out[k] = acc;
285 }
286 out
287}
288
289/// Synthesis step: transpose of `conv_downsample` with periodic (circular) boundary.
290///
291/// This is the adjoint (transpose) of the analysis step, ensuring perfect
292/// reconstruction for orthogonal wavelets. The formula is:
293///
294/// xhat[n] = Σₖ input[k] · filter[(2k − n) mod target_len]
295///
296/// where the modular index is only applied when it falls within the filter support
297/// `[0, filter.len())`.
298///
299/// `target_len` must equal the length of the signal that was passed to
300/// `conv_downsample` to produce `input`.
301fn synthesis_step(input: &[f64], filter: &[f64], target_len: usize) -> Vec<f64> {
302 let n_in = input.len();
303 let flen = filter.len();
304 let n_out = target_len;
305 let mut out = vec![0.0_f64; n_out];
306 for n_idx in 0..n_out {
307 let mut acc = 0.0_f64;
308 for k in 0..n_in {
309 let j = ((2 * k as isize - n_idx as isize).rem_euclid(n_out as isize)) as usize;
310 if j < flen {
311 acc += input[k] * filter[j];
312 }
313 }
314 out[n_idx] = acc;
315 }
316 out
317}
318
319// ─────────────────────────────────────────────────────────────────────────────
320// Node & tree structures
321// ─────────────────────────────────────────────────────────────────────────────
322
323/// A single node in the wavelet packet tree.
324///
325/// The node stores the subband coefficients and its position in the tree.
326/// Position is identified by `(level, index)` where `index ∈ [0, 2^level)`.
327#[derive(Debug, Clone)]
328pub struct WaveletPacketNode {
329 /// Subband coefficients at this node.
330 pub coeffs: Vec<f64>,
331 /// Decomposition level (0 = root, i.e. the original signal).
332 pub level: usize,
333 /// Node index within the level (frequency-ordered).
334 pub index: usize,
335}
336
337impl WaveletPacketNode {
338 /// Create a new node.
339 pub fn new(coeffs: Vec<f64>, level: usize, index: usize) -> Self {
340 WaveletPacketNode {
341 coeffs,
342 level,
343 index,
344 }
345 }
346
347 /// Returns `true` if this node is the root (level 0).
348 pub fn is_root(&self) -> bool {
349 self.level == 0
350 }
351
352 /// Flat key used for `HashMap` storage: `level * OFFSET + index`.
353 fn key(level: usize, index: usize) -> u64 {
354 (level as u64) << 32 | (index as u64)
355 }
356}
357
358/// A full binary tree of wavelet packet nodes.
359///
360/// Nodes are stored in a `HashMap` keyed by `(level, index)`.
361/// The tree is built by `wpd` and each interior node stores *both* the
362/// node's own coefficients and its children (low/high subbands).
363#[derive(Debug, Clone)]
364pub struct WaveletPacketTree {
365 /// All computed nodes, keyed by `WaveletPacketNode::key(level, index)`.
366 nodes: HashMap<u64, WaveletPacketNode>,
367 /// Maximum decomposition depth.
368 pub max_level: usize,
369 /// Wavelet used to build this tree.
370 pub wavelet: Wavelet,
371 /// Length of the original signal (needed for reconstruction).
372 pub signal_len: usize,
373}
374
375impl WaveletPacketTree {
376 /// Create an empty tree.
377 pub fn new(wavelet: Wavelet, max_level: usize, signal_len: usize) -> Self {
378 WaveletPacketTree {
379 nodes: HashMap::new(),
380 max_level,
381 wavelet,
382 signal_len,
383 }
384 }
385
386 /// Insert a node into the tree.
387 pub fn insert(&mut self, node: WaveletPacketNode) {
388 let key = WaveletPacketNode::key(node.level, node.index);
389 self.nodes.insert(key, node);
390 }
391
392 /// Retrieve a node by `(level, index)`.
393 pub fn get(&self, level: usize, index: usize) -> Option<&WaveletPacketNode> {
394 self.nodes.get(&WaveletPacketNode::key(level, index))
395 }
396
397 /// Iterate over all nodes at a given `level`.
398 pub fn nodes_at_level(&self, level: usize) -> impl Iterator<Item = &WaveletPacketNode> {
399 self.nodes.values().filter(move |n| n.level == level)
400 }
401
402 /// All nodes in the tree.
403 pub fn all_nodes(&self) -> impl Iterator<Item = &WaveletPacketNode> {
404 self.nodes.values()
405 }
406}
407
408// ─────────────────────────────────────────────────────────────────────────────
409// Wavelet Packet Decomposition (WPD)
410// ─────────────────────────────────────────────────────────────────────────────
411
412/// Perform a full wavelet packet decomposition up to `max_level`.
413///
414/// Every node (approximation *and* detail) at every level is recursively
415/// decomposed, producing a complete binary tree with `2^(max_level+1) - 1` nodes.
416///
417/// # Arguments
418///
419/// * `signal` – Real-valued input signal.
420/// * `wavelet` – Wavelet to use (determines analysis filters).
421/// * `max_level` – Maximum decomposition depth. The root is level 0.
422///
423/// # Errors
424///
425/// Returns `FFTError::ValueError` if `signal` is empty or `max_level == 0`.
426///
427/// # Example
428///
429/// ```
430/// use scirs2_fft::wavelet_packets::{wpd, Wavelet};
431///
432/// let signal: Vec<f64> = (0..64).map(|i| (i as f64 * 0.1).sin()).collect();
433/// let tree = wpd(&signal, Wavelet::Db4, 3).expect("decomposition failed");
434/// // Tree has nodes at levels 0 through 3
435/// assert!(tree.get(0, 0).is_some());
436/// assert!(tree.get(3, 7).is_some());
437/// ```
438pub fn wpd(signal: &[f64], wavelet: Wavelet, max_level: usize) -> FFTResult<WaveletPacketTree> {
439 if signal.is_empty() {
440 return Err(FFTError::ValueError("signal must be non-empty".to_string()));
441 }
442 if max_level == 0 {
443 return Err(FFTError::ValueError("max_level must be >= 1".to_string()));
444 }
445
446 let filters = WaveletFilters::for_wavelet(wavelet);
447 let signal_len = signal.len();
448 let mut tree = WaveletPacketTree::new(wavelet, max_level, signal_len);
449
450 // Root node (level 0, index 0) = original signal
451 tree.insert(WaveletPacketNode::new(signal.to_vec(), 0, 0));
452
453 // BFS decomposition
454 for level in 0..max_level {
455 let num_nodes = 1_usize << level;
456 for index in 0..num_nodes {
457 let coeffs = match tree.get(level, index) {
458 Some(n) => n.coeffs.clone(),
459 None => {
460 return Err(FFTError::InternalError(format!(
461 "missing node ({level}, {index})"
462 )))
463 }
464 };
465
466 // Low-pass child (approximation) → (level+1, 2*index)
467 let lo = conv_downsample(&coeffs, &filters.lo_d);
468 tree.insert(WaveletPacketNode::new(lo, level + 1, 2 * index));
469
470 // High-pass child (detail) → (level+1, 2*index+1)
471 let hi = conv_downsample(&coeffs, &filters.hi_d);
472 tree.insert(WaveletPacketNode::new(hi, level + 1, 2 * index + 1));
473 }
474 }
475
476 Ok(tree)
477}
478
479// ─────────────────────────────────────────────────────────────────────────────
480// Cost functions
481// ─────────────────────────────────────────────────────────────────────────────
482
483/// Shannon entropy cost function.
484///
485/// E(s) = -∑ |s_i|² log₂(|s_i|²)
486///
487/// Zero coefficients are excluded from the sum (lim_{p→0} p log p = 0).
488///
489/// # Example
490///
491/// ```
492/// use scirs2_fft::wavelet_packets::shannon_entropy;
493///
494/// let coeffs = vec![0.5, -0.5, 0.5, -0.5];
495/// let e = shannon_entropy(&coeffs);
496/// assert!(e >= 0.0);
497/// ```
498pub fn shannon_entropy(coeffs: &[f64]) -> f64 {
499 coeffs
500 .iter()
501 .filter_map(|&c| {
502 let p = c * c;
503 if p > 0.0 {
504 Some(-p * p.log2())
505 } else {
506 None
507 }
508 })
509 .sum()
510}
511
512/// Log-energy entropy cost function.
513///
514/// E(s) = ∑ log(|s_i|²) (non-zero coefficients only)
515pub fn log_energy_entropy(coeffs: &[f64]) -> f64 {
516 coeffs
517 .iter()
518 .filter_map(|&c| {
519 let p = c * c;
520 if p > 0.0 {
521 Some(p.ln() / LN_2)
522 } else {
523 None
524 }
525 })
526 .sum()
527}
528
529/// Lp-norm (p ≠ 2) cost function – measures sparsity.
530///
531/// E(s) = ∑ |s_i|^p
532pub fn lp_norm_cost(coeffs: &[f64], p: f64) -> f64 {
533 coeffs.iter().map(|&c| c.abs().powf(p)).sum()
534}
535
536// ─────────────────────────────────────────────────────────────────────────────
537// Best Basis Selection (Coifman–Wickerhauser)
538// ─────────────────────────────────────────────────────────────────────────────
539
540/// Select the best orthonormal basis from a wavelet packet tree.
541///
542/// The algorithm minimises an additive cost function `cost_fn` using a
543/// bottom-up pass: a parent node is kept when its cost is *less than or equal*
544/// to the sum of the costs of its two children.
545///
546/// # Arguments
547///
548/// * `tree` – Packet tree produced by `wpd`.
549/// * `cost_fn` – Additive cost function; must satisfy `cost(A∪B) = cost(A) + cost(B)`.
550///
551/// # Returns
552///
553/// A `Vec<WaveletPacketNode>` that forms a partition of the time-frequency
554/// plane (i.e. a valid orthonormal basis).
555///
556/// # Errors
557///
558/// Returns `FFTError::ValueError` if the tree is empty.
559///
560/// # Example
561///
562/// ```
563/// use scirs2_fft::wavelet_packets::{wpd, best_basis, shannon_entropy, Wavelet};
564///
565/// let signal: Vec<f64> = (0..64).map(|i| (i as f64 * 0.1).sin()).collect();
566/// let tree = wpd(&signal, Wavelet::Haar, 3).expect("decomp");
567/// let basis = best_basis(&tree, shannon_entropy).expect("basis");
568/// assert!(!basis.is_empty());
569/// ```
570pub fn best_basis<F>(tree: &WaveletPacketTree, cost_fn: F) -> FFTResult<Vec<WaveletPacketNode>>
571where
572 F: Fn(&[f64]) -> f64,
573{
574 if tree.max_level == 0 {
575 return Err(FFTError::ValueError("tree is empty".to_string()));
576 }
577
578 // Pre-compute costs for every node in the tree
579 let mut costs: HashMap<u64, f64> = HashMap::new();
580 for node in tree.all_nodes() {
581 let key = WaveletPacketNode::key(node.level, node.index);
582 costs.insert(key, cost_fn(&node.coeffs));
583 }
584
585 // best_flag[key] = true → keep this node (do NOT split)
586 let mut best_flag: HashMap<u64, bool> = HashMap::new();
587
588 // Bottom-up: iterate from max_level - 1 down to 0
589 for level in (0..tree.max_level).rev() {
590 let num_nodes = 1_usize << level;
591 for index in 0..num_nodes {
592 let parent_key = WaveletPacketNode::key(level, index);
593 let left_key = WaveletPacketNode::key(level + 1, 2 * index);
594 let right_key = WaveletPacketNode::key(level + 1, 2 * index + 1);
595
596 let parent_cost = match costs.get(&parent_key) {
597 Some(&c) => c,
598 None => continue,
599 };
600
601 // Children cost is the sum; if a child is already "split", we use
602 // the *best* cost that the subtree achieves (propagated upward).
603 let left_cost = effective_cost(&costs, &best_flag, level + 1, 2 * index);
604 let right_cost = effective_cost(&costs, &best_flag, level + 1, 2 * index + 1);
605 let children_cost = left_cost + right_cost;
606
607 if parent_cost <= children_cost {
608 // Parent is better (or equal) → keep parent, prune children
609 best_flag.insert(parent_key, false); // false = "not split"
610 costs.insert(parent_key, parent_cost);
611 } else {
612 // Split is better → mark parent as "split"
613 best_flag.insert(parent_key, true);
614 // Update the effective cost of this node to the children sum
615 // so grandparents can compare correctly
616 costs.insert(parent_key, children_cost);
617 }
618
619 // Ensure leaf flags exist for the children (they have no children of their own)
620 best_flag.entry(left_key).or_insert(false);
621 best_flag.entry(right_key).or_insert(false);
622 }
623 }
624
625 // Collect the basis by selecting nodes that are NOT split
626 let mut basis: Vec<WaveletPacketNode> = Vec::new();
627 collect_basis(tree, &best_flag, 0, 0, &mut basis)?;
628
629 Ok(basis)
630}
631
632/// Recursively collect basis nodes starting from `(level, index)`.
633fn collect_basis(
634 tree: &WaveletPacketTree,
635 best_flag: &HashMap<u64, bool>,
636 level: usize,
637 index: usize,
638 out: &mut Vec<WaveletPacketNode>,
639) -> FFTResult<()> {
640 let key = WaveletPacketNode::key(level, index);
641 let is_split = best_flag.get(&key).copied().unwrap_or(false);
642
643 if !is_split || level == tree.max_level {
644 // Leaf of best basis tree
645 if let Some(node) = tree.get(level, index) {
646 out.push(node.clone());
647 }
648 } else {
649 collect_basis(tree, best_flag, level + 1, 2 * index, out)?;
650 collect_basis(tree, best_flag, level + 1, 2 * index + 1, out)?;
651 }
652 Ok(())
653}
654
655/// Return the effective (post-best-basis) cost for a node.
656fn effective_cost(
657 costs: &HashMap<u64, f64>,
658 best_flag: &HashMap<u64, bool>,
659 level: usize,
660 index: usize,
661) -> f64 {
662 // If the node has already been processed (and possibly "split"), its
663 // cost in the map already reflects the best achievable cost.
664 let key = WaveletPacketNode::key(level, index);
665 costs.get(&key).copied().unwrap_or(f64::INFINITY)
666}
667
668// ─────────────────────────────────────────────────────────────────────────────
669// Reconstruction
670// ─────────────────────────────────────────────────────────────────────────────
671
672/// Reconstruct the signal from a set of wavelet packet nodes forming a basis.
673///
674/// The nodes must constitute a valid partition of the time-frequency plane
675/// (e.g. those returned by `best_basis`). Mixed-level bases (where some nodes
676/// are at depth 2 and others at depth 3, etc.) are fully supported.
677///
678/// # Arguments
679///
680/// * `tree` – The original packet tree (provides wavelet & signal length).
681/// * `basis_nodes` – A valid wavelet packet basis (partition of the root).
682///
683/// # Errors
684///
685/// Returns `FFTError::InternalError` if reconstruction encounters a missing node.
686///
687/// # Example
688///
689/// ```
690/// use scirs2_fft::wavelet_packets::{wpd, best_basis, wp_reconstruct, shannon_entropy, Wavelet};
691///
692/// let signal: Vec<f64> = (0..64).map(|i| (i as f64 * 0.1).sin()).collect();
693/// let tree = wpd(&signal, Wavelet::Haar, 3).expect("decomp");
694/// let basis = best_basis(&tree, shannon_entropy).expect("basis");
695/// let recon = wp_reconstruct(&tree, &basis).expect("recon");
696/// assert_eq!(recon.len(), signal.len());
697/// // Perfect reconstruction (approx)
698/// for (a, b) in signal.iter().zip(recon.iter()) {
699/// assert!((a - b).abs() < 1e-10, "mismatch: {} vs {}", a, b);
700/// }
701/// ```
702pub fn wp_reconstruct(
703 tree: &WaveletPacketTree,
704 basis_nodes: &[WaveletPacketNode],
705) -> FFTResult<Vec<f64>> {
706 if basis_nodes.is_empty() {
707 return Err(FFTError::ValueError(
708 "basis_nodes must be non-empty".to_string(),
709 ));
710 }
711
712 let filters = WaveletFilters::for_wavelet(tree.wavelet);
713
714 // Map each basis node into the tree storage so we can do upward synthesis
715 let mut node_map: HashMap<u64, Vec<f64>> = HashMap::new();
716 for node in basis_nodes {
717 let key = WaveletPacketNode::key(node.level, node.index);
718 node_map.insert(key, node.coeffs.clone());
719 }
720
721 // We need to know the coefficient length at each level.
722 // We reuse the already-computed nodes in the tree.
723 // Bottom-up synthesis from max_level to level 0
724 for level in (1..=tree.max_level).rev() {
725 let num_nodes = 1_usize << level;
726 let parent_level = level - 1;
727 let num_parents = 1_usize << parent_level;
728
729 for p_idx in 0..num_parents {
730 let left_key = WaveletPacketNode::key(level, 2 * p_idx);
731 let right_key = WaveletPacketNode::key(level, 2 * p_idx + 1);
732 let parent_key = WaveletPacketNode::key(parent_level, p_idx);
733
734 // Skip if parent already exists in the basis (it was a leaf)
735 if node_map.contains_key(&parent_key) {
736 continue;
737 }
738
739 // Both children must be present to reconstruct the parent
740 let left_coeffs = match node_map.get(&left_key) {
741 Some(c) => c.clone(),
742 None => continue,
743 };
744 let right_coeffs = match node_map.get(&right_key) {
745 Some(c) => c.clone(),
746 None => continue,
747 };
748
749 // Target length: get from the tree if available, else estimate
750 let target_len = tree
751 .get(parent_level, p_idx)
752 .map(|n| n.coeffs.len())
753 .unwrap_or_else(|| {
754 // Estimate: parent length ≈ 2 * child length
755 left_coeffs.len() * 2
756 });
757
758 // Synthesis: lo branch + hi branch (transpose of analysis)
759 let lo_rec = synthesis_step(&left_coeffs, &filters.lo_r, target_len);
760 let hi_rec = synthesis_step(&right_coeffs, &filters.hi_r, target_len);
761 let parent_coeffs: Vec<f64> = lo_rec
762 .iter()
763 .zip(hi_rec.iter())
764 .map(|(a, b)| a + b)
765 .collect();
766
767 node_map.insert(parent_key, parent_coeffs);
768 }
769
770 // We no longer need the children at this level to save memory
771 for idx in 0..num_nodes {
772 // Only remove if both siblings have been consumed
773 let left_key = WaveletPacketNode::key(level, idx);
774 // Keep it if still needed (might be a basis leaf)
775 let _ = left_key;
776 }
777 }
778
779 // The reconstructed signal is the root (level 0, index 0)
780 let root_key = WaveletPacketNode::key(0, 0);
781 node_map.remove(&root_key).ok_or_else(|| {
782 FFTError::InternalError("reconstruction failed: root not reached".to_string())
783 })
784}
785
786// ─────────────────────────────────────────────────────────────────────────────
787// WPT Denoising
788// ─────────────────────────────────────────────────────────────────────────────
789
790/// Thresholding method for wavelet denoising.
791#[derive(Debug, Clone, Copy, PartialEq)]
792pub enum ThresholdMethod {
793 /// Hard thresholding: coefficients with |c| < τ are set to 0.
794 Hard,
795 /// Soft thresholding: shrinks coefficients toward zero by τ.
796 Soft,
797 /// Garrote (non-negative garrote): c → c - τ²/c for |c| > τ.
798 Garrote,
799 /// Firm (semi-soft): linear transition between hard and soft.
800 Firm { t2: f64 },
801}
802
803/// Apply a scalar threshold to a coefficient vector.
804fn threshold_coeffs(coeffs: &[f64], tau: f64, method: ThresholdMethod) -> Vec<f64> {
805 coeffs
806 .iter()
807 .map(|&c| apply_threshold(c, tau, method))
808 .collect()
809}
810
811/// Apply threshold to a single coefficient.
812fn apply_threshold(c: f64, tau: f64, method: ThresholdMethod) -> f64 {
813 match method {
814 ThresholdMethod::Hard => {
815 if c.abs() >= tau {
816 c
817 } else {
818 0.0
819 }
820 }
821 ThresholdMethod::Soft => {
822 if c > tau {
823 c - tau
824 } else if c < -tau {
825 c + tau
826 } else {
827 0.0
828 }
829 }
830 ThresholdMethod::Garrote => {
831 if c.abs() <= tau {
832 0.0
833 } else {
834 c - tau * tau / c
835 }
836 }
837 ThresholdMethod::Firm { t2 } => {
838 let t1 = tau;
839 let abs_c = c.abs();
840 if abs_c <= t1 {
841 0.0
842 } else if abs_c >= t2 {
843 c
844 } else {
845 // Linear ramp
846 c.signum() * t1 * (abs_c - t1) / (t2 - t1)
847 }
848 }
849 }
850}
851
852/// Denoise a signal using the Wavelet Packet Transform.
853///
854/// The procedure is:
855/// 1. Compute the full WPT up to `max_level`.
856/// 2. Select the best basis using Shannon entropy.
857/// 3. Threshold the coefficients in the best-basis nodes.
858/// 4. Reconstruct the signal.
859///
860/// # Arguments
861///
862/// * `signal` – Noisy input signal.
863/// * `wavelet` – Wavelet to use.
864/// * `max_level` – Maximum decomposition depth.
865/// * `threshold` – Threshold value τ.
866/// * `method` – Thresholding method.
867///
868/// # Errors
869///
870/// Propagates any error from `wpd`, `best_basis`, or `wp_reconstruct`.
871///
872/// # Example
873///
874/// ```
875/// use scirs2_fft::wavelet_packets::{wp_denoising, ThresholdMethod, Wavelet};
876///
877/// let signal: Vec<f64> = (0..64).map(|i| (i as f64 * 0.2).sin()).collect();
878/// let denoised = wp_denoising(&signal, Wavelet::Db4, 3, 0.05, ThresholdMethod::Soft)
879/// .expect("denoising failed");
880/// assert_eq!(denoised.len(), signal.len());
881/// ```
882pub fn wp_denoising(
883 signal: &[f64],
884 wavelet: Wavelet,
885 max_level: usize,
886 threshold: f64,
887 method: ThresholdMethod,
888) -> FFTResult<Vec<f64>> {
889 // 1. Decompose
890 let tree = wpd(signal, wavelet, max_level)?;
891
892 // 2. Best basis
893 let basis = best_basis(&tree, shannon_entropy)?;
894
895 // 3. Threshold coefficients (do NOT threshold the root / approximation leaf)
896 let thresholded: Vec<WaveletPacketNode> = basis
897 .into_iter()
898 .map(|mut node| {
899 if node.level > 0 {
900 node.coeffs = threshold_coeffs(&node.coeffs, threshold, method);
901 }
902 node
903 })
904 .collect();
905
906 // 4. Reconstruct
907 let mut recon = wp_reconstruct(&tree, &thresholded)?;
908
909 // Trim or extend to original signal length
910 recon.truncate(signal.len());
911 while recon.len() < signal.len() {
912 recon.push(0.0);
913 }
914
915 Ok(recon)
916}
917
918// ─────────────────────────────────────────────────────────────────────────────
919// Tests
920// ─────────────────────────────────────────────────────────────────────────────
921
922#[cfg(test)]
923mod tests {
924 use super::*;
925
926 /// Build a simple test signal.
927 fn test_signal(n: usize) -> Vec<f64> {
928 (0..n)
929 .map(|i| {
930 let t = i as f64 / n as f64;
931 (2.0 * std::f64::consts::PI * 5.0 * t).sin()
932 + 0.5 * (2.0 * std::f64::consts::PI * 13.0 * t).sin()
933 })
934 .collect()
935 }
936
937 #[test]
938 fn test_haar_decomp_shape() {
939 let sig = test_signal(64);
940 let tree = wpd(&sig, Wavelet::Haar, 3).expect("wpd failed");
941 // All nodes at level 3 should exist
942 for idx in 0..8 {
943 assert!(tree.get(3, idx).is_some(), "missing node (3, {idx})");
944 }
945 }
946
947 #[test]
948 fn test_qmf_energy_preservation() {
949 // lo and hi filters of db2 should each have unit energy
950 let filters = WaveletFilters::for_wavelet(Wavelet::Db2);
951 let e_lo: f64 = filters.lo_d.iter().map(|&c| c * c).sum();
952 let e_hi: f64 = filters.hi_d.iter().map(|&c| c * c).sum();
953 assert!((e_lo - 1.0).abs() < 1e-10, "lo energy {e_lo}");
954 assert!((e_hi - 1.0).abs() < 1e-10, "hi energy {e_hi}");
955 }
956
957 #[test]
958 fn test_shannon_entropy_uniform() {
959 // Uniform signal (all equal nonzero): entropy should be positive
960 let coeffs = vec![0.5_f64; 8];
961 let e = shannon_entropy(&coeffs);
962 assert!(e > 0.0, "expected positive entropy, got {e}");
963 }
964
965 #[test]
966 fn test_shannon_entropy_sparse() {
967 // A single non-zero coefficient → minimum entropy (sparse)
968 let mut coeffs = vec![0.0_f64; 64];
969 coeffs[0] = 1.0;
970 let e = shannon_entropy(&coeffs);
971 assert!((e - 0.0).abs() < 1e-12, "sparse signal entropy {e}");
972 }
973
974 #[test]
975 fn test_best_basis_returns_valid_partition() {
976 let sig = test_signal(64);
977 let tree = wpd(&sig, Wavelet::Db2, 3).expect("wpd");
978 let basis = best_basis(&tree, shannon_entropy).expect("best_basis");
979
980 // Basis must be non-empty
981 assert!(!basis.is_empty(), "basis is empty");
982
983 // All nodes in basis must exist in the tree
984 for node in &basis {
985 assert!(
986 tree.get(node.level, node.index).is_some(),
987 "basis node ({}, {}) not in tree",
988 node.level,
989 node.index
990 );
991 }
992 }
993
994 #[test]
995 fn test_haar_perfect_reconstruction() {
996 let sig = test_signal(64);
997 let tree = wpd(&sig, Wavelet::Haar, 2).expect("wpd");
998 // Use all leaf nodes as basis (no simplification)
999 let basis: Vec<WaveletPacketNode> = (0..4_usize)
1000 .filter_map(|idx| tree.get(2, idx).cloned())
1001 .collect();
1002 let recon = wp_reconstruct(&tree, &basis).expect("recon");
1003 for (i, (&s, &r)) in sig.iter().zip(recon.iter()).enumerate() {
1004 assert!(
1005 (s - r).abs() < 1e-10,
1006 "mismatch at {i}: orig={s}, recon={r}"
1007 );
1008 }
1009 }
1010
1011 #[test]
1012 fn test_denoising_length_preserved() {
1013 let sig = test_signal(64);
1014 let denoised =
1015 wp_denoising(&sig, Wavelet::Db4, 3, 0.1, ThresholdMethod::Soft).expect("denoise");
1016 assert_eq!(denoised.len(), sig.len());
1017 }
1018
1019 #[test]
1020 fn test_threshold_hard() {
1021 let coeffs = vec![1.0, -0.5, 0.3, -0.1, 2.0];
1022 let out = threshold_coeffs(&coeffs, 0.4, ThresholdMethod::Hard);
1023 assert_eq!(out, vec![1.0, -0.5, 0.0, 0.0, 2.0]);
1024 }
1025
1026 #[test]
1027 fn test_threshold_soft() {
1028 let out = threshold_coeffs(&[1.0, -1.5, 0.2], 0.5, ThresholdMethod::Soft);
1029 assert!((out[0] - 0.5).abs() < 1e-12);
1030 assert!((out[1] - (-1.0)).abs() < 1e-12);
1031 assert!((out[2] - 0.0).abs() < 1e-12);
1032 }
1033
1034 #[test]
1035 fn test_wpd_error_on_empty() {
1036 let result = wpd(&[], Wavelet::Haar, 2);
1037 assert!(result.is_err());
1038 }
1039
1040 #[test]
1041 fn test_wpd_error_on_zero_level() {
1042 let result = wpd(&[1.0, 2.0, 3.0], Wavelet::Haar, 0);
1043 assert!(result.is_err());
1044 }
1045}