use image::{GenericImage, Pixel};
use itertools::Itertools;
use num::cast::AsPrimitive;
use crate::definitions::Image;
pub trait ColorDistance<P> {
fn color_distance(&self, pixel1: &P, pixel2: &P) -> f32;
}
pub struct GaussianEuclideanColorDistance {
sigma_squared: f32,
}
impl GaussianEuclideanColorDistance {
pub fn new(sigma: f32) -> Self {
GaussianEuclideanColorDistance {
sigma_squared: sigma.powi(2),
}
}
}
impl<P> ColorDistance<P> for GaussianEuclideanColorDistance
where
P: Pixel,
f32: From<P::Subpixel>,
{
fn color_distance(&self, pixel1: &P, pixel2: &P) -> f32 {
let euclidean_distance_squared = pixel1
.channels()
.iter()
.zip(pixel2.channels().iter())
.map(|(c1, c2)| (f32::from(*c1) - f32::from(*c2)).powi(2))
.sum::<f32>();
gaussian_weight(euclidean_distance_squared, self.sigma_squared)
}
}
#[must_use = "the function does not modify the original image"]
#[allow(clippy::doc_overindented_list_items)]
pub fn bilateral_filter<I, P, C>(
image: &I,
radius: u8,
spatial_sigma: f32,
color_distance: C,
) -> Image<P>
where
I: GenericImage<Pixel = P>,
P: Pixel,
C: ColorDistance<P>,
<P as image::Pixel>::Subpixel: 'static,
f32: From<P::Subpixel> + AsPrimitive<P::Subpixel>,
{
assert!(!image.width() > i32::MAX as u32);
assert!(!image.height() > i32::MAX as u32);
assert_ne!(image.width(), 0);
assert_ne!(image.height(), 0);
let radius = i16::from(radius);
let window_range = -radius..=radius;
let spatial_distance_lookup = window_range
.clone()
.cartesian_product(window_range.clone())
.map(|(w_y, w_x)| {
gaussian_weight(
<f32 as From<i16>>::from(w_x).powi(2) + <f32 as From<i16>>::from(w_y).powi(2),
spatial_sigma.powi(2),
)
})
.collect_vec();
let (width, height) = image.dimensions();
let bilateral_pixel_filter = |x, y| {
debug_assert!(image.in_bounds(x, y));
let center_pixel = unsafe { image.unsafe_get_pixel(x, y) };
let window_len = 2 * radius + 1;
let weights_and_values = window_range
.clone()
.cartesian_product(window_range.clone())
.map(|(w_y, w_x)| {
let window_y = (i32::from(w_y) + (y as i32)).clamp(0, (height as i32) - 1);
let window_x = (i32::from(w_x) + (x as i32)).clamp(0, (width as i32) - 1);
let (window_y, window_x) = (window_y as u32, window_x as u32);
debug_assert!(image.in_bounds(window_x, window_y));
let window_pixel = unsafe { image.unsafe_get_pixel(window_x, window_y) };
let spatial_distance = spatial_distance_lookup
[(window_len * (w_y + radius) + (w_x + radius)) as usize];
let color_distance = color_distance.color_distance(¢er_pixel, &window_pixel);
let weight = spatial_distance * color_distance;
(weight, window_pixel)
});
weighted_average(weights_and_values)
};
Image::from_fn(width, height, bilateral_pixel_filter)
}
fn weighted_average<P>(weights_and_values: impl Iterator<Item = (f32, P)>) -> P
where
P: Pixel,
<P as image::Pixel>::Subpixel: 'static,
f32: From<P::Subpixel> + AsPrimitive<P::Subpixel>,
{
let (weights_sum, weighted_channel_sums) = weights_and_values
.map(|(w, v)| {
(
w,
v.channels().iter().map(|s| w * f32::from(*s)).collect_vec(),
)
})
.reduce(|(w1, channels1), (w2, channels2)| {
(
w1 + w2,
channels1
.into_iter()
.zip_eq(channels2)
.map(|(c1, c2)| c1 + c2)
.collect_vec(),
)
})
.expect("cannot find a weighted average given no weights and values");
let channel_averages = weighted_channel_sums.iter().map(|x| x / weights_sum);
*P::from_slice(
&channel_averages
.map(<f32 as AsPrimitive<P::Subpixel>>::as_)
.collect_vec(),
)
}
fn gaussian_weight(x_squared: f32, sigma_squared: f32) -> f32 {
(-0.5 * x_squared / sigma_squared).exp()
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg_attr(miri, ignore = "assert_pixels_eq fails")]
#[test]
fn test_bilateral_filter_greyscale() {
let image = gray_image!(
1, 2, 3;
4, 5, 6;
7, 8, 9);
let actual = bilateral_filter(&image, 1, 3.0, GaussianEuclideanColorDistance::new(10.0));
let expect = gray_image!(
2, 2, 3;
4, 5, 5;
6, 7, 7);
assert_pixels_eq!(actual, expect);
}
}
#[cfg(not(miri))]
#[cfg(test)]
mod proptests {
use super::*;
use crate::proptest_utils::arbitrary_image;
use image::Luma;
use image::Rgb;
use proptest::prelude::*;
proptest! {
#[test]
fn proptest_bilateral_filter_greyscale(
img in arbitrary_image::<Luma<u8>>(1..40, 1..40),
radius in 0..5u8,
color_sigma in any::<f32>(),
spatial_sigma in any::<f32>(),
) {
let out = bilateral_filter(&img, radius, spatial_sigma, GaussianEuclideanColorDistance::new(color_sigma));
assert_eq!(out.dimensions(), img.dimensions());
}
#[test]
fn proptest_bilateral_filter_rgb(
img in arbitrary_image::<Rgb<u8>>(1..40, 1..40),
radius in 0..5u8,
color_sigma in any::<f32>(),
spatial_sigma in any::<f32>(),
) {
let out = bilateral_filter(&img, radius, spatial_sigma, GaussianEuclideanColorDistance::new(color_sigma));
assert_eq!(out.dimensions(), img.dimensions());
}
}
}
#[cfg(not(miri))]
#[cfg(test)]
mod benches {
use super::*;
use crate::utils::{gray_bench_image, rgb_bench_image};
use test::{Bencher, black_box};
#[bench]
fn bench_bilateral_filter_greyscale(b: &mut Bencher) {
let image = gray_bench_image(100, 100);
b.iter(|| {
let filtered =
bilateral_filter(&image, 5, 3., GaussianEuclideanColorDistance::new(10.0));
black_box(filtered);
});
}
#[bench]
fn bench_bilateral_filter_rgb(b: &mut Bencher) {
let image = rgb_bench_image(100, 100);
b.iter(|| {
let filtered =
bilateral_filter(&image, 5, 3., GaussianEuclideanColorDistance::new(10.0));
black_box(filtered);
});
}
}