1#![doc = include_str!("../README.md")]
2
3use ndarray::{s, ArrayView1, ArrayViewMut1, ArrayViewMut2, Axis};
4use std::ops::{AddAssign, DivAssign, MulAssign};
5
6#[cfg(not(feature = "f64"))]
7type FLTYPE = f32;
8#[cfg(feature = "f64")]
9type FLTYPE = f64;
10
11const TWO: usize = 2;
12const NORM_COEFF: FLTYPE = std::f64::consts::FRAC_1_SQRT_2 as FLTYPE;
13
14pub mod wavelet {
15 use super::*;
25
26 mod bior;
27 mod coif;
28 mod db;
29 mod sym;
30
31 pub use bior::*;
32 pub use coif::*;
33 pub use db::*;
34 pub use sym::*;
35 pub const HAAR: Wavelet = Wavelet {
36 decomp_low: &[0.7071067811865476, 0.7071067811865476],
37 decomp_high: &[-0.7071067811865476, 0.7071067811865476],
38 recons_low: &[0.7071067811865476, 0.7071067811865476],
39 recons_high: &[-0.7071067811865476, 0.7071067811865476],
40 };
41}
42
43#[derive(Debug, Clone, Copy)]
57pub struct Wavelet<'a> {
58 pub decomp_low: &'a [FLTYPE],
59 pub decomp_high: &'a [FLTYPE],
60 pub recons_low: &'a [FLTYPE],
61 pub recons_high: &'a [FLTYPE],
62}
63impl<'a> Wavelet<'a> {
64 #[inline]
65 pub const fn window_size(&self) -> usize {
66 self.decomp_low.len()
67 }
68 #[inline]
69 pub const fn half_padding_length(&self) -> usize {
70 (self.decomp_low.len() - 2) >> 1
71 }
72}
73
74#[doc(alias = "dwt")]
124pub fn decompose(
125 mut signal: ArrayViewMut1<FLTYPE>,
126 mut buffer: ArrayViewMut1<FLTYPE>,
127 wavelet: Wavelet,
128) {
129 let signal_size = signal.len();
130 let window_size = wavelet.window_size();
131
132 let expected_buffer_size = signal_size + window_size - TWO; let mut step_buffer = 0;
137 while step_buffer < expected_buffer_size {
138 buffer
139 .slice_mut(s![step_buffer..])
140 .iter_mut()
141 .zip(signal.view())
142 .for_each(|(dst, src)| *dst = *src);
143 step_buffer += signal_size;
144 }
145
146 let (low_pass, high_pass) = (
149 ArrayView1::from_shape(window_size, wavelet.decomp_low).unwrap(),
150 ArrayView1::from_shape(window_size, wavelet.decomp_high).unwrap(),
151 );
152
153 let half = signal_size / TWO;
154 for (step_signal, step_buffer) in (0..signal_size).step_by(TWO).enumerate() {
155 let slice_signal = buffer.slice(s![step_buffer..step_buffer + window_size]);
156
157 signal[step_signal] = slice_signal.dot(&low_pass); signal[half + step_signal] = slice_signal.dot(&high_pass);
159 }
160
161 signal.mul_assign(NORM_COEFF);
164}
165
166#[doc(alias = "idwt")]
176pub fn reconstruct(
177 mut signal: ArrayViewMut1<FLTYPE>,
178 mut buffer: ArrayViewMut1<FLTYPE>,
179 wavelet: Wavelet,
180) {
181 let signal_size = signal.len();
182 let window_size = wavelet.window_size();
183
184 let occupied_buffer_size = signal_size + window_size - TWO;
185
186 let half = signal_size / TWO;
189 signal.div_assign(NORM_COEFF);
190
191 buffer.slice_mut(s![..occupied_buffer_size]).fill(0.);
194 let (low_pass, high_pass) = (
195 ArrayView1::from_shape(window_size, wavelet.recons_low).unwrap(),
196 ArrayView1::from_shape(window_size, wavelet.recons_high).unwrap(),
197 );
198
199 for (mut step_buffer, approx_n) in signal.slice(s![..half]).into_iter().enumerate() {
200 step_buffer *= 2;
201 buffer
202 .slice_mut(s![step_buffer..step_buffer + window_size])
203 .add_assign(&(&low_pass * *approx_n));
204 }
205 for (mut step_buffer, detail_n) in signal.slice(s![half..]).into_iter().enumerate() {
206 step_buffer *= 2;
207 buffer
208 .slice_mut(s![step_buffer..step_buffer + window_size])
209 .add_assign(&(&high_pass * *detail_n));
210 }
211
212 signal.fill(0.);
215
216 let mut step_signal = 0;
217 while step_signal < occupied_buffer_size - signal_size {
218 signal.add_assign(&buffer.slice(s![step_signal..step_signal + signal_size]));
219 step_signal += signal_size;
220 }
221 signal
222 .iter_mut()
223 .zip(buffer.slice(s![step_signal..occupied_buffer_size]))
224 .for_each(|(sig, buf)| *sig += *buf);
225}
226
227#[doc(alias = "dwt")]
238pub fn completely_decompose(
239 mut signal: ArrayViewMut1<FLTYPE>,
240 mut buffer: ArrayViewMut1<FLTYPE>,
241 wavelet: Wavelet,
242) {
243 let mut signal_size = signal.len();
244 debug_assert!(signal_size.is_power_of_two());
245
246 while signal_size >= TWO {
247 decompose(
248 signal.slice_mut(s![..signal_size]),
249 buffer.view_mut(),
250 wavelet,
251 );
252 signal_size >>= 1;
253 }
254}
255
256#[doc(alias = "idwt")]
267pub fn completely_reconstruct(
268 mut signal: ArrayViewMut1<FLTYPE>,
269 mut buffer: ArrayViewMut1<FLTYPE>,
270 wavelet: Wavelet,
271) {
272 let signal_size = signal.len();
273 debug_assert!(signal_size.is_power_of_two());
274
275 let mut stage = TWO;
276 while stage <= signal_size {
277 reconstruct(signal.slice_mut(s![..stage]), buffer.view_mut(), wavelet);
278 stage <<= 1;
279 }
280}
281
282#[doc(alias = "dwt2")]
298pub fn decompose_2d(
299 mut signal_2d: ArrayViewMut2<FLTYPE>,
300 mut buffer: ArrayViewMut1<FLTYPE>,
301 wavelet: Wavelet,
302) {
303 signal_2d
304 .rows_mut()
305 .into_iter()
306 .for_each(|row| decompose(row, buffer.view_mut(), wavelet));
307
308 signal_2d
309 .columns_mut()
310 .into_iter()
311 .for_each(|col| decompose(col, buffer.view_mut(), wavelet));
312}
313
314#[doc(alias = "idwt2")]
330pub fn reconstruct_2d(
331 mut signal_2d: ArrayViewMut2<FLTYPE>,
332 mut buffer: ArrayViewMut1<FLTYPE>,
333 wavelet: Wavelet,
334) {
335 signal_2d
336 .columns_mut()
337 .into_iter()
338 .for_each(|col| reconstruct(col, buffer.view_mut(), wavelet));
339
340 signal_2d
341 .rows_mut()
342 .into_iter()
343 .for_each(|row| reconstruct(row, buffer.view_mut(), wavelet));
344}
345
346#[doc(alias = "dwt2")]
368pub fn completely_decompose_2d(
369 mut signal_2d: ArrayViewMut2<FLTYPE>,
370 mut buffer: ArrayViewMut1<FLTYPE>,
371 wavelet: Wavelet,
372) {
373 let height = signal_2d.len_of(Axis(0));
374 let mut width = signal_2d.len_of(Axis(1));
375 debug_assert!(width == height);
376 debug_assert!(width.is_power_of_two());
377
378 while width >= TWO {
379 decompose_2d(
380 signal_2d.slice_mut(s![..width, ..width]),
381 buffer.view_mut(),
382 wavelet,
383 );
384 width >>= 1;
385 }
386}
387
388#[doc(alias = "idwt2")]
410pub fn completely_reconstruct_2d(
411 mut signal_2d: ArrayViewMut2<FLTYPE>,
412 mut buffer: ArrayViewMut1<FLTYPE>,
413 wavelet: Wavelet,
414) {
415 let height = signal_2d.len_of(Axis(0));
416 let width = signal_2d.len_of(Axis(1));
417 debug_assert!(width == height);
418 debug_assert!(width.is_power_of_two());
419
420 let mut stage = TWO;
421 while stage <= width {
422 reconstruct_2d(
423 signal_2d.slice_mut(s![..stage, ..stage]),
424 buffer.view_mut(),
425 wavelet,
426 );
427 stage <<= 1;
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434 use approx::assert_abs_diff_eq;
435 use ndarray::{s, Array1, Array2, Axis};
436
437 #[test]
438 fn norm() {
439 let wavelet = wavelet::BIOR_3_1;
441
442 let mut signal = Array1::<FLTYPE>::from_vec(vec![0., 0., 0., 0., 255., 255., 255., 255.]);
443 let mut buffer = Array1::<FLTYPE>::zeros(signal.len() + wavelet.window_size() - TWO);
444 completely_decompose(signal.view_mut(), buffer.view_mut(), wavelet);
445 println!("{signal:>5.1}");
446 assert!(0. <= signal[0] && signal[0] <= 255.);
447 signal
448 .slice(s![1..])
449 .for_each(|v| assert!(*v >= -128. && *v <= 128.));
450
451 let mut signal =
452 Array1::<FLTYPE>::from_vec(vec![255., 255., 255., 255., 255., 255., 255., 255.]);
453 let mut buffer = Array1::<FLTYPE>::zeros(signal.len() + wavelet.window_size() - TWO);
454 completely_decompose(signal.view_mut(), buffer.view_mut(), wavelet);
455 println!("{signal:>5.1}");
456 assert!(0. <= signal[0] && signal[0] <= 255.);
457 signal
458 .slice(s![1..])
459 .for_each(|v| assert!(*v >= -128. && *v <= 128.));
460 }
461
462 #[test]
463 fn auto_2d() {
464 let wavelet = wavelet::BIOR_3_5;
465 #[rustfmt::skip]
466 let raw = Array2::from_shape_vec((8, 8), vec![
467 52., 55., 61., 66., 70., 61., 64., 73.,
468 63., 59., 55., 90., 109.,85., 69., 72.,
469 62., 59., 68., 113.,144.,104.,66., 73.,
470 63., 58., 71., 122.,154.,106.,70., 69.,
471 67., 61., 68., 104.,126.,88., 68., 70.,
472 79., 65., 60., 70., 77., 68., 58., 75.,
473 85., 71., 64., 59., 55., 61., 65., 83.,
474 87., 79., 69., 68., 65., 76., 78., 94.,
475 ]).unwrap();
476 let mut signal_2d = raw.clone();
477 let mut buffer =
478 Array1::<FLTYPE>::zeros(signal_2d.len_of(Axis(0)) + wavelet.window_size() - TWO);
479
480 completely_decompose_2d(signal_2d.view_mut(), buffer.view_mut(), wavelet);
481 println!("{signal_2d:>5.1}");
482 completely_reconstruct_2d(signal_2d.view_mut(), buffer.view_mut(), wavelet);
483
484 raw.into_iter()
485 .zip(signal_2d)
486 .for_each(|(a, b)| assert_abs_diff_eq!(a, b, epsilon = 0.0001));
487 }
488
489 #[test]
490 fn manual_2d() {
491 let wavelet = wavelet::BIOR_3_1;
492 #[rustfmt::skip]
493 let raw = Array2::<FLTYPE>::from_shape_vec((6, 10), vec![
494 0., 0., 0.,99.,99.,99.,99., 0., 0., 0.,
495 0., 0.,99.,99.,99.,99.,99.,99., 0., 0.,
496 0.,99.,99.,99.,99.,99.,99.,99.,99., 0.,
497 0.,99.,99.,99.,99.,99.,99.,99.,99., 0.,
498 0., 0.,99.,99.,99.,99.,99.,99., 0., 0.,
499 0., 0., 0.,99.,99.,99.,99., 0., 0., 0.,
500 ]).unwrap();
501 let mut signal_2d = raw.clone();
502 let mut buffer =
503 Array1::<FLTYPE>::zeros(signal_2d.len_of(Axis(1)) + wavelet.window_size() - TWO);
504
505 decompose_2d(signal_2d.view_mut(), buffer.view_mut(), wavelet);
506 println!("{signal_2d:>5.1}");
507 reconstruct_2d(signal_2d.view_mut(), buffer.view_mut(), wavelet);
508
509 raw.into_iter()
510 .zip(signal_2d)
511 .for_each(|(a, b)| assert_abs_diff_eq!(a, b, epsilon = 0.0001));
512 }
513
514 #[test]
515 fn auto_1d() {
516 let wavelet = wavelet::BIOR_2_2;
517 let raw = Array1::<FLTYPE>::from_vec(vec![0., 10., 100., 200., 250., 30., 20., 10.]);
518 let mut signal = raw.clone();
519 let mut buffer = Array1::<FLTYPE>::zeros(signal.len() + wavelet.window_size() - TWO);
520
521 completely_decompose(signal.view_mut(), buffer.view_mut(), wavelet);
522 println!("{signal:>5.1}");
523 completely_reconstruct(signal.view_mut(), buffer.view_mut(), wavelet);
524
525 raw.into_iter()
526 .zip(signal)
527 .for_each(|(a, b)| assert_abs_diff_eq!(a, b, epsilon = 0.0001));
528 }
529
530 #[test]
531 fn manual_1d() {
532 let wavelet = wavelet::HAAR;
533 let raw = Array1::<FLTYPE>::from_vec(vec![31., 41., 59., 26., 53., 58., 97., 93.]);
534 let mut signal = raw.clone();
535 let mut buffer = Array1::<FLTYPE>::zeros(signal.len() + wavelet.window_size() - TWO);
536
537 decompose(signal.slice_mut(s![..8]), buffer.view_mut(), wavelet);
538 decompose(signal.slice_mut(s![..4]), buffer.view_mut(), wavelet);
539 decompose(signal.slice_mut(s![..2]), buffer.view_mut(), wavelet);
540 println!("{signal:>5.1}");
541 reconstruct(signal.slice_mut(s![..2]), buffer.view_mut(), wavelet);
542 reconstruct(signal.slice_mut(s![..4]), buffer.view_mut(), wavelet);
543 reconstruct(signal.slice_mut(s![..8]), buffer.view_mut(), wavelet);
544
545 raw.into_iter()
546 .zip(signal)
547 .for_each(|(a, b)| assert_abs_diff_eq!(a, b, epsilon = 0.0001));
548 }
549}