image_dwt/
transform.rs

1use image::DynamicImage;
2use ndarray::{Array2, Array3};
3
4use crate::aggregate::Aggregate;
5use crate::decompose::WaveletDecompose;
6use crate::kernels::Kernel;
7use crate::layer::{WaveletLayer, WaveletLayerBuffer};
8
9#[derive(Copy, Clone)]
10pub struct Scale {
11    min: f32,
12    #[allow(unused)]
13    max: f32,
14    scaling_ratio: f32,
15}
16
17impl Scale {
18    pub fn new(min: f32, max: f32) -> Self {
19        Self {
20            min,
21            max,
22            scaling_ratio: max - min,
23        }
24    }
25
26    #[inline]
27    pub fn apply(&self, value: f32) -> f32 {
28        (value - self.min) / self.scaling_ratio
29    }
30}
31
32#[derive(Clone)]
33enum ATrousTransformInput {
34    Grayscale { data: Array2<f32> },
35    Rgb { data: Array3<f32> },
36}
37
38impl Aggregate for ATrousTransformInput {
39    fn min(&self) -> f32 {
40        match self {
41            ATrousTransformInput::Grayscale { data } => data.min(),
42            ATrousTransformInput::Rgb { data } => data.min(),
43        }
44    }
45
46    fn max(&self) -> f32 {
47        match self {
48            ATrousTransformInput::Grayscale { data } => data.max(),
49            ATrousTransformInput::Rgb { data } => data.max(),
50        }
51    }
52}
53
54#[derive(Clone)]
55pub struct ATrousTransform {
56    input: ATrousTransformInput,
57    levels: usize,
58    kernel: Kernel,
59    current_level: usize,
60}
61
62impl ATrousTransform {
63    pub fn new(input: &DynamicImage, levels: usize, kernel: Kernel) -> Self {
64        let (width, height) = (input.width() as usize, input.height() as usize);
65
66        let input = match &input {
67            DynamicImage::ImageLuma8(_)
68            | DynamicImage::ImageLumaA8(_)
69            | DynamicImage::ImageLuma16(_)
70            | DynamicImage::ImageLumaA16(_) => {
71                let mut data = Array2::zeros((height, width));
72                for (x, y, pixel) in input.to_luma32f().enumerate_pixels() {
73                    data[[y as usize, x as usize]] = pixel.0[0];
74                }
75
76                ATrousTransformInput::Grayscale { data }
77            }
78            _ => {
79                let mut data = Array3::zeros((height, width, 3));
80                for (x, y, pixel) in input.to_rgb32f().enumerate_pixels() {
81                    let [red, green, blue] = pixel.0;
82                    data[[y as usize, x as usize, 0]] = red;
83                    data[[y as usize, x as usize, 1]] = green;
84                    data[[y as usize, x as usize, 2]] = blue;
85                }
86
87                ATrousTransformInput::Rgb { data }
88            }
89        };
90
91        Self {
92            input,
93            levels,
94            kernel,
95            current_level: 0,
96        }
97    }
98
99    pub fn linear(input: &DynamicImage, levels: usize) -> Self {
100        ATrousTransform::new(input, levels, Kernel::LinearInterpolationKernel)
101    }
102
103    pub fn low_scale(input: &DynamicImage, levels: usize) -> Self {
104        ATrousTransform::new(input, levels, Kernel::LowScaleKernel)
105    }
106
107    pub fn b_spline(input: &DynamicImage, levels: usize) -> Self {
108        ATrousTransform::new(input, levels, Kernel::B3SplineKernel)
109    }
110}
111
112impl Iterator for ATrousTransform {
113    type Item = WaveletLayer;
114
115    fn next(&mut self) -> Option<Self::Item> {
116        let pixel_scale = self.current_level;
117        self.current_level += 1;
118
119        if pixel_scale > self.levels {
120            return None;
121        }
122
123        match &mut self.input {
124            ATrousTransformInput::Grayscale { data } => {
125                if pixel_scale == self.levels {
126                    return Some(WaveletLayer {
127                        buffer: WaveletLayerBuffer::Grayscale { data: data.clone() },
128                        pixel_scale: None,
129                    });
130                }
131
132                let kernel = self.kernel;
133
134                let layer_buffer = data.wavelet_decompose(kernel, pixel_scale);
135                Some(WaveletLayer {
136                    pixel_scale: Some(pixel_scale),
137                    buffer: layer_buffer,
138                })
139            }
140            ATrousTransformInput::Rgb { data } => {
141                if pixel_scale == self.levels {
142                    return Some(WaveletLayer {
143                        buffer: WaveletLayerBuffer::Rgb { data: data.clone() },
144                        pixel_scale: None,
145                    });
146                }
147
148                let kernel = self.kernel;
149
150                let layer_buffer = data.wavelet_decompose(kernel, pixel_scale);
151                Some(WaveletLayer {
152                    pixel_scale: Some(pixel_scale),
153                    buffer: layer_buffer,
154                })
155            }
156        }
157    }
158}