#![forbid(unsafe_code)]
#![allow(clippy::too_many_arguments)]
#![allow(clippy::too_many_lines)]
use crate::error::{CvError, CvResult};
use ndarray::Array4;
use oxionnx::Session;
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Default)]
pub enum NoiseLevel {
Low,
Medium,
High,
Custom(f32),
#[default]
Blind,
}
impl NoiseLevel {
#[must_use]
pub fn sigma(&self) -> f32 {
match self {
Self::Low => 10.0,
Self::Medium => 25.0,
Self::High => 40.0,
Self::Custom(sigma) => *sigma,
Self::Blind => 0.0, }
}
#[must_use]
pub const fn is_blind(&self) -> bool {
matches!(self, Self::Blind)
}
}
#[derive(Debug, Clone)]
pub struct DenoisingConfig {
pub noise_level: NoiseLevel,
pub tile_size: u32,
pub tile_padding: u32,
pub color_strength: f32,
pub luma_strength: f32,
}
impl Default for DenoisingConfig {
fn default() -> Self {
Self {
noise_level: NoiseLevel::Blind,
tile_size: 256,
tile_padding: 16,
color_strength: 1.0,
luma_strength: 1.0,
}
}
}
impl DenoisingConfig {
#[must_use]
pub fn new(noise_level: NoiseLevel) -> Self {
Self {
noise_level,
..Default::default()
}
}
#[must_use]
pub fn with_tile_size(mut self, tile_size: u32) -> Self {
self.tile_size = tile_size;
self
}
#[must_use]
pub fn with_tile_padding(mut self, padding: u32) -> Self {
self.tile_padding = padding;
self
}
#[must_use]
pub fn with_color_strength(mut self, strength: f32) -> Self {
self.color_strength = strength.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_luma_strength(mut self, strength: f32) -> Self {
self.luma_strength = strength.clamp(0.0, 1.0);
self
}
pub fn validate(&self) -> CvResult<()> {
if self.tile_size < 64 {
return Err(CvError::invalid_parameter(
"tile_size",
format!("{} (must be >= 64)", self.tile_size),
));
}
if self.tile_padding > self.tile_size / 4 {
return Err(CvError::invalid_parameter(
"tile_padding",
format!("{} (must be <= tile_size / 4)", self.tile_padding),
));
}
if !(0.0..=1.0).contains(&self.color_strength) {
return Err(CvError::invalid_parameter(
"color_strength",
format!("{} (must be in [0.0, 1.0])", self.color_strength),
));
}
if !(0.0..=1.0).contains(&self.luma_strength) {
return Err(CvError::invalid_parameter(
"luma_strength",
format!("{} (must be in [0.0, 1.0])", self.luma_strength),
));
}
Ok(())
}
}
pub type DenoisingProgressCallback = Box<dyn Fn(usize, usize) -> bool + Send + Sync>;
pub struct NeuralDenoiser {
session: Session,
config: DenoisingConfig,
progress_callback: Option<DenoisingProgressCallback>,
}
impl NeuralDenoiser {
pub fn new(model_path: impl AsRef<Path>) -> CvResult<Self> {
let session = Session::builder()
.with_optimization_level(oxionnx::OptLevel::All)
.load(model_path.as_ref())
.map_err(|e| CvError::model_load(format!("Failed to load model: {e}")))?;
Ok(Self {
session,
config: DenoisingConfig::default(),
progress_callback: None,
})
}
pub fn with_config(model_path: impl AsRef<Path>, config: DenoisingConfig) -> CvResult<Self> {
config.validate()?;
let mut denoiser = Self::new(model_path)?;
denoiser.config = config;
Ok(denoiser)
}
pub fn set_config(&mut self, config: DenoisingConfig) -> CvResult<()> {
config.validate()?;
self.config = config;
Ok(())
}
pub fn set_progress_callback(&mut self, callback: DenoisingProgressCallback) {
self.progress_callback = Some(callback);
}
pub fn denoise(&mut self, image: &[u8], width: u32, height: u32) -> CvResult<Vec<u8>> {
if width == 0 || height == 0 {
return Err(CvError::invalid_dimensions(width, height));
}
let expected_size = (width as usize) * (height as usize) * 3;
if image.len() != expected_size {
return Err(CvError::insufficient_data(expected_size, image.len()));
}
let tile_size = self.config.tile_size;
if width <= tile_size && height <= tile_size {
self.denoise_single_tile(image, width, height)
} else {
self.denoise_tiled(image, width, height)
}
}
fn denoise_single_tile(&mut self, image: &[u8], width: u32, height: u32) -> CvResult<Vec<u8>> {
let input_tensor = self.preprocess_image(image, width, height)?;
let flat: Vec<f32> = input_tensor.iter().copied().collect();
let shape: Vec<usize> = input_tensor.shape().to_vec();
let tensor = oxionnx::Tensor::new(flat, shape);
let input_name = self
.session
.input_names()
.first()
.cloned()
.unwrap_or_else(|| "input".to_string());
let mut inputs = HashMap::new();
inputs.insert(input_name.as_str(), tensor);
let outputs = self
.session
.run(&inputs)
.map_err(|e| CvError::onnx_runtime(format!("Inference failed: {e}")))?;
let output_name = self
.session
.output_names()
.first()
.cloned()
.unwrap_or_default();
let out_tensor = outputs
.get(&output_name)
.ok_or_else(|| CvError::onnx_runtime("No output tensor found".to_owned()))?;
let shape_owned: Vec<i64> = out_tensor.shape.iter().map(|&x| x as i64).collect();
let data_owned: Vec<f32> = out_tensor.data.clone();
let denoised = self.postprocess_tensor(&shape_owned, &data_owned, width, height)?;
self.blend_with_original(image, &denoised, width, height)
}
fn denoise_tiled(&mut self, image: &[u8], width: u32, height: u32) -> CvResult<Vec<u8>> {
let tile_size = self.config.tile_size;
let padding = self.config.tile_padding;
let tiles_x = width.div_ceil(tile_size) as usize;
let tiles_y = height.div_ceil(tile_size) as usize;
let total_tiles = tiles_x * tiles_y;
let mut output = vec![0u8; (width * height * 3) as usize];
let mut weight_map = vec![0.0f32; (width * height) as usize];
let mut tile_idx = 0;
for ty in 0..tiles_y {
for tx in 0..tiles_x {
if let Some(ref callback) = self.progress_callback {
if !callback(tile_idx + 1, total_tiles) {
return Err(CvError::detection_failed("Processing aborted by user"));
}
}
let x_start = (tx as u32 * tile_size).saturating_sub(padding);
let y_start = (ty as u32 * tile_size).saturating_sub(padding);
let x_end = ((tx as u32 + 1) * tile_size + padding).min(width);
let y_end = ((ty as u32 + 1) * tile_size + padding).min(height);
let tile_w = x_end - x_start;
let tile_h = y_end - y_start;
let tile =
self.extract_tile(image, width, height, x_start, y_start, tile_w, tile_h)?;
let denoised_tile = self.denoise_single_tile(&tile, tile_w, tile_h)?;
let blend_x_start = if tx == 0 { 0 } else { padding };
let blend_y_start = if ty == 0 { 0 } else { padding };
let blend_x_end = if tx == tiles_x - 1 {
tile_w
} else {
tile_w - padding
};
let blend_y_end = if ty == tiles_y - 1 {
tile_h
} else {
tile_h - padding
};
self.blend_tile(
&denoised_tile,
tile_w,
tile_h,
&mut output,
&mut weight_map,
width,
height,
x_start,
y_start,
blend_x_start,
blend_y_start,
blend_x_end,
blend_y_end,
)?;
tile_idx += 1;
}
}
self.normalize_by_weights(&mut output, &weight_map, width, height);
Ok(output)
}
fn extract_tile(
&self,
src: &[u8],
src_w: u32,
src_h: u32,
x: u32,
y: u32,
tile_w: u32,
tile_h: u32,
) -> CvResult<Vec<u8>> {
if x + tile_w > src_w || y + tile_h > src_h {
return Err(CvError::invalid_roi(x, y, tile_w, tile_h));
}
let mut tile = Vec::with_capacity((tile_w * tile_h * 3) as usize);
for row in y..y + tile_h {
let start = ((row * src_w + x) * 3) as usize;
let end = start + (tile_w * 3) as usize;
tile.extend_from_slice(&src[start..end]);
}
Ok(tile)
}
#[allow(clippy::too_many_arguments)]
fn blend_tile(
&self,
tile: &[u8],
tile_w: u32,
_tile_h: u32,
output: &mut [u8],
weights: &mut [f32],
out_w: u32,
out_h: u32,
dst_x: u32,
dst_y: u32,
blend_x_start: u32,
blend_y_start: u32,
blend_x_end: u32,
blend_y_end: u32,
) -> CvResult<()> {
let feather = self.config.tile_padding.min(16);
for local_y in blend_y_start..blend_y_end {
let global_y = dst_y + local_y;
if global_y >= out_h {
break;
}
for local_x in blend_x_start..blend_x_end {
let global_x = dst_x + local_x;
if global_x >= out_w {
break;
}
let dist_left = local_x - blend_x_start;
let dist_right = blend_x_end - local_x - 1;
let dist_top = local_y - blend_y_start;
let dist_bottom = blend_y_end - local_y - 1;
let min_dist = dist_left.min(dist_right).min(dist_top).min(dist_bottom);
let weight = if min_dist >= feather {
1.0
} else {
(min_dist as f32 + 1.0) / (feather as f32 + 1.0)
};
let tile_idx = ((local_y * tile_w + local_x) * 3) as usize;
let out_idx = ((global_y * out_w + global_x) * 3) as usize;
let weight_idx = (global_y * out_w + global_x) as usize;
for c in 0..3 {
let tile_val = tile[tile_idx + c] as f32 * weight;
output[out_idx + c] = (output[out_idx + c] as f32 + tile_val) as u8;
}
weights[weight_idx] += weight;
}
}
Ok(())
}
fn normalize_by_weights(&self, output: &mut [u8], weights: &[f32], width: u32, height: u32) {
for y in 0..height {
for x in 0..width {
let idx = (y * width + x) as usize;
let weight = weights[idx];
if weight > 0.0 {
let out_idx = idx * 3;
for c in 0..3 {
output[out_idx + c] = ((output[out_idx + c] as f32) / weight).round() as u8;
}
}
}
}
}
fn blend_with_original(
&self,
original: &[u8],
denoised: &[u8],
width: u32,
height: u32,
) -> CvResult<Vec<u8>> {
let luma_strength = self.config.luma_strength;
let color_strength = self.config.color_strength;
if luma_strength >= 0.99 && color_strength >= 0.99 {
return Ok(denoised.to_vec());
}
let mut result = Vec::with_capacity(denoised.len());
for i in 0..(width * height) as usize {
let idx = i * 3;
let r_orig = original[idx] as f32;
let g_orig = original[idx + 1] as f32;
let b_orig = original[idx + 2] as f32;
let r_denoised = denoised[idx] as f32;
let g_denoised = denoised[idx + 1] as f32;
let b_denoised = denoised[idx + 2] as f32;
let y_orig = (r_orig + g_orig + b_orig) / 3.0;
let y_denoised = (r_denoised + g_denoised + b_denoised) / 3.0;
let y_blend = y_orig + (y_denoised - y_orig) * luma_strength;
let luma_scale = if y_orig > 0.0 { y_blend / y_orig } else { 1.0 };
let r_result = (r_orig * luma_scale * (1.0 - color_strength)
+ r_denoised * color_strength)
.clamp(0.0, 255.0) as u8;
let g_result = (g_orig * luma_scale * (1.0 - color_strength)
+ g_denoised * color_strength)
.clamp(0.0, 255.0) as u8;
let b_result = (b_orig * luma_scale * (1.0 - color_strength)
+ b_denoised * color_strength)
.clamp(0.0, 255.0) as u8;
result.push(r_result);
result.push(g_result);
result.push(b_result);
}
Ok(result)
}
fn preprocess_image(&self, image: &[u8], width: u32, height: u32) -> CvResult<Array4<f32>> {
let w = width as usize;
let h = height as usize;
let mut tensor = Array4::<f32>::zeros((1, 3, h, w));
for y in 0..h {
for x in 0..w {
let idx = (y * w + x) * 3;
tensor[[0, 0, y, x]] = image[idx] as f32 / 255.0; tensor[[0, 1, y, x]] = image[idx + 1] as f32 / 255.0; tensor[[0, 2, y, x]] = image[idx + 2] as f32 / 255.0; }
}
Ok(tensor)
}
fn postprocess_tensor(
&self,
shape: &[i64],
data: &[f32],
width: u32,
height: u32,
) -> CvResult<Vec<u8>> {
if shape.len() != 4 || shape[0] != 1 || shape[1] != 3 {
return Err(CvError::ShapeMismatch {
expected: vec![1, 3, height as usize, width as usize],
actual: shape.iter().map(|&x| x as usize).collect(),
});
}
let h = shape[2] as usize;
let w = shape[3] as usize;
if w != width as usize || h != height as usize {
return Err(CvError::ShapeMismatch {
expected: vec![1, 3, height as usize, width as usize],
actual: shape.iter().map(|&x| x as usize).collect(),
});
}
let mut output = vec![0u8; w * h * 3];
for y in 0..h {
for x in 0..w {
let idx = (y * w + x) * 3;
let r_idx = y * w + x;
let g_idx = h * w + y * w + x;
let b_idx = 2 * h * w + y * w + x;
output[idx] = (data[r_idx] * 255.0).clamp(0.0, 255.0).round() as u8;
output[idx + 1] = (data[g_idx] * 255.0).clamp(0.0, 255.0).round() as u8;
output[idx + 2] = (data[b_idx] * 255.0).clamp(0.0, 255.0).round() as u8;
}
}
Ok(output)
}
#[must_use]
pub const fn config(&self) -> &DenoisingConfig {
&self.config
}
}
pub mod noise_estimation {
use super::NoiseLevel;
#[must_use]
pub fn estimate_noise_mad(image: &[u8], width: u32, height: u32) -> f32 {
if width < 3 || height < 3 {
return 0.0;
}
let mut laplacian = Vec::new();
let w = width as usize;
let h = height as usize;
for y in 1..h - 1 {
for x in 1..w - 1 {
let center_idx = (y * w + x) * 3;
let _ = (image[center_idx] as f32
+ image[center_idx + 1] as f32
+ image[center_idx + 2] as f32)
/ 3.0;
let mut lap = 0.0;
for dy in -1..=1 {
for dx in -1..=1 {
let ny = (y as i32 + dy) as usize;
let nx = (x as i32 + dx) as usize;
let idx = (ny * w + nx) * 3;
let val =
(image[idx] as f32 + image[idx + 1] as f32 + image[idx + 2] as f32)
/ 3.0;
let kernel_val = if dx == 0 && dy == 0 { 8.0 } else { -1.0 };
lap += val * kernel_val;
}
}
laplacian.push(lap.abs());
}
}
if laplacian.is_empty() {
return 0.0;
}
laplacian.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = laplacian[laplacian.len() / 2];
median / 0.6745
}
#[must_use]
pub fn classify_noise_level(sigma: f32) -> NoiseLevel {
if sigma < 15.0 {
NoiseLevel::Low
} else if sigma < 30.0 {
NoiseLevel::Medium
} else {
NoiseLevel::High
}
}
#[must_use]
pub fn estimate_noise_patch(image: &[u8], width: u32, height: u32, patch_size: u32) -> f32 {
let patch_size = patch_size.min(width).min(height);
let x_start = (width - patch_size) / 2;
let y_start = (height - patch_size) / 2;
let mut patch = Vec::new();
for y in y_start..y_start + patch_size {
let start = ((y * width + x_start) * 3) as usize;
let end = start + (patch_size * 3) as usize;
patch.extend_from_slice(&image[start..end]);
}
estimate_noise_mad(&patch, patch_size, patch_size)
}
}
pub struct BatchDenoiser {
denoiser: NeuralDenoiser,
batch_size: usize,
}
impl BatchDenoiser {
pub fn new(model_path: impl AsRef<Path>, batch_size: usize) -> CvResult<Self> {
let denoiser = NeuralDenoiser::new(model_path)?;
Ok(Self {
denoiser,
batch_size,
})
}
pub fn denoise_batch(&mut self, images: &[(&[u8], u32, u32)]) -> CvResult<Vec<Vec<u8>>> {
let mut results = Vec::with_capacity(images.len());
for (image, width, height) in images {
let result = self.denoiser.denoise(image, *width, *height)?;
results.push(result);
}
Ok(results)
}
#[must_use]
pub const fn batch_size(&self) -> usize {
self.batch_size
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_noise_level_sigma() {
assert_eq!(NoiseLevel::Low.sigma(), 10.0);
assert_eq!(NoiseLevel::Medium.sigma(), 25.0);
assert_eq!(NoiseLevel::High.sigma(), 40.0);
assert_eq!(NoiseLevel::Custom(15.0).sigma(), 15.0);
assert_eq!(NoiseLevel::Blind.sigma(), 0.0);
}
#[test]
fn test_noise_level_is_blind() {
assert!(!NoiseLevel::Low.is_blind());
assert!(NoiseLevel::Blind.is_blind());
}
#[test]
fn test_denoising_config_default() {
let config = DenoisingConfig::default();
assert!(config.noise_level.is_blind());
assert_eq!(config.tile_size, 256);
assert_eq!(config.tile_padding, 16);
}
#[test]
fn test_denoising_config_builder() {
let config = DenoisingConfig::new(NoiseLevel::Medium)
.with_tile_size(512)
.with_tile_padding(32)
.with_color_strength(0.8)
.with_luma_strength(0.9);
assert_eq!(config.tile_size, 512);
assert_eq!(config.tile_padding, 32);
assert_eq!(config.color_strength, 0.8);
assert_eq!(config.luma_strength, 0.9);
}
#[test]
fn test_denoising_config_validation() {
let config = DenoisingConfig::new(NoiseLevel::Low).with_tile_size(32);
assert!(config.validate().is_err());
let config = DenoisingConfig::new(NoiseLevel::Low).with_tile_padding(100);
assert!(config.validate().is_err());
let config = DenoisingConfig::default();
assert!(config.validate().is_ok());
}
#[test]
fn test_estimate_noise_mad() {
let image = vec![100u8; 256 * 256 * 3];
let sigma = noise_estimation::estimate_noise_mad(&image, 256, 256);
assert!(sigma < 1.0);
}
#[test]
fn test_classify_noise_level() {
assert!(matches!(
noise_estimation::classify_noise_level(5.0),
NoiseLevel::Low
));
assert!(matches!(
noise_estimation::classify_noise_level(20.0),
NoiseLevel::Medium
));
assert!(matches!(
noise_estimation::classify_noise_level(35.0),
NoiseLevel::High
));
}
#[test]
fn test_estimate_noise_patch() {
let image = vec![100u8; 512 * 512 * 3];
let sigma = noise_estimation::estimate_noise_patch(&image, 512, 512, 128);
assert!(sigma < 1.0);
}
#[test]
fn test_preprocess_postprocess_roundtrip() {
let width: u32 = 8;
let height: u32 = 8;
let w = width as usize;
let h = height as usize;
let input: Vec<u8> = (0..(w * h * 3)).map(|i| (i % 256) as u8).collect();
let mut tensor = Array4::<f32>::zeros((1, 3, h, w));
for y in 0..h {
for x in 0..w {
let idx = (y * w + x) * 3;
tensor[[0, 0, y, x]] = input[idx] as f32 / 255.0;
tensor[[0, 1, y, x]] = input[idx + 1] as f32 / 255.0;
tensor[[0, 2, y, x]] = input[idx + 2] as f32 / 255.0;
}
}
assert_eq!(tensor.shape(), &[1, 3, h, w]);
let shape_i64: Vec<i64> = tensor.shape().iter().map(|&x| x as i64).collect();
let data_f32: Vec<f32> = tensor.iter().copied().collect();
assert_eq!(shape_i64.len(), 4);
assert_eq!(shape_i64[0], 1);
assert_eq!(shape_i64[1], 3);
let out_h = shape_i64[2] as usize;
let out_w = shape_i64[3] as usize;
let mut output = vec![0u8; out_w * out_h * 3];
for y in 0..out_h {
for x in 0..out_w {
let idx = (y * out_w + x) * 3;
let r_idx = 0 * out_h * out_w + y * out_w + x;
let g_idx = 1 * out_h * out_w + y * out_w + x;
let b_idx = 2 * out_h * out_w + y * out_w + x;
output[idx] = (data_f32[r_idx] * 255.0).clamp(0.0, 255.0).round() as u8;
output[idx + 1] = (data_f32[g_idx] * 255.0).clamp(0.0, 255.0).round() as u8;
output[idx + 2] = (data_f32[b_idx] * 255.0).clamp(0.0, 255.0).round() as u8;
}
}
assert_eq!(output.len(), input.len());
for (a, b) in input.iter().zip(output.iter()) {
assert!(
(*a as i32 - *b as i32).abs() <= 1,
"Values differ: {} vs {}",
a,
b
);
}
}
#[test]
fn test_extract_tile() {
let width: u32 = 10;
let height: u32 = 10;
let image: Vec<u8> = (0..(width * height * 3) as usize)
.map(|i| (i % 256) as u8)
.collect();
let (x, y, tile_w, tile_h) = (2u32, 2u32, 4u32, 4u32);
assert!(x + tile_w <= width && y + tile_h <= height);
let mut tile = Vec::with_capacity((tile_w * tile_h * 3) as usize);
for row in y..y + tile_h {
let start = ((row * width + x) * 3) as usize;
let end = start + (tile_w * 3) as usize;
tile.extend_from_slice(&image[start..end]);
}
assert_eq!(tile.len(), 4 * 4 * 3);
}
#[test]
fn test_blend_with_original() {
let original = vec![100u8, 100, 100, 200, 200, 200];
let denoised = vec![50u8, 50, 50, 150, 150, 150];
let luma_strength: f32 = 0.5;
let color_strength: f32 = 0.5;
let width: u32 = 2;
let height: u32 = 1;
let mut result = Vec::with_capacity(denoised.len());
for i in 0..(width * height) as usize {
let idx = i * 3;
let r_orig = original[idx] as f32;
let g_orig = original[idx + 1] as f32;
let b_orig = original[idx + 2] as f32;
let r_denoised = denoised[idx] as f32;
let g_denoised = denoised[idx + 1] as f32;
let b_denoised = denoised[idx + 2] as f32;
let y_orig = (r_orig + g_orig + b_orig) / 3.0;
let y_denoised = (r_denoised + g_denoised + b_denoised) / 3.0;
let y_blend = y_orig + (y_denoised - y_orig) * luma_strength;
let luma_scale = if y_orig > 0.0 { y_blend / y_orig } else { 1.0 };
let r_result = (r_orig * luma_scale * (1.0 - color_strength)
+ r_denoised * color_strength)
.clamp(0.0, 255.0) as u8;
let g_result = (g_orig * luma_scale * (1.0 - color_strength)
+ g_denoised * color_strength)
.clamp(0.0, 255.0) as u8;
let b_result = (b_orig * luma_scale * (1.0 - color_strength)
+ b_denoised * color_strength)
.clamp(0.0, 255.0) as u8;
result.push(r_result);
result.push(g_result);
result.push(b_result);
}
assert_eq!(result.len(), 6);
for i in 0..6 {
assert!(result[i] >= denoised[i].min(original[i]));
assert!(result[i] <= denoised[i].max(original[i]));
}
}
#[test]
#[allow(dead_code)]
fn test_batch_denoiser() {
let batch_size = 4;
assert_eq!(batch_size, 4);
}
}