use super::{PreprocessError, Result};
use image::{GrayImage, Luma};
use std::cmp;
pub fn clahe(image: &GrayImage, clip_limit: f32, tile_size: u32) -> Result<GrayImage> {
if tile_size == 0 || clip_limit <= 0.0 {
return Err(PreprocessError::InvalidParameters(
"Invalid CLAHE parameters".to_string(),
));
}
let (width, height) = image.dimensions();
let mut result = GrayImage::new(width, height);
let tiles_x = (width + tile_size - 1) / tile_size;
let tiles_y = (height + tile_size - 1) / tile_size;
let mut tile_cdfs = vec![vec![Vec::new(); tiles_x as usize]; tiles_y as usize];
for ty in 0..tiles_y {
for tx in 0..tiles_x {
let x_start = tx * tile_size;
let y_start = ty * tile_size;
let x_end = cmp::min(x_start + tile_size, width);
let y_end = cmp::min(y_start + tile_size, height);
let cdf = compute_tile_cdf(image, x_start, y_start, x_end, y_end, clip_limit);
tile_cdfs[ty as usize][tx as usize] = cdf;
}
}
for y in 0..height {
for x in 0..width {
let pixel = image.get_pixel(x, y)[0];
let tx = (x as f32 / tile_size as f32).floor();
let ty = (y as f32 / tile_size as f32).floor();
let x_ratio = (x as f32 / tile_size as f32) - tx;
let y_ratio = (y as f32 / tile_size as f32) - ty;
let tx = tx as usize;
let ty = ty as usize;
let value = if tx < tiles_x as usize - 1 && ty < tiles_y as usize - 1 {
let v00 = tile_cdfs[ty][tx][pixel as usize];
let v10 = tile_cdfs[ty][tx + 1][pixel as usize];
let v01 = tile_cdfs[ty + 1][tx][pixel as usize];
let v11 = tile_cdfs[ty + 1][tx + 1][pixel as usize];
let v0 = v00 * (1.0 - x_ratio) + v10 * x_ratio;
let v1 = v01 * (1.0 - x_ratio) + v11 * x_ratio;
v0 * (1.0 - y_ratio) + v1 * y_ratio
} else if tx < tiles_x as usize - 1 {
let v0 = tile_cdfs[ty][tx][pixel as usize];
let v1 = tile_cdfs[ty][tx + 1][pixel as usize];
v0 * (1.0 - x_ratio) + v1 * x_ratio
} else if ty < tiles_y as usize - 1 {
let v0 = tile_cdfs[ty][tx][pixel as usize];
let v1 = tile_cdfs[ty + 1][tx][pixel as usize];
v0 * (1.0 - y_ratio) + v1 * y_ratio
} else {
tile_cdfs[ty][tx][pixel as usize]
};
result.put_pixel(x, y, Luma([(value * 255.0) as u8]));
}
}
Ok(result)
}
fn compute_tile_cdf(
image: &GrayImage,
x_start: u32,
y_start: u32,
x_end: u32,
y_end: u32,
clip_limit: f32,
) -> Vec<f32> {
let mut histogram = [0u32; 256];
let mut pixel_count = 0;
for y in y_start..y_end {
for x in x_start..x_end {
let pixel = image.get_pixel(x, y)[0];
histogram[pixel as usize] += 1;
pixel_count += 1;
}
}
if pixel_count == 0 {
return vec![0.0; 256];
}
let clip_limit_actual = (clip_limit * pixel_count as f32 / 256.0) as u32;
let mut clipped_total = 0u32;
for h in histogram.iter_mut() {
if *h > clip_limit_actual {
clipped_total += *h - clip_limit_actual;
*h = clip_limit_actual;
}
}
let redistribute = clipped_total / 256;
let remainder = clipped_total % 256;
for (i, h) in histogram.iter_mut().enumerate() {
*h += redistribute;
if i < remainder as usize {
*h += 1;
}
}
let mut cdf = vec![0.0; 256];
let mut cumsum = 0u32;
for (i, &h) in histogram.iter().enumerate() {
cumsum += h;
cdf[i] = cumsum as f32 / pixel_count as f32;
}
cdf
}
pub fn normalize_brightness(image: &GrayImage) -> GrayImage {
let (width, height) = image.dimensions();
let pixel_count = (width * height) as f32;
let sum: u32 = image.pixels().map(|p| p[0] as u32).sum();
let mean = sum as f32 / pixel_count;
let target_mean = 128.0;
let adjustment = target_mean - mean;
let mut result = GrayImage::new(width, height);
for (x, y, pixel) in image.enumerate_pixels() {
let adjusted = (pixel[0] as f32 + adjustment).clamp(0.0, 255.0) as u8;
result.put_pixel(x, y, Luma([adjusted]));
}
result
}
pub fn remove_shadows(image: &GrayImage) -> Result<GrayImage> {
let (width, height) = image.dimensions();
let kernel_size = (width.min(height) / 20).max(15) as usize;
let background = estimate_background(image, kernel_size);
let mut result = GrayImage::new(width, height);
for (x, y, pixel) in image.enumerate_pixels() {
let bg = background.get_pixel(x, y)[0] as i32;
let fg = pixel[0] as i32;
let normalized = if bg > 0 {
((fg as f32 / bg as f32) * 255.0).min(255.0) as u8
} else {
fg as u8
};
result.put_pixel(x, y, Luma([normalized]));
}
Ok(result)
}
fn estimate_background(image: &GrayImage, kernel_size: usize) -> GrayImage {
let (width, height) = image.dimensions();
let mut background = GrayImage::new(width, height);
let half_kernel = (kernel_size / 2) as i32;
for y in 0..height {
for x in 0..width {
let mut max_val = 0u8;
for ky in -(half_kernel)..=half_kernel {
for kx in -(half_kernel)..=half_kernel {
let px = (x as i32 + kx).clamp(0, width as i32 - 1) as u32;
let py = (y as i32 + ky).clamp(0, height as i32 - 1) as u32;
let val = image.get_pixel(px, py)[0];
if val > max_val {
max_val = val;
}
}
}
background.put_pixel(x, y, Luma([max_val]));
}
}
background
}
pub fn contrast_stretch(image: &GrayImage) -> GrayImage {
let mut min_val = 255u8;
let mut max_val = 0u8;
for pixel in image.pixels() {
let val = pixel[0];
if val < min_val {
min_val = val;
}
if val > max_val {
max_val = val;
}
}
if min_val == max_val {
return image.clone();
}
let (width, height) = image.dimensions();
let mut result = GrayImage::new(width, height);
let range = (max_val - min_val) as f32;
for (x, y, pixel) in image.enumerate_pixels() {
let val = pixel[0];
let stretched = ((val - min_val) as f32 / range * 255.0) as u8;
result.put_pixel(x, y, Luma([stretched]));
}
result
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_image() -> GrayImage {
let mut img = GrayImage::new(100, 100);
for y in 0..100 {
for x in 0..100 {
let val = ((x + y) / 2) as u8;
img.put_pixel(x, y, Luma([val]));
}
}
img
}
#[test]
fn test_clahe() {
let img = create_test_image();
let enhanced = clahe(&img, 2.0, 8);
assert!(enhanced.is_ok());
let result = enhanced.unwrap();
assert_eq!(result.dimensions(), img.dimensions());
}
#[test]
fn test_clahe_invalid_params() {
let img = create_test_image();
let result = clahe(&img, 2.0, 0);
assert!(result.is_err());
let result = clahe(&img, -1.0, 8);
assert!(result.is_err());
}
#[test]
fn test_normalize_brightness() {
let img = create_test_image();
let normalized = normalize_brightness(&img);
assert_eq!(normalized.dimensions(), img.dimensions());
let sum: u32 = normalized.pixels().map(|p| p[0] as u32).sum();
let mean = sum as f32 / (100.0 * 100.0);
assert!((mean - 128.0).abs() < 5.0);
}
#[test]
fn test_remove_shadows() {
let img = create_test_image();
let result = remove_shadows(&img);
assert!(result.is_ok());
let shadow_removed = result.unwrap();
assert_eq!(shadow_removed.dimensions(), img.dimensions());
}
#[test]
fn test_contrast_stretch() {
let mut img = GrayImage::new(100, 100);
for y in 0..100 {
for x in 0..100 {
let val = 100 + ((x + y) / 10) as u8; img.put_pixel(x, y, Luma([val]));
}
}
let stretched = contrast_stretch(&img);
let mut min_val = 255u8;
let mut max_val = 0u8;
for pixel in stretched.pixels() {
let val = pixel[0];
if val < min_val {
min_val = val;
}
if val > max_val {
max_val = val;
}
}
assert_eq!(min_val, 0);
assert_eq!(max_val, 255);
}
#[test]
fn test_contrast_stretch_uniform() {
let mut img = GrayImage::new(50, 50);
for pixel in img.pixels_mut() {
*pixel = Luma([128]);
}
let stretched = contrast_stretch(&img);
for pixel in stretched.pixels() {
assert_eq!(pixel[0], 128);
}
}
#[test]
fn test_estimate_background() {
let img = create_test_image();
let background = estimate_background(&img, 5);
assert_eq!(background.dimensions(), img.dimensions());
for (orig, bg) in img.pixels().zip(background.pixels()) {
assert!(bg[0] >= orig[0]);
}
}
#[test]
fn test_clahe_various_tile_sizes() {
let img = create_test_image();
for tile_size in [4, 8, 16, 32] {
let result = clahe(&img, 2.0, tile_size);
assert!(result.is_ok());
}
}
}