Skip to main content

oxiphysics_core/wavelet_transform/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5#![allow(clippy::needless_range_loop)]
6use std::f64::consts::PI;
7
8use super::types::{
9    CwtResult, DwtDecomposition, DwtLevel, ModwtDecomposition, MotherWavelet,
10    MultiresolutionAnalysis, Scalogram, SwtDecomposition, ThresholdMode, WaveletFamily,
11    WaveletPacketNode, WaveletPacketTree,
12};
13
14/// Haar low-pass filter coefficients (db1).
15pub(super) const HAAR_LO: [f64; 2] = [
16    std::f64::consts::FRAC_1_SQRT_2,
17    std::f64::consts::FRAC_1_SQRT_2,
18];
19/// Haar high-pass filter coefficients (db1).
20pub(super) const HAAR_HI: [f64; 2] = [
21    std::f64::consts::FRAC_1_SQRT_2,
22    -std::f64::consts::FRAC_1_SQRT_2,
23];
24/// Daubechies-2 low-pass filter coefficients.
25pub(super) const DB2_LO: [f64; 4] = [
26    0.4829629131445341,
27    0.8365163037378079,
28    0.2241438680420134,
29    -0.1294095225512604,
30];
31/// Daubechies-2 high-pass filter coefficients.
32pub(super) const DB2_HI: [f64; 4] = [
33    -0.1294095225512604,
34    -0.2241438680420134,
35    0.8365163037378079,
36    -0.4829629131445341,
37];
38/// Daubechies-3 low-pass filter coefficients.
39pub(super) const DB3_LO: [f64; 6] = [
40    0.3326705529500826,
41    0.8068915093110925,
42    0.4598775021184915,
43    -0.1350110200102546,
44    -0.0854412738820267,
45    0.0352262918857095,
46];
47/// Daubechies-3 high-pass filter coefficients.
48pub(super) const DB3_HI: [f64; 6] = [
49    0.0352262918857095,
50    0.0854412738820267,
51    -0.1350110200102546,
52    -0.4598775021184915,
53    0.8068915093110925,
54    -0.3326705529500826,
55];
56/// Daubechies-4 low-pass filter coefficients.
57pub(super) const DB4_LO: [f64; 8] = [
58    0.2303778133088965,
59    0.7148465705529156,
60    0.6308807679298589,
61    -0.0279837694169839,
62    -0.1870348117190931,
63    0.0308413818355607,
64    0.0328830116668852,
65    -0.0105974017850690,
66];
67/// Daubechies-4 high-pass filter coefficients.
68pub(super) const DB4_HI: [f64; 8] = [
69    -0.0105974017850690,
70    -0.0328830116668852,
71    0.0308413818355607,
72    0.1870348117190931,
73    -0.0279837694169839,
74    -0.6308807679298589,
75    0.7148465705529156,
76    -0.2303778133088965,
77];
78/// Daubechies-5 low-pass filter coefficients.
79pub(super) const DB5_LO: [f64; 10] = [
80    0.160_102_397_974_193,
81    0.6038292697971898,
82    0.7243085284377729,
83    0.1384281459013204,
84    -0.2422948870663824,
85    -0.0322448695846381,
86    0.0775714938400459,
87    -0.0062414902127983,
88    -0.0125807519990820,
89    0.0033357252854738,
90];
91/// Daubechies-5 high-pass filter coefficients.
92pub(super) const DB5_HI: [f64; 10] = [
93    0.0033357252854738,
94    0.0125807519990820,
95    -0.0062414902127983,
96    -0.0775714938400459,
97    -0.0322448695846381,
98    0.2422948870663824,
99    0.1384281459013204,
100    -0.7243085284377729,
101    0.6038292697971898,
102    -0.160_102_397_974_193,
103];
104/// Daubechies-6 low-pass filter coefficients.
105pub(super) const DB6_LO: [f64; 12] = [
106    0.1115407433501095,
107    0.4946238903984533,
108    0.7511339080210959,
109    0.3152503517091982,
110    -0.226_264_693_965_44,
111    -0.1297668675672625,
112    0.0975016055873225,
113    0.0275228655303053,
114    -0.0315820393174862,
115    0.0005538422011614,
116    0.0047772575109455,
117    -0.0010773010853085,
118];
119/// Daubechies-6 high-pass filter coefficients.
120pub(super) const DB6_HI: [f64; 12] = [
121    -0.0010773010853085,
122    -0.0047772575109455,
123    0.0005538422011614,
124    0.0315820393174862,
125    0.0275228655303053,
126    -0.0975016055873225,
127    -0.1297668675672625,
128    0.226_264_693_965_44,
129    0.3152503517091982,
130    -0.7511339080210959,
131    0.4946238903984533,
132    -0.1115407433501095,
133];
134/// Convolve `signal` with `filter` and downsample by 2 (periodic extension).
135pub fn convolve_downsample(signal: &[f64], filter: &[f64]) -> Vec<f64> {
136    let n = signal.len();
137    let flen = filter.len();
138    let out_len = (n + flen - 1) / 2;
139    let mut result = Vec::with_capacity(out_len);
140    for i in 0..out_len {
141        let mut sum = 0.0;
142        for (j, &h) in filter.iter().enumerate() {
143            let idx = (2 * i + flen - 1).wrapping_sub(j) % n;
144            sum += h * signal[idx];
145        }
146        result.push(sum);
147    }
148    result
149}
150/// Upsample by 2 (zero-insertion) and convolve with `filter` (periodic extension).
151pub fn upsample_convolve(coeffs: &[f64], filter: &[f64], target_len: usize) -> Vec<f64> {
152    let _flen = filter.len();
153    let mut result = vec![0.0; target_len];
154    for (k, &c) in coeffs.iter().enumerate() {
155        for (j, &h) in filter.iter().enumerate() {
156            let idx = (2 * k + j) % target_len;
157            result[idx] += c * h;
158        }
159    }
160    result
161}
162/// Perform a single-level forward DWT.
163pub fn dwt_single(signal: &[f64], wavelet: WaveletFamily) -> DwtLevel {
164    let lo = wavelet.lo_dec();
165    let hi = wavelet.hi_dec();
166    DwtLevel {
167        approx: convolve_downsample(signal, lo),
168        detail: convolve_downsample(signal, hi),
169    }
170}
171/// Perform a single-level inverse DWT.
172pub fn idwt_single(level: &DwtLevel, wavelet: WaveletFamily, target_len: usize) -> Vec<f64> {
173    let lo_r = wavelet.lo_rec();
174    let hi_r = wavelet.hi_rec();
175    let a = upsample_convolve(&level.approx, &lo_r, target_len);
176    let d = upsample_convolve(&level.detail, &hi_r, target_len);
177    a.iter().zip(d.iter()).map(|(&x, &y)| x + y).collect()
178}
179/// Perform multi-level forward DWT decomposition.
180///
181/// `levels` is the number of decomposition levels. If the signal is too
182/// short for the requested number of levels, fewer levels are computed.
183pub fn dwt(signal: &[f64], wavelet: WaveletFamily, levels: usize) -> DwtDecomposition {
184    let mut current = signal.to_vec();
185    let mut details = Vec::with_capacity(levels);
186    let mut lengths = Vec::with_capacity(levels);
187    let filter_len = wavelet.filter_length();
188    for _level in 0..levels {
189        if current.len() < filter_len {
190            break;
191        }
192        lengths.push(current.len());
193        let decomp = dwt_single(&current, wavelet);
194        details.push(decomp.detail);
195        current = decomp.approx;
196    }
197    DwtDecomposition {
198        details,
199        approx: current,
200        wavelet,
201        lengths,
202    }
203}
204/// Perform multi-level inverse DWT reconstruction.
205pub fn idwt(decomp: &DwtDecomposition) -> Vec<f64> {
206    let mut current = decomp.approx.clone();
207    let n_levels = decomp.details.len();
208    for i in (0..n_levels).rev() {
209        let target_len = decomp.lengths[i];
210        let level = DwtLevel {
211            approx: current,
212            detail: decomp.details[i].clone(),
213        };
214        current = idwt_single(&level, decomp.wavelet, target_len);
215    }
216    current
217}
218/// Perform multiresolution analysis using the given wavelet.
219///
220/// Returns approximations and detail contributions at each level.
221pub fn multiresolution_analysis(
222    signal: &[f64],
223    wavelet: WaveletFamily,
224    levels: usize,
225) -> MultiresolutionAnalysis {
226    let decomp = dwt(signal, wavelet, levels);
227    let n_levels = decomp.details.len();
228    let mut approximations = Vec::with_capacity(n_levels + 1);
229    let mut detail_contributions = Vec::with_capacity(n_levels);
230    let mut current_approx = decomp.approx.clone();
231    approximations.push(current_approx.clone());
232    for i in (0..n_levels).rev() {
233        let target_len = decomp.lengths[i];
234        let zero_approx = vec![0.0; current_approx.len()];
235        let detail_level = DwtLevel {
236            approx: zero_approx,
237            detail: decomp.details[i].clone(),
238        };
239        let detail_contrib = idwt_single(&detail_level, wavelet, target_len);
240        detail_contributions.push(detail_contrib);
241        let full_level = DwtLevel {
242            approx: current_approx,
243            detail: decomp.details[i].clone(),
244        };
245        current_approx = idwt_single(&full_level, wavelet, target_len);
246        approximations.push(current_approx.clone());
247    }
248    detail_contributions.reverse();
249    MultiresolutionAnalysis {
250        approximations,
251        detail_contributions,
252    }
253}
254/// Compute Shannon entropy of a coefficient vector (normalized).
255pub fn shannon_entropy(coeffs: &[f64]) -> f64 {
256    let total_energy: f64 = coeffs.iter().map(|&c| c * c).sum();
257    if total_energy < 1e-30 {
258        return 0.0;
259    }
260    let mut entropy = 0.0;
261    for &c in coeffs {
262        let p = (c * c) / total_energy;
263        if p > 1e-30 {
264            entropy -= p * p.ln();
265        }
266    }
267    entropy
268}
269/// Perform wavelet packet decomposition.
270///
271/// Builds a full binary tree of wavelet packet nodes up to `max_level` levels.
272pub fn wavelet_packet_decompose(
273    signal: &[f64],
274    wavelet: WaveletFamily,
275    max_level: usize,
276) -> WaveletPacketTree {
277    let root = WaveletPacketNode {
278        coefficients: signal.to_vec(),
279        level: 0,
280        position: 0,
281        entropy: shannon_entropy(signal),
282    };
283    let mut nodes: Vec<Vec<WaveletPacketNode>> = vec![vec![root]];
284    let filter_len = wavelet.filter_length();
285    for level in 0..max_level {
286        let mut next_level = Vec::new();
287        for node in &nodes[level] {
288            if node.coefficients.len() < filter_len {
289                continue;
290            }
291            let lo = wavelet.lo_dec();
292            let hi = wavelet.hi_dec();
293            let approx = convolve_downsample(&node.coefficients, lo);
294            let detail = convolve_downsample(&node.coefficients, hi);
295            next_level.push(WaveletPacketNode {
296                entropy: shannon_entropy(&approx),
297                coefficients: approx,
298                level: level + 1,
299                position: 2 * node.position,
300            });
301            next_level.push(WaveletPacketNode {
302                entropy: shannon_entropy(&detail),
303                coefficients: detail,
304                level: level + 1,
305                position: 2 * node.position + 1,
306            });
307        }
308        if next_level.is_empty() {
309            break;
310        }
311        nodes.push(next_level);
312    }
313    WaveletPacketTree {
314        nodes,
315        wavelet,
316        max_level,
317    }
318}
319/// Select best basis from wavelet packet tree using minimum entropy criterion.
320///
321/// Returns indices (level, position) of the selected nodes.
322pub fn best_basis_selection(tree: &WaveletPacketTree) -> Vec<(usize, usize)> {
323    let max_lvl = tree.nodes.len() - 1;
324    if max_lvl == 0 {
325        return vec![(0, 0)];
326    }
327    let mut selected: Vec<Vec<bool>> = tree
328        .nodes
329        .iter()
330        .map(|level| vec![true; level.len()])
331        .collect();
332    for level in (0..max_lvl).rev() {
333        for (i, node) in tree.nodes[level].iter().enumerate() {
334            let child_base = 2 * i;
335            if level + 1 < tree.nodes.len() && child_base + 1 < tree.nodes[level + 1].len() {
336                let child_entropy = tree.nodes[level + 1][child_base].entropy
337                    + tree.nodes[level + 1][child_base + 1].entropy;
338                if node.entropy <= child_entropy {
339                    selected[level][i] = true;
340                    deselect_subtree(&mut selected, level + 1, child_base);
341                    deselect_subtree(&mut selected, level + 1, child_base + 1);
342                } else {
343                    selected[level][i] = false;
344                }
345            }
346        }
347    }
348    let mut result = Vec::new();
349    for (level, level_selected) in selected.iter().enumerate() {
350        for (pos, &sel) in level_selected.iter().enumerate() {
351            if sel {
352                result.push((level, pos));
353            }
354        }
355    }
356    result
357}
358/// Recursively deselect a subtree rooted at (level, position).
359pub fn deselect_subtree(selected: &mut [Vec<bool>], level: usize, pos: usize) {
360    if level >= selected.len() || pos >= selected[level].len() {
361        return;
362    }
363    selected[level][pos] = false;
364    if level + 1 < selected.len() {
365        deselect_subtree(selected, level + 1, 2 * pos);
366        deselect_subtree(selected, level + 1, 2 * pos + 1);
367    }
368}
369/// Compute the continuous wavelet transform.
370///
371/// # Arguments
372/// * `signal` - Input signal.
373/// * `scales` - Array of scales at which to compute the CWT.
374/// * `wavelet` - Mother wavelet to use.
375/// * `dt` - Sampling period.
376pub fn cwt(signal: &[f64], scales: &[f64], wavelet: MotherWavelet, dt: f64) -> CwtResult {
377    let n = signal.len();
378    let mut coefficients = Vec::with_capacity(scales.len());
379    for &scale in scales {
380        let mut row = Vec::with_capacity(n);
381        let norm = (dt / scale).sqrt();
382        for b in 0..n {
383            let mut sum = 0.0;
384            for (k, &s) in signal.iter().enumerate() {
385                let t = ((k as f64) - (b as f64)) * dt / scale;
386                sum += s * wavelet.evaluate(t);
387            }
388            row.push(sum * norm * dt);
389        }
390        coefficients.push(row);
391    }
392    CwtResult {
393        coefficients,
394        scales: scales.to_vec(),
395        wavelet,
396    }
397}
398/// Compute the inverse CWT using the Morlet reconstruction formula.
399///
400/// This is an approximate reconstruction using the admissibility constant.
401pub fn icwt(cwt_result: &CwtResult, dt: f64) -> Vec<f64> {
402    let n = if cwt_result.coefficients.is_empty() {
403        0
404    } else {
405        cwt_result.coefficients[0].len()
406    };
407    let mut signal = vec![0.0; n];
408    if cwt_result.scales.is_empty() || n == 0 {
409        return signal;
410    }
411    let c_psi = 1.0;
412    for (si, &scale) in cwt_result.scales.iter().enumerate() {
413        let norm = 1.0 / (scale * scale);
414        for b in 0..n {
415            signal[b] += cwt_result.coefficients[si][b] * norm * dt;
416        }
417    }
418    let dj = if cwt_result.scales.len() > 1 {
419        (cwt_result.scales[1] / cwt_result.scales[0]).ln()
420    } else {
421        1.0
422    };
423    for val in &mut signal {
424        *val *= dj / c_psi;
425    }
426    signal
427}
428/// Compute the scalogram from a CWT result.
429pub fn scalogram(cwt_result: &CwtResult) -> Scalogram {
430    let mut energy = Vec::with_capacity(cwt_result.coefficients.len());
431    let mut scale_energy = Vec::with_capacity(cwt_result.scales.len());
432    for row in &cwt_result.coefficients {
433        let e_row: Vec<f64> = row.iter().map(|&c| c * c).collect();
434        let total: f64 = e_row.iter().sum();
435        scale_energy.push(total);
436        energy.push(e_row);
437    }
438    Scalogram {
439        energy,
440        scales: cwt_result.scales.clone(),
441        scale_energy,
442    }
443}
444/// Compute the global wavelet spectrum (time-averaged energy at each scale).
445pub fn global_wavelet_spectrum(scalo: &Scalogram) -> Vec<f64> {
446    scalo
447        .energy
448        .iter()
449        .map(|row| {
450            if row.is_empty() {
451                0.0
452            } else {
453                row.iter().sum::<f64>() / row.len() as f64
454            }
455        })
456        .collect()
457}
458/// Apply thresholding to a coefficient vector.
459pub fn apply_threshold(coeffs: &[f64], threshold: f64, mode: ThresholdMode) -> Vec<f64> {
460    match mode {
461        ThresholdMode::Hard => coeffs
462            .iter()
463            .map(|&c| if c.abs() < threshold { 0.0 } else { c })
464            .collect(),
465        ThresholdMode::Soft => coeffs
466            .iter()
467            .map(|&c| {
468                if c.abs() < threshold {
469                    0.0
470                } else {
471                    c.signum() * (c.abs() - threshold)
472                }
473            })
474            .collect(),
475    }
476}
477/// Compute the universal (VisuShrink) threshold.
478///
479/// `sigma` is the noise standard deviation.
480/// `n` is the signal length.
481pub fn universal_threshold(sigma: f64, n: usize) -> f64 {
482    sigma * (2.0 * (n as f64).ln()).sqrt()
483}
484/// Estimate noise standard deviation from the finest detail coefficients.
485///
486/// Uses the median absolute deviation (MAD) estimator.
487pub fn estimate_noise_sigma(detail_coeffs: &[f64]) -> f64 {
488    if detail_coeffs.is_empty() {
489        return 0.0;
490    }
491    let mut abs_coeffs: Vec<f64> = detail_coeffs.iter().map(|c| c.abs()).collect();
492    abs_coeffs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
493    let median = if abs_coeffs.len().is_multiple_of(2) {
494        (abs_coeffs[abs_coeffs.len() / 2 - 1] + abs_coeffs[abs_coeffs.len() / 2]) / 2.0
495    } else {
496        abs_coeffs[abs_coeffs.len() / 2]
497    };
498    median / 0.6745
499}
500/// Denoise a signal using wavelet thresholding.
501///
502/// # Arguments
503/// * `signal` - Noisy input signal.
504/// * `wavelet` - Wavelet family to use.
505/// * `levels` - Number of decomposition levels.
506/// * `mode` - Thresholding mode (hard or soft).
507/// * `threshold` - Optional threshold value. If `None`, uses universal threshold.
508#[allow(clippy::too_many_arguments)]
509pub fn wavelet_denoise(
510    signal: &[f64],
511    wavelet: WaveletFamily,
512    levels: usize,
513    mode: ThresholdMode,
514    threshold: Option<f64>,
515) -> Vec<f64> {
516    let mut decomp = dwt(signal, wavelet, levels);
517    let thresh = threshold.unwrap_or_else(|| {
518        if decomp.details.is_empty() {
519            0.0
520        } else {
521            let sigma = estimate_noise_sigma(&decomp.details[0]);
522            universal_threshold(sigma, signal.len())
523        }
524    });
525    for detail in &mut decomp.details {
526        *detail = apply_threshold(detail, thresh, mode);
527    }
528    idwt(&decomp)
529}
530/// Energy of wavelet coefficients at a given level.
531pub fn level_energy(coeffs: &[f64]) -> f64 {
532    coeffs.iter().map(|&c| c * c).sum()
533}
534/// Compute energy distribution across all DWT levels.
535///
536/// Returns a vector of (level_index, energy) pairs, plus the approximation energy.
537pub fn energy_distribution(decomp: &DwtDecomposition) -> (Vec<f64>, f64) {
538    let detail_energies: Vec<f64> = decomp.details.iter().map(|d| level_energy(d)).collect();
539    let approx_energy = level_energy(&decomp.approx);
540    (detail_energies, approx_energy)
541}
542/// Compute the relative energy (percentage) at each level.
543pub fn relative_energy(decomp: &DwtDecomposition) -> Vec<f64> {
544    let (detail_energies, approx_energy) = energy_distribution(decomp);
545    let total: f64 = detail_energies.iter().sum::<f64>() + approx_energy;
546    if total < 1e-30 {
547        return vec![0.0; detail_energies.len() + 1];
548    }
549    let mut result: Vec<f64> = detail_energies.iter().map(|&e| e / total).collect();
550    result.push(approx_energy / total);
551    result
552}
553/// Compute the wavelet entropy (measure of signal complexity).
554pub fn wavelet_entropy(decomp: &DwtDecomposition) -> f64 {
555    let (detail_energies, approx_energy) = energy_distribution(decomp);
556    let total: f64 = detail_energies.iter().sum::<f64>() + approx_energy;
557    if total < 1e-30 {
558        return 0.0;
559    }
560    let mut entropy = 0.0;
561    for &e in detail_energies
562        .iter()
563        .chain(std::iter::once(&approx_energy))
564    {
565        let p = e / total;
566        if p > 1e-30 {
567            entropy -= p * p.ln();
568        }
569    }
570    entropy
571}
572/// Perform a stationary wavelet transform (SWT).
573///
574/// Also known as the "algorithme a trous" (with holes).
575pub fn swt(signal: &[f64], wavelet: WaveletFamily, levels: usize) -> SwtDecomposition {
576    let n = signal.len();
577    let lo = wavelet.lo_dec();
578    let hi = wavelet.hi_dec();
579    let flen = lo.len();
580    let mut current = signal.to_vec();
581    let mut details = Vec::with_capacity(levels);
582    for level in 0..levels {
583        let step = 1_usize << level;
584        let mut approx = vec![0.0; n];
585        let mut detail = vec![0.0; n];
586        for i in 0..n {
587            let mut sum_lo = 0.0;
588            let mut sum_hi = 0.0;
589            for j in 0..flen {
590                let idx = (i + j * step) % n;
591                sum_lo += lo[j] * current[idx];
592                sum_hi += hi[j] * current[idx];
593            }
594            approx[i] = sum_lo;
595            detail[i] = sum_hi;
596        }
597        details.push(detail);
598        current = approx;
599    }
600    SwtDecomposition {
601        details,
602        approx: current,
603        wavelet,
604    }
605}
606/// Compute wavelet cross-spectrum between two CWT results.
607///
608/// Returns the cross-spectrum matrix (scale x time).
609pub fn wavelet_cross_spectrum(cwt_x: &CwtResult, cwt_y: &CwtResult) -> Vec<Vec<f64>> {
610    let n_scales = cwt_x.scales.len().min(cwt_y.scales.len());
611    let mut cross = Vec::with_capacity(n_scales);
612    for s in 0..n_scales {
613        let n_time = cwt_x.coefficients[s].len().min(cwt_y.coefficients[s].len());
614        let row: Vec<f64> = (0..n_time)
615            .map(|t| cwt_x.coefficients[s][t] * cwt_y.coefficients[s][t])
616            .collect();
617        cross.push(row);
618    }
619    cross
620}
621/// Compute wavelet coherence between two signals.
622///
623/// Returns values between 0 and 1 at each (scale, time) point.
624pub fn wavelet_coherence(cwt_x: &CwtResult, cwt_y: &CwtResult, smoothing: usize) -> Vec<Vec<f64>> {
625    let cross = wavelet_cross_spectrum(cwt_x, cwt_y);
626    let scalo_x = scalogram(cwt_x);
627    let scalo_y = scalogram(cwt_y);
628    let n_scales = cross.len();
629    let mut coherence = Vec::with_capacity(n_scales);
630    for s in 0..n_scales {
631        let n_time = cross[s].len();
632        let mut coh_row = Vec::with_capacity(n_time);
633        for t in 0..n_time {
634            let lo = t.saturating_sub(smoothing);
635            let hi_t = (t + smoothing + 1).min(n_time);
636            let mut sum_cross = 0.0;
637            let mut sum_xx = 0.0;
638            let mut sum_yy = 0.0;
639            for k in lo..hi_t {
640                sum_cross += cross[s][k];
641                sum_xx += scalo_x.energy[s][k];
642                sum_yy += scalo_y.energy[s][k];
643            }
644            let denom = (sum_xx * sum_yy).sqrt();
645            let c = if denom > 1e-30 {
646                (sum_cross.abs() / denom).min(1.0)
647            } else {
648                0.0
649            };
650            coh_row.push(c);
651        }
652        coherence.push(coh_row);
653    }
654    coherence
655}
656/// Compute wavelet variance (energy per scale, normalized by signal length).
657pub fn wavelet_variance(cwt_result: &CwtResult) -> Vec<f64> {
658    cwt_result
659        .coefficients
660        .iter()
661        .map(|row| {
662            let n = row.len().max(1) as f64;
663            row.iter().map(|&c| c * c).sum::<f64>() / n
664        })
665        .collect()
666}
667/// Compute instantaneous wavelet power at each (scale, time) point.
668pub fn wavelet_power(cwt_result: &CwtResult) -> Vec<Vec<f64>> {
669    cwt_result
670        .coefficients
671        .iter()
672        .map(|row| row.iter().map(|&c| c * c).collect())
673        .collect()
674}
675/// Generate logarithmically spaced scales for CWT.
676///
677/// # Arguments
678/// * `s0` - Smallest scale.
679/// * `num_scales` - Number of scales.
680/// * `dj` - Scale spacing in octaves (e.g. 0.25 for 4 scales per octave).
681pub fn log_scales(s0: f64, num_scales: usize, dj: f64) -> Vec<f64> {
682    (0..num_scales)
683        .map(|j| s0 * (2.0_f64).powf(j as f64 * dj))
684        .collect()
685}
686/// Generate linearly spaced scales for CWT.
687pub fn linear_scales(s_min: f64, s_max: f64, num_scales: usize) -> Vec<f64> {
688    if num_scales <= 1 {
689        return vec![s_min];
690    }
691    let step = (s_max - s_min) / (num_scales - 1) as f64;
692    (0..num_scales).map(|i| s_min + i as f64 * step).collect()
693}
694/// Detect ridges in a CWT scalogram (local maxima along scale axis).
695///
696/// Returns positions of ridge points as (scale_index, time_index).
697pub fn detect_ridges(scalo: &Scalogram) -> Vec<(usize, usize)> {
698    let n_scales = scalo.energy.len();
699    if n_scales < 3 {
700        return Vec::new();
701    }
702    let mut ridges = Vec::new();
703    let n_time = if scalo.energy.is_empty() {
704        0
705    } else {
706        scalo.energy[0].len()
707    };
708    for t in 0..n_time {
709        for s in 1..n_scales - 1 {
710            if s < scalo.energy.len()
711                && t < scalo.energy[s].len()
712                && t < scalo.energy[s - 1].len()
713                && t < scalo.energy[s + 1].len()
714            {
715                let e = scalo.energy[s][t];
716                if e > scalo.energy[s - 1][t] && e > scalo.energy[s + 1][t] && e > 1e-30 {
717                    ridges.push((s, t));
718                }
719            }
720        }
721    }
722    ridges
723}
724/// Compress a signal by keeping only the top `fraction` of wavelet coefficients.
725///
726/// Returns the denoised/compressed signal.
727pub fn wavelet_compress(
728    signal: &[f64],
729    wavelet: WaveletFamily,
730    levels: usize,
731    fraction: f64,
732) -> Vec<f64> {
733    let mut decomp = dwt(signal, wavelet, levels);
734    let mut all_mags: Vec<f64> = decomp
735        .details
736        .iter()
737        .flat_map(|d| d.iter().map(|&c| c.abs()))
738        .chain(decomp.approx.iter().map(|&c| c.abs()))
739        .collect();
740    all_mags.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
741    let keep_count = ((all_mags.len() as f64 * fraction).ceil() as usize).min(all_mags.len());
742    let threshold = if keep_count > 0 && keep_count <= all_mags.len() {
743        all_mags[keep_count.saturating_sub(1)]
744    } else {
745        0.0
746    };
747    for detail in &mut decomp.details {
748        for c in detail.iter_mut() {
749            if c.abs() < threshold {
750                *c = 0.0;
751            }
752        }
753    }
754    for c in decomp.approx.iter_mut() {
755        if c.abs() < threshold {
756            *c = 0.0;
757        }
758    }
759    idwt(&decomp)
760}
761/// Compute the compression ratio (fraction of non-zero coefficients).
762pub fn compression_ratio(decomp: &DwtDecomposition) -> f64 {
763    let total: usize = decomp.details.iter().map(|d| d.len()).sum::<usize>() + decomp.approx.len();
764    let nonzero: usize = decomp
765        .details
766        .iter()
767        .flat_map(|d| d.iter())
768        .chain(decomp.approx.iter())
769        .filter(|&&c| c.abs() > 1e-30)
770        .count();
771    if total == 0 {
772        0.0
773    } else {
774        nonzero as f64 / total as f64
775    }
776}
777/// Compute the MODWT (non-decimated DWT with rescaled filters).
778pub fn modwt(signal: &[f64], wavelet: WaveletFamily, levels: usize) -> ModwtDecomposition {
779    let n = signal.len();
780    let lo_orig = wavelet.lo_dec();
781    let hi_orig = wavelet.hi_dec();
782    let flen = lo_orig.len();
783    let scale = std::f64::consts::FRAC_1_SQRT_2;
784    let lo: Vec<f64> = lo_orig.iter().map(|&h| h * scale).collect();
785    let hi: Vec<f64> = hi_orig.iter().map(|&h| h * scale).collect();
786    let mut current = signal.to_vec();
787    let mut details = Vec::with_capacity(levels);
788    for level in 0..levels {
789        let step = 1_usize << level;
790        let mut approx = vec![0.0; n];
791        let mut detail = vec![0.0; n];
792        for i in 0..n {
793            let mut s_lo = 0.0;
794            let mut s_hi = 0.0;
795            for j in 0..flen {
796                let idx = (i + n - j * step) % n;
797                s_lo += lo[j] * current[idx];
798                s_hi += hi[j] * current[idx];
799            }
800            approx[i] = s_lo;
801            detail[i] = s_hi;
802        }
803        details.push(detail);
804        current = approx;
805    }
806    ModwtDecomposition {
807        details,
808        approx: current,
809        wavelet,
810    }
811}
812/// Compute the cone of influence for CWT.
813///
814/// The cone of influence defines the region where edge effects become important.
815/// Returns the e-folding time for each scale.
816pub fn cone_of_influence(scales: &[f64], wavelet: MotherWavelet) -> Vec<f64> {
817    scales
818        .iter()
819        .map(|&s| match wavelet {
820            MotherWavelet::Morlet { .. } => s * 2.0_f64.sqrt(),
821            MotherWavelet::MexicanHat => s * (2.0_f64.sqrt()),
822        })
823        .collect()
824}
825/// Convert CWT scale to approximate pseudo-frequency.
826///
827/// f = center_frequency / (scale * dt)
828pub fn scale_to_frequency(scale: f64, dt: f64, wavelet: MotherWavelet) -> f64 {
829    let center_freq = match wavelet {
830        MotherWavelet::Morlet { omega0 } => omega0 / (2.0 * PI),
831        MotherWavelet::MexicanHat => 2.0 / (PI * (2.0_f64 / 3.0).sqrt()),
832    };
833    center_freq / (scale * dt)
834}
835/// Convert frequency to CWT scale.
836pub fn frequency_to_scale(freq: f64, dt: f64, wavelet: MotherWavelet) -> f64 {
837    let center_freq = match wavelet {
838        MotherWavelet::Morlet { omega0 } => omega0 / (2.0 * PI),
839        MotherWavelet::MexicanHat => 2.0 / (PI * (2.0_f64 / 3.0).sqrt()),
840    };
841    center_freq / (freq * dt)
842}
843/// Compute wavelet-based signal features for classification/analysis.
844///
845/// Returns: mean energy, std energy, max coefficient, min coefficient per level.
846pub fn wavelet_features(decomp: &DwtDecomposition) -> Vec<[f64; 4]> {
847    let mut features = Vec::new();
848    for detail in &decomp.details {
849        if detail.is_empty() {
850            features.push([0.0; 4]);
851            continue;
852        }
853        let n = detail.len() as f64;
854        let energy: f64 = detail.iter().map(|&c| c * c).sum();
855        let mean_energy = energy / n;
856        let mean = detail.iter().sum::<f64>() / n;
857        let var = detail.iter().map(|&c| (c - mean).powi(2)).sum::<f64>() / n;
858        let std_energy = var.sqrt();
859        let max_c = detail.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
860        let min_c = detail.iter().cloned().fold(f64::INFINITY, f64::min);
861        features.push([mean_energy, std_energy, max_c, min_c]);
862    }
863    if !decomp.approx.is_empty() {
864        let n = decomp.approx.len() as f64;
865        let energy: f64 = decomp.approx.iter().map(|&c| c * c).sum();
866        let mean = decomp.approx.iter().sum::<f64>() / n;
867        let var = decomp
868            .approx
869            .iter()
870            .map(|&c| (c - mean).powi(2))
871            .sum::<f64>()
872            / n;
873        let max_c = decomp
874            .approx
875            .iter()
876            .cloned()
877            .fold(f64::NEG_INFINITY, f64::max);
878        let min_c = decomp.approx.iter().cloned().fold(f64::INFINITY, f64::min);
879        features.push([energy / n, var.sqrt(), max_c, min_c]);
880    }
881    features
882}
883/// Compute the reconstruction error (L2 norm) between original and reconstructed signals.
884pub fn reconstruction_error(original: &[f64], reconstructed: &[f64]) -> f64 {
885    let n = original.len().min(reconstructed.len());
886    let err: f64 = (0..n)
887        .map(|i| (original[i] - reconstructed[i]).powi(2))
888        .sum();
889    (err / n as f64).sqrt()
890}
891/// Compute the signal-to-noise ratio (SNR) of the reconstruction.
892pub fn reconstruction_snr(original: &[f64], reconstructed: &[f64]) -> f64 {
893    let n = original.len().min(reconstructed.len());
894    let signal_power: f64 = (0..n).map(|i| original[i].powi(2)).sum();
895    let noise_power: f64 = (0..n)
896        .map(|i| (original[i] - reconstructed[i]).powi(2))
897        .sum();
898    if noise_power < 1e-30 {
899        return f64::INFINITY;
900    }
901    10.0 * (signal_power / noise_power).log10()
902}
903/// Compute the BayesShrink adaptive threshold for each level.
904///
905/// The BayesShrink threshold is: sigma_noise^2 / sigma_signal
906/// where sigma_signal = max(0, sigma_x^2 - sigma_noise^2).
907pub fn bayes_shrink_threshold(detail_coeffs: &[f64], noise_sigma: f64) -> f64 {
908    let n = detail_coeffs.len() as f64;
909    if n < 1.0 {
910        return 0.0;
911    }
912    let sigma_x_sq = detail_coeffs.iter().map(|&c| c * c).sum::<f64>() / n;
913    let sigma_n_sq = noise_sigma * noise_sigma;
914    let sigma_s_sq = (sigma_x_sq - sigma_n_sq).max(0.0);
915    if sigma_s_sq < 1e-30 {
916        detail_coeffs
917            .iter()
918            .map(|c| c.abs())
919            .fold(0.0_f64, f64::max)
920    } else {
921        sigma_n_sq / sigma_s_sq.sqrt()
922    }
923}
924/// Denoise using BayesShrink (adaptive, level-dependent threshold).
925pub fn bayes_shrink_denoise(
926    signal: &[f64],
927    wavelet: WaveletFamily,
928    levels: usize,
929    mode: ThresholdMode,
930) -> Vec<f64> {
931    let mut decomp = dwt(signal, wavelet, levels);
932    let noise_sigma = if decomp.details.is_empty() {
933        0.0
934    } else {
935        estimate_noise_sigma(&decomp.details[0])
936    };
937    for detail in &mut decomp.details {
938        let thresh = bayes_shrink_threshold(detail, noise_sigma);
939        *detail = apply_threshold(detail, thresh, mode);
940    }
941    idwt(&decomp)
942}