use super::error::Result;
use super::image_view::{ImageView, OwnedImage};
use super::scalar::Scalar;
pub struct GradientField {
pub(crate) gx: OwnedImage<Scalar>,
pub(crate) gy: OwnedImage<Scalar>,
}
impl GradientField {
#[inline]
pub fn width(&self) -> usize {
self.gx.width()
}
#[inline]
pub fn height(&self) -> usize {
self.gx.height()
}
#[inline]
pub fn gx(&self) -> ImageView<'_, Scalar> {
self.gx.view()
}
#[inline]
pub fn gy(&self) -> ImageView<'_, Scalar> {
self.gy.view()
}
#[inline]
pub fn get(&self, x: usize, y: usize) -> Option<(Scalar, Scalar)> {
let gx = self.gx.get(x, y)?;
let gy = self.gy.get(x, y)?;
Some((gx, gy))
}
#[inline]
pub fn magnitude(&self, x: usize, y: usize) -> Option<Scalar> {
let (gx, gy) = self.get(x, y)?;
Some((gx * gx + gy * gy).sqrt())
}
pub fn max_magnitude(&self) -> Scalar {
let gx = self.gx.data();
let gy = self.gy.data();
gx.iter()
.zip(gy.iter())
.map(|(&x, &y)| x * x + y * y)
.fold(0.0f32, Scalar::max)
.sqrt()
}
}
pub fn sobel_gradient(image: &ImageView<'_, u8>) -> Result<GradientField> {
let w = image.width();
let h = image.height();
let mut gx = OwnedImage::<Scalar>::zeros(w, h)?;
let mut gy = OwnedImage::<Scalar>::zeros(w, h)?;
let gx_data = gx.data_mut();
let gy_data = gy.data_mut();
for y in 1..h - 1 {
for x in 1..w - 1 {
let p00 = image.get(x - 1, y - 1).copied().unwrap_or(0) as Scalar;
let p10 = image.get(x, y - 1).copied().unwrap_or(0) as Scalar;
let p20 = image.get(x + 1, y - 1).copied().unwrap_or(0) as Scalar;
let p01 = image.get(x - 1, y).copied().unwrap_or(0) as Scalar;
let p21 = image.get(x + 1, y).copied().unwrap_or(0) as Scalar;
let p02 = image.get(x - 1, y + 1).copied().unwrap_or(0) as Scalar;
let p12 = image.get(x, y + 1).copied().unwrap_or(0) as Scalar;
let p22 = image.get(x + 1, y + 1).copied().unwrap_or(0) as Scalar;
let dx = (-p00 + p20 - 2.0 * p01 + 2.0 * p21 - p02 + p22) / 8.0;
let dy = (-p00 - 2.0 * p10 - p20 + p02 + 2.0 * p12 + p22) / 8.0;
let idx = y * w + x;
gx_data[idx] = dx;
gy_data[idx] = dy;
}
}
Ok(GradientField { gx, gy })
}
pub fn sobel_gradient_f32(image: &ImageView<'_, f32>) -> Result<GradientField> {
let w = image.width();
let h = image.height();
let mut gx = OwnedImage::<Scalar>::zeros(w, h)?;
let mut gy = OwnedImage::<Scalar>::zeros(w, h)?;
let gx_data = gx.data_mut();
let gy_data = gy.data_mut();
for y in 1..h - 1 {
for x in 1..w - 1 {
let p00 = *image.get(x - 1, y - 1).unwrap_or(&0.0);
let p10 = *image.get(x, y - 1).unwrap_or(&0.0);
let p20 = *image.get(x + 1, y - 1).unwrap_or(&0.0);
let p01 = *image.get(x - 1, y).unwrap_or(&0.0);
let p21 = *image.get(x + 1, y).unwrap_or(&0.0);
let p02 = *image.get(x - 1, y + 1).unwrap_or(&0.0);
let p12 = *image.get(x, y + 1).unwrap_or(&0.0);
let p22 = *image.get(x + 1, y + 1).unwrap_or(&0.0);
let dx = (-p00 + p20 - 2.0 * p01 + 2.0 * p21 - p02 + p22) / 8.0;
let dy = (-p00 - 2.0 * p10 - p20 + p02 + 2.0 * p12 + p22) / 8.0;
let idx = y * w + x;
gx_data[idx] = dx;
gy_data[idx] = dy;
}
}
Ok(GradientField { gx, gy })
}
pub fn gradient_magnitude(field: &GradientField) -> Result<OwnedImage<Scalar>> {
let w = field.width();
let h = field.height();
let mut mag = OwnedImage::<Scalar>::zeros(w, h)?;
let mag_data = mag.data_mut();
let gx_data = field.gx.data();
let gy_data = field.gy.data();
for i in 0..w * h {
let gx = gx_data[i];
let gy = gy_data[i];
mag_data[i] = (gx * gx + gy * gy).sqrt();
}
Ok(mag)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub enum GradientOperator {
#[default]
Sobel,
Scharr,
}
pub fn scharr_gradient(image: &ImageView<'_, u8>) -> Result<GradientField> {
let w = image.width();
let h = image.height();
let mut gx = OwnedImage::<Scalar>::zeros(w, h)?;
let mut gy = OwnedImage::<Scalar>::zeros(w, h)?;
let gx_data = gx.data_mut();
let gy_data = gy.data_mut();
for y in 1..h - 1 {
for x in 1..w - 1 {
let p00 = image.get(x - 1, y - 1).copied().unwrap_or(0) as Scalar;
let p10 = image.get(x, y - 1).copied().unwrap_or(0) as Scalar;
let p20 = image.get(x + 1, y - 1).copied().unwrap_or(0) as Scalar;
let p01 = image.get(x - 1, y).copied().unwrap_or(0) as Scalar;
let p21 = image.get(x + 1, y).copied().unwrap_or(0) as Scalar;
let p02 = image.get(x - 1, y + 1).copied().unwrap_or(0) as Scalar;
let p12 = image.get(x, y + 1).copied().unwrap_or(0) as Scalar;
let p22 = image.get(x + 1, y + 1).copied().unwrap_or(0) as Scalar;
let dx =
(-3.0 * p00 + 3.0 * p20 - 10.0 * p01 + 10.0 * p21 - 3.0 * p02 + 3.0 * p22) / 32.0;
let dy =
(-3.0 * p00 - 10.0 * p10 - 3.0 * p20 + 3.0 * p02 + 10.0 * p12 + 3.0 * p22) / 32.0;
let idx = y * w + x;
gx_data[idx] = dx;
gy_data[idx] = dy;
}
}
Ok(GradientField { gx, gy })
}
pub fn scharr_gradient_f32(image: &ImageView<'_, f32>) -> Result<GradientField> {
let w = image.width();
let h = image.height();
let mut gx = OwnedImage::<Scalar>::zeros(w, h)?;
let mut gy = OwnedImage::<Scalar>::zeros(w, h)?;
let gx_data = gx.data_mut();
let gy_data = gy.data_mut();
for y in 1..h - 1 {
for x in 1..w - 1 {
let p00 = *image.get(x - 1, y - 1).unwrap_or(&0.0);
let p10 = *image.get(x, y - 1).unwrap_or(&0.0);
let p20 = *image.get(x + 1, y - 1).unwrap_or(&0.0);
let p01 = *image.get(x - 1, y).unwrap_or(&0.0);
let p21 = *image.get(x + 1, y).unwrap_or(&0.0);
let p02 = *image.get(x - 1, y + 1).unwrap_or(&0.0);
let p12 = *image.get(x, y + 1).unwrap_or(&0.0);
let p22 = *image.get(x + 1, y + 1).unwrap_or(&0.0);
let dx =
(-3.0 * p00 + 3.0 * p20 - 10.0 * p01 + 10.0 * p21 - 3.0 * p02 + 3.0 * p22) / 32.0;
let dy =
(-3.0 * p00 - 10.0 * p10 - 3.0 * p20 + 3.0 * p02 + 10.0 * p12 + 3.0 * p22) / 32.0;
let idx = y * w + x;
gx_data[idx] = dx;
gy_data[idx] = dy;
}
}
Ok(GradientField { gx, gy })
}
pub fn compute_gradient(
image: &ImageView<'_, u8>,
operator: GradientOperator,
) -> Result<GradientField> {
match operator {
GradientOperator::Sobel => sobel_gradient(image),
GradientOperator::Scharr => scharr_gradient(image),
}
}
pub fn compute_gradient_f32(
image: &ImageView<'_, f32>,
operator: GradientOperator,
) -> Result<GradientField> {
match operator {
GradientOperator::Sobel => sobel_gradient_f32(image),
GradientOperator::Scharr => scharr_gradient_f32(image),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gradient_of_horizontal_step() {
#[rustfmt::skip]
let data: Vec<u8> = vec![
0, 0, 255, 255, 255,
0, 0, 255, 255, 255,
0, 0, 255, 255, 255,
];
let image = ImageView::from_slice(&data, 5, 3).unwrap();
let grad = sobel_gradient(&image).unwrap();
let (gx, gy) = grad.get(2, 1).unwrap();
assert!(
gx > 30.0,
"expected strong horizontal gradient, got gx={gx}"
);
assert!(
gy.abs() < 1e-6,
"expected zero vertical gradient, got gy={gy}"
);
}
#[test]
fn gradient_of_vertical_step() {
#[rustfmt::skip]
let data: Vec<u8> = vec![
0, 0, 0,
0, 0, 0,
255, 255, 255,
255, 255, 255,
255, 255, 255,
];
let image = ImageView::from_slice(&data, 3, 5).unwrap();
let grad = sobel_gradient(&image).unwrap();
let (gx, gy) = grad.get(1, 2).unwrap();
assert!(
gx.abs() < 1e-6,
"expected zero horizontal gradient, got gx={gx}"
);
assert!(gy > 30.0, "expected strong vertical gradient, got gy={gy}");
}
#[test]
fn gradient_magnitude_computation() {
let data: Vec<u8> = vec![0; 9];
let image = ImageView::from_slice(&data, 3, 3).unwrap();
let grad = sobel_gradient(&image).unwrap();
let mag = gradient_magnitude(&grad).unwrap();
assert!(mag.data().iter().all(|&v| v == 0.0));
}
#[test]
fn gradient_field_dimensions() {
let data: Vec<u8> = vec![128; 20];
let image = ImageView::from_slice(&data, 5, 4).unwrap();
let grad = sobel_gradient(&image).unwrap();
assert_eq!(grad.width(), 5);
assert_eq!(grad.height(), 4);
}
#[test]
fn scharr_gradient_of_horizontal_step() {
#[rustfmt::skip]
let data: Vec<u8> = vec![
0, 0, 255, 255, 255,
0, 0, 255, 255, 255,
0, 0, 255, 255, 255,
];
let image = ImageView::from_slice(&data, 5, 3).unwrap();
let grad = scharr_gradient(&image).unwrap();
let (gx, gy) = grad.get(2, 1).unwrap();
assert!(
gx > 30.0,
"expected strong horizontal gradient, got gx={gx}"
);
assert!(
gy.abs() < 1e-6,
"expected zero vertical gradient, got gy={gy}"
);
}
#[test]
fn scharr_gradient_zeros_on_uniform() {
let data: Vec<u8> = vec![128; 25];
let image = ImageView::from_slice(&data, 5, 5).unwrap();
let grad = scharr_gradient(&image).unwrap();
assert!(grad.gx().as_slice().iter().all(|&v| v == 0.0));
assert!(grad.gy().as_slice().iter().all(|&v| v == 0.0));
}
#[test]
fn scharr_dimensions_match() {
let data: Vec<u8> = vec![128; 20];
let image = ImageView::from_slice(&data, 5, 4).unwrap();
let grad = scharr_gradient(&image).unwrap();
assert_eq!(grad.width(), 5);
assert_eq!(grad.height(), 4);
}
#[test]
fn scharr_gradient_f32_matches_u8() {
#[rustfmt::skip]
let data_u8: Vec<u8> = vec![
0, 0, 255, 255, 255,
0, 0, 255, 255, 255,
0, 0, 255, 255, 255,
];
let data_f32: Vec<f32> = data_u8.iter().map(|&v| v as f32).collect();
let img_u8 = ImageView::from_slice(&data_u8, 5, 3).unwrap();
let img_f32 = ImageView::from_slice(&data_f32, 5, 3).unwrap();
let grad_u8 = scharr_gradient(&img_u8).unwrap();
let grad_f32 = scharr_gradient_f32(&img_f32).unwrap();
let (gx_u8, gy_u8) = grad_u8.get(2, 1).unwrap();
let (gx_f32, gy_f32) = grad_f32.get(2, 1).unwrap();
assert!(
(gx_u8 - gx_f32).abs() < 1e-4,
"gx mismatch: u8={gx_u8} f32={gx_f32}"
);
assert!(
(gy_u8 - gy_f32).abs() < 1e-4,
"gy mismatch: u8={gy_u8} f32={gy_f32}"
);
}
#[test]
fn compute_gradient_f32_dispatches() {
#[rustfmt::skip]
let data: Vec<f32> = vec![
0.0, 0.0, 255.0, 255.0, 255.0,
0.0, 0.0, 255.0, 255.0, 255.0,
0.0, 0.0, 255.0, 255.0, 255.0,
];
let image = ImageView::from_slice(&data, 5, 3).unwrap();
let sobel = compute_gradient_f32(&image, GradientOperator::Sobel).unwrap();
let scharr = compute_gradient_f32(&image, GradientOperator::Scharr).unwrap();
let (sobel_gx, _) = sobel.get(2, 1).unwrap();
let (scharr_gx, _) = scharr.get(2, 1).unwrap();
assert!(
sobel_gx.abs() > 0.1,
"Sobel f32 gx should be nonzero: {sobel_gx}"
);
assert!(
scharr_gx.abs() > 0.1,
"Scharr f32 gx should be nonzero: {scharr_gx}"
);
}
#[test]
fn compute_gradient_dispatches_correctly() {
#[rustfmt::skip]
let data: Vec<u8> = vec![
100, 0, 0,
0, 0, 0,
0, 0, 0,
];
let image = ImageView::from_slice(&data, 3, 3).unwrap();
let sobel = compute_gradient(&image, GradientOperator::Sobel).unwrap();
let scharr = compute_gradient(&image, GradientOperator::Scharr).unwrap();
let (sx, _) = sobel.get(1, 1).unwrap();
let (cx, _) = scharr.get(1, 1).unwrap();
assert!(sx.abs() > 0.1, "Sobel gx should be nonzero: {sx}");
assert!(cx.abs() > 0.1, "Scharr gx should be nonzero: {cx}");
assert!(
(sx - cx).abs() > 0.1,
"Sobel gx ({sx}) and Scharr gx ({cx}) should differ"
);
}
}