1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use ndarray::{Array2, Array3};

use crate::kernels::Kernel;
use crate::layer::WaveletLayerBuffer;

pub trait WaveletDecompose {
    fn wavelet_decompose<const KERNEL_SIZE: usize>(
        &mut self,
        kernel: impl Kernel<KERNEL_SIZE>,
        pixel_scale: usize,
        width: usize,
        height: usize,
    ) -> WaveletLayerBuffer;
}

impl WaveletDecompose for Array2<f32> {
    fn wavelet_decompose<const KERNEL_SIZE: usize>(
        &mut self,
        kernel: impl Kernel<KERNEL_SIZE>,
        pixel_scale: usize,
        width: usize,
        height: usize,
    ) -> WaveletLayerBuffer {
        let distance = 2_usize.pow(pixel_scale as u32);
        let mut current_data = Array2::<f32>::zeros((height, width));

        for x in 0..width {
            for y in 0..height {
                let mut pixels_sum = 0.0;

                let abs_kernel_size = (kernel.size() / 2) as isize;
                let kernel_values = kernel.values();

                for kernel_index_x in -abs_kernel_size..=abs_kernel_size {
                    for kernel_index_y in -abs_kernel_size..=abs_kernel_size {
                        let index = kernel.compute_extended_index(
                            x,
                            y,
                            kernel_index_x * distance as isize,
                            kernel_index_y * distance as isize,
                            current_data.dim(),
                        );
                        let kernel_value = kernel_values
                            [(kernel_index_x + abs_kernel_size) as usize]
                            [(kernel_index_y + abs_kernel_size) as usize];

                        pixels_sum += kernel_value * self[index];
                    }
                }

                current_data[[y, x]] = pixels_sum;
            }
        }

        let final_data = self.clone() - &current_data;
        *self = current_data;

        WaveletLayerBuffer::Grayscale { data: final_data }
    }
}

impl WaveletDecompose for Array3<f32> {
    fn wavelet_decompose<const KERNEL_SIZE: usize>(
        &mut self,
        kernel: impl Kernel<KERNEL_SIZE>,
        pixel_scale: usize,
        width: usize,
        height: usize,
    ) -> WaveletLayerBuffer {
        let distance = 2_usize.pow(pixel_scale as u32);
        let mut current_data = Array3::<f32>::zeros((height, width, 3));

        for x in 0..width {
            for y in 0..height {
                for channel in 0..3 {
                    let mut pixels_sum = 0.0;

                    let abs_kernel_size = (kernel.size() / 2) as isize;
                    let kernel_values = kernel.values();

                    for kernel_index_x in -abs_kernel_size..=abs_kernel_size {
                        for kernel_index_y in -abs_kernel_size..=abs_kernel_size {
                            let index = kernel.compute_extended_index(
                                x,
                                y,
                                kernel_index_x * distance as isize,
                                kernel_index_y * distance as isize,
                                (current_data.dim().0, current_data.dim().1),
                            );
                            let kernel_value = kernel_values
                                [(kernel_index_x + abs_kernel_size) as usize]
                                [(kernel_index_y + abs_kernel_size) as usize];

                            pixels_sum += kernel_value * self[[index[0], index[1], channel]];
                        }
                    }

                    current_data[[y, x, channel]] = pixels_sum;
                }
            }
        }

        let final_data = self.clone() - &current_data;
        *self = current_data;

        WaveletLayerBuffer::Rgb { data: final_data }
    }
}