1use crate::core::{ColourModel, Image, ImageBase};
2use crate::processing::*;
3use ndarray::prelude::*;
4use ndarray::{DataMut, IntoDimension};
5use num_traits::{cast::FromPrimitive, real::Real, Num, NumAssignOps};
6use std::collections::HashSet;
7use std::marker::PhantomData;
8
9pub trait CannyEdgeDetectorExt<T> {
11 type Output;
13
14 fn canny_edge_detector(&self, params: CannyParameters<T>) -> Result<Self::Output, Error>;
18}
19
20#[derive(Clone, Eq, PartialEq, Hash, Debug)]
23pub struct CannyBuilder<T> {
24 blur: Option<Array3<T>>,
25 t1: Option<T>,
26 t2: Option<T>,
27}
28
29#[derive(Clone, Eq, PartialEq, Hash, Debug)]
31pub struct CannyParameters<T> {
32 pub blur: Array3<T>,
35 pub t1: T,
37 pub t2: T,
39}
40
41impl<T, U, C> CannyEdgeDetectorExt<T> for ImageBase<U, C>
42where
43 U: DataMut<Elem = T>,
44 T: Copy + Clone + FromPrimitive + Real + Num + NumAssignOps,
45 C: ColourModel,
46{
47 type Output = Image<bool, C>;
48
49 fn canny_edge_detector(&self, params: CannyParameters<T>) -> Result<Self::Output, Error> {
50 let data = self.data.canny_edge_detector(params)?;
51 Ok(Self::Output {
52 data,
53 model: PhantomData,
54 })
55 }
56}
57
58impl<T, U> CannyEdgeDetectorExt<T> for ArrayBase<U, Ix3>
59where
60 U: DataMut<Elem = T>,
61 T: Copy + Clone + FromPrimitive + Real + Num + NumAssignOps,
62{
63 type Output = Array3<bool>;
64
65 fn canny_edge_detector(&self, params: CannyParameters<T>) -> Result<Self::Output, Error> {
66 if self.shape()[2] > 1 {
67 Err(Error::ChannelDimensionMismatch)
68 } else {
69 let blurred = self.conv2d(params.blur.view())?;
71 let (mag, rot) = blurred.full_sobel()?;
72
73 let mag = non_maxima_supression(mag, rot.view());
74
75 Ok(link_edges(mag, params.t1, params.t2))
76 }
77 }
78}
79
80fn non_maxima_supression<T>(magnitudes: Array3<T>, rotations: ArrayView3<T>) -> Array3<T>
81where
82 T: Copy + Clone + FromPrimitive + Real + Num + NumAssignOps,
83{
84 let row_size = magnitudes.shape()[0] as isize;
85 let column_size = magnitudes.shape()[1] as isize;
86
87 let get_neighbours = |r, c, dr, dc| {
88 if (r == 0 && dr < 0) || (r == (row_size - 1) && dr > 0) {
89 T::zero()
90 } else if (c == 0 && dc < 0) || (c == (column_size - 1) && dc > 0) {
91 T::zero()
92 } else {
93 magnitudes[[(r + dr) as usize, (c + dc) as usize, 0]]
94 }
95 };
96
97 let mut result = magnitudes.clone();
98
99 for (i, mut row) in result.outer_iter_mut().enumerate() {
100 let i = i as isize;
101 for (j, mut col) in row.outer_iter_mut().enumerate() {
102 let mut dir = rotations[[i as usize, j, 0]]
103 .to_degrees()
104 .to_f64()
105 .unwrap_or(0.0);
106
107 let j = j as isize;
108 if dir >= 180.0 {
109 dir -= 180.0;
110 } else if dir < 0.0 {
111 dir += 180.0;
112 }
113 let (a, b) = if dir < 45.0 {
115 (get_neighbours(i, j, 0, -1), get_neighbours(i, j, 0, 1))
116 } else if dir < 90.0 {
117 (get_neighbours(i, j, -1, -1), get_neighbours(i, j, 1, 1))
118 } else if dir < 135.0 {
119 (get_neighbours(i, j, -1, 0), get_neighbours(i, j, 1, 0))
120 } else {
121 (get_neighbours(i, j, -1, 1), get_neighbours(i, j, 1, -1))
122 };
123
124 if a > col[[0]] || b > col[[0]] {
125 col.fill(T::zero());
126 }
127 }
128 }
129 result
130}
131
132fn get_candidates(
133 coord: (usize, usize),
134 bounds: (usize, usize),
135 closed_set: &HashSet<[usize; 2]>,
136) -> Vec<[usize; 2]> {
137 let mut result = Vec::new();
138 let (r, c) = coord;
139 let (rows, cols) = bounds;
140
141 if r > 0 {
142 if c > 0 && !closed_set.contains(&[r - 1, c + 1]) {
143 result.push([r - 1, c - 1]);
144 }
145 if c < cols - 1 && !closed_set.contains(&[r - 1, c + 1]) {
146 result.push([r - 1, c + 1]);
147 }
148 if !closed_set.contains(&[r - 1, c]) {
149 result.push([r - 1, c]);
150 }
151 }
152 if r < rows - 1 {
153 if c > 0 && !closed_set.contains(&[r + 1, c - 1]) {
154 result.push([r + 1, c - 1]);
155 }
156 if c < cols - 1 && !closed_set.contains(&[r + 1, c + 1]) {
157 result.push([r + 1, c + 1]);
158 }
159 if !closed_set.contains(&[r + 1, c]) {
160 result.push([r + 1, c]);
161 }
162 }
163 result
164}
165
166fn link_edges<T>(magnitudes: Array3<T>, lower: T, upper: T) -> Array3<bool>
167where
168 T: Copy + Clone + FromPrimitive + Real + Num + NumAssignOps,
169{
170 let magnitudes = magnitudes.mapv(|x| if x >= lower { x } else { T::zero() });
171 let mut result = magnitudes.mapv(|x| x >= upper);
172 let mut visited = HashSet::new();
173
174 let rows = result.shape()[0];
175 let cols = result.shape()[1];
176
177 for r in 0..rows {
178 for c in 0..cols {
179 if result[[r, c, 0]] {
181 visited.insert([r, c]);
182 let mut buffer = get_candidates((r, c), (rows, cols), &visited);
183
184 while let Some(cand) = buffer.pop() {
185 let coord3 = [cand[0], cand[1], 0];
186 if magnitudes[coord3] > lower {
187 visited.insert(cand);
188 result[coord3] = true;
189
190 let temp = get_candidates((cand[0], cand[1]), (rows, cols), &visited);
191 buffer.extend_from_slice(temp.as_slice());
192 }
193 }
194 }
195 }
196 }
197 result
198}
199
200impl<T> Default for CannyBuilder<T>
201where
202 T: Copy + Clone + FromPrimitive + Real + Num,
203{
204 fn default() -> Self {
205 Self::new()
206 }
207}
208
209impl<T> CannyBuilder<T>
210where
211 T: Copy + Clone + FromPrimitive + Real + Num,
212{
213 pub fn new() -> Self {
215 Self {
216 blur: None,
217 t1: None,
218 t2: None,
219 }
220 }
221
222 pub fn lower_threshold(self, t1: T) -> Self {
224 Self {
225 blur: self.blur,
226 t1: Some(t1),
227 t2: self.t2,
228 }
229 }
230
231 pub fn upper_threshold(self, t2: T) -> Self {
233 Self {
234 blur: self.blur,
235 t1: self.t1,
236 t2: Some(t2),
237 }
238 }
239
240 pub fn blur<D>(self, shape: D, covariance: [f64; 2]) -> Self
243 where
244 D: Copy + IntoDimension<Dim = Ix2>,
245 {
246 let shape = shape.into_dimension();
247 let shape = (shape[0], shape[1], 1);
248 if let Ok(blur) = GaussianFilter::build_with_params(shape, covariance) {
249 Self {
250 blur: Some(blur),
251 t1: self.t1,
252 t2: self.t2,
253 }
254 } else {
255 self
256 }
257 }
258
259 pub fn build(self) -> CannyParameters<T> {
266 let blur = match self.blur {
267 Some(b) => b,
268 None => GaussianFilter::build_with_params((5, 5, 1), [2.0, 2.0]).unwrap(),
269 };
270 let mut t1 = match self.t1 {
271 Some(t) => t,
272 None => T::from_f64(0.3).unwrap(),
273 };
274 let mut t2 = match self.t2 {
275 Some(t) => t,
276 None => T::from_f64(0.7).unwrap(),
277 };
278 if t2 < t1 {
279 std::mem::swap(&mut t1, &mut t2);
280 }
281 CannyParameters { blur, t1, t2 }
282 }
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288 use ndarray::arr3;
289
290 #[test]
291 fn canny_builder() {
292 let builder = CannyBuilder::<f64>::new()
293 .lower_threshold(0.75)
294 .upper_threshold(0.25);
295
296 assert_eq!(builder.t1, Some(0.75));
297 assert_eq!(builder.t2, Some(0.25));
298 assert_eq!(builder.blur, None);
299
300 let result = builder.clone().build();
301
302 assert_eq!(result.t1, 0.25);
303 assert_eq!(result.t2, 0.75);
304
305 let builder2 = builder.blur((3, 3), [0.2, 0.2]);
306
307 assert_eq!(builder2.t1, Some(0.75));
308 assert_eq!(builder2.t2, Some(0.25));
309 assert!(builder2.blur.is_some());
310 let gauss = builder2.blur.unwrap();
311 assert_eq!(gauss.shape(), [3, 3, 1]);
312 }
313
314 #[test]
315 fn canny_thresholding() {
316 let magnitudes = arr3(&[
317 [[0.2], [0.4], [0.0]],
318 [[0.7], [0.5], [0.8]],
319 [[0.1], [0.6], [0.0]],
320 ]);
321
322 let expected = arr3(&[
323 [[false], [false], [false]],
324 [[true], [true], [true]],
325 [[false], [true], [false]],
326 ]);
327
328 let result = link_edges(magnitudes, 0.4, 0.69);
329
330 assert_eq!(result, expected);
331 }
332}