use crate::{GpuDevice, Result};
use rayon::prelude::*;
use super::utils;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EqualizationMode {
LumaOnly,
PerChannel,
}
impl Default for EqualizationMode {
fn default() -> Self {
Self::LumaOnly
}
}
#[derive(Debug, Clone)]
pub struct HistogramEqualizerConfig {
pub mode: EqualizationMode,
pub clip_limit: f32,
}
impl Default for HistogramEqualizerConfig {
fn default() -> Self {
Self {
mode: EqualizationMode::default(),
clip_limit: 0.0,
}
}
}
pub struct HistogramEqualizer {
config: HistogramEqualizerConfig,
}
impl HistogramEqualizer {
#[must_use]
pub fn new(config: HistogramEqualizerConfig) -> Self {
Self { config }
}
#[must_use]
pub fn default_config() -> Self {
Self::new(HistogramEqualizerConfig::default())
}
pub fn equalize(
&self,
_device: &GpuDevice,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
) -> Result<()> {
self.equalize_cpu(input, output, width, height)
}
pub fn equalize_cpu(
&self,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
) -> Result<()> {
utils::validate_dimensions(width, height)?;
utils::validate_buffer_size(input, width, height, 4)?;
utils::validate_buffer_size(output, width, height, 4)?;
match self.config.mode {
EqualizationMode::LumaOnly => self.equalize_luma(input, output, width, height),
EqualizationMode::PerChannel => self.equalize_per_channel(input, output, width, height),
}
}
fn equalize_luma(
&self,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
) -> Result<()> {
let n_pixels = (width * height) as usize;
let mut hist = [0u64; 256];
for px in input.chunks_exact(4) {
let y = luma_bt601(px[0], px[1], px[2]);
hist[y as usize] += 1;
}
let hist = self.apply_clip_limit(hist, n_pixels);
let lut = build_equalization_lut(&hist, n_pixels);
output
.par_chunks_exact_mut(4)
.zip(input.par_chunks_exact(4))
.for_each(|(out, inn)| {
let y_orig = luma_bt601(inn[0], inn[1], inn[2]);
let y_eq = lut[y_orig as usize];
if y_orig == 0 {
out[0] = 0;
out[1] = 0;
out[2] = 0;
} else {
let scale = f32::from(y_eq) / f32::from(y_orig);
out[0] = (f32::from(inn[0]) * scale).clamp(0.0, 255.0).round() as u8;
out[1] = (f32::from(inn[1]) * scale).clamp(0.0, 255.0).round() as u8;
out[2] = (f32::from(inn[2]) * scale).clamp(0.0, 255.0).round() as u8;
}
out[3] = inn[3]; });
Ok(())
}
fn equalize_per_channel(
&self,
input: &[u8],
output: &mut [u8],
width: u32,
height: u32,
) -> Result<()> {
let n_pixels = (width * height) as usize;
let mut hist_r = [0u64; 256];
let mut hist_g = [0u64; 256];
let mut hist_b = [0u64; 256];
for px in input.chunks_exact(4) {
hist_r[px[0] as usize] += 1;
hist_g[px[1] as usize] += 1;
hist_b[px[2] as usize] += 1;
}
let hist_r = self.apply_clip_limit(hist_r, n_pixels);
let hist_g = self.apply_clip_limit(hist_g, n_pixels);
let hist_b = self.apply_clip_limit(hist_b, n_pixels);
let lut_r = build_equalization_lut(&hist_r, n_pixels);
let lut_g = build_equalization_lut(&hist_g, n_pixels);
let lut_b = build_equalization_lut(&hist_b, n_pixels);
output
.par_chunks_exact_mut(4)
.zip(input.par_chunks_exact(4))
.for_each(|(out, inn)| {
out[0] = lut_r[inn[0] as usize];
out[1] = lut_g[inn[1] as usize];
out[2] = lut_b[inn[2] as usize];
out[3] = inn[3];
});
Ok(())
}
fn apply_clip_limit(&self, mut hist: [u64; 256], n_pixels: usize) -> [u64; 256] {
let clip = self.config.clip_limit;
if clip <= 0.0 {
return hist;
}
let max_count = (clip.clamp(0.0, 1.0) * n_pixels as f32 / 256.0).round() as u64;
if max_count == 0 {
return hist;
}
let mut clipped_total = 0u64;
for bin in &mut hist {
if *bin > max_count {
clipped_total += *bin - max_count;
*bin = max_count;
}
}
let redistribute = clipped_total / 256;
let remainder = (clipped_total % 256) as usize;
for bin in &mut hist {
*bin += redistribute;
}
for bin in hist.iter_mut().take(remainder) {
*bin += 1;
}
hist
}
}
#[inline(always)]
fn luma_bt601(r: u8, g: u8, b: u8) -> u8 {
let y = 0.299 * f32::from(r) + 0.587 * f32::from(g) + 0.114 * f32::from(b);
y.clamp(0.0, 255.0).round() as u8
}
fn build_equalization_lut(hist: &[u64; 256], n_pixels: usize) -> [u8; 256] {
let mut cdf = [0u64; 256];
cdf[0] = hist[0];
for i in 1..256 {
cdf[i] = cdf[i - 1] + hist[i];
}
let cdf_min = cdf.iter().find(|&&v| v > 0).copied().unwrap_or(0);
let denom = (n_pixels as u64).saturating_sub(cdf_min);
let mut lut = [0u8; 256];
for (i, lut_v) in lut.iter_mut().enumerate() {
if cdf[i] == 0 {
*lut_v = 0;
} else if denom == 0 {
*lut_v = 255;
} else if cdf[i] <= cdf_min {
if cdf[i] == cdf[255] {
*lut_v = 255;
} else {
*lut_v = 0;
}
} else {
let num = cdf[i] - cdf_min;
*lut_v = ((num as f64 / denom as f64) * 255.0)
.round()
.clamp(0.0, 255.0) as u8;
}
}
lut
}
#[cfg(test)]
mod tests {
use super::*;
fn gray_rgba(w: u32, h: u32, v: u8) -> Vec<u8> {
vec![v, v, v, 255u8].repeat((w * h) as usize)
}
fn gradient_rgba(w: u32, h: u32) -> Vec<u8> {
(0..(w * h))
.flat_map(|i| {
let v = ((i * 255) / (w * h - 1).max(1)) as u8;
[v, v, v, 255u8]
})
.collect()
}
#[test]
fn test_lut_uniform_histogram() {
let hist = [4u64; 256];
let lut = build_equalization_lut(&hist, 1024);
assert_eq!(lut[0], 0, "first bin maps to 0");
assert_eq!(lut[255], 255, "last bin maps to 255");
for i in 1..256 {
assert!(lut[i] >= lut[i - 1], "LUT must be monotone at {i}");
}
}
#[test]
fn test_lut_single_value_histogram() {
let mut hist = [0u64; 256];
hist[128] = 100;
let lut = build_equalization_lut(&hist, 100);
assert_eq!(lut[128], 255, "single-value bin maps to 255");
}
#[test]
fn test_luma_pure_red() {
let y = luma_bt601(255, 0, 0);
assert_eq!(y, 76, "BT.601 luma of red ≈ 76");
}
#[test]
fn test_luma_pure_green() {
let y = luma_bt601(0, 255, 0);
assert_eq!(y, 150, "BT.601 luma of green ≈ 150");
}
#[test]
fn test_luma_white() {
let y = luma_bt601(255, 255, 255);
assert_eq!(y, 255, "luma of white = 255");
}
#[test]
fn test_luma_black() {
let y = luma_bt601(0, 0, 0);
assert_eq!(y, 0, "luma of black = 0");
}
#[test]
fn test_equalize_constant_image_luma() {
let w = 8u32;
let h = 8u32;
let input = gray_rgba(w, h, 100);
let mut output = vec![0u8; (w * h * 4) as usize];
let eq = HistogramEqualizer::default_config();
eq.equalize_cpu(&input, &mut output, w, h)
.expect("equalize constant image");
for i in 0..(w * h) as usize {
assert_eq!(output[i * 4 + 3], 255, "alpha must be preserved");
}
}
#[test]
fn test_equalize_gradient_luma_monotone() {
let w = 16u32;
let h = 16u32;
let input = gradient_rgba(w, h);
let mut output = vec![0u8; (w * h * 4) as usize];
let eq = HistogramEqualizer::default_config();
eq.equalize_cpu(&input, &mut output, w, h)
.expect("equalize gradient");
let mut prev_y = 0u8;
for i in 0..(w * h) as usize {
let y = luma_bt601(output[i * 4], output[i * 4 + 1], output[i * 4 + 2]);
assert!(
y >= prev_y,
"output luma must be non-decreasing: prev={prev_y}, cur={y}"
);
prev_y = y;
}
}
#[test]
fn test_equalize_per_channel() {
let w = 8u32;
let h = 8u32;
let input = gradient_rgba(w, h);
let mut output = vec![0u8; (w * h * 4) as usize];
let eq = HistogramEqualizer::new(HistogramEqualizerConfig {
mode: EqualizationMode::PerChannel,
clip_limit: 0.0,
});
eq.equalize_cpu(&input, &mut output, w, h)
.expect("equalize per channel");
let n = (w * h) as usize;
assert_eq!(output[0], 0, "first pixel red = 0 after per-channel eq");
assert_eq!(
output[(n - 1) * 4],
255,
"last pixel red = 255 after per-channel eq"
);
}
#[test]
fn test_equalize_alpha_passthrough_luma() {
let w = 4u32;
let h = 4u32;
let input: Vec<u8> = (0..w * h * 4)
.map(|i| if i % 4 == 3 { 200u8 } else { 128 })
.collect();
let mut output = vec![0u8; (w * h * 4) as usize];
HistogramEqualizer::default_config()
.equalize_cpu(&input, &mut output, w, h)
.expect("equalize alpha passthrough luma");
for i in 0..(w * h) as usize {
assert_eq!(output[i * 4 + 3], 200, "alpha must pass through");
}
}
#[test]
fn test_equalize_alpha_passthrough_per_channel() {
let w = 4u32;
let h = 4u32;
let input: Vec<u8> = (0..w * h * 4)
.map(|i| if i % 4 == 3 { 77u8 } else { 100 })
.collect();
let mut output = vec![0u8; (w * h * 4) as usize];
HistogramEqualizer::new(HistogramEqualizerConfig {
mode: EqualizationMode::PerChannel,
clip_limit: 0.0,
})
.equalize_cpu(&input, &mut output, w, h)
.expect("equalize alpha passthrough per channel");
for i in 0..(w * h) as usize {
assert_eq!(output[i * 4 + 3], 77, "alpha must pass through");
}
}
#[test]
fn test_equalize_invalid_dimensions() {
let input = vec![0u8; 64];
let mut output = vec![0u8; 64];
let result = HistogramEqualizer::default_config().equalize_cpu(&input, &mut output, 0, 4);
assert!(result.is_err());
}
#[test]
fn test_equalize_buffer_too_small() {
let input = vec![0u8; 4]; let mut output = vec![0u8; 64];
let result = HistogramEqualizer::default_config().equalize_cpu(&input, &mut output, 4, 4);
assert!(result.is_err());
}
#[test]
fn test_clip_limit_preserves_total() {
let eq = HistogramEqualizer::new(HistogramEqualizerConfig {
mode: EqualizationMode::LumaOnly,
clip_limit: 0.3,
});
let mut hist = [0u64; 256];
hist[100] = 500;
hist[150] = 300;
let n = 800usize;
let clipped = eq.apply_clip_limit(hist, n);
let total: u64 = clipped.iter().sum();
assert_eq!(total, n as u64, "clip limit must preserve pixel count");
}
#[test]
fn test_clip_limit_zero_no_change() {
let eq = HistogramEqualizer::new(HistogramEqualizerConfig {
mode: EqualizationMode::LumaOnly,
clip_limit: 0.0,
});
let hist = [10u64; 256];
let result = eq.apply_clip_limit(hist, 2560);
assert_eq!(hist, result, "zero clip limit must not change histogram");
}
}