use scirs2_core::ndarray::{Array2, ArrayView1, ArrayView2, Axis, Zip};
use scirs2_core::numeric::{Float, FromPrimitive};
use scirs2_core::parallel_ops::*;
use scirs2_core::simd_ops::SimdUnifiedOps;
use std::fmt::Debug;
use super::BorderMode;
use crate::error::{NdimageError, NdimageResult};
use crate::utils::safe_f64_to_float;
#[allow(dead_code)]
pub fn sobel_2d_optimized<T>(
input: &ArrayView2<T>,
axis: usize,
mode: Option<BorderMode>,
) -> NdimageResult<Array2<T>>
where
T: Float + FromPrimitive + Debug + Send + Sync + 'static,
T: SimdUnifiedOps,
{
let border_mode = mode.unwrap_or(BorderMode::Reflect);
let (height, width) = input.dim();
if axis > 1 {
return Err(NdimageError::InvalidInput(format!(
"Invalid axis {} for 2D array",
axis
)));
}
let mut output = Array2::zeros((height, width));
let k2 = safe_f64_to_float::<T>(2.0)?;
let (k1, k3) = (T::one(), T::one());
let use_parallel = height * width > 10_000;
if use_parallel {
sobel_parallel(input, &mut output, axis, k1, k2, k3, &border_mode)?;
} else {
sobel_sequential(input, &mut output, axis, k1, k2, k3, &border_mode)?;
}
Ok(output)
}
#[allow(dead_code)]
fn sobel_sequential<T>(
input: &ArrayView2<T>,
output: &mut Array2<T>,
axis: usize,
k1: T,
k2: T,
k3: T,
mode: &BorderMode,
) -> NdimageResult<()>
where
T: Float + FromPrimitive + Debug + Send + Sync + 'static,
T: SimdUnifiedOps,
{
let (height, width) = input.dim();
for i in 0..height {
for j in 0..width {
let val = if axis == 0 {
let top =
get_pixel_value(input, i as isize - 1, j as isize - 1, mode, Some(T::zero()))
* k1
+ get_pixel_value(input, i as isize - 1, j as isize, mode, Some(T::zero()))
* k2
+ get_pixel_value(
input,
i as isize - 1,
j as isize + 1,
mode,
Some(T::zero()),
) * k3;
let bottom =
get_pixel_value(input, i as isize + 1, j as isize - 1, mode, Some(T::zero()))
* k1
+ get_pixel_value(input, i as isize + 1, j as isize, mode, Some(T::zero()))
* k2
+ get_pixel_value(
input,
i as isize + 1,
j as isize + 1,
mode,
Some(T::zero()),
) * k3;
top - bottom
} else {
let left =
get_pixel_value(input, i as isize - 1, j as isize - 1, mode, Some(T::zero()))
* k1
+ get_pixel_value(input, i as isize, j as isize - 1, mode, Some(T::zero()))
* k2
+ get_pixel_value(
input,
i as isize + 1,
j as isize - 1,
mode,
Some(T::zero()),
) * k3;
let right =
get_pixel_value(input, i as isize - 1, j as isize + 1, mode, Some(T::zero()))
* k1
+ get_pixel_value(input, i as isize, j as isize + 1, mode, Some(T::zero()))
* k2
+ get_pixel_value(
input,
i as isize + 1,
j as isize + 1,
mode,
Some(T::zero()),
) * k3;
right - left
};
output[[i, j]] = val;
}
}
Ok(())
}
#[allow(dead_code)]
fn sobel_parallel<T>(
input: &ArrayView2<T>,
output: &mut Array2<T>,
axis: usize,
k1: T,
k2: T,
k3: T,
mode: &BorderMode,
) -> NdimageResult<()>
where
T: Float + FromPrimitive + Debug + Send + Sync + 'static,
{
let (height, width) = input.dim();
let input_shared = input.to_owned(); let mode_clone = mode.clone();
output
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(i, mut row)| {
let input_view = input_shared.view();
let input_ref = &input_view;
for j in 0..width {
let val = if axis == 0 {
let top = get_pixel_value(
input_ref,
i as isize - 1,
j as isize - 1,
&mode_clone,
Some(T::zero()),
) * k1
+ get_pixel_value(
input_ref,
i as isize - 1,
j as isize,
&mode_clone,
Some(T::zero()),
) * k2
+ get_pixel_value(
input_ref,
i as isize - 1,
j as isize + 1,
&mode_clone,
Some(T::zero()),
) * k3;
let bottom = get_pixel_value(
input_ref,
i as isize + 1,
j as isize - 1,
&mode_clone,
Some(T::zero()),
) * k1
+ get_pixel_value(
input_ref,
i as isize + 1,
j as isize,
&mode_clone,
Some(T::zero()),
) * k2
+ get_pixel_value(
input_ref,
i as isize + 1,
j as isize + 1,
&mode_clone,
Some(T::zero()),
) * k3;
top - bottom
} else {
let left = get_pixel_value(
input_ref,
i as isize - 1,
j as isize - 1,
&mode_clone,
Some(T::zero()),
) * k1
+ get_pixel_value(
input_ref,
i as isize,
j as isize - 1,
&mode_clone,
Some(T::zero()),
) * k2
+ get_pixel_value(
input_ref,
i as isize + 1,
j as isize - 1,
&mode_clone,
Some(T::zero()),
) * k3;
let right = get_pixel_value(
input_ref,
i as isize - 1,
j as isize + 1,
&mode_clone,
Some(T::zero()),
) * k1
+ get_pixel_value(
input_ref,
i as isize,
j as isize + 1,
&mode_clone,
Some(T::zero()),
) * k2
+ get_pixel_value(
input_ref,
i as isize + 1,
j as isize + 1,
&mode_clone,
Some(T::zero()),
) * k3;
right - left
};
row[j] = val;
}
});
Ok(())
}
#[allow(dead_code)]
fn get_pixel_value<T>(
input: &ArrayView2<T>,
i: isize,
j: isize,
mode: &BorderMode,
cval: Option<T>,
) -> T
where
T: Float + FromPrimitive + Debug,
{
let (height, width) = input.dim();
let (ni, nj) = match mode {
BorderMode::Constant => {
if i < 0 || i >= height as isize || j < 0 || j >= width as isize {
return cval.unwrap_or(T::zero());
}
(i as usize, j as usize)
}
BorderMode::Nearest => {
let ni = i.clamp(0, (height as isize) - 1) as usize;
let nj = j.clamp(0, (width as isize) - 1) as usize;
(ni, nj)
}
BorderMode::Mirror => {
let ni = if i < 0 {
(-i - 1) as usize
} else if i >= height as isize {
(2 * height as isize - i - 1) as usize
} else {
i as usize
};
let nj = if j < 0 {
(-j - 1) as usize
} else if j >= width as isize {
(2 * width as isize - j - 1) as usize
} else {
j as usize
};
(ni.min(height - 1), nj.min(width - 1))
}
BorderMode::Reflect => {
let ni = if i < 0 {
(-i) as usize
} else if i >= height as isize {
(2 * (height as isize - 1) - i) as usize
} else {
i as usize
};
let nj = if j < 0 {
(-j) as usize
} else if j >= width as isize {
(2 * (width as isize - 1) - j) as usize
} else {
j as usize
};
(ni.min(height - 1), nj.min(width - 1))
}
BorderMode::Wrap => {
let ni = ((i % height as isize + height as isize) % height as isize) as usize;
let nj = ((j % width as isize + width as isize) % width as isize) as usize;
(ni, nj)
}
};
input[[ni, nj]]
}
#[allow(dead_code)]
pub fn laplace_2d_optimized<T>(
input: &ArrayView2<T>,
use_diagonal: bool,
mode: Option<BorderMode>,
) -> NdimageResult<Array2<T>>
where
T: Float + FromPrimitive + Debug + Send + Sync + 'static,
T: SimdUnifiedOps,
{
let border_mode = mode.unwrap_or(BorderMode::Reflect);
let (height, width) = input.dim();
let mut output = Array2::zeros((height, width));
let use_parallel = height * width > 10_000;
if use_parallel {
laplace_parallel(input, &mut output, use_diagonal, &border_mode)?;
} else {
laplace_sequential(input, &mut output, use_diagonal, &border_mode)?;
}
Ok(output)
}
#[allow(dead_code)]
fn laplace_sequential<T>(
input: &ArrayView2<T>,
output: &mut Array2<T>,
use_diagonal: bool,
mode: &BorderMode,
) -> NdimageResult<()>
where
T: Float + FromPrimitive + Debug + Send + Sync + 'static,
T: SimdUnifiedOps,
{
let (height, width) = input.dim();
let eight = safe_f64_to_float::<T>(8.0)?;
let four = safe_f64_to_float::<T>(4.0)?;
for i in 0..height {
for j in 0..width {
let center = input[[i, j]];
if use_diagonal {
let mut sum = T::zero();
for di in -1..=1 {
for dj in -1..=1 {
if di == 0 && dj == 0 {
sum = sum + center * eight;
} else {
sum = sum
- get_pixel_value(
input,
i as isize + di,
j as isize + dj,
mode,
Some(T::zero()),
);
}
}
}
output[[i, j]] = sum;
} else {
let sum = center * four
- get_pixel_value(input, i as isize - 1, j as isize, mode, Some(T::zero()))
- get_pixel_value(input, i as isize + 1, j as isize, mode, Some(T::zero()))
- get_pixel_value(input, i as isize, j as isize - 1, mode, Some(T::zero()))
- get_pixel_value(input, i as isize, j as isize + 1, mode, Some(T::zero()));
output[[i, j]] = sum;
}
}
}
Ok(())
}
#[allow(dead_code)]
fn laplace_parallel<T>(
input: &ArrayView2<T>,
output: &mut Array2<T>,
use_diagonal: bool,
mode: &BorderMode,
) -> NdimageResult<()>
where
T: Float + FromPrimitive + Debug + Send + Sync + 'static,
{
let (height, width) = input.dim();
let input_shared = input.to_owned(); let mode_clone = mode.clone();
let eight = safe_f64_to_float::<T>(8.0)?;
let four = safe_f64_to_float::<T>(4.0)?;
output
.axis_iter_mut(Axis(0))
.into_par_iter()
.enumerate()
.for_each(|(i, mut row)| {
let input_view = input_shared.view();
let input_ref = &input_view;
for j in 0..width {
let center = input_ref[[i, j]];
if use_diagonal {
let mut sum = T::zero();
for di in -1..=1 {
for dj in -1..=1 {
if di == 0 && dj == 0 {
sum = sum + center * eight;
} else {
sum = sum
- get_pixel_value(
input_ref,
i as isize + di,
j as isize + dj,
&mode_clone,
Some(T::zero()),
);
}
}
}
row[j] = sum;
} else {
let sum = center * four
- get_pixel_value(
input_ref,
i as isize - 1,
j as isize,
&mode_clone,
Some(T::zero()),
)
- get_pixel_value(
input_ref,
i as isize + 1,
j as isize,
&mode_clone,
Some(T::zero()),
)
- get_pixel_value(
input_ref,
i as isize,
j as isize - 1,
&mode_clone,
Some(T::zero()),
)
- get_pixel_value(
input_ref,
i as isize,
j as isize + 1,
&mode_clone,
Some(T::zero()),
);
row[j] = sum;
}
}
});
Ok(())
}
#[allow(dead_code)]
pub fn gradient_magnitude_optimized<T>(
grad_x: &ArrayView2<T>,
grad_y: &ArrayView2<T>,
) -> NdimageResult<Array2<T>>
where
T: Float + FromPrimitive + Debug + Send + Sync + 'static,
T: SimdUnifiedOps,
{
if grad_x.dim() != grad_y.dim() {
return Err(NdimageError::InvalidInput(
"Gradient arrays must have the same shape".into(),
));
}
let (height, width) = grad_x.dim();
let mut magnitude = Array2::zeros((height, width));
if height * width > 1000 && T::simd_available() {
let gx_flat = grad_x.as_slice().ok_or_else(|| {
NdimageError::ComputationError("Failed to get contiguous slice from grad_x".into())
})?;
let gy_flat = grad_y.as_slice().ok_or_else(|| {
NdimageError::ComputationError("Failed to get contiguous slice from grad_y".into())
})?;
let mag_flat = magnitude.as_slice_mut().ok_or_else(|| {
NdimageError::ComputationError(
"Failed to get mutable contiguous slice from magnitude".into(),
)
})?;
let gx_view = ArrayView1::from(gx_flat);
let gy_view = ArrayView1::from(gy_flat);
let gx_squared = T::simd_mul(&gx_view, &gx_view);
let gy_squared = T::simd_mul(&gy_view, &gy_view);
let magnitude_squared = T::simd_add(&gx_squared.view(), &gy_squared.view());
let magnitude_result = T::simd_sqrt(&magnitude_squared.view());
mag_flat.copy_from_slice(magnitude_result.as_slice().expect("Operation failed"));
} else {
Zip::from(&mut magnitude)
.and(grad_x)
.and(grad_y)
.for_each(|m, &gx, &gy| {
*m = (gx * gx + gy * gy).sqrt();
});
}
Ok(magnitude)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_sobel_optimized() {
let input = array![
[0.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 1.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 0.0]
];
let grad_x = sobel_2d_optimized(&input.view(), 1, Some(BorderMode::Constant))
.expect("sobel_2d_optimized should succeed for test");
assert!(grad_x[[1, 1]] > 0.0);
assert!(grad_x[[1, 2]] < 0.0);
}
#[test]
fn test_laplace_optimized() {
let input = array![[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]];
let result = laplace_2d_optimized(&input.view(), false, None)
.expect("laplace_2d_optimized should succeed for test");
assert!(result[[1, 1]].abs() > 0.0);
}
#[test]
fn test_gradient_magnitude() {
let grad_x = array![[1.0, 0.0], [0.0, 1.0]];
let grad_y = array![[0.0, 1.0], [1.0, 0.0]];
let magnitude = gradient_magnitude_optimized(&grad_x.view(), &grad_y.view())
.expect("gradient_magnitude_optimized should succeed for test");
assert!((magnitude[[0, 0]] - 1.0).abs() < 1e-6);
assert!((magnitude[[0, 1]] - 1.0).abs() < 1e-6);
assert!((magnitude[[1, 0]] - 1.0).abs() < 1e-6);
assert!((magnitude[[1, 1]] - 1.0).abs() < 1e-6);
}
}