align3d/bilateral/
edge_aware_filter.rs

1use std::{marker::PhantomData, mem::swap};
2
3use ndarray::{Array2, Array4, Axis};
4use num::ToPrimitive;
5
6use super::BilateralGrid;
7
8/// Bilateral filter using Bilateral Grid.
9///
10/// Port from https://gist.github.com/ginrou/02e945562607fad170a1.
11#[derive(Debug, Clone)]
12pub struct BilateralFilter<I> {
13    _phantom: PhantomData<I>,
14    /// The space (XY) down sample factor.
15    pub sigma_space: f64,
16    /// The intensity down sample factor.
17    pub sigma_color: f64,
18}
19
20impl<I> Default for BilateralFilter<I>
21where
22    I: num::Bounded
23        + Ord
24        + Copy
25        + std::ops::Sub
26        + ToPrimitive
27        + std::convert::From<<I as std::ops::Sub>::Output>
28        + num::NumCast,
29{
30    fn default() -> Self {
31        BilateralFilter {
32            sigma_space: 4.50000000225,
33            sigma_color: 29.9999880000072,
34            _phantom: PhantomData,
35        }
36    }
37}
38
39impl<I> BilateralFilter<I>
40where
41    I: num::Bounded
42        + Ord
43        + Copy
44        + std::ops::Sub
45        + ToPrimitive
46        + std::convert::From<<I as std::ops::Sub>::Output>
47        + num::NumCast,
48{
49    pub fn new(sigma_space: f64, sigma_color: f64) -> Self {
50        Self {
51            sigma_space,
52            sigma_color,
53            _phantom: PhantomData,
54        }
55    }
56
57    fn convolution(grid: &mut BilateralGrid<I>) {
58        let mut data_ptr = grid.data.as_mut_ptr();
59
60        let mut buffer = Array4::zeros(grid.dim());
61        let mut buffer_ptr: *mut f64 = buffer.as_mut_ptr();
62
63        let (grid_height, grid_width, grid_depth, _) = grid.dim();
64
65        let row_stride = grid.data.stride_of(Axis(0));
66        let col_stride = grid.data.stride_of(Axis(1));
67        let channel_stride = grid.data.stride_of(Axis(2));
68        for plane_offset in &[row_stride, col_stride, channel_stride] {
69            let plane_offset = *plane_offset;
70            for _ in 0..2 {
71                swap(&mut data_ptr, &mut buffer_ptr);
72                for row in 1..grid_height - 1 {
73                    for col in 1..grid_width - 1 {
74                        let mut b_ptr = unsafe {
75                            buffer_ptr.offset(row as isize * row_stride + col as isize * col_stride)
76                        };
77                        let mut d_ptr = unsafe {
78                            data_ptr.offset(row as isize * row_stride + col as isize * col_stride)
79                        };
80
81                        for _channel in 1..grid_depth {
82                            let (prev_value, curr_value, next_value) = {
83                                unsafe {
84                                    let prev = (
85                                        *b_ptr.offset(-plane_offset),
86                                        *b_ptr.offset(-plane_offset + 1),
87                                    );
88                                    let curr = (*b_ptr, *b_ptr.add(1));
89                                    let next = (
90                                        *b_ptr.add(plane_offset as usize),
91                                        *b_ptr.add((plane_offset + 1) as usize),
92                                    );
93
94                                    (prev, curr, next)
95                                }
96                            };
97
98                            let (value, weight) = (
99                                (prev_value.0 + next_value.0 + 2.0 * curr_value.0) * 0.25,
100                                (prev_value.1 + next_value.1 + 2.0 * curr_value.1) * 0.25,
101                            );
102
103                            unsafe {
104                                *d_ptr = value;
105                                *d_ptr.add(1) = weight;
106
107                                b_ptr = b_ptr.offset(channel_stride);
108                                d_ptr = d_ptr.offset(channel_stride);
109                            }
110                        }
111                    }
112                }
113            }
114        }
115    }
116
117    /// Filters the image. It will try to reuse buffers if possible.
118    ///
119    /// # Arguments:
120    ///
121    /// * `image`: Input image.
122    ///
123    /// # Returns:
124    ///
125    /// * The filtered image.
126    pub fn filter(&self, image: &Array2<I>) -> Array2<I>
127    where
128        I: num::Zero,
129    {
130        let mut grid = BilateralGrid::from_image(image, self.sigma_space, self.sigma_color);
131        BilateralFilter::convolution(&mut grid);
132
133        grid.normalize();
134        grid.slice(image)
135    }
136
137    pub fn scale_down(&self, image: &Array2<I>) -> Array2<I>
138    where
139        I: num::Zero,
140    {
141        let (src_height, src_width) = image.dim();
142        let image = self.filter(image);
143        let (dst_height, dst_width) = (src_height / 2, src_width / 2);
144        Array2::<I>::from_shape_fn((dst_height, dst_width), |(i_dst, j_dst)| {
145            image[[i_dst * 2, j_dst * 2]]
146        })
147    }
148}