align3d/bilateral/
edge_aware_filter.rs1use std::{marker::PhantomData, mem::swap};
2
3use ndarray::{Array2, Array4, Axis};
4use num::ToPrimitive;
5
6use super::BilateralGrid;
7
8#[derive(Debug, Clone)]
12pub struct BilateralFilter<I> {
13 _phantom: PhantomData<I>,
14 pub sigma_space: f64,
16 pub sigma_color: f64,
18}
19
20impl<I> Default for BilateralFilter<I>
21where
22 I: num::Bounded
23 + Ord
24 + Copy
25 + std::ops::Sub
26 + ToPrimitive
27 + std::convert::From<<I as std::ops::Sub>::Output>
28 + num::NumCast,
29{
30 fn default() -> Self {
31 BilateralFilter {
32 sigma_space: 4.50000000225,
33 sigma_color: 29.9999880000072,
34 _phantom: PhantomData,
35 }
36 }
37}
38
39impl<I> BilateralFilter<I>
40where
41 I: num::Bounded
42 + Ord
43 + Copy
44 + std::ops::Sub
45 + ToPrimitive
46 + std::convert::From<<I as std::ops::Sub>::Output>
47 + num::NumCast,
48{
49 pub fn new(sigma_space: f64, sigma_color: f64) -> Self {
50 Self {
51 sigma_space,
52 sigma_color,
53 _phantom: PhantomData,
54 }
55 }
56
57 fn convolution(grid: &mut BilateralGrid<I>) {
58 let mut data_ptr = grid.data.as_mut_ptr();
59
60 let mut buffer = Array4::zeros(grid.dim());
61 let mut buffer_ptr: *mut f64 = buffer.as_mut_ptr();
62
63 let (grid_height, grid_width, grid_depth, _) = grid.dim();
64
65 let row_stride = grid.data.stride_of(Axis(0));
66 let col_stride = grid.data.stride_of(Axis(1));
67 let channel_stride = grid.data.stride_of(Axis(2));
68 for plane_offset in &[row_stride, col_stride, channel_stride] {
69 let plane_offset = *plane_offset;
70 for _ in 0..2 {
71 swap(&mut data_ptr, &mut buffer_ptr);
72 for row in 1..grid_height - 1 {
73 for col in 1..grid_width - 1 {
74 let mut b_ptr = unsafe {
75 buffer_ptr.offset(row as isize * row_stride + col as isize * col_stride)
76 };
77 let mut d_ptr = unsafe {
78 data_ptr.offset(row as isize * row_stride + col as isize * col_stride)
79 };
80
81 for _channel in 1..grid_depth {
82 let (prev_value, curr_value, next_value) = {
83 unsafe {
84 let prev = (
85 *b_ptr.offset(-plane_offset),
86 *b_ptr.offset(-plane_offset + 1),
87 );
88 let curr = (*b_ptr, *b_ptr.add(1));
89 let next = (
90 *b_ptr.add(plane_offset as usize),
91 *b_ptr.add((plane_offset + 1) as usize),
92 );
93
94 (prev, curr, next)
95 }
96 };
97
98 let (value, weight) = (
99 (prev_value.0 + next_value.0 + 2.0 * curr_value.0) * 0.25,
100 (prev_value.1 + next_value.1 + 2.0 * curr_value.1) * 0.25,
101 );
102
103 unsafe {
104 *d_ptr = value;
105 *d_ptr.add(1) = weight;
106
107 b_ptr = b_ptr.offset(channel_stride);
108 d_ptr = d_ptr.offset(channel_stride);
109 }
110 }
111 }
112 }
113 }
114 }
115 }
116
117 pub fn filter(&self, image: &Array2<I>) -> Array2<I>
127 where
128 I: num::Zero,
129 {
130 let mut grid = BilateralGrid::from_image(image, self.sigma_space, self.sigma_color);
131 BilateralFilter::convolution(&mut grid);
132
133 grid.normalize();
134 grid.slice(image)
135 }
136
137 pub fn scale_down(&self, image: &Array2<I>) -> Array2<I>
138 where
139 I: num::Zero,
140 {
141 let (src_height, src_width) = image.dim();
142 let image = self.filter(image);
143 let (dst_height, dst_width) = (src_height / 2, src_width / 2);
144 Array2::<I>::from_shape_fn((dst_height, dst_width), |(i_dst, j_dst)| {
145 image[[i_dst * 2, j_dst * 2]]
146 })
147 }
148}