omni_wave/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use ndarray::{s, ArrayView1, ArrayViewMut1, ArrayViewMut2, Axis};
4use std::ops::{AddAssign, DivAssign, MulAssign};
5
6#[cfg(not(feature = "f64"))]
7type FLTYPE = f32;
8#[cfg(feature = "f64")]
9type FLTYPE = f64;
10
11const TWO: usize = 2;
12const NORM_COEFF: FLTYPE = std::f64::consts::FRAC_1_SQRT_2 as FLTYPE;
13
14pub mod wavelet {
15    //! Thanks to [Wavelet Browser](https://wavelets.pybytes.com/)!
16    //!
17    //! > *但你们为什么要把分解滤波器给反过来??*
18    //! >
19    //! > 所以,这里的所有小波也都是反过来的。懒得改了,反正不影响使用就对了。
20    //!
21    //! The number of wavelets are currently limited, and none of them has more than **12** coefficients.
22    //!
23    //! *(Because I'm lazy :)*
24    use super::*;
25
26    mod bior;
27    mod coif;
28    mod db;
29    mod sym;
30
31    pub use bior::*;
32    pub use coif::*;
33    pub use db::*;
34    pub use sym::*;
35    pub const HAAR: Wavelet = Wavelet {
36        decomp_low: &[0.7071067811865476, 0.7071067811865476],
37        decomp_high: &[-0.7071067811865476, 0.7071067811865476],
38        recons_low: &[0.7071067811865476, 0.7071067811865476],
39        recons_high: &[-0.7071067811865476, 0.7071067811865476],
40    };
41}
42
43/// Check [`wavelet`] to see all the wavelets provided.
44///
45/// Each filter of a single wavelet must be of equal length!
46///
47/// ``` plaintext
48///  |<------------->|- window_size (N.coeffs)
49///  ↓               ↓
50///  (A B C;D E;F G H)
51///        ↑   ↑     ↑
52///        |   |<--->|- half_padding_length
53///        |   |
54///        |<->|------- *Sliding*!
55/// ```
56#[derive(Debug, Clone, Copy)]
57pub struct Wavelet<'a> {
58    pub decomp_low: &'a [FLTYPE],
59    pub decomp_high: &'a [FLTYPE],
60    pub recons_low: &'a [FLTYPE],
61    pub recons_high: &'a [FLTYPE],
62}
63impl<'a> Wavelet<'a> {
64    #[inline]
65    pub const fn window_size(&self) -> usize {
66        self.decomp_low.len()
67    }
68    #[inline]
69    pub const fn half_padding_length(&self) -> usize {
70        (self.decomp_low.len() - 2) >> 1
71    }
72}
73
74// /// Signal extension modes.
75// ///
76// /// ``` plaintext
77// /// Symmetric   /      *\. *      |      * | *      |
78// ///            |      * *|* *     |     * *|* *     |
79// ///            |  ****   |   **** | ****   |   **** |
80// ///             \*      /'       *|*       |       *|
81// ///
82// /// Reflect     /      *\.*      |     * |*      |
83// ///            |      * *| *     |    * *| *     |
84// ///            |  ****   |  **** |****   |  **** |
85// ///             \*      /'      *|       |      *|
86// ///
87// /// Periodic    /      *\.      * |      * |
88// ///            |      * *|     * *|     * *|
89// ///            |  ****   | ****   | ****   |
90// ///             \*      /'*       |*       |
91// ///
92// /// Edge        /      *\.        |        |
93// ///            |      * *|********|********|
94// ///            |  ****   |        |        |
95// ///             \*      /'        |        |
96// /// Zero ==
97// /// Constant(0) /      *\.        |        |
98// ///            |      * *|        |        |
99// ///            |  ****   |        |        |
100// ///             \*      /'********|********|
101// /// ```
102// #[non_exhaustive]
103// #[derive(Debug, Default, Clone, Copy, PartialEq)]
104// pub enum ExtensionMode {
105//     #[default]
106//     Symmetric,
107//     Reflect,
108//     Periodic,
109//     Edge,
110//     Zero,
111//     Constant(FLTYPE),
112// }
113
114/// Forward wavelet transform, 1D, only once, inplace.
115///
116/// ``` plaintext
117/// [      Signal      ] => [ Approx ][ Detail ]
118/// ```
119///
120/// # Hard requirements
121///
122/// - `buffer_size` >= `signal_size + window_size - TWO`
123#[doc(alias = "dwt")]
124pub fn decompose(
125    mut signal: ArrayViewMut1<FLTYPE>,
126    mut buffer: ArrayViewMut1<FLTYPE>,
127    wavelet: Wavelet,
128) {
129    let signal_size = signal.len();
130    let window_size = wavelet.window_size();
131
132    let expected_buffer_size = signal_size + window_size - TWO; // 这个公式不是很直观……但它没错就对了
133
134    /* 填充 */
135
136    let mut step_buffer = 0;
137    while step_buffer < expected_buffer_size {
138        buffer
139            .slice_mut(s![step_buffer..])
140            .iter_mut()
141            .zip(signal.view())
142            .for_each(|(dst, src)| *dst = *src);
143        step_buffer += signal_size;
144    }
145
146    /* 卷积 */
147
148    let (low_pass, high_pass) = (
149        ArrayView1::from_shape(window_size, wavelet.decomp_low).unwrap(),
150        ArrayView1::from_shape(window_size, wavelet.decomp_high).unwrap(),
151    );
152
153    let half = signal_size / TWO;
154    for (step_signal, step_buffer) in (0..signal_size).step_by(TWO).enumerate() {
155        let slice_signal = buffer.slice(s![step_buffer..step_buffer + window_size]);
156
157        signal[step_signal] = slice_signal.dot(&low_pass); // 就地操作
158        signal[half + step_signal] = slice_signal.dot(&high_pass);
159    }
160
161    /* 归一化 */
162
163    signal.mul_assign(NORM_COEFF);
164}
165
166/// Inverse wavelet transform, 1D, only once, inplace.
167///
168/// ``` plaintext
169/// [ Approx ][ Detail ] => [ Original signal ]
170/// ```
171///
172/// # Hard requirements
173///
174/// - `buffer_size` >= `signal_size + window_size - TWO`
175#[doc(alias = "idwt")]
176pub fn reconstruct(
177    mut signal: ArrayViewMut1<FLTYPE>,
178    mut buffer: ArrayViewMut1<FLTYPE>,
179    wavelet: Wavelet,
180) {
181    let signal_size = signal.len();
182    let window_size = wavelet.window_size();
183
184    let occupied_buffer_size = signal_size + window_size - TWO;
185
186    /* 反归一化 */
187
188    let half = signal_size / TWO;
189    signal.div_assign(NORM_COEFF);
190
191    /* 卷积 */
192
193    buffer.slice_mut(s![..occupied_buffer_size]).fill(0.);
194    let (low_pass, high_pass) = (
195        ArrayView1::from_shape(window_size, wavelet.recons_low).unwrap(),
196        ArrayView1::from_shape(window_size, wavelet.recons_high).unwrap(),
197    );
198
199    for (mut step_buffer, approx_n) in signal.slice(s![..half]).into_iter().enumerate() {
200        step_buffer *= 2;
201        buffer
202            .slice_mut(s![step_buffer..step_buffer + window_size])
203            .add_assign(&(&low_pass * *approx_n));
204    }
205    for (mut step_buffer, detail_n) in signal.slice(s![half..]).into_iter().enumerate() {
206        step_buffer *= 2;
207        buffer
208            .slice_mut(s![step_buffer..step_buffer + window_size])
209            .add_assign(&(&high_pass * *detail_n));
210    }
211
212    /* “折叠”!!*/
213
214    signal.fill(0.);
215
216    let mut step_signal = 0;
217    while step_signal < occupied_buffer_size - signal_size {
218        signal.add_assign(&buffer.slice(s![step_signal..step_signal + signal_size]));
219        step_signal += signal_size;
220    }
221    signal
222        .iter_mut()
223        .zip(buffer.slice(s![step_signal..occupied_buffer_size]))
224        .for_each(|(sig, buf)| *sig += *buf);
225}
226
227/// Forward wavelet transform, 1D, completely, inplace.
228///
229/// ``` plaintext
230/// [            Signal            ] => [A₂][D₂][  D₁  ][      D₀      ]
231/// ```
232///
233/// # Hard requirements
234///
235/// - `signal_size` should be exactly a power of 2, otherwise it will panic in debug builds.
236/// - `buffer_size` >= `signal_size + window_size - TWO`
237#[doc(alias = "dwt")]
238pub fn completely_decompose(
239    mut signal: ArrayViewMut1<FLTYPE>,
240    mut buffer: ArrayViewMut1<FLTYPE>,
241    wavelet: Wavelet,
242) {
243    let mut signal_size = signal.len();
244    debug_assert!(signal_size.is_power_of_two());
245
246    while signal_size >= TWO {
247        decompose(
248            signal.slice_mut(s![..signal_size]),
249            buffer.view_mut(),
250            wavelet,
251        );
252        signal_size >>= 1;
253    }
254}
255
256/// Inverse wavelet transform, 1D, completely, inplace.
257///
258/// ``` plaintext
259/// [A₂][D₂][  D₁  ][      D₀      ] => [        Original signal        ]
260/// ```
261///
262/// # Hard requirements
263///
264/// - `signal_size` should be exactly a power of 2, otherwise it will panic in debug builds.
265/// - `buffer_size` >= `signal_size + window_size - TWO`
266#[doc(alias = "idwt")]
267pub fn completely_reconstruct(
268    mut signal: ArrayViewMut1<FLTYPE>,
269    mut buffer: ArrayViewMut1<FLTYPE>,
270    wavelet: Wavelet,
271) {
272    let signal_size = signal.len();
273    debug_assert!(signal_size.is_power_of_two());
274
275    let mut stage = TWO;
276    while stage <= signal_size {
277        reconstruct(signal.slice_mut(s![..stage]), buffer.view_mut(), wavelet);
278        stage <<= 1;
279    }
280}
281
282/// Forward wavelet transform, 2D, only twice, inplace.
283///
284/// ``` plaintext
285/// +-----------+    +-----+-----+
286/// |           |    |  A  |  H  |
287/// | 2D Signal | => +-----+-----+
288/// |           |    |  V  |  D  |
289/// +-----------+    +-----+-----+
290/// ```
291///
292/// Horizontal firstly, then vertical.
293///
294/// # Hard requirements
295///
296/// - `buffer_size` >= `signal_side_length + window_size - TWO`
297#[doc(alias = "dwt2")]
298pub fn decompose_2d(
299    mut signal_2d: ArrayViewMut2<FLTYPE>,
300    mut buffer: ArrayViewMut1<FLTYPE>,
301    wavelet: Wavelet,
302) {
303    signal_2d
304        .rows_mut()
305        .into_iter()
306        .for_each(|row| decompose(row, buffer.view_mut(), wavelet));
307
308    signal_2d
309        .columns_mut()
310        .into_iter()
311        .for_each(|col| decompose(col, buffer.view_mut(), wavelet));
312}
313
314/// Inverse wavelet transform, 2D, only twice, inplace.
315///
316/// ``` plaintext
317/// +-----+-----+    +-----------+
318/// |  A  |  H  |    |           |
319/// +-----+-----+ => | 2D Signal |
320/// |  V  |  D  |    |           |
321/// +-----+-----+    +-----------+
322/// ```
323///
324/// Vertical firstly, then horizontal.
325///
326/// # Hard requirements
327///
328/// - `buffer_size` >= `signal_side_length + window_size - TWO`
329#[doc(alias = "idwt2")]
330pub fn reconstruct_2d(
331    mut signal_2d: ArrayViewMut2<FLTYPE>,
332    mut buffer: ArrayViewMut1<FLTYPE>,
333    wavelet: Wavelet,
334) {
335    signal_2d
336        .columns_mut()
337        .into_iter()
338        .for_each(|col| reconstruct(col, buffer.view_mut(), wavelet));
339
340    signal_2d
341        .rows_mut()
342        .into_iter()
343        .for_each(|row| reconstruct(row, buffer.view_mut(), wavelet));
344}
345
346/// Forward wavelet transform, 2D, completely, inplace.
347///
348/// ``` plaintext
349/// +-------------------+    +----+----+---------+
350/// |                   |    |----| H₁ |         |
351/// |                   |    +----+----+    H₀   |
352/// |                   |    | V₁ | D₁ |         |
353/// |     2D Signal     | => +----+----+---------+
354/// |                   |    |         |         |
355/// |                   |    |    V₀   |    D₀   |
356/// |                   |    |         |         |
357/// +-------------------+    +---------+---------+ ...
358/// ```
359///
360/// Horizontal firstly, then vertical.
361///
362/// # Hard requirements
363///
364/// - `signal_shape` should be exactly a **square**, otherwise it will panic in debug builds.
365/// - `signal_side_length` should be exactly a power of 2, otherwise it will panic in debug builds.
366/// - `buffer_size` >= `signal_side_length + window_size - TWO`
367#[doc(alias = "dwt2")]
368pub fn completely_decompose_2d(
369    mut signal_2d: ArrayViewMut2<FLTYPE>,
370    mut buffer: ArrayViewMut1<FLTYPE>,
371    wavelet: Wavelet,
372) {
373    let height = signal_2d.len_of(Axis(0));
374    let mut width = signal_2d.len_of(Axis(1));
375    debug_assert!(width == height);
376    debug_assert!(width.is_power_of_two());
377
378    while width >= TWO {
379        decompose_2d(
380            signal_2d.slice_mut(s![..width, ..width]),
381            buffer.view_mut(),
382            wavelet,
383        );
384        width >>= 1;
385    }
386}
387
388/// Inverse wavelet transform, 2D, completely, inplace.
389///
390/// ``` plaintext
391/// +----+----+---------+    +-------------------+
392/// |----| H₁ |         |    |                   |
393/// +----+----+    H₀   |    |                   |
394/// | V₁ | D₁ |         |    |     Original      |
395/// +----+----+---------+ => |                   |
396/// |         |         |    |     2D Signal     |
397/// |    V₀   |    D₀   |    |                   |
398/// |         |         |    |                   |
399/// +---------+---------+    +-------------------+
400/// ```
401///
402/// Vertical firstly, then horizontal.
403///
404/// # Hard requirements
405///
406/// - `signal_shape` should be exactly a **square**, otherwise it will panic in debug builds.
407/// - `signal_side_length` should be exactly a power of 2, otherwise it will panic in debug builds.
408/// - `buffer_size` >= `signal_side_length + window_size - TWO`
409#[doc(alias = "idwt2")]
410pub fn completely_reconstruct_2d(
411    mut signal_2d: ArrayViewMut2<FLTYPE>,
412    mut buffer: ArrayViewMut1<FLTYPE>,
413    wavelet: Wavelet,
414) {
415    let height = signal_2d.len_of(Axis(0));
416    let width = signal_2d.len_of(Axis(1));
417    debug_assert!(width == height);
418    debug_assert!(width.is_power_of_two());
419
420    let mut stage = TWO;
421    while stage <= width {
422        reconstruct_2d(
423            signal_2d.slice_mut(s![..stage, ..stage]),
424            buffer.view_mut(),
425            wavelet,
426        );
427        stage <<= 1;
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use approx::assert_abs_diff_eq;
435    use ndarray::{s, Array1, Array2, Axis};
436
437    #[test]
438    fn norm() {
439        // let wavelet = wavelet::RBIO_3_1;
440        let wavelet = wavelet::BIOR_3_1;
441
442        let mut signal = Array1::<FLTYPE>::from_vec(vec![0., 0., 0., 0., 255., 255., 255., 255.]);
443        let mut buffer = Array1::<FLTYPE>::zeros(signal.len() + wavelet.window_size() - TWO);
444        completely_decompose(signal.view_mut(), buffer.view_mut(), wavelet);
445        println!("{signal:>5.1}");
446        assert!(0. <= signal[0] && signal[0] <= 255.);
447        signal
448            .slice(s![1..])
449            .for_each(|v| assert!(*v >= -128. && *v <= 128.));
450
451        let mut signal =
452            Array1::<FLTYPE>::from_vec(vec![255., 255., 255., 255., 255., 255., 255., 255.]);
453        let mut buffer = Array1::<FLTYPE>::zeros(signal.len() + wavelet.window_size() - TWO);
454        completely_decompose(signal.view_mut(), buffer.view_mut(), wavelet);
455        println!("{signal:>5.1}");
456        assert!(0. <= signal[0] && signal[0] <= 255.);
457        signal
458            .slice(s![1..])
459            .for_each(|v| assert!(*v >= -128. && *v <= 128.));
460    }
461
462    #[test]
463    fn auto_2d() {
464        let wavelet = wavelet::BIOR_3_5;
465        #[rustfmt::skip]
466        let raw = Array2::from_shape_vec((8, 8), vec![
467            52., 55., 61., 66., 70., 61., 64., 73.,
468            63., 59., 55., 90., 109.,85., 69., 72.,
469            62., 59., 68., 113.,144.,104.,66., 73.,
470            63., 58., 71., 122.,154.,106.,70., 69.,
471            67., 61., 68., 104.,126.,88., 68., 70.,
472            79., 65., 60., 70., 77., 68., 58., 75.,
473            85., 71., 64., 59., 55., 61., 65., 83.,
474            87., 79., 69., 68., 65., 76., 78., 94.,
475        ]).unwrap();
476        let mut signal_2d = raw.clone();
477        let mut buffer =
478            Array1::<FLTYPE>::zeros(signal_2d.len_of(Axis(0)) + wavelet.window_size() - TWO);
479
480        completely_decompose_2d(signal_2d.view_mut(), buffer.view_mut(), wavelet);
481        println!("{signal_2d:>5.1}");
482        completely_reconstruct_2d(signal_2d.view_mut(), buffer.view_mut(), wavelet);
483
484        raw.into_iter()
485            .zip(signal_2d)
486            .for_each(|(a, b)| assert_abs_diff_eq!(a, b, epsilon = 0.0001));
487    }
488
489    #[test]
490    fn manual_2d() {
491        let wavelet = wavelet::BIOR_3_1;
492        #[rustfmt::skip]
493        let raw = Array2::<FLTYPE>::from_shape_vec((6, 10), vec![
494            0., 0., 0.,99.,99.,99.,99., 0., 0., 0.,
495            0., 0.,99.,99.,99.,99.,99.,99., 0., 0.,
496            0.,99.,99.,99.,99.,99.,99.,99.,99., 0.,
497            0.,99.,99.,99.,99.,99.,99.,99.,99., 0.,
498            0., 0.,99.,99.,99.,99.,99.,99., 0., 0.,
499            0., 0., 0.,99.,99.,99.,99., 0., 0., 0.,
500        ]).unwrap();
501        let mut signal_2d = raw.clone();
502        let mut buffer =
503            Array1::<FLTYPE>::zeros(signal_2d.len_of(Axis(1)) + wavelet.window_size() - TWO);
504
505        decompose_2d(signal_2d.view_mut(), buffer.view_mut(), wavelet);
506        println!("{signal_2d:>5.1}");
507        reconstruct_2d(signal_2d.view_mut(), buffer.view_mut(), wavelet);
508
509        raw.into_iter()
510            .zip(signal_2d)
511            .for_each(|(a, b)| assert_abs_diff_eq!(a, b, epsilon = 0.0001));
512    }
513
514    #[test]
515    fn auto_1d() {
516        let wavelet = wavelet::BIOR_2_2;
517        let raw = Array1::<FLTYPE>::from_vec(vec![0., 10., 100., 200., 250., 30., 20., 10.]);
518        let mut signal = raw.clone();
519        let mut buffer = Array1::<FLTYPE>::zeros(signal.len() + wavelet.window_size() - TWO);
520
521        completely_decompose(signal.view_mut(), buffer.view_mut(), wavelet);
522        println!("{signal:>5.1}");
523        completely_reconstruct(signal.view_mut(), buffer.view_mut(), wavelet);
524
525        raw.into_iter()
526            .zip(signal)
527            .for_each(|(a, b)| assert_abs_diff_eq!(a, b, epsilon = 0.0001));
528    }
529
530    #[test]
531    fn manual_1d() {
532        let wavelet = wavelet::HAAR;
533        let raw = Array1::<FLTYPE>::from_vec(vec![31., 41., 59., 26., 53., 58., 97., 93.]);
534        let mut signal = raw.clone();
535        let mut buffer = Array1::<FLTYPE>::zeros(signal.len() + wavelet.window_size() - TWO);
536
537        decompose(signal.slice_mut(s![..8]), buffer.view_mut(), wavelet);
538        decompose(signal.slice_mut(s![..4]), buffer.view_mut(), wavelet);
539        decompose(signal.slice_mut(s![..2]), buffer.view_mut(), wavelet);
540        println!("{signal:>5.1}");
541        reconstruct(signal.slice_mut(s![..2]), buffer.view_mut(), wavelet);
542        reconstruct(signal.slice_mut(s![..4]), buffer.view_mut(), wavelet);
543        reconstruct(signal.slice_mut(s![..8]), buffer.view_mut(), wavelet);
544
545        raw.into_iter()
546            .zip(signal)
547            .for_each(|(a, b)| assert_abs_diff_eq!(a, b, epsilon = 0.0001));
548    }
549}