1#[derive(Debug, Clone)]
11pub struct DwtResult {
12 pub approximation: Vec<f32>,
14 pub detail: Vec<f32>,
16 pub wavelet: WaveletType,
18 pub level: usize,
20}
21
22impl DwtResult {
23 pub fn total_coefficients(&self) -> usize {
25 self.approximation.len() + self.detail.len()
26 }
27}
28
29#[derive(Debug, Clone)]
31pub struct DwtMultiLevel {
32 pub approximation: Vec<f32>,
34 pub details: Vec<Vec<f32>>,
36 pub wavelet: WaveletType,
38 pub levels: usize,
40 pub original_length: usize,
42 pub lengths: Vec<usize>,
44}
45
46#[derive(Debug, Clone, Copy)]
48pub enum WaveletType {
49 Haar,
51 Daubechies2,
53 Daubechies4,
55 Daubechies6,
57 Symlet2,
59 Symlet4,
61 Coiflet1,
63}
64
65impl WaveletType {
66 pub fn decomposition_low(&self) -> Vec<f32> {
68 match self {
69 WaveletType::Haar => vec![0.707_106_77, 0.707_106_77],
70 WaveletType::Daubechies2 => {
71 vec![0.482_962_9, 0.836_516_3, 0.224_143_9, -0.129_409_5]
72 }
73 WaveletType::Daubechies4 => {
74 vec![
75 0.230_377_8,
76 0.714_846_6,
77 0.630_880_8,
78 -0.027_983_77,
79 -0.187_034_8,
80 0.030_841_38,
81 0.032_883_0,
82 -0.010_597_4,
83 ]
84 }
85 WaveletType::Daubechies6 => {
86 vec![
87 0.111_540_7,
88 0.494_623_9,
89 0.751_133_9,
90 0.315_250_4,
91 -0.226_264_7,
92 -0.129_766_9,
93 0.097_501_6,
94 0.027_522_87,
95 -0.031_582_0,
96 0.000_553_84,
97 0.004_777_26,
98 -0.001_077_3,
99 ]
100 }
101 WaveletType::Symlet2 => {
102 vec![-0.129_409_5, 0.224_143_9, 0.836_516_3, 0.482_962_9]
103 }
104 WaveletType::Symlet4 => {
105 vec![
106 -0.075_765_7,
107 -0.029_635_53,
108 0.497_618_7,
109 0.803_738_8,
110 0.297_857_8,
111 -0.099_219_5,
112 -0.012_604_0,
113 0.032_223_1,
114 ]
115 }
116 WaveletType::Coiflet1 => {
117 vec![
118 -0.015_655_73,
119 -0.072_732_6,
120 0.384_864_9,
121 0.852_572,
122 0.337_897_7,
123 -0.072_732_6,
124 ]
125 }
126 }
127 }
128
129 pub fn decomposition_high(&self) -> Vec<f32> {
131 let low = self.decomposition_low();
132 let n = low.len();
133 low.iter()
134 .enumerate()
135 .map(|(i, _)| if i % 2 == 0 { -1.0 } else { 1.0 } * low[n - 1 - i])
136 .collect()
137 }
138
139 pub fn reconstruction_low(&self) -> Vec<f32> {
141 self.decomposition_low().into_iter().rev().collect()
142 }
143
144 pub fn reconstruction_high(&self) -> Vec<f32> {
146 self.decomposition_high().into_iter().rev().collect()
147 }
148
149 pub fn filter_length(&self) -> usize {
151 self.decomposition_low().len()
152 }
153}
154
155#[derive(Debug, Clone)]
157pub struct WaveletAnalyzer {
158 wavelet: WaveletType,
159}
160
161impl WaveletAnalyzer {
162 pub fn new(wavelet: WaveletType) -> Self {
164 Self { wavelet }
165 }
166
167 pub fn dwt(&self, signal: &[f32]) -> DwtResult {
169 let n = signal.len();
170 let out_len = n.div_ceil(2);
171
172 let mut approx = Vec::with_capacity(out_len);
173 let mut detail = Vec::with_capacity(out_len);
174
175 match self.wavelet {
176 WaveletType::Haar => {
177 let sqrt2_inv = 1.0 / std::f32::consts::SQRT_2;
178 for i in 0..out_len {
179 let idx0 = i * 2;
180 let idx1 = (idx0 + 1).min(n - 1);
181 let x0 = signal[idx0];
182 let x1 = signal[idx1];
183 approx.push((x0 + x1) * sqrt2_inv);
184 detail.push((x0 - x1) * sqrt2_inv);
185 }
186 }
187 _ => {
188 let low_filter = self.wavelet.decomposition_low();
189 let high_filter = self.wavelet.decomposition_high();
190 let f_len = low_filter.len();
191
192 for i in 0..out_len {
193 let center = i * 2;
194 let mut sum_low = 0.0f32;
195 let mut sum_high = 0.0f32;
196
197 for j in 0..f_len {
198 let idx = (center + j) as isize - (f_len as isize / 2);
199 let val = self.extend_signal(signal, idx, n);
200 sum_low += val * low_filter[j];
201 sum_high += val * high_filter[j];
202 }
203
204 approx.push(sum_low);
205 detail.push(sum_high);
206 }
207 }
208 }
209
210 DwtResult {
211 approximation: approx,
212 detail,
213 wavelet: self.wavelet,
214 level: 1,
215 }
216 }
217
218 pub fn dwt_multilevel(&self, signal: &[f32], levels: usize) -> DwtMultiLevel {
220 let mut details = Vec::with_capacity(levels);
221 let mut lengths = Vec::with_capacity(levels);
222 let mut current = signal.to_vec();
223 let original_length = signal.len();
224
225 for _ in 0..levels {
226 if current.len() < self.wavelet.filter_length() {
227 break;
228 }
229
230 lengths.push(current.len());
231 let result = self.dwt(¤t);
232 details.push(result.detail);
233 current = result.approximation;
234 }
235
236 let num_levels = details.len();
237
238 DwtMultiLevel {
239 approximation: current,
240 details,
241 wavelet: self.wavelet,
242 levels: num_levels,
243 original_length,
244 lengths,
245 }
246 }
247
248 pub fn idwt(&self, approx: &[f32], detail: &[f32], output_length: usize) -> Vec<f32> {
250 let mut result = vec![0.0f32; output_length];
251
252 match self.wavelet {
253 WaveletType::Haar => {
254 let sqrt2_inv = 1.0 / std::f32::consts::SQRT_2;
255 for (i, (&a, &d)) in approx.iter().zip(detail.iter()).enumerate() {
256 let idx0 = i * 2;
257 let idx1 = idx0 + 1;
258
259 if idx0 < output_length {
260 result[idx0] = (a + d) * sqrt2_inv;
261 }
262 if idx1 < output_length {
263 result[idx1] = (a - d) * sqrt2_inv;
264 }
265 }
266 }
267 _ => {
268 let low_filter = self.wavelet.reconstruction_low();
269 let high_filter = self.wavelet.reconstruction_high();
270 let f_len = low_filter.len();
271
272 for (i, (&a, &d)) in approx.iter().zip(detail.iter()).enumerate() {
273 let pos = i * 2;
274 for j in 0..f_len {
275 let out_idx = pos as isize + j as isize - (f_len as isize / 2) + 1;
276 if out_idx >= 0 && (out_idx as usize) < output_length {
277 result[out_idx as usize] += a * low_filter[j] + d * high_filter[j];
278 }
279 }
280 }
281 }
282 }
283
284 result
285 }
286
287 pub fn idwt_multilevel(&self, decomp: &DwtMultiLevel) -> Vec<f32> {
289 let mut current = decomp.approximation.clone();
290
291 for (i, detail) in decomp.details.iter().enumerate().rev() {
292 let output_len = if i < decomp.lengths.len() {
293 decomp.lengths[i]
294 } else {
295 detail.len() * 2
296 };
297 current = self.idwt(¤t, detail, output_len);
298 }
299
300 current.truncate(decomp.original_length);
301 current
302 }
303
304 #[allow(dead_code)]
306 fn convolve_downsample(&self, signal: &[f32], filter: &[f32]) -> Vec<f32> {
307 let n = signal.len();
308 let f_len = filter.len();
309 let out_len = (n + f_len - 1) / 2;
310
311 let mut result = Vec::with_capacity(out_len);
312
313 for i in 0..out_len {
314 let center = i * 2;
315 let mut sum = 0.0f32;
316
317 for (j, &f) in filter.iter().enumerate() {
318 let idx = center as isize + j as isize - (f_len as isize - 1);
319 let val = self.extend_signal(signal, idx, n);
320 sum += val * f;
321 }
322
323 result.push(sum);
324 }
325
326 result
327 }
328
329 fn extend_signal(&self, signal: &[f32], idx: isize, n: usize) -> f32 {
331 if idx < 0 {
332 signal[(-1 - idx) as usize % n]
333 } else if idx >= n as isize {
334 let reflected = 2 * n as isize - 2 - idx;
335 if reflected >= 0 && (reflected as usize) < n {
336 signal[reflected as usize]
337 } else {
338 signal[n - 1]
339 }
340 } else {
341 signal[idx as usize]
342 }
343 }
344
345 #[allow(dead_code)]
347 fn upsample_convolve(&self, signal: &[f32], filter: &[f32], output_length: usize) -> Vec<f32> {
348 let upsampled_len = signal.len() * 2;
349 let mut upsampled = vec![0.0f32; upsampled_len];
350
351 for (i, &s) in signal.iter().enumerate() {
352 upsampled[i * 2] = s;
353 }
354
355 let f_len = filter.len();
356 let mut result = vec![0.0; output_length];
357
358 for (i, res) in result.iter_mut().enumerate() {
359 let mut sum = 0.0f32;
360 for (j, &f) in filter.iter().enumerate() {
361 let idx = i as isize + j as isize - (f_len as isize - 1);
362 if idx >= 0 && (idx as usize) < upsampled_len {
363 sum += upsampled[idx as usize] * f;
364 }
365 }
366 *res = sum;
367 }
368
369 result
370 }
371
372 pub fn swt(&self, signal: &[f32], levels: usize) -> Vec<(Vec<f32>, Vec<f32>)> {
375 let mut results = Vec::with_capacity(levels);
376 let mut low_filter = self.wavelet.decomposition_low();
377 let mut high_filter = self.wavelet.decomposition_high();
378 let mut current = signal.to_vec();
379
380 for _ in 0..levels {
381 let approx = self.convolve_full(¤t, &low_filter);
382 let detail = self.convolve_full(¤t, &high_filter);
383
384 results.push((approx.clone(), detail));
385 current = approx;
386
387 low_filter = self.upsample_filter(&low_filter);
388 high_filter = self.upsample_filter(&high_filter);
389 }
390
391 results
392 }
393
394 fn convolve_full(&self, signal: &[f32], filter: &[f32]) -> Vec<f32> {
396 let n = signal.len();
397 let f_len = filter.len();
398 let mut result = Vec::with_capacity(n);
399
400 for i in 0..n {
401 let mut sum = 0.0f32;
402 for (j, &f) in filter.iter().enumerate() {
403 let idx = i as isize + j as isize - (f_len as isize / 2);
404 let val = if idx < 0 {
405 signal[(-idx - 1) as usize % n]
406 } else if idx >= n as isize {
407 signal[(2 * n as isize - idx - 1) as usize % n]
408 } else {
409 signal[idx as usize]
410 };
411 sum += val * f;
412 }
413 result.push(sum);
414 }
415
416 result
417 }
418
419 fn upsample_filter(&self, filter: &[f32]) -> Vec<f32> {
421 let mut result = Vec::with_capacity(filter.len() * 2 - 1);
422 for (i, &f) in filter.iter().enumerate() {
423 result.push(f);
424 if i < filter.len() - 1 {
425 result.push(0.0);
426 }
427 }
428 result
429 }
430
431 pub fn wavelet_energy(&self, decomp: &DwtMultiLevel) -> Vec<f32> {
433 let mut energies = Vec::with_capacity(decomp.levels + 1);
434
435 for detail in &decomp.details {
436 let energy: f32 = detail.iter().map(|&x| x * x).sum();
437 energies.push(energy);
438 }
439
440 let approx_energy: f32 = decomp.approximation.iter().map(|&x| x * x).sum();
441 energies.push(approx_energy);
442
443 let total: f32 = energies.iter().sum();
444 if total > 0.0 {
445 for e in &mut energies {
446 *e /= total;
447 }
448 }
449
450 energies
451 }
452
453 pub fn denoise(&self, signal: &[f32], levels: usize, threshold: f32) -> Vec<f32> {
455 let decomp = self.dwt_multilevel(signal, levels);
456
457 let thresholded_details: Vec<Vec<f32>> = decomp
458 .details
459 .iter()
460 .map(|detail| {
461 detail
462 .iter()
463 .map(|&x| Self::soft_threshold(x, threshold))
464 .collect()
465 })
466 .collect();
467
468 let thresholded = DwtMultiLevel {
469 approximation: decomp.approximation,
470 details: thresholded_details,
471 wavelet: decomp.wavelet,
472 levels: decomp.levels,
473 original_length: decomp.original_length,
474 lengths: decomp.lengths,
475 };
476
477 self.idwt_multilevel(&thresholded)
478 }
479
480 fn soft_threshold(x: f32, threshold: f32) -> f32 {
482 if x.abs() <= threshold {
483 0.0
484 } else if x > 0.0 {
485 x - threshold
486 } else {
487 x + threshold
488 }
489 }
490
491 pub fn universal_threshold(detail: &[f32]) -> f32 {
493 let n = detail.len() as f32;
494 let sigma = Self::mad_sigma(detail);
495 sigma * (2.0 * n.ln()).sqrt()
496 }
497
498 fn mad_sigma(data: &[f32]) -> f32 {
500 if data.is_empty() {
501 return 0.0;
502 }
503
504 let mut abs_data: Vec<f32> = data.iter().map(|&x| x.abs()).collect();
505 abs_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
506
507 let median = if abs_data.len().is_multiple_of(2) {
508 (abs_data[abs_data.len() / 2 - 1] + abs_data[abs_data.len() / 2]) / 2.0
509 } else {
510 abs_data[abs_data.len() / 2]
511 };
512
513 median / 0.674_489_75
514 }
515}