use crate::error::Result;
use scirs2_core::ndarray::{s, Array2, Array3};
use scirs2_core::parallel_ops::*;
use std::sync::Mutex;
#[allow(dead_code)]
pub fn nlm_denoise(
input: &Array2<f32>,
h: f32,
template_window_size: usize,
search_window_size: usize,
) -> Result<Array2<f32>> {
let (height, width) = input.dim();
let template_size = if template_window_size.is_multiple_of(2) {
template_window_size + 1
} else {
template_window_size
};
let search_size = if search_window_size.is_multiple_of(2) {
search_window_size + 1
} else {
search_window_size
};
let template_radius = template_size / 2;
let search_radius = search_size / 2;
let padsize = search_radius + template_radius;
let padded = pad_reflect(input, padsize);
let mut output = Array2::zeros((height, width));
let h_squared = h * h;
for y in 0..height {
for x in 0..width {
let py = y + padsize;
let px = x + padsize;
let template_patch = padded.slice(s![
(py - template_radius)..=(py + template_radius),
(px - template_radius)..=(px + template_radius)
]);
let mut weight_sum = 0.0;
let mut weighted_value_sum = 0.0;
for sy in (py.saturating_sub(search_radius))
..=(py + search_radius).min(padded.nrows() - template_radius - 1)
{
for sx in (px.saturating_sub(search_radius))
..=(px + search_radius).min(padded.ncols() - template_radius - 1)
{
if sy < template_radius || sx < template_radius {
continue;
}
let compare_patch = padded.slice(s![
(sy - template_radius)..=(sy + template_radius),
(sx - template_radius)..=(sx + template_radius)
]);
let mut distance = 0.0;
for i in 0..template_size {
for j in 0..template_size {
let diff = template_patch[[i, j]] - compare_patch[[i, j]];
distance += diff * diff;
}
}
distance /= (template_size * template_size) as f32;
let weight = (-distance / h_squared).exp();
weight_sum += weight;
weighted_value_sum += weight * padded[[sy, sx]];
}
}
output[[y, x]] = if weight_sum > 0.0 {
weighted_value_sum / weight_sum
} else {
input[[y, x]]
};
}
}
Ok(output)
}
#[allow(dead_code)]
pub fn nlm_denoise_color(
input: &Array3<f32>,
h: f32,
template_window_size: usize,
search_window_size: usize,
) -> Result<Array3<f32>> {
let (height, width, channels) = input.dim();
if channels != 3 {
return Err(crate::error::VisionError::InvalidParameter(
"Input must be an HxWx3 color image".to_string(),
));
}
let mut output = Array3::zeros((height, width, 3));
for c in 0..3 {
let channel = input.slice(s![.., .., c]).to_owned();
let denoised = nlm_denoise(&channel, h, template_window_size, search_window_size)?;
output.slice_mut(s![.., .., c]).assign(&denoised);
}
Ok(output)
}
#[allow(dead_code)]
pub fn nlm_denoise_parallel(
input: &Array2<f32>,
h: f32,
template_window_size: usize,
search_window_size: usize,
) -> Result<Array2<f32>> {
let (height, width) = input.dim();
let template_size = if template_window_size.is_multiple_of(2) {
template_window_size + 1
} else {
template_window_size
};
let search_size = if search_window_size.is_multiple_of(2) {
search_window_size + 1
} else {
search_window_size
};
let template_radius = template_size / 2;
let search_radius = search_size / 2;
let padsize = search_radius + template_radius;
let padded = pad_reflect(input, padsize);
let output = Mutex::new(Array2::zeros((height, width)));
let h_squared = h * h;
let pixels: Vec<(usize, usize)> = (0..height)
.flat_map(|y| (0..width).map(move |x| (y, x)))
.collect();
pixels.par_iter().for_each(|&(y, x)| {
let py = y + padsize;
let px = x + padsize;
let template_patch = padded.slice(s![
(py - template_radius)..=(py + template_radius),
(px - template_radius)..=(px + template_radius)
]);
let mut weight_sum = 0.0;
let mut weighted_value_sum = 0.0;
for sy in (py.saturating_sub(search_radius))
..=(py + search_radius).min(padded.nrows() - template_radius - 1)
{
for sx in (px.saturating_sub(search_radius))
..=(px + search_radius).min(padded.ncols() - template_radius - 1)
{
if sy < template_radius || sx < template_radius {
continue;
}
let compare_patch = padded.slice(s![
(sy - template_radius)..=(sy + template_radius),
(sx - template_radius)..=(sx + template_radius)
]);
let mut distance = 0.0;
for i in 0..template_size {
for j in 0..template_size {
let diff = template_patch[[i, j]] - compare_patch[[i, j]];
distance += diff * diff;
}
}
distance /= (template_size * template_size) as f32;
let weight = (-distance / h_squared).exp();
weight_sum += weight;
weighted_value_sum += weight * padded[[sy, sx]];
}
}
let value = if weight_sum > 0.0 {
weighted_value_sum / weight_sum
} else {
input[[y, x]]
};
output.lock().expect("Operation failed")[[y, x]] = value;
});
Ok(output.into_inner().expect("Operation failed"))
}
#[allow(dead_code)]
fn pad_reflect(array: &Array2<f32>, padsize: usize) -> Array2<f32> {
let (height, width) = array.dim();
let new_height = height + 2 * padsize;
let new_width = width + 2 * padsize;
let mut padded = Array2::zeros((new_height, new_width));
padded
.slice_mut(s![padsize..padsize + height, padsize..padsize + width])
.assign(array);
for i in 0..padsize {
let src_row = padsize + i;
let dst_row = padsize - i - 1;
for col in padsize..padsize + width {
padded[[dst_row, col]] = padded[[src_row, col]];
}
let src_row = padsize + height - i - 1;
let dst_row = padsize + height + i;
for col in padsize..padsize + width {
padded[[dst_row, col]] = padded[[src_row, col]];
}
}
for j in 0..padsize {
let src_col = padsize + j;
let dst_col = padsize - j - 1;
for row in 0..new_height {
padded[[row, dst_col]] = padded[[row, src_col]];
}
let src_col = padsize + width - j - 1;
let dst_col = padsize + width + j;
for row in 0..new_height {
padded[[row, dst_col]] = padded[[row, src_col]];
}
}
padded
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_pad_reflect() {
let array =
Array2::from_shape_vec((3, 3), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0])
.expect("Operation failed");
let padded = pad_reflect(&array, 1);
assert_eq!(padded.dim(), (5, 5));
for i in 0..3 {
for j in 0..3 {
assert_eq!(padded[[i + 1, j + 1]], array[[i, j]]);
}
}
assert_eq!(padded[[0, 0]], padded[[1, 1]]); assert_eq!(padded[[0, 1]], padded[[1, 1]]); assert_eq!(padded[[1, 0]], padded[[1, 1]]); }
#[test]
fn test_nlm_denoise_constant() {
let input = Array2::ones((10, 10));
let result = nlm_denoise(&input, 0.1, 3, 5).expect("Operation failed");
for val in result.iter() {
assert!((val - 1.0).abs() < 0.01);
}
}
#[test]
fn test_nlm_denoise_dimensions() {
let input = Array2::zeros((20, 30));
let result = nlm_denoise(&input, 0.1, 7, 21).expect("Operation failed");
assert_eq!(result.dim(), (20, 30));
}
#[test]
fn test_nlm_denoise_color() {
let input = Array3::ones((10, 10, 3));
let result = nlm_denoise_color(&input, 0.1, 3, 5).expect("Operation failed");
assert_eq!(result.dim(), (10, 10, 3));
}
#[test]
fn test_nlm_denoise_parallel() {
let input = Array2::from_shape_fn((20, 20), |(i, j)| ((i + j) % 2) as f32);
let serial = nlm_denoise(&input, 0.1, 3, 7).expect("Operation failed");
let parallel = nlm_denoise_parallel(&input, 0.1, 3, 7).expect("Operation failed");
for (s, p) in serial.iter().zip(parallel.iter()) {
assert!((s - p).abs() < 1e-5);
}
}
#[test]
fn test_invalid_color_channels() {
let input = Array3::zeros((10, 10, 4)); let result = nlm_denoise_color(&input, 0.1, 3, 5);
assert!(result.is_err());
}
}