1use crate::{GpuDevice, GpuError, Result};
15use bytemuck::{Pod, Zeroable};
16use rayon::prelude::*;
17
18#[derive(Debug, Clone, Copy, PartialEq)]
24pub enum DenoiseAlgorithm {
25 Gaussian {
30 sigma: f32,
32 },
33
34 NonLocalMeans {
41 h: f32,
43 patch_radius: u32,
45 search_radius: u32,
47 },
48
49 BilateralFilter {
55 sigma_spatial: f32,
57 sigma_range: f32,
59 },
60}
61
62#[repr(C)]
67#[derive(Clone, Copy, Pod, Zeroable)]
68struct GaussianDenoiseParams {
69 width: u32,
70 height: u32,
71 kernel_radius: u32,
72 _pad: u32,
73 sigma: f32,
74 inv_two_sigma_sq: f32,
75 _pad2: [f32; 2],
76}
77
78#[repr(C)]
79#[derive(Clone, Copy, Pod, Zeroable)]
80struct BilateralParams {
81 width: u32,
82 height: u32,
83 kernel_radius: u32,
84 _pad: u32,
85 sigma_spatial: f32,
86 sigma_range: f32,
87 inv_two_sigma_s_sq: f32,
88 inv_two_sigma_r_sq: f32,
89}
90
91#[repr(C)]
92#[derive(Clone, Copy, Pod, Zeroable)]
93struct NlmParams {
94 width: u32,
95 height: u32,
96 patch_radius: u32,
97 search_radius: u32,
98 h_sq: f32,
99 inv_patch_area: f32,
100 _pad: [f32; 2],
101}
102
103pub struct DenoiseOperation;
116
117impl DenoiseOperation {
118 pub fn denoise(
127 device: &GpuDevice,
128 input: &[u8],
129 output: &mut [u8],
130 width: u32,
131 height: u32,
132 algorithm: DenoiseAlgorithm,
133 ) -> Result<()> {
134 super::utils::validate_dimensions(width, height)?;
135 super::utils::validate_buffer_size(input, width, height, 4)?;
136 super::utils::validate_buffer_size(output, width, height, 4)?;
137
138 match algorithm {
139 DenoiseAlgorithm::Gaussian { sigma } => {
140 Self::denoise_gaussian_cpu(input, output, width, height, sigma)
141 }
142 DenoiseAlgorithm::BilateralFilter {
143 sigma_spatial,
144 sigma_range,
145 } => Self::denoise_bilateral_cpu(
146 input,
147 output,
148 width,
149 height,
150 sigma_spatial,
151 sigma_range,
152 ),
153 DenoiseAlgorithm::NonLocalMeans {
154 h,
155 patch_radius,
156 search_radius,
157 } => {
158 Self::denoise_nlm_cpu(input, output, width, height, h, patch_radius, search_radius)
159 }
160 }
161 .map(|()| {
163 let _ = device;
164 })
165 }
166
167 #[allow(clippy::cast_possible_truncation)]
174 fn denoise_gaussian_cpu(
175 input: &[u8],
176 output: &mut [u8],
177 width: u32,
178 height: u32,
179 sigma: f32,
180 ) -> Result<()> {
181 let w = width as usize;
182 let h = height as usize;
183 let radius = (3.0 * sigma).ceil() as usize;
184
185 let kernel_size = 2 * radius + 1;
187 let mut kernel = vec![0.0f32; kernel_size];
188 let two_sigma_sq = 2.0 * sigma * sigma;
189 let mut sum = 0.0f32;
190 for (i, k) in kernel.iter_mut().enumerate() {
191 let x = i as f32 - radius as f32;
192 *k = (-(x * x) / two_sigma_sq).exp();
193 sum += *k;
194 }
195 for k in &mut kernel {
196 *k /= sum;
197 }
198
199 let mut temp = vec![0u8; input.len()];
201 temp.par_chunks_exact_mut(w * 4)
202 .enumerate()
203 .for_each(|(y, row)| {
204 if y >= h {
205 return;
206 }
207 for x in 0..w {
208 for c in 0..4usize {
209 let mut acc = 0.0f32;
210 for (ki, &kv) in kernel.iter().enumerate() {
211 let sx = (x as i64 + ki as i64 - radius as i64).clamp(0, w as i64 - 1)
212 as usize;
213 acc += kv * f32::from(input[(y * w + sx) * 4 + c]);
214 }
215 row[x * 4 + c] = acc.round().clamp(0.0, 255.0) as u8;
216 }
217 }
218 });
219
220 output
222 .par_chunks_exact_mut(4)
223 .enumerate()
224 .for_each(|(i, pixel)| {
225 let x = i % w;
226 let y = i / w;
227 if y >= h {
228 return;
229 }
230 for c in 0..4usize {
231 let mut acc = 0.0f32;
232 for (ki, &kv) in kernel.iter().enumerate() {
233 let sy =
234 (y as i64 + ki as i64 - radius as i64).clamp(0, h as i64 - 1) as usize;
235 acc += kv * f32::from(temp[(sy * w + x) * 4 + c]);
236 }
237 pixel[c] = acc.round().clamp(0.0, 255.0) as u8;
238 }
239 });
240
241 Ok(())
242 }
243
244 #[allow(clippy::cast_possible_truncation)]
246 fn denoise_bilateral_cpu(
247 input: &[u8],
248 output: &mut [u8],
249 width: u32,
250 height: u32,
251 sigma_spatial: f32,
252 sigma_range: f32,
253 ) -> Result<()> {
254 let w = width as usize;
255 let h = height as usize;
256 let radius = (3.0 * sigma_spatial).ceil() as usize;
257 let inv_two_ss_sq = 1.0 / (2.0 * sigma_spatial * sigma_spatial);
258 let inv_two_sr_sq = 1.0 / (2.0 * sigma_range * sigma_range);
259
260 output
261 .par_chunks_exact_mut(4)
262 .enumerate()
263 .for_each(|(i, pixel)| {
264 let x = i % w;
265 let y = i / w;
266 if y >= h {
267 return;
268 }
269
270 let center = [
271 f32::from(input[(y * w + x) * 4]),
272 f32::from(input[(y * w + x) * 4 + 1]),
273 f32::from(input[(y * w + x) * 4 + 2]),
274 f32::from(input[(y * w + x) * 4 + 3]),
275 ];
276
277 let mut acc = [0.0f32; 4];
278 let mut weight_sum = 0.0f32;
279
280 for dy in -(radius as i64)..=(radius as i64) {
281 for dx in -(radius as i64)..=(radius as i64) {
282 let sx = (x as i64 + dx).clamp(0, w as i64 - 1) as usize;
283 let sy = (y as i64 + dy).clamp(0, h as i64 - 1) as usize;
284
285 let spatial_dist_sq = (dx * dx + dy * dy) as f32;
286 let w_spatial = (-spatial_dist_sq * inv_two_ss_sq).exp();
287
288 let neighbor = [
289 f32::from(input[(sy * w + sx) * 4]),
290 f32::from(input[(sy * w + sx) * 4 + 1]),
291 f32::from(input[(sy * w + sx) * 4 + 2]),
292 f32::from(input[(sy * w + sx) * 4 + 3]),
293 ];
294
295 let range_dist_sq = (0..3)
296 .map(|c| (center[c] - neighbor[c]).powi(2))
297 .sum::<f32>();
298 let w_range = (-range_dist_sq * inv_two_sr_sq).exp();
299
300 let w_total = w_spatial * w_range;
301 weight_sum += w_total;
302
303 for c in 0..4 {
304 acc[c] += w_total * neighbor[c];
305 }
306 }
307 }
308
309 if weight_sum > 0.0 {
310 for c in 0..4 {
311 pixel[c] = (acc[c] / weight_sum).round().clamp(0.0, 255.0) as u8;
312 }
313 } else {
314 pixel.copy_from_slice(&input[i * 4..i * 4 + 4]);
315 }
316 });
317
318 Ok(())
319 }
320
321 #[allow(clippy::cast_possible_truncation)]
325 fn denoise_nlm_cpu(
326 input: &[u8],
327 output: &mut [u8],
328 width: u32,
329 height: u32,
330 h: f32,
331 patch_radius: u32,
332 search_radius: u32,
333 ) -> Result<()> {
334 if h <= 0.0 {
335 return Err(GpuError::Internal(
336 "NLM filter strength h must be positive".to_string(),
337 ));
338 }
339
340 let w = width as usize;
341 let ht = height as usize;
342 let pr = patch_radius as usize;
343 let sr = search_radius as usize;
344 let h_sq = h * h;
345 let patch_area = ((2 * pr + 1) * (2 * pr + 1)) as f32;
346 let inv_h_sq_patch = 1.0 / (h_sq * patch_area);
347
348 output
349 .par_chunks_exact_mut(4)
350 .enumerate()
351 .for_each(|(i, pixel)| {
352 let px = i % w;
353 let py = i / w;
354 if py >= ht {
355 return;
356 }
357
358 let mut acc = [0.0f32; 4];
359 let mut weight_sum = 0.0f32;
360
361 for qy in
363 (py as i64 - sr as i64).max(0)..=(py as i64 + sr as i64).min(ht as i64 - 1)
364 {
365 for qx in
366 (px as i64 - sr as i64).max(0)..=(px as i64 + sr as i64).min(w as i64 - 1)
367 {
368 let mut patch_dist_sq = 0.0f32;
370 for ky in -(pr as i64)..=(pr as i64) {
371 for kx in -(pr as i64)..=(pr as i64) {
372 let p_x = (px as i64 + kx).clamp(0, w as i64 - 1) as usize;
373 let p_y = (py as i64 + ky).clamp(0, ht as i64 - 1) as usize;
374 let q_x = (qx + kx).clamp(0, w as i64 - 1) as usize;
375 let q_y = (qy + ky).clamp(0, ht as i64 - 1) as usize;
376
377 let diff = f32::from(input[(p_y * w + p_x) * 4])
379 - f32::from(input[(q_y * w + q_x) * 4]);
380 patch_dist_sq += diff * diff;
381 }
382 }
383
384 let w_nlm = (-patch_dist_sq * inv_h_sq_patch).exp();
385 weight_sum += w_nlm;
386
387 for c in 0..4 {
388 acc[c] +=
389 w_nlm * f32::from(input[(qy as usize * w + qx as usize) * 4 + c]);
390 }
391 }
392 }
393
394 if weight_sum > 0.0 {
395 for c in 0..4 {
396 pixel[c] = (acc[c] / weight_sum).round().clamp(0.0, 255.0) as u8;
397 }
398 } else {
399 pixel.copy_from_slice(&input[i * 4..i * 4 + 4]);
400 }
401 });
402
403 Ok(())
404 }
405
406 #[allow(dead_code)]
408 fn check_sigma(sigma: f32, name: &str) -> Result<()> {
409 if sigma <= 0.0 {
410 Err(GpuError::Internal(format!(
411 "{name} must be positive, got {sigma}"
412 )))
413 } else {
414 Ok(())
415 }
416 }
417
418 pub fn auto_denoise(
427 device: &GpuDevice,
428 input: &[u8],
429 output: &mut [u8],
430 width: u32,
431 height: u32,
432 noise_level: f32,
433 ) -> Result<()> {
434 let sigma = noise_level.clamp(0.0, 1.0) * 3.0 + 0.5;
435 Self::denoise(
436 device,
437 input,
438 output,
439 width,
440 height,
441 DenoiseAlgorithm::Gaussian { sigma },
442 )
443 }
444}
445
446#[derive(Debug, Clone)]
452pub struct DenoiseKernel {
453 algorithm: DenoiseAlgorithm,
454}
455
456impl DenoiseKernel {
457 #[must_use]
459 pub fn new(algorithm: DenoiseAlgorithm) -> Self {
460 Self { algorithm }
461 }
462
463 #[must_use]
465 pub fn gaussian(sigma: f32) -> Self {
466 Self::new(DenoiseAlgorithm::Gaussian { sigma })
467 }
468
469 #[must_use]
471 pub fn bilateral(sigma_spatial: f32, sigma_range: f32) -> Self {
472 Self::new(DenoiseAlgorithm::BilateralFilter {
473 sigma_spatial,
474 sigma_range,
475 })
476 }
477
478 #[must_use]
480 pub fn nlm(h: f32, patch_radius: u32, search_radius: u32) -> Self {
481 Self::new(DenoiseAlgorithm::NonLocalMeans {
482 h,
483 patch_radius,
484 search_radius,
485 })
486 }
487
488 pub fn apply(
494 &self,
495 device: &GpuDevice,
496 input: &[u8],
497 output: &mut [u8],
498 width: u32,
499 height: u32,
500 ) -> Result<()> {
501 DenoiseOperation::denoise(device, input, output, width, height, self.algorithm)
502 }
503
504 #[must_use]
506 pub fn algorithm(&self) -> DenoiseAlgorithm {
507 self.algorithm
508 }
509
510 #[must_use]
512 pub fn estimate_gflops(&self, width: u32, height: u32) -> f64 {
513 let pixels = u64::from(width) * u64::from(height);
514 let ops: u64 = match self.algorithm {
515 DenoiseAlgorithm::Gaussian { sigma } => {
516 let r = (3.0 * sigma).ceil() as u64;
517 let k = 2 * r + 1;
518 pixels * k * 4 * 4 }
520 DenoiseAlgorithm::BilateralFilter { sigma_spatial, .. } => {
521 let r = (3.0 * sigma_spatial).ceil() as u64;
522 let k = (2 * r + 1).pow(2);
523 pixels * k * 12 * 4 }
525 DenoiseAlgorithm::NonLocalMeans {
526 patch_radius,
527 search_radius,
528 ..
529 } => {
530 let pr = u64::from(2 * patch_radius + 1).pow(2);
531 let sr = u64::from(2 * search_radius + 1).pow(2);
532 pixels * sr * pr * 5 }
534 };
535 ops as f64 / 1e9
536 }
537}
538
539#[cfg(test)]
544mod tests {
545 use super::*;
546
547 fn gray_image(w: u32, h: u32, value: u8) -> Vec<u8> {
548 vec![value; (w * h * 4) as usize]
549 }
550
551 fn noisy_image(w: u32, h: u32) -> Vec<u8> {
552 (0..(w * h * 4))
553 .map(|i| (i as u8).wrapping_mul(37))
554 .collect()
555 }
556
557 #[test]
560 fn test_gaussian_denoise_cpu_constant_image() {
561 let w = 16u32;
562 let h = 16u32;
563 let input = gray_image(w, h, 200);
564 let mut output = vec![0u8; (w * h * 4) as usize];
565 let result = DenoiseOperation::denoise_gaussian_cpu(&input, &mut output, w, h, 1.5);
566 assert!(result.is_ok());
567 for &v in &output {
569 assert_eq!(v, 200);
570 }
571 }
572
573 #[test]
574 fn test_gaussian_denoise_cpu_noisy() {
575 let w = 32u32;
576 let h = 32u32;
577 let input = noisy_image(w, h);
578 let mut output = vec![0u8; (w * h * 4) as usize];
579 let result = DenoiseOperation::denoise_gaussian_cpu(&input, &mut output, w, h, 2.0);
580 assert!(result.is_ok());
581 assert!(output.iter().any(|&v| v > 0));
583 }
584
585 #[test]
586 fn test_bilateral_denoise_cpu_constant() {
587 let w = 8u32;
588 let h = 8u32;
589 let input = gray_image(w, h, 100);
590 let mut output = vec![0u8; (w * h * 4) as usize];
591 let result = DenoiseOperation::denoise_bilateral_cpu(&input, &mut output, w, h, 1.5, 30.0);
592 assert!(result.is_ok());
593 for &v in &output {
594 assert_eq!(v, 100);
595 }
596 }
597
598 #[test]
599 fn test_nlm_denoise_cpu_constant() {
600 let w = 8u32;
601 let h = 8u32;
602 let input = gray_image(w, h, 150);
603 let mut output = vec![0u8; (w * h * 4) as usize];
604 let result = DenoiseOperation::denoise_nlm_cpu(&input, &mut output, w, h, 10.0, 2, 5);
605 assert!(result.is_ok());
606 for &v in &output {
607 assert_eq!(v, 150);
608 }
609 }
610
611 #[test]
612 fn test_nlm_denoise_invalid_h() {
613 let w = 4u32;
614 let h = 4u32;
615 let input = gray_image(w, h, 0);
616 let mut output = vec![0u8; (w * h * 4) as usize];
617 let result = DenoiseOperation::denoise_nlm_cpu(&input, &mut output, w, h, 0.0, 1, 3);
618 assert!(result.is_err());
619 }
620
621 #[test]
624 fn test_denoise_kernel_gaussian() {
625 let k = DenoiseKernel::gaussian(1.0);
626 assert_eq!(k.algorithm(), DenoiseAlgorithm::Gaussian { sigma: 1.0 });
627 }
628
629 #[test]
630 fn test_denoise_kernel_bilateral() {
631 let k = DenoiseKernel::bilateral(2.0, 25.0);
632 assert_eq!(
633 k.algorithm(),
634 DenoiseAlgorithm::BilateralFilter {
635 sigma_spatial: 2.0,
636 sigma_range: 25.0,
637 }
638 );
639 }
640
641 #[test]
642 fn test_denoise_kernel_nlm() {
643 let k = DenoiseKernel::nlm(10.0, 3, 10);
644 assert_eq!(
645 k.algorithm(),
646 DenoiseAlgorithm::NonLocalMeans {
647 h: 10.0,
648 patch_radius: 3,
649 search_radius: 10,
650 }
651 );
652 }
653
654 #[test]
655 fn test_estimate_gflops_not_zero() {
656 let k = DenoiseKernel::gaussian(1.5);
657 assert!(k.estimate_gflops(1920, 1080) > 0.0);
658
659 let k2 = DenoiseKernel::nlm(10.0, 3, 10);
660 assert!(k2.estimate_gflops(1920, 1080) > 0.0);
661 }
662}