use zune_core::bit_depth::BitType;
use zune_core::log::trace;
use zune_image::channel::Channel;
use zune_image::errors::ImageErrors;
use zune_image::image::Image;
use zune_image::traits::OperationsTrait;
use crate::pad::{pad, PadMethod};
use crate::utils::z_prefetch;
#[derive(Default)]
pub struct Median {
radius: usize
}
impl Median {
#[must_use]
pub fn new(radius: usize) -> Median {
Median { radius }
}
}
impl OperationsTrait for Median {
fn name(&self) -> &'static str {
"Median Filter"
}
fn execute_impl(&self, image: &mut Image) -> Result<(), ImageErrors> {
let (width, height) = image.dimensions();
if self.radius < 2 {
return Ok(());
}
let depth = image.depth();
#[cfg(not(feature = "threads"))]
{
trace!("Running median filter single threaded mode");
for channel in image.get_channels_mut(false) {
let mut new_channel = Channel::new_with_bit_type(channel.len(), depth.bit_type());
match depth.bit_type() {
BitType::U16 => median_u16(
channel.reinterpret_as::<u16>()?,
new_channel.reinterpret_as_mut::<u16>()?,
self.radius,
width,
height
),
BitType::U8 => median_u8(
channel.reinterpret_as::<u8>()?,
new_channel.reinterpret_as_mut::<u8>()?,
self.radius,
width,
height
),
d => {
return Err(ImageErrors::ImageOperationNotImplemented(
self.get_name(),
d
))
}
}
*channel = new_channel;
}
}
#[cfg(feature = "threads")]
{
trace!("Running median filter multithreaded mode");
std::thread::scope(|s| {
let mut errors = vec![];
for channel in image.channels_mut(true) {
let result = s.spawn(|| {
let mut new_channel =
Channel::new_with_bit_type(channel.len(), depth.bit_type());
match depth.bit_type() {
BitType::U16 => median_u16(
channel.reinterpret_as::<u16>()?,
new_channel.reinterpret_as_mut::<u16>()?,
self.radius,
width,
height
),
BitType::U8 => median_u8(
channel.reinterpret_as::<u8>()?,
new_channel.reinterpret_as_mut::<u8>()?,
self.radius,
width,
height
),
d => {
return Err(ImageErrors::ImageOperationNotImplemented(
self.name(),
d
))
}
}
*channel = new_channel;
Ok(())
});
errors.push(result);
}
errors
.into_iter()
.map(|x| x.join().unwrap())
.collect::<Result<Vec<()>, ImageErrors>>()
})?;
}
Ok(())
}
fn supported_types(&self) -> &'static [BitType] {
&[BitType::U8, BitType::U16]
}
}
#[allow(clippy::cast_possible_truncation)]
pub fn median_u16(
in_channel: &[u16], out_channel: &mut [u16], radius: usize, width: usize, height: usize
) {
let mut histogram_arr = vec![0_u32; 65536];
let histogram: &mut [u32; 65536] = histogram_arr.get_mut(..).unwrap().try_into().unwrap();
let radius_size = (2 * radius) + 1;
let mut to_be_dropped = vec![0; radius_size];
let mut counter = 0;
let func = |array: &[u16]| -> u16 {
let median_pos = array.len() / 2;
if (counter % width) == 0 {
histogram.fill(0);
for c in array {
histogram[usize::from(*c)] += 1;
}
} else {
for x in &to_be_dropped {
histogram[usize::from(*x)] -= 1;
}
array.chunks_exact(radius_size).for_each(|v| {
let to_add = usize::from(*v.last().unwrap());
histogram[to_add] += 1;
});
}
let mut accum = 0;
let mut median = 0;
for (pos, v) in histogram.iter().enumerate() {
accum += *v;
if accum >= median_pos as u32 {
median = pos as u16;
break;
}
}
counter += 1;
debug_assert!(array.chunks_exact(radius_size).len() == to_be_dropped.len());
array
.chunks_exact(radius_size)
.zip(to_be_dropped.iter_mut())
.for_each(|(x, v)| {
*v = x[0];
});
median
};
let padded_input = pad(
in_channel,
width,
height,
radius,
radius,
PadMethod::Replicate
);
spatial_median(&padded_input, out_channel, radius, width, height, func);
}
#[allow(clippy::cast_possible_truncation)]
pub fn median_u8(
in_channel: &[u8], out_channel: &mut [u8], radius: usize, width: usize, height: usize
) {
let mut histogram = [0_u32; 256];
let radius_size = (2 * radius) + 1;
let mut to_be_dropped = vec![0; radius_size];
let mut counter = 0;
let func = |array: &[u8]| -> u8 {
let median_pos = array.len() / 2;
if (counter % width) == 0 {
histogram.fill(0);
for c in array {
histogram[usize::from(*c)] += 1;
}
} else {
for x in &to_be_dropped {
histogram[usize::from(*x)] -= 1;
}
array.chunks_exact(radius_size).for_each(|v| {
let to_add = usize::from(*v.last().unwrap());
histogram[to_add] += 1;
});
}
let mut accum = 0;
let mut median = 0;
for (pos, v) in histogram.iter().enumerate() {
accum += *v;
if accum >= median_pos as u32 {
median = pos as u8;
break;
}
}
counter += 1;
debug_assert!(array.chunks_exact(radius_size).len() == to_be_dropped.len());
array
.chunks_exact(radius_size)
.zip(to_be_dropped.iter_mut())
.for_each(|(x, v)| {
*v = x[0];
});
median
};
let padded_input = pad(
in_channel,
width,
height,
radius,
radius,
PadMethod::Replicate
);
spatial_median(&padded_input, out_channel, radius, width, height, func);
}
pub fn spatial_median<T, F>(
in_channel: &[T], out_channel: &mut [T], radius: usize, width: usize, height: usize,
mut function: F
) where
T: Default + Copy,
F: FnMut(&[T]) -> T
{
let old_width = width;
let height = (radius * 2) + height;
let width = (radius * 2) + width;
assert_eq!(height * width, in_channel.len());
let radius_size = (2 * radius) + 1;
let radius_loop = radius_size >> 1;
let mut local_storage = vec![T::default(); radius_size * radius_size];
for y in radius_loop..height - radius_loop {
for x in radius_loop..width - radius_loop {
let iy = y - radius_loop;
let ix = x - radius_loop;
let mut i = 0;
for ky in 0..radius_size {
let iy_i = iy + ky;
z_prefetch(in_channel, (iy_i + 1) * width + ix);
let in_slice = &in_channel[(iy_i * width) + ix..(iy_i * width) + ix + radius_size];
local_storage[i..i + radius_size].copy_from_slice(in_slice);
z_prefetch(in_channel, (iy_i + 2) * width + ix);
i += radius_size;
}
let result = function(&local_storage);
out_channel[iy * old_width + ix] = result;
}
}
}