kizzasi_io/signal/
wavelets.rs

1//! Wavelet transforms and analysis
2//!
3//! This module provides wavelet-based signal processing capabilities including:
4//! - Discrete Wavelet Transform (DWT)
5//! - Multi-level wavelet decomposition
6//! - Wavelet denoising
7//! - Stationary Wavelet Transform (SWT)
8
9/// Discrete Wavelet Transform (DWT) result
10#[derive(Debug, Clone)]
11pub struct DwtResult {
12    /// Approximation coefficients (low-frequency)
13    pub approximation: Vec<f32>,
14    /// Detail coefficients (high-frequency)
15    pub detail: Vec<f32>,
16    /// Wavelet type used
17    pub wavelet: WaveletType,
18    /// Decomposition level
19    pub level: usize,
20}
21
22impl DwtResult {
23    /// Get total number of coefficients
24    pub fn total_coefficients(&self) -> usize {
25        self.approximation.len() + self.detail.len()
26    }
27}
28
29/// Multi-level DWT result
30#[derive(Debug, Clone)]
31pub struct DwtMultiLevel {
32    /// Approximation coefficients at the coarsest level
33    pub approximation: Vec<f32>,
34    /// Detail coefficients at each level (from finest to coarsest)
35    pub details: Vec<Vec<f32>>,
36    /// Wavelet type used
37    pub wavelet: WaveletType,
38    /// Number of decomposition levels
39    pub levels: usize,
40    /// Original signal length
41    pub original_length: usize,
42    /// Signal lengths at each level (for reconstruction)
43    pub lengths: Vec<usize>,
44}
45
46/// Wavelet types for wavelet transforms
47#[derive(Debug, Clone, Copy)]
48pub enum WaveletType {
49    /// Haar wavelet (simplest, good for edge detection)
50    Haar,
51    /// Daubechies-2 (db2) wavelet
52    Daubechies2,
53    /// Daubechies-4 (db4) wavelet
54    Daubechies4,
55    /// Daubechies-6 (db6) wavelet
56    Daubechies6,
57    /// Symlet-2 (sym2) wavelet
58    Symlet2,
59    /// Symlet-4 (sym4) wavelet
60    Symlet4,
61    /// Coiflet-1 (coif1) wavelet
62    Coiflet1,
63}
64
65impl WaveletType {
66    /// Get the low-pass decomposition filter coefficients (scaling function)
67    pub fn decomposition_low(&self) -> Vec<f32> {
68        match self {
69            WaveletType::Haar => vec![0.707_106_77, 0.707_106_77],
70            WaveletType::Daubechies2 => {
71                vec![0.482_962_9, 0.836_516_3, 0.224_143_9, -0.129_409_5]
72            }
73            WaveletType::Daubechies4 => {
74                vec![
75                    0.230_377_8,
76                    0.714_846_6,
77                    0.630_880_8,
78                    -0.027_983_77,
79                    -0.187_034_8,
80                    0.030_841_38,
81                    0.032_883_0,
82                    -0.010_597_4,
83                ]
84            }
85            WaveletType::Daubechies6 => {
86                vec![
87                    0.111_540_7,
88                    0.494_623_9,
89                    0.751_133_9,
90                    0.315_250_4,
91                    -0.226_264_7,
92                    -0.129_766_9,
93                    0.097_501_6,
94                    0.027_522_87,
95                    -0.031_582_0,
96                    0.000_553_84,
97                    0.004_777_26,
98                    -0.001_077_3,
99                ]
100            }
101            WaveletType::Symlet2 => {
102                vec![-0.129_409_5, 0.224_143_9, 0.836_516_3, 0.482_962_9]
103            }
104            WaveletType::Symlet4 => {
105                vec![
106                    -0.075_765_7,
107                    -0.029_635_53,
108                    0.497_618_7,
109                    0.803_738_8,
110                    0.297_857_8,
111                    -0.099_219_5,
112                    -0.012_604_0,
113                    0.032_223_1,
114                ]
115            }
116            WaveletType::Coiflet1 => {
117                vec![
118                    -0.015_655_73,
119                    -0.072_732_6,
120                    0.384_864_9,
121                    0.852_572,
122                    0.337_897_7,
123                    -0.072_732_6,
124                ]
125            }
126        }
127    }
128
129    /// Get the high-pass decomposition filter coefficients (wavelet function)
130    pub fn decomposition_high(&self) -> Vec<f32> {
131        let low = self.decomposition_low();
132        let n = low.len();
133        low.iter()
134            .enumerate()
135            .map(|(i, _)| if i % 2 == 0 { -1.0 } else { 1.0 } * low[n - 1 - i])
136            .collect()
137    }
138
139    /// Get the low-pass reconstruction filter coefficients
140    pub fn reconstruction_low(&self) -> Vec<f32> {
141        self.decomposition_low().into_iter().rev().collect()
142    }
143
144    /// Get the high-pass reconstruction filter coefficients
145    pub fn reconstruction_high(&self) -> Vec<f32> {
146        self.decomposition_high().into_iter().rev().collect()
147    }
148
149    /// Get the filter length
150    pub fn filter_length(&self) -> usize {
151        self.decomposition_low().len()
152    }
153}
154
155/// Wavelet analyzer for performing wavelet transforms
156#[derive(Debug, Clone)]
157pub struct WaveletAnalyzer {
158    wavelet: WaveletType,
159}
160
161impl WaveletAnalyzer {
162    /// Create a new wavelet analyzer
163    pub fn new(wavelet: WaveletType) -> Self {
164        Self { wavelet }
165    }
166
167    /// Perform single-level DWT decomposition using lifting scheme
168    pub fn dwt(&self, signal: &[f32]) -> DwtResult {
169        let n = signal.len();
170        let out_len = n.div_ceil(2);
171
172        let mut approx = Vec::with_capacity(out_len);
173        let mut detail = Vec::with_capacity(out_len);
174
175        match self.wavelet {
176            WaveletType::Haar => {
177                let sqrt2_inv = 1.0 / std::f32::consts::SQRT_2;
178                for i in 0..out_len {
179                    let idx0 = i * 2;
180                    let idx1 = (idx0 + 1).min(n - 1);
181                    let x0 = signal[idx0];
182                    let x1 = signal[idx1];
183                    approx.push((x0 + x1) * sqrt2_inv);
184                    detail.push((x0 - x1) * sqrt2_inv);
185                }
186            }
187            _ => {
188                let low_filter = self.wavelet.decomposition_low();
189                let high_filter = self.wavelet.decomposition_high();
190                let f_len = low_filter.len();
191
192                for i in 0..out_len {
193                    let center = i * 2;
194                    let mut sum_low = 0.0f32;
195                    let mut sum_high = 0.0f32;
196
197                    for j in 0..f_len {
198                        let idx = (center + j) as isize - (f_len as isize / 2);
199                        let val = self.extend_signal(signal, idx, n);
200                        sum_low += val * low_filter[j];
201                        sum_high += val * high_filter[j];
202                    }
203
204                    approx.push(sum_low);
205                    detail.push(sum_high);
206                }
207            }
208        }
209
210        DwtResult {
211            approximation: approx,
212            detail,
213            wavelet: self.wavelet,
214            level: 1,
215        }
216    }
217
218    /// Perform multi-level DWT decomposition
219    pub fn dwt_multilevel(&self, signal: &[f32], levels: usize) -> DwtMultiLevel {
220        let mut details = Vec::with_capacity(levels);
221        let mut lengths = Vec::with_capacity(levels);
222        let mut current = signal.to_vec();
223        let original_length = signal.len();
224
225        for _ in 0..levels {
226            if current.len() < self.wavelet.filter_length() {
227                break;
228            }
229
230            lengths.push(current.len());
231            let result = self.dwt(&current);
232            details.push(result.detail);
233            current = result.approximation;
234        }
235
236        let num_levels = details.len();
237
238        DwtMultiLevel {
239            approximation: current,
240            details,
241            wavelet: self.wavelet,
242            levels: num_levels,
243            original_length,
244            lengths,
245        }
246    }
247
248    /// Perform inverse DWT (single level reconstruction) using lifting scheme
249    pub fn idwt(&self, approx: &[f32], detail: &[f32], output_length: usize) -> Vec<f32> {
250        let mut result = vec![0.0f32; output_length];
251
252        match self.wavelet {
253            WaveletType::Haar => {
254                let sqrt2_inv = 1.0 / std::f32::consts::SQRT_2;
255                for (i, (&a, &d)) in approx.iter().zip(detail.iter()).enumerate() {
256                    let idx0 = i * 2;
257                    let idx1 = idx0 + 1;
258
259                    if idx0 < output_length {
260                        result[idx0] = (a + d) * sqrt2_inv;
261                    }
262                    if idx1 < output_length {
263                        result[idx1] = (a - d) * sqrt2_inv;
264                    }
265                }
266            }
267            _ => {
268                let low_filter = self.wavelet.reconstruction_low();
269                let high_filter = self.wavelet.reconstruction_high();
270                let f_len = low_filter.len();
271
272                for (i, (&a, &d)) in approx.iter().zip(detail.iter()).enumerate() {
273                    let pos = i * 2;
274                    for j in 0..f_len {
275                        let out_idx = pos as isize + j as isize - (f_len as isize / 2) + 1;
276                        if out_idx >= 0 && (out_idx as usize) < output_length {
277                            result[out_idx as usize] += a * low_filter[j] + d * high_filter[j];
278                        }
279                    }
280                }
281            }
282        }
283
284        result
285    }
286
287    /// Perform multi-level inverse DWT reconstruction
288    pub fn idwt_multilevel(&self, decomp: &DwtMultiLevel) -> Vec<f32> {
289        let mut current = decomp.approximation.clone();
290
291        for (i, detail) in decomp.details.iter().enumerate().rev() {
292            let output_len = if i < decomp.lengths.len() {
293                decomp.lengths[i]
294            } else {
295                detail.len() * 2
296            };
297            current = self.idwt(&current, detail, output_len);
298        }
299
300        current.truncate(decomp.original_length);
301        current
302    }
303
304    /// Convolve and downsample by 2 (for decomposition)
305    #[allow(dead_code)]
306    fn convolve_downsample(&self, signal: &[f32], filter: &[f32]) -> Vec<f32> {
307        let n = signal.len();
308        let f_len = filter.len();
309        let out_len = (n + f_len - 1) / 2;
310
311        let mut result = Vec::with_capacity(out_len);
312
313        for i in 0..out_len {
314            let center = i * 2;
315            let mut sum = 0.0f32;
316
317            for (j, &f) in filter.iter().enumerate() {
318                let idx = center as isize + j as isize - (f_len as isize - 1);
319                let val = self.extend_signal(signal, idx, n);
320                sum += val * f;
321            }
322
323            result.push(sum);
324        }
325
326        result
327    }
328
329    /// Get signal value with symmetric extension
330    fn extend_signal(&self, signal: &[f32], idx: isize, n: usize) -> f32 {
331        if idx < 0 {
332            signal[(-1 - idx) as usize % n]
333        } else if idx >= n as isize {
334            let reflected = 2 * n as isize - 2 - idx;
335            if reflected >= 0 && (reflected as usize) < n {
336                signal[reflected as usize]
337            } else {
338                signal[n - 1]
339            }
340        } else {
341            signal[idx as usize]
342        }
343    }
344
345    /// Upsample by 2 and convolve (for reconstruction)
346    #[allow(dead_code)]
347    fn upsample_convolve(&self, signal: &[f32], filter: &[f32], output_length: usize) -> Vec<f32> {
348        let upsampled_len = signal.len() * 2;
349        let mut upsampled = vec![0.0f32; upsampled_len];
350
351        for (i, &s) in signal.iter().enumerate() {
352            upsampled[i * 2] = s;
353        }
354
355        let f_len = filter.len();
356        let mut result = vec![0.0; output_length];
357
358        for (i, res) in result.iter_mut().enumerate() {
359            let mut sum = 0.0f32;
360            for (j, &f) in filter.iter().enumerate() {
361                let idx = i as isize + j as isize - (f_len as isize - 1);
362                if idx >= 0 && (idx as usize) < upsampled_len {
363                    sum += upsampled[idx as usize] * f;
364                }
365            }
366            *res = sum;
367        }
368
369        result
370    }
371
372    /// Compute stationary wavelet transform (SWT) - undecimated DWT
373    /// Returns coefficients at full resolution for each level
374    pub fn swt(&self, signal: &[f32], levels: usize) -> Vec<(Vec<f32>, Vec<f32>)> {
375        let mut results = Vec::with_capacity(levels);
376        let mut low_filter = self.wavelet.decomposition_low();
377        let mut high_filter = self.wavelet.decomposition_high();
378        let mut current = signal.to_vec();
379
380        for _ in 0..levels {
381            let approx = self.convolve_full(&current, &low_filter);
382            let detail = self.convolve_full(&current, &high_filter);
383
384            results.push((approx.clone(), detail));
385            current = approx;
386
387            low_filter = self.upsample_filter(&low_filter);
388            high_filter = self.upsample_filter(&high_filter);
389        }
390
391        results
392    }
393
394    /// Convolve without downsampling (for SWT)
395    fn convolve_full(&self, signal: &[f32], filter: &[f32]) -> Vec<f32> {
396        let n = signal.len();
397        let f_len = filter.len();
398        let mut result = Vec::with_capacity(n);
399
400        for i in 0..n {
401            let mut sum = 0.0f32;
402            for (j, &f) in filter.iter().enumerate() {
403                let idx = i as isize + j as isize - (f_len as isize / 2);
404                let val = if idx < 0 {
405                    signal[(-idx - 1) as usize % n]
406                } else if idx >= n as isize {
407                    signal[(2 * n as isize - idx - 1) as usize % n]
408                } else {
409                    signal[idx as usize]
410                };
411                sum += val * f;
412            }
413            result.push(sum);
414        }
415
416        result
417    }
418
419    /// Upsample filter by inserting zeros
420    fn upsample_filter(&self, filter: &[f32]) -> Vec<f32> {
421        let mut result = Vec::with_capacity(filter.len() * 2 - 1);
422        for (i, &f) in filter.iter().enumerate() {
423            result.push(f);
424            if i < filter.len() - 1 {
425                result.push(0.0);
426            }
427        }
428        result
429    }
430
431    /// Compute wavelet energy at each level
432    pub fn wavelet_energy(&self, decomp: &DwtMultiLevel) -> Vec<f32> {
433        let mut energies = Vec::with_capacity(decomp.levels + 1);
434
435        for detail in &decomp.details {
436            let energy: f32 = detail.iter().map(|&x| x * x).sum();
437            energies.push(energy);
438        }
439
440        let approx_energy: f32 = decomp.approximation.iter().map(|&x| x * x).sum();
441        energies.push(approx_energy);
442
443        let total: f32 = energies.iter().sum();
444        if total > 0.0 {
445            for e in &mut energies {
446                *e /= total;
447            }
448        }
449
450        energies
451    }
452
453    /// Denoise signal using wavelet thresholding
454    pub fn denoise(&self, signal: &[f32], levels: usize, threshold: f32) -> Vec<f32> {
455        let decomp = self.dwt_multilevel(signal, levels);
456
457        let thresholded_details: Vec<Vec<f32>> = decomp
458            .details
459            .iter()
460            .map(|detail| {
461                detail
462                    .iter()
463                    .map(|&x| Self::soft_threshold(x, threshold))
464                    .collect()
465            })
466            .collect();
467
468        let thresholded = DwtMultiLevel {
469            approximation: decomp.approximation,
470            details: thresholded_details,
471            wavelet: decomp.wavelet,
472            levels: decomp.levels,
473            original_length: decomp.original_length,
474            lengths: decomp.lengths,
475        };
476
477        self.idwt_multilevel(&thresholded)
478    }
479
480    /// Soft thresholding function
481    fn soft_threshold(x: f32, threshold: f32) -> f32 {
482        if x.abs() <= threshold {
483            0.0
484        } else if x > 0.0 {
485            x - threshold
486        } else {
487            x + threshold
488        }
489    }
490
491    /// Estimate universal threshold (VisuShrink)
492    pub fn universal_threshold(detail: &[f32]) -> f32 {
493        let n = detail.len() as f32;
494        let sigma = Self::mad_sigma(detail);
495        sigma * (2.0 * n.ln()).sqrt()
496    }
497
498    /// Estimate noise standard deviation using MAD (Median Absolute Deviation)
499    fn mad_sigma(data: &[f32]) -> f32 {
500        if data.is_empty() {
501            return 0.0;
502        }
503
504        let mut abs_data: Vec<f32> = data.iter().map(|&x| x.abs()).collect();
505        abs_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
506
507        let median = if abs_data.len().is_multiple_of(2) {
508            (abs_data[abs_data.len() / 2 - 1] + abs_data[abs_data.len() / 2]) / 2.0
509        } else {
510            abs_data[abs_data.len() / 2]
511        };
512
513        median / 0.674_489_75
514    }
515}