align3d/bilateral/
grid.rs

1use ndarray::{Array2, Array4, Axis};
2use num::{clamp, ToPrimitive};
3use std::cmp::{max, min};
4
5/// Bilateral grid. A data structure for representing images
6/// within its intensity space.
7///
8/// More information: Chen, J., Paris, S.,
9/// & Durand, F. (2007). Real-time edge-aware image processing with
10/// the bilateral grid. ACM Transactions on Graphics (TOG), 26(3), 103-es.
11pub struct BilateralGrid<I> {
12    /// The grid data. Shape is [H W Z 2], where the last dimension contains,
13    /// in order, the color value and the counter.
14    pub data: Array4<f64>,
15    sigma_space: f64,
16    sigma_color: f64,
17    color_min: I,
18    space_pad: usize,
19    color_pad: usize,
20}
21
22impl<I> BilateralGrid<I>
23where
24    I: num::Bounded
25        + Ord
26        + Copy
27        + std::ops::Sub
28        + ToPrimitive
29        + std::convert::From<<I as std::ops::Sub>::Output>
30        + num::NumCast,
31{
32    pub fn from_image(image: &Array2<I>, sigma_space: f64, sigma_color: f64) -> Self {
33        let space_pad = 2;
34        let color_pad = 2;
35
36        let (image_height, image_width) = image.dim();
37
38        let grid_height = ((image_height - 1) as f64 / sigma_space) as usize + 1 + 2 * space_pad;
39        let grid_width = ((image_width - 1) as f64 / sigma_space) as usize + 1 + 2 * space_pad;
40
41        let (color_min, color_max) = {
42            let mut mi = I::max_value();
43            let mut ma = I::min_value();
44            image.iter().for_each(|v| {
45                mi = min(mi, *v);
46                ma = max(ma, *v);
47            });
48            (mi, ma)
49        };
50
51        let grid_depth = {
52            let diff: I = (color_max - color_min).into();
53            (diff.to_f64().unwrap() / sigma_color) as usize + 1 + 2 * color_pad
54        };
55
56        let inv_sigma_space = 1.0 / sigma_space;
57        let inv_sigma_color = 1.0 / sigma_color;
58
59        let mut grid = Array4::<f64>::zeros((grid_height, grid_width, grid_depth, 2));
60        for row in 0..image_height {
61            let grid_row = (row as f64 * inv_sigma_space + 0.5) as usize + space_pad;
62
63            for col in 0..image_width {
64                let grid_col = (col as f64 * inv_sigma_space + 0.5) as usize + space_pad;
65
66                let color = image[(row, col)];
67                if color <= I::min_value() {
68                    continue;
69                }
70
71                let channel = {
72                    let diff: I = (color - color_min).into();
73                    (diff.to_f64().unwrap() * inv_sigma_color + 0.5) as usize + color_pad
74                };
75                grid[(grid_row, grid_col, channel, 0)] += color.to_f64().unwrap();
76                grid[(grid_row, grid_col, channel, 1)] += 1.0;
77            }
78        }
79
80        Self {
81            data: grid,
82            sigma_color,
83            sigma_space,
84            color_min,
85            space_pad,
86            color_pad,
87        }
88    }
89
90    pub fn normalize(&mut self) {
91        let dim = self.dim();
92        self.data
93            .view_mut()
94            .into_shape((dim.0 * dim.1 * dim.2, 2))
95            .unwrap()
96            .axis_iter_mut(Axis(0))
97            .for_each(|mut color_count| {
98                let count = color_count[1];
99                if count > 0.0 {
100                    color_count[0] /= count;
101                    color_count[1] = 1.0;
102                }
103            });
104    }
105
106    pub fn slice(&self, image: &Array2<I>) -> Array2<I>
107    where
108        I: num::Zero,
109    {
110        let inv_sigma_space = 1.0 / self.sigma_space;
111        let inv_sigma_color = 1.0 / self.sigma_color;
112        let space_pad = self.space_pad as f64;
113        let color_pad = self.color_pad as f64;
114
115        let mut dst_image = Array2::<I>::zeros(image.dim());
116        image
117            .iter()
118            .zip(dst_image.indexed_iter_mut())
119            .for_each(|(color, ((row, col), dst))| {
120                let trilinear = self.trilinear(
121                    row as f64 * inv_sigma_space + space_pad,
122                    col as f64 * inv_sigma_space + space_pad,
123                    {
124                        let diff: I = (*color - self.color_min).into();
125                        diff.to_f64().unwrap() * inv_sigma_color + color_pad
126                    },
127                );
128
129                *dst = num::cast::cast(trilinear).unwrap();
130            });
131        dst_image
132    }
133
134    pub fn trilinear(&self, row: f64, col: f64, channel: f64) -> f64 {
135        let (height, width, depth, _) = self.data.dim();
136
137        let z_index = clamp(channel as usize, 0, depth - 1);
138        let zz_index: usize = clamp((channel + 1.0) as usize, 0, depth - 1);
139        let z_alpha = channel - z_index as f64;
140
141        let y_index = clamp(row as usize, 0, height - 1);
142        let yy_index: usize = clamp((row + 1.0) as usize, 0, height - 1);
143        let y_alpha = row - y_index as f64;
144
145        let x_index = clamp(col as usize, 0, width - 1);
146        let xx_index: usize = clamp((col + 1.0) as usize, 0, width - 1);
147        let x_alpha = col - x_index as f64;
148
149        #[rustfmt::skip]
150        let value =
151        {
152              (1.0 - y_alpha) * (1.0 - x_alpha) * (1.0 - z_alpha) *  self.data[(y_index,  x_index , z_index,  0)]
153            + (1.0 - y_alpha) * x_alpha         * (1.0 - z_alpha) *  self.data[(y_index,  xx_index, z_index,  0)]
154            + y_alpha         * (1.0 - x_alpha) * (1.0 - z_alpha) *  self.data[(yy_index, x_index , z_index,  0)]
155            + y_alpha         * x_alpha         * (1.0 - z_alpha) *  self.data[(yy_index, xx_index, z_index,  0)]
156            + (1.0 - y_alpha) * (1.0 - x_alpha) * z_alpha         *  self.data[(y_index,  x_index , zz_index, 0)]
157            + (1.0 - y_alpha) * x_alpha         * z_alpha         *  self.data[(y_index,  xx_index, zz_index, 0)]
158            + y_alpha         * (1.0 - x_alpha) * z_alpha         *  self.data[(yy_index, x_index , zz_index, 0)]
159            + y_alpha         * x_alpha         * z_alpha         *  self.data[(yy_index, xx_index, zz_index, 0)]
160        };
161        value
162    }
163
164    pub fn dim(&self) -> (usize, usize, usize, usize) {
165        self.data.dim()
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use crate::unit_test::bloei_luma16;
172    use ndarray::Array2;
173    use rstest::{fixture, rstest};
174
175    use super::BilateralGrid;
176
177    #[fixture]
178    fn bilateral_grid(bloei_luma16: Array2<u16>) -> BilateralGrid<u16> {
179        BilateralGrid::from_image(&bloei_luma16, 4.5, 30.0)
180    }
181
182    #[rstest]
183    fn verify_grid_creation(bilateral_grid: BilateralGrid<u16>) {
184        assert_eq!(bilateral_grid.dim(), (138, 104, 173, 2));
185    }
186
187    #[rstest]
188    fn verify_slice(bloei_luma16: Array2<u16>, mut bilateral_grid: BilateralGrid<u16>) {
189        bilateral_grid.normalize();
190        let dest_image = bilateral_grid.slice(&bloei_luma16);
191
192        assert_eq!(dest_image.dim(), (600, 450));
193        assert_eq!(dest_image[(421, 123)], 2266);
194    }
195}