use crate::core::error::{PureCvError, Result};
use crate::core::types::BorderTypes;
use crate::core::Matrix;
use crate::imgproc::derivatives::sobel;
use num_traits::{FromPrimitive, NumCast, ToPrimitive};
#[cfg(not(feature = "parallel"))]
use crate::core::utils::ParIterFallback;
#[cfg(feature = "parallel")]
use rayon::prelude::*;
pub fn canny<T>(
src: &Matrix<T>,
threshold1: f64,
threshold2: f64,
aperture_size: i32,
l2_gradient: bool,
) -> Result<Matrix<u8>>
where
T: Default + Clone + ToPrimitive + FromPrimitive + NumCast + Copy + Send + Sync,
{
if src.channels != 1 {
return Err(PureCvError::InvalidInput(
"Canny only supports single-channel images".into(),
));
}
let rows = src.rows;
let cols = src.cols;
let src_f64 = src.convert_to::<f64>()?;
let dx = sobel(
&src_f64,
1,
0,
aperture_size,
1.0,
0.0,
BorderTypes::Reflect101,
)?;
let dy = sobel(
&src_f64,
0,
1,
aperture_size,
1.0,
0.0,
BorderTypes::Reflect101,
)?;
let mut map = Matrix::<f32>::new(rows, cols, 1);
let mut mag = Matrix::<f32>::new(rows, cols, 1);
mag.data
.par_iter_mut()
.zip(dx.data.par_iter())
.zip(dy.data.par_iter())
.zip(map.data.par_iter_mut())
.for_each(|(((m, &gx), &gy), o)| {
let gx_f = ToPrimitive::to_f64(&gx).unwrap_or(0.0);
let gy_f = ToPrimitive::to_f64(&gy).unwrap_or(0.0);
let magnitude = if l2_gradient {
(gx_f * gx_f + gy_f * gy_f).sqrt()
} else {
gx_f.abs() + gy_f.abs()
};
*m = magnitude as f32;
if magnitude > 1e-5 {
let angle = gy_f.atan2(gx_f) * 180.0 / std::f64::consts::PI;
let normalized_angle = if angle < 0.0 { angle + 180.0 } else { angle };
if (0.0..22.5).contains(&normalized_angle)
|| (157.5..=180.0).contains(&normalized_angle)
{
*o = 0.0; } else if (22.5..67.5).contains(&normalized_angle) {
*o = 1.0; } else if (67.5..112.5).contains(&normalized_angle) {
*o = 2.0; } else {
*o = 3.0; }
} else {
*o = -1.0;
}
});
let mut suppressed = Matrix::<u8>::new(rows, cols, 1);
let low_threshold = threshold1.min(threshold2) as f32;
let high_threshold = threshold1.max(threshold2) as f32;
for y in 1..rows - 1 {
for x in 1..cols - 1 {
let m = *mag.at(y as i32, x as i32, 0).unwrap();
let o = *map.at(y as i32, x as i32, 0).unwrap();
if m < low_threshold {
continue;
}
let (m1, m2) = match o as i32 {
0 => (
*mag.at(y as i32, x as i32 - 1, 0).unwrap(),
*mag.at(y as i32, x as i32 + 1, 0).unwrap(),
),
1 => (
*mag.at(y as i32 - 1, x as i32 + 1, 0).unwrap(),
*mag.at(y as i32 + 1, x as i32 - 1, 0).unwrap(),
),
2 => (
*mag.at(y as i32 - 1, x as i32, 0).unwrap(),
*mag.at(y as i32 + 1, x as i32, 0).unwrap(),
),
3 => (
*mag.at(y as i32 - 1, x as i32 - 1, 0).unwrap(),
*mag.at(y as i32 + 1, x as i32 + 1, 0).unwrap(),
),
_ => (0.0, 0.0),
};
if m >= m1 && m >= m2 {
if m >= high_threshold {
suppressed.set(y, x, 0, 2); } else {
suppressed.set(y, x, 0, 1); }
}
}
}
let mut dst = Matrix::<u8>::new(rows, cols, 1);
let mut stack = Vec::with_capacity(rows * cols / 10);
for y in 1..rows - 1 {
for x in 1..cols - 1 {
if *suppressed.at(y as i32, x as i32, 0).unwrap() == 2 {
stack.push((y, x));
dst.set(y, x, 0, 255);
suppressed.set(y, x, 0, 0); }
}
}
while let Some((y, x)) = stack.pop() {
for dy in -1..=1 {
for dx in -1..=1 {
if dy == 0 && dx == 0 {
continue;
}
let ny = y as i32 + dy;
let nx = x as i32 + dx;
if (1..rows as i32 - 1).contains(&ny)
&& (1..cols as i32 - 1).contains(&nx)
&& *suppressed.at(ny, nx, 0).unwrap() == 1
{
dst.set(ny as usize, nx as usize, 0, 255);
suppressed.set(ny as usize, nx as usize, 0, 0);
stack.push((ny as usize, nx as usize));
}
}
}
}
Ok(dst)
}