use image::{GenericImage, Pixel};
use num::cast::AsPrimitive;
use crate::definitions::Image;
pub trait ColorDistance<P> {
fn color_distance(&self, pixel1: &P, pixel2: &P) -> f32;
}
pub struct GaussianEuclideanColorDistance {
sigma_squared: f32,
}
impl GaussianEuclideanColorDistance {
pub fn new(sigma: f32) -> Self {
assert!(
sigma > 0.0,
"GaussianEuclideanColorDistance sigma must be positive"
);
GaussianEuclideanColorDistance {
sigma_squared: sigma.powi(2),
}
}
}
impl<P> ColorDistance<P> for GaussianEuclideanColorDistance
where
P: Pixel,
f32: From<P::Subpixel>,
{
fn color_distance(&self, pixel1: &P, pixel2: &P) -> f32 {
let euclidean_distance_squared = pixel1
.channels()
.iter()
.zip(pixel2.channels().iter())
.map(|(c1, c2)| (f32::from(*c1) - f32::from(*c2)).powi(2))
.sum::<f32>();
fast_exp_negative(-0.5 * euclidean_distance_squared / self.sigma_squared)
}
}
#[must_use = "the function does not modify the original image"]
#[allow(clippy::doc_overindented_list_items)]
pub fn bilateral_filter<I, P, C>(
image: &I,
radius: u8,
spatial_sigma: f32,
color_distance: C,
) -> Image<P>
where
I: GenericImage<Pixel = P>,
P: Pixel,
C: ColorDistance<P>,
<P as image::Pixel>::Subpixel: 'static,
f32: From<P::Subpixel> + AsPrimitive<P::Subpixel>,
{
const MAX_CHANNELS: usize = 4;
assert!(
P::CHANNEL_COUNT as usize <= MAX_CHANNELS,
"bilateral_filter only supports up to 4 channel images"
);
assert!(!image.width() > i32::MAX as u32);
assert!(!image.height() > i32::MAX as u32);
assert_ne!(image.width(), 0);
assert_ne!(image.height(), 0);
assert!(spatial_sigma > 0.0, "spatial_sigma must be positive");
let radius = i16::from(radius);
let spatial_sigma_squared = spatial_sigma.powi(2);
let mut spatial_distance_lookup =
Vec::with_capacity(((2 * radius + 1) * (2 * radius + 1)) as usize);
for w_y in -radius..=radius {
for w_x in -radius..=radius {
spatial_distance_lookup.push(gaussian_weight(
(w_x as f32).powi(2) + (w_y as f32).powi(2),
spatial_sigma_squared,
));
}
}
let (width, height) = image.dimensions();
let window_len = 2 * radius + 1;
let bilateral_pixel_filter = |x, y| {
debug_assert!(image.in_bounds(x, y));
let center_pixel = unsafe { image.unsafe_get_pixel(x, y) };
let mut channel_sums = [0f32; MAX_CHANNELS];
let mut weight_sum = 0f32;
for w_y in -radius..=radius {
for w_x in -radius..=radius {
let window_y = (i32::from(w_y) + (y as i32)).clamp(0, (height as i32) - 1);
let window_x = (i32::from(w_x) + (x as i32)).clamp(0, (width as i32) - 1);
let (window_y, window_x) = (window_y as u32, window_x as u32);
debug_assert!(image.in_bounds(window_x, window_y));
let window_pixel = unsafe { image.unsafe_get_pixel(window_x, window_y) };
let spatial_weight = spatial_distance_lookup
[(window_len * (w_y + radius) + (w_x + radius)) as usize];
let color_weight = color_distance.color_distance(¢er_pixel, &window_pixel);
let weight = spatial_weight * color_weight;
weight_sum += weight;
for (i, c) in window_pixel.channels().iter().enumerate() {
channel_sums[i] += weight * f32::from(*c);
}
}
}
let mut out_pixel = center_pixel;
let num_channels = P::CHANNEL_COUNT as usize;
let out_channels = out_pixel.channels_mut();
for i in 0..num_channels {
out_channels[i] = (channel_sums[i] / weight_sum).as_();
}
out_pixel
};
Image::from_fn(width, height, bilateral_pixel_filter)
}
fn gaussian_weight(x_squared: f32, sigma_squared: f32) -> f32 {
(-0.5 * x_squared / sigma_squared).exp()
}
#[inline]
fn fast_exp_negative(x: f32) -> f32 {
debug_assert!(x <= 0.0, "fast_exp_negative only valid for negative inputs");
const A: f32 = 12102203.0;
const B: f32 = 1065353216.0 - 486411.0;
let bits = ((A * x + B) as i32).max(0) as u32;
f32::from_bits(bits)
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(miri, ignore = "assert_pixels_eq fails")]
#[test]
fn test_bilateral_filter_greyscale() {
let image = gray_image!(
1, 2, 3;
4, 5, 6;
7, 8, 9);
let actual = bilateral_filter(&image, 1, 3.0, GaussianEuclideanColorDistance::new(10.0));
let expect = gray_image!(
2, 2, 3;
4, 5, 5;
6, 7, 7);
assert_pixels_eq!(actual, expect);
}
#[ignore = "exhaustive sweep over ~1.12B f32 values; run with --ignored"]
#[cfg_attr(miri, ignore = "slow")]
#[test]
fn fast_exp_negative_accuracy_sweep() {
let mut x = -87.0_f32;
let end = 0.0_f32;
let mut max_rel_err = 0.0_f32;
let mut max_rel_at = 0.0_f32;
let mut sum_rel_err = 0.0_f64; let mut max_abs_err = 0.0_f32;
let mut count_rel: u64 = 0;
let mut count_total: u64 = 0;
while x <= end {
let approx = fast_exp_negative(x);
let truth = x.exp();
let abs_err = (approx - truth).abs();
if abs_err > max_abs_err {
max_abs_err = abs_err;
}
if truth >= f32::MIN_POSITIVE {
let rel = abs_err / truth;
if rel > max_rel_err {
max_rel_err = rel;
max_rel_at = x;
}
sum_rel_err += rel as f64;
count_rel += 1;
}
count_total += 1;
x = x.next_up();
}
let mean_rel_err = sum_rel_err / count_rel as f64;
println!(
"fast_exp_negative sweep: {count_total} samples ({count_rel} scored) \
max_rel_err={max_rel_err:e} at x={max_rel_at} \
mean_rel_err={mean_rel_err:e} max_abs_err={max_abs_err:e}"
);
assert!(
max_rel_err < 0.04,
"max relative error {max_rel_err} exceeded 4% threshold at x={max_rel_at}"
);
}
}
#[cfg(not(miri))]
#[cfg(test)]
mod proptests {
use super::*;
use crate::proptest_utils::arbitrary_image;
use image::Luma;
use image::Rgb;
use proptest::prelude::*;
const SIGMA_RANGE: std::ops::Range<f32> = 1e-12..1e32;
proptest! {
#[test]
fn proptest_bilateral_filter_greyscale(
img in arbitrary_image::<Luma<u8>>(1..40, 1..40),
radius in 0..5u8,
color_sigma in SIGMA_RANGE,
spatial_sigma in SIGMA_RANGE,
) {
let out = bilateral_filter(&img, radius, spatial_sigma, GaussianEuclideanColorDistance::new(color_sigma));
prop_assert_eq!(out.dimensions(), img.dimensions());
}
#[test]
fn proptest_bilateral_filter_rgb(
img in arbitrary_image::<Rgb<u8>>(1..40, 1..40),
radius in 0..5u8,
color_sigma in SIGMA_RANGE,
spatial_sigma in SIGMA_RANGE,
) {
let out = bilateral_filter(&img, radius, spatial_sigma, GaussianEuclideanColorDistance::new(color_sigma));
prop_assert_eq!(out.dimensions(), img.dimensions());
}
}
}
#[cfg(not(miri))]
#[cfg(test)]
mod benches {
use super::*;
use crate::utils::{gray_bench_image, rgb_bench_image};
use test::{Bencher, black_box};
#[bench]
fn bench_bilateral_filter_greyscale(b: &mut Bencher) {
let image = gray_bench_image(100, 100);
b.iter(|| {
let filtered =
bilateral_filter(&image, 5, 3., GaussianEuclideanColorDistance::new(10.0));
black_box(filtered);
});
}
#[bench]
fn bench_bilateral_filter_rgb(b: &mut Bencher) {
let image = rgb_bench_image(100, 100);
b.iter(|| {
let filtered =
bilateral_filter(&image, 5, 3., GaussianEuclideanColorDistance::new(10.0));
black_box(filtered);
});
}
}