use crate::definitions::Image;
use image::{GenericImageView, Pixel};
use std::cmp::{max, min};
#[must_use = "the function does not modify the original image"]
pub fn median_filter<P>(image: &Image<P>, x_radius: u32, y_radius: u32) -> Image<P>
where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
if width == 0 || height == 0 {
return image.clone();
}
if (width + x_radius) > i32::MAX as u32 || (height + y_radius) > i32::MAX as u32 {
panic!("(width + x_radius) and (height + y_radius) must both be <= i32::MAX");
}
let mut out = Image::<P>::new(width, height);
let mut hist = initialise_histogram_for_top_left_pixel(image, x_radius, y_radius);
slide_down_column(&mut hist, image, &mut out, 0, x_radius, y_radius);
for x in 1..width {
if x % 2 == 0 {
slide_right(&mut hist, image, x, 0, x_radius, y_radius);
slide_down_column(&mut hist, image, &mut out, x, x_radius, y_radius);
} else {
slide_right(&mut hist, image, x, height - 1, x_radius, y_radius);
slide_up_column(&mut hist, image, &mut out, x, x_radius, y_radius);
}
}
out
}
fn initialise_histogram_for_top_left_pixel<P>(
image: &Image<P>,
x_radius: u32,
y_radius: u32,
) -> HistSet
where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
let kernel_size = (2 * x_radius + 1) * (2 * y_radius + 1);
let num_channels = P::CHANNEL_COUNT;
let mut hist = HistSet::new(num_channels, kernel_size);
let rx = x_radius as i32;
let ry = y_radius as i32;
for dy in -ry..(ry + 1) {
let py = min(max(0, dy), height as i32 - 1) as u32;
for dx in -rx..(rx + 1) {
let px = min(max(0, dx), width as i32 - 1) as u32;
unsafe {
hist.incr(image, px, py);
}
}
}
hist
}
fn slide_right<P>(hist: &mut HistSet, image: &Image<P>, x: u32, y: u32, rx: u32, ry: u32)
where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
assert!(x < width);
assert!(y < height);
let rx = rx as i32;
let ry = ry as i32;
let prev_x = max(0, x as i32 - rx - 1) as u32;
let next_x = min(x as i32 + rx, width as i32 - 1) as u32;
for dy in -ry..(ry + 1) {
let py = min(max(0, y as i32 + dy), (height - 1) as i32) as u32;
unsafe {
hist.decr(image, prev_x, py);
}
unsafe {
hist.incr(image, next_x, py);
}
}
}
fn slide_down_column<P>(
hist: &mut HistSet,
image: &Image<P>,
out: &mut Image<P>,
x: u32,
rx: u32,
ry: u32,
) where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
assert!(x < width);
let rx = rx as i32;
let ry = ry as i32;
unsafe {
hist.set_to_median(out, x, 0);
}
for y in 1..height {
let prev_y = max(0, y as i32 - ry - 1) as u32;
let next_y = min(y as i32 + ry, height as i32 - 1) as u32;
for dx in -rx..(rx + 1) {
let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
unsafe {
hist.decr(image, px, prev_y);
}
unsafe {
hist.incr(image, px, next_y);
}
}
unsafe {
hist.set_to_median(out, x, y);
}
}
}
fn slide_up_column<P>(
hist: &mut HistSet,
image: &Image<P>,
out: &mut Image<P>,
x: u32,
rx: u32,
ry: u32,
) where
P: Pixel<Subpixel = u8>,
{
let (width, height) = image.dimensions();
assert!(x < width);
let rx = rx as i32;
let ry = ry as i32;
unsafe {
hist.set_to_median(out, x, height - 1);
}
for y in (0..(height - 1)).rev() {
let prev_y = min(y as i32 + ry + 1, height as i32 - 1) as u32;
let next_y = max(0, y as i32 - ry) as u32;
for dx in -rx..(rx + 1) {
let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
unsafe {
hist.decr(image, px, prev_y);
}
unsafe {
hist.incr(image, px, next_y);
}
}
unsafe {
hist.set_to_median(out, x, y);
}
}
}
struct HistSet {
data: Vec<[u32; 256]>,
expected_count: u32,
}
impl HistSet {
fn new(num_channels: u8, expected_count: u32) -> HistSet {
HistSet {
data: vec![[0u32; 256]; num_channels.into()],
expected_count,
}
}
unsafe fn incr<P>(&mut self, image: &Image<P>, x: u32, y: u32)
where
P: Pixel<Subpixel = u8>,
{
unsafe {
let pixel = image.unsafe_get_pixel(x, y);
let channels = pixel.channels();
for c in 0..channels.len() {
let p = *channels.get_unchecked(c) as usize;
let hist = self.data.get_unchecked_mut(c);
*hist.get_unchecked_mut(p) += 1;
}
}
}
unsafe fn decr<P>(&mut self, image: &Image<P>, x: u32, y: u32)
where
P: Pixel<Subpixel = u8>,
{
unsafe {
let pixel = image.unsafe_get_pixel(x, y);
let channels = pixel.channels();
for c in 0..channels.len() {
let p = *channels.get_unchecked(c) as usize;
let hist = self.data.get_unchecked_mut(c);
*hist.get_unchecked_mut(p) -= 1;
}
}
}
unsafe fn set_to_median<P>(&self, image: &mut Image<P>, x: u32, y: u32)
where
P: Pixel<Subpixel = u8>,
{
unsafe {
let target = image.get_pixel_mut(x, y);
let channels = target.channels_mut();
for c in 0..channels.len() {
*channels.get_unchecked_mut(c) = self.channel_median(c as u8);
}
}
}
unsafe fn channel_median(&self, c: u8) -> u8 {
unsafe {
let hist = self.data.get_unchecked(c as usize);
let mut count = 0;
for i in 0..256 {
count += *hist.get_unchecked(i);
if 2 * count >= self.expected_count {
return i as u8;
}
}
255
}
}
}
#[cfg(not(miri))]
#[cfg(test)]
mod benches {
use super::*;
use crate::utils::gray_bench_image;
use test::{Bencher, black_box};
macro_rules! bench_median_filter {
($name:ident, side: $s:expr, x_radius: $rx:expr, y_radius: $ry:expr) => {
#[bench]
fn $name(b: &mut Bencher) {
let image = gray_bench_image($s, $s);
b.iter(|| {
let filtered = median_filter(&image, $rx, $ry);
black_box(filtered);
})
}
};
}
bench_median_filter!(bench_median_filter_s100_r1, side: 100, x_radius: 1,y_radius: 1);
bench_median_filter!(bench_median_filter_s100_r4, side: 100, x_radius: 4,y_radius: 4);
bench_median_filter!(bench_median_filter_s100_r8, side: 100, x_radius: 8,y_radius: 8);
bench_median_filter!(bench_median_filter_s100_rx1_ry4, side: 100, x_radius: 1,y_radius: 4);
bench_median_filter!(bench_median_filter_s100_rx1_ry8, side: 100, x_radius: 1,y_radius: 8);
bench_median_filter!(bench_median_filter_s100_rx4_ry8, side: 100, x_radius: 4,y_radius: 1);
bench_median_filter!(bench_median_filter_s100_rx8_ry1, side: 100, x_radius: 8,y_radius: 1);
}
#[cfg(not(miri))]
#[cfg(test)]
mod proptests {
use super::*;
use crate::proptest_utils::arbitrary_image;
use image::{GrayImage, Luma};
use proptest::prelude::*;
use std::cmp::{max, min};
fn reference_median_filter(image: &GrayImage, x_radius: u32, y_radius: u32) -> GrayImage {
let (width, height) = image.dimensions();
if width == 0 || height == 0 {
return image.clone();
}
let mut out = GrayImage::new(width, height);
let x_filter_side = (2 * x_radius + 1) as usize;
let y_filter_side = (2 * y_radius + 1) as usize;
let mut neighbors = vec![0u8; x_filter_side * y_filter_side];
let rx = x_radius as i32;
let ry = y_radius as i32;
for y in 0..height {
for x in 0..width {
let mut idx = 0;
for dy in -ry..(ry + 1) {
for dx in -rx..(rx + 1) {
let px = min(max(0, x as i32 + dx), (width - 1) as i32) as u32;
let py = min(max(0, y as i32 + dy), (height - 1) as i32) as u32;
neighbors[idx] = image.get_pixel(px, py)[0];
idx += 1;
}
}
neighbors.sort();
let m = median(&neighbors);
out.put_pixel(x, y, Luma([m]));
}
}
out
}
fn median(sorted: &[u8]) -> u8 {
let mid = sorted.len() / 2;
sorted[mid]
}
proptest! {
#[test]
fn test_median_filter_matches_reference_implementation(image in arbitrary_image::<Luma<u8>>(0..10, 0..10), x_radius in 0_u32..5, y_radius in 0_u32..5) {
let expected = reference_median_filter(&image, x_radius, y_radius);
let actual = median_filter(&image, x_radius, y_radius);
assert_eq!(actual, expected);
}
}
}