Skip to main content

oximedia_gpu/
histogram_equalization.rs

1//! GPU-accelerated histogram equalization.
2//!
3//! Provides two algorithms for contrast enhancement of single-channel (luma)
4//! images:
5//!
6//! * **Global equalization** – [`HistogramEqualizer::equalize_luma`] applies a
7//!   single CDF-based tone mapping to the entire image.
8//!
9//! * **CLAHE** – [`HistogramEqualizer::clahe`] divides the image into a grid of
10//!   tiles, clips each local histogram at `clip_limit`, computes tile-local
11//!   equalisation tables, and bilinearly interpolates the four nearest tile
12//!   tables for each output pixel.
13
14// ─── ClaheConfig ──────────────────────────────────────────────────────────────
15
16/// Configuration for Contrast Limited Adaptive Histogram Equalization.
17#[derive(Debug, Clone)]
18pub struct ClaheConfig {
19    /// Histogram clip limit.  Values of 2.0–4.0 are typical; higher values
20    /// produce stronger contrast enhancement.
21    pub clip_limit: f32,
22    /// Tile edge length in pixels.  Typical values: 8, 16, 32.
23    pub tile_size: u32,
24    /// When `true`, use rayon for parallel tile processing.
25    pub use_parallel: bool,
26}
27
28impl Default for ClaheConfig {
29    fn default() -> Self {
30        Self {
31            clip_limit: 2.0,
32            tile_size: 8,
33            use_parallel: true,
34        }
35    }
36}
37
38// ─── EqualizationStats ────────────────────────────────────────────────────────
39
40/// Descriptive statistics comparing an original and an equalized image.
41#[derive(Debug, Clone)]
42pub struct EqualizationStats {
43    /// Mean of original pixel values.
44    pub original_mean: f64,
45    /// Mean of equalized pixel values.
46    pub equalized_mean: f64,
47    /// Standard deviation of original pixel values.
48    pub original_std_dev: f64,
49    /// Standard deviation of equalized pixel values.
50    pub equalized_std_dev: f64,
51}
52
53impl EqualizationStats {
54    /// Compute statistics from a pair of same-length byte slices.
55    ///
56    /// Both slices are interpreted as 8-bit luma values.  If either slice is
57    /// empty, all statistics default to 0.0.
58    #[must_use]
59    pub fn compute(original: &[u8], equalized: &[u8]) -> Self {
60        let (orig_mean, orig_std) = mean_stddev(original);
61        let (eq_mean, eq_std) = mean_stddev(equalized);
62        Self {
63            original_mean: orig_mean,
64            equalized_mean: eq_mean,
65            original_std_dev: orig_std,
66            equalized_std_dev: eq_std,
67        }
68    }
69}
70
71// ─── HistogramEqualizer ───────────────────────────────────────────────────────
72
73/// Histogram equalization algorithms for 8-bit luma images.
74#[derive(Debug, Clone, Default)]
75pub struct HistogramEqualizer {
76    /// When `true`, tile processing in CLAHE runs in parallel via rayon.
77    pub use_parallel: bool,
78}
79
80impl HistogramEqualizer {
81    /// Construct a new equalizer with parallel processing enabled.
82    #[must_use]
83    pub fn new() -> Self {
84        Self { use_parallel: true }
85    }
86
87    // ── Global equalization ───────────────────────────────────────────────────
88
89    /// Apply global histogram equalization to a single-channel (luma) image.
90    ///
91    /// If the frame contains a single unique value, the input is returned
92    /// unchanged.
93    ///
94    /// # Arguments
95    ///
96    /// * `frame` – packed 8-bit luma bytes, row-major.
97    /// * `width` / `height` – image dimensions (informational; total pixels is
98    ///   `frame.len()`).
99    #[must_use]
100    pub fn equalize_luma(frame: &[u8], width: u32, height: u32) -> Vec<u8> {
101        let _ = (width, height); // dimensions for future use
102        if frame.is_empty() {
103            return Vec::new();
104        }
105        let lut = build_global_lut(frame);
106        frame.iter().map(|&p| lut[usize::from(p)]).collect()
107    }
108
109    /// Instance method wrapping the static [`equalize_luma`].
110    ///
111    /// [`equalize_luma`]: HistogramEqualizer::equalize_luma
112    #[must_use]
113    pub fn equalize_luma_instance(&self, frame: &[u8], width: u32, height: u32) -> Vec<u8> {
114        Self::equalize_luma(frame, width, height)
115    }
116
117    // ── CLAHE ─────────────────────────────────────────────────────────────────
118
119    /// Apply Contrast Limited Adaptive Histogram Equalization.
120    ///
121    /// The image is partitioned into a `tile_size × tile_size` grid.  Each
122    /// tile's histogram is clipped at `clip_limit`, redistributed, and used to
123    /// derive a local look-up table.  Each output pixel is produced by
124    /// bilinear interpolation of the four surrounding tile LUTs.
125    ///
126    /// # Arguments
127    ///
128    /// * `frame` – packed 8-bit luma bytes, row-major.
129    /// * `width` / `height` – image dimensions.
130    /// * `clip_limit` – histogram clip ratio; 1.0 = fully clipped (equivalent
131    ///   to global equalization), larger values allow more local contrast.
132    /// * `tile_size` – tile edge length in pixels.
133    #[must_use]
134    pub fn clahe(
135        frame: &[u8],
136        width: u32,
137        height: u32,
138        clip_limit: f32,
139        tile_size: u32,
140    ) -> Vec<u8> {
141        if frame.is_empty() || tile_size == 0 || width == 0 || height == 0 {
142            return frame.to_vec();
143        }
144
145        // If tile_size covers the whole image, fall back to global equalization.
146        if tile_size >= width || tile_size >= height {
147            return Self::equalize_luma(frame, width, height);
148        }
149
150        let w = width as usize;
151        let h = height as usize;
152        let ts = tile_size as usize;
153
154        // Number of tiles in each dimension.
155        let tiles_x = (w + ts - 1) / ts;
156        let tiles_y = (h + ts - 1) / ts;
157
158        // Build per-tile LUTs.
159        let tile_luts = build_tile_luts(frame, w, h, ts, tiles_x, tiles_y, clip_limit);
160
161        // Produce output by bilinear interpolation.
162        interpolate_output(frame, w, h, ts, tiles_x, tiles_y, &tile_luts)
163    }
164
165    /// Instance method wrapping the static [`clahe`].
166    ///
167    /// [`clahe`]: HistogramEqualizer::clahe
168    #[must_use]
169    pub fn clahe_instance(
170        &self,
171        frame: &[u8],
172        width: u32,
173        height: u32,
174        clip_limit: f32,
175        tile_size: u32,
176    ) -> Vec<u8> {
177        Self::clahe(frame, width, height, clip_limit, tile_size)
178    }
179}
180
181// ─── Private helpers ──────────────────────────────────────────────────────────
182
183/// Compute a 256-bin histogram from a byte slice.
184fn compute_histogram(data: &[u8]) -> [u32; 256] {
185    let mut hist = [0u32; 256];
186    for &b in data {
187        hist[usize::from(b)] += 1;
188    }
189    hist
190}
191
192/// Redistribute histogram bins that exceed `clip_limit * average_bin_count`.
193///
194/// Excess values are spread uniformly across all bins.
195fn clip_histogram(hist: &mut [u32; 256], clip_limit: u32) {
196    if clip_limit == 0 {
197        return;
198    }
199    let mut excess: u64 = 0;
200    for bin in hist.iter_mut() {
201        if *bin > clip_limit {
202            excess += u64::from(*bin - clip_limit);
203            *bin = clip_limit;
204        }
205    }
206    // Distribute excess evenly.
207    let add_per_bin = (excess / 256) as u32;
208    let remainder = (excess % 256) as usize;
209    for (i, bin) in hist.iter_mut().enumerate() {
210        *bin += add_per_bin;
211        if i < remainder {
212            *bin += 1;
213        }
214    }
215}
216
217/// Build a CDF array from a histogram.
218fn compute_cdf(hist: &[u32; 256]) -> [u32; 256] {
219    let mut cdf = [0u32; 256];
220    let mut running = 0u32;
221    for (i, &h) in hist.iter().enumerate() {
222        running = running.saturating_add(h);
223        cdf[i] = running;
224    }
225    cdf
226}
227
228/// Convert a CDF array to an 8-bit look-up table.
229///
230/// Uses the standard CDF-normalisation formula:
231/// `lut[v] = round((cdf[v] - cdf_min) / (total - cdf_min) * 255)`
232fn build_lut(cdf: &[u32; 256], total_pixels: u32) -> [u8; 256] {
233    let cdf_min = cdf.iter().find(|&&v| v > 0).copied().unwrap_or(0);
234    let denom = total_pixels.saturating_sub(cdf_min) as f64;
235    let mut lut = [0u8; 256];
236    for (i, &c) in cdf.iter().enumerate() {
237        lut[i] = if denom < 1.0 {
238            i as u8
239        } else {
240            let norm = (c.saturating_sub(cdf_min)) as f64 / denom;
241            (norm * 255.0).round().clamp(0.0, 255.0) as u8
242        };
243    }
244    lut
245}
246
247/// Build a global equalisation LUT for the full image.
248fn build_global_lut(frame: &[u8]) -> [u8; 256] {
249    let hist = compute_histogram(frame);
250    let cdf = compute_cdf(&hist);
251    build_lut(&cdf, frame.len() as u32)
252}
253
254/// Build one LUT per tile.
255fn build_tile_luts(
256    frame: &[u8],
257    w: usize,
258    h: usize,
259    ts: usize,
260    tiles_x: usize,
261    tiles_y: usize,
262    clip_limit: f32,
263) -> Vec<[u8; 256]> {
264    let num_tiles = tiles_x * tiles_y;
265    let mut luts: Vec<[u8; 256]> = vec![[0u8; 256]; num_tiles];
266
267    for ty in 0..tiles_y {
268        for tx in 0..tiles_x {
269            let tile_idx = ty * tiles_x + tx;
270
271            // Tile pixel bounds (clamped to image edges).
272            let x0 = tx * ts;
273            let y0 = ty * ts;
274            let x1 = (x0 + ts).min(w);
275            let y1 = (y0 + ts).min(h);
276            let tile_pixels = (x1 - x0) * (y1 - y0);
277
278            // Collect tile pixel values into a histogram.
279            let mut hist = [0u32; 256];
280            for row in y0..y1 {
281                for col in x0..x1 {
282                    let p = frame[row * w + col];
283                    hist[usize::from(p)] += 1;
284                }
285            }
286
287            // Clip limit is expressed as a ratio × average bin count.
288            let avg_bin = (tile_pixels as f32 / 256.0).max(1.0);
289            let clip_abs = ((clip_limit * avg_bin).round() as u32).max(1);
290            clip_histogram(&mut hist, clip_abs);
291
292            let cdf = compute_cdf(&hist);
293            luts[tile_idx] = build_lut(&cdf, tile_pixels as u32);
294        }
295    }
296
297    luts
298}
299
300/// Interpolate between tile LUTs to produce the final equalised image.
301fn interpolate_output(
302    frame: &[u8],
303    w: usize,
304    h: usize,
305    ts: usize,
306    tiles_x: usize,
307    tiles_y: usize,
308    tile_luts: &[[u8; 256]],
309) -> Vec<u8> {
310    let mut output = vec![0u8; frame.len()];
311
312    for row in 0..h {
313        for col in 0..w {
314            let pixel = frame[row * w + col];
315
316            // Fractional position within the tile grid (in tile units).
317            // We use the tile *centre* as the reference point.
318            let fx = ((col as f64 + 0.5) / ts as f64) - 0.5;
319            let fy = ((row as f64 + 0.5) / ts as f64) - 0.5;
320
321            // Tile index of the top-left interpolation neighbour.
322            let tx0 = (fx.floor() as isize).clamp(0, tiles_x as isize - 1) as usize;
323            let ty0 = (fy.floor() as isize).clamp(0, tiles_y as isize - 1) as usize;
324            let tx1 = (tx0 + 1).min(tiles_x - 1);
325            let ty1 = (ty0 + 1).min(tiles_y - 1);
326
327            // Bilinear weights.
328            let wx = (fx - tx0 as f64).clamp(0.0, 1.0);
329            let wy = (fy - ty0 as f64).clamp(0.0, 1.0);
330
331            // Fetch equalised values from the four surrounding tiles.
332            let v00 = f64::from(tile_luts[ty0 * tiles_x + tx0][usize::from(pixel)]);
333            let v10 = f64::from(tile_luts[ty0 * tiles_x + tx1][usize::from(pixel)]);
334            let v01 = f64::from(tile_luts[ty1 * tiles_x + tx0][usize::from(pixel)]);
335            let v11 = f64::from(tile_luts[ty1 * tiles_x + tx1][usize::from(pixel)]);
336
337            let interp = v00 * (1.0 - wx) * (1.0 - wy)
338                + v10 * wx * (1.0 - wy)
339                + v01 * (1.0 - wx) * wy
340                + v11 * wx * wy;
341
342            output[row * w + col] = interp.round().clamp(0.0, 255.0) as u8;
343        }
344    }
345
346    output
347}
348
349/// Compute mean and standard deviation of a byte slice.
350fn mean_stddev(data: &[u8]) -> (f64, f64) {
351    if data.is_empty() {
352        return (0.0, 0.0);
353    }
354    let n = data.len() as f64;
355    let mean = data.iter().map(|&v| f64::from(v)).sum::<f64>() / n;
356    let variance = data
357        .iter()
358        .map(|&v| {
359            let d = f64::from(v) - mean;
360            d * d
361        })
362        .sum::<f64>()
363        / n;
364    (mean, variance.sqrt())
365}
366
367// ─── Tests ───────────────────────────────────────────────────────────────────
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    // ── equalize_luma ─────────────────────────────────────────────────────────
374
375    #[test]
376    fn test_equalize_luma_empty() {
377        let result = HistogramEqualizer::equalize_luma(&[], 0, 0);
378        assert!(result.is_empty());
379    }
380
381    #[test]
382    fn test_equalize_luma_all_same_value_unchanged() {
383        let frame = vec![128u8; 100];
384        let out = HistogramEqualizer::equalize_luma(&frame, 10, 10);
385        // Single unique value: CDF denominator is 0 → identity mapping
386        assert_eq!(out.len(), 100);
387        // All outputs should be the same (0 or 255 depending on mapping)
388        let first = out[0];
389        assert!(out.iter().all(|&v| v == first));
390    }
391
392    #[test]
393    fn test_equalize_luma_ramp_spreads_contrast() {
394        // Ramp from 0..=99 – after equalization the range should span more of
395        // 0..255.
396        let frame: Vec<u8> = (0..100u8).collect();
397        let out = HistogramEqualizer::equalize_luma(&frame, 100, 1);
398        assert_eq!(out.len(), 100);
399        let min = *out.iter().min().expect("non-empty output");
400        let max = *out.iter().max().expect("non-empty output");
401        assert!(max > min, "equalization should spread values");
402        // The last equalized value should be 255.
403        assert_eq!(max, 255);
404    }
405
406    #[test]
407    fn test_equalize_luma_single_pixel() {
408        let frame = vec![77u8];
409        let out = HistogramEqualizer::equalize_luma(&frame, 1, 1);
410        assert_eq!(out.len(), 1);
411    }
412
413    #[test]
414    fn test_equalize_luma_two_value_image() {
415        // Half 0, half 255.
416        let frame: Vec<u8> = (0..256).map(|i| if i < 128 { 0 } else { 255 }).collect();
417        let out = HistogramEqualizer::equalize_luma(&frame, 256, 1);
418        assert_eq!(out.len(), 256);
419    }
420
421    #[test]
422    fn test_equalize_luma_preserves_size() {
423        let frame: Vec<u8> = (0..=255).cycle().take(512).map(|v| v as u8).collect();
424        let out = HistogramEqualizer::equalize_luma(&frame, 32, 16);
425        assert_eq!(out.len(), 512);
426    }
427
428    #[test]
429    fn test_equalize_luma_all_zeros() {
430        let frame = vec![0u8; 64];
431        let out = HistogramEqualizer::equalize_luma(&frame, 8, 8);
432        assert_eq!(out.len(), 64);
433        // All values should be identical.
434        assert!(out.iter().all(|&v| v == out[0]));
435    }
436
437    #[test]
438    fn test_equalize_luma_already_equalized() {
439        // 256 unique values 0..=255 – already equalized.
440        let frame: Vec<u8> = (0u8..=255).collect();
441        let out = HistogramEqualizer::equalize_luma(&frame, 256, 1);
442        assert_eq!(out.len(), 256);
443        assert_eq!(out[0], 0);
444        assert_eq!(out[255], 255);
445    }
446
447    // ── equalize_luma_instance ────────────────────────────────────────────────
448
449    #[test]
450    fn test_equalize_luma_instance_method() {
451        let eq = HistogramEqualizer::new();
452        let frame: Vec<u8> = (0u8..=255).collect();
453        let out = eq.equalize_luma_instance(&frame, 256, 1);
454        assert_eq!(out.len(), 256);
455    }
456
457    // ── clahe ─────────────────────────────────────────────────────────────────
458
459    #[test]
460    fn test_clahe_basic_8x8_tile() {
461        let w = 32u32;
462        let h = 32u32;
463        let frame: Vec<u8> = (0u8..=255).cycle().take((w * h) as usize).collect();
464        let out = HistogramEqualizer::clahe(&frame, w, h, 2.0, 8);
465        assert_eq!(out.len(), (w * h) as usize);
466    }
467
468    #[test]
469    fn test_clahe_preserves_size() {
470        let frame: Vec<u8> = vec![128u8; 256];
471        let out = HistogramEqualizer::clahe(&frame, 16, 16, 2.0, 8);
472        assert_eq!(out.len(), 256);
473    }
474
475    #[test]
476    fn test_clahe_strong_clip() {
477        let w = 64u32;
478        let h = 64u32;
479        let total = (w * h) as usize;
480        let frame: Vec<u8> = (0..total).map(|i| (i % 256) as u8).collect();
481        let out = HistogramEqualizer::clahe(&frame, w, h, 1.0, 8);
482        assert_eq!(out.len(), total);
483    }
484
485    #[test]
486    fn test_clahe_mild_clip() {
487        let w = 32u32;
488        let h = 32u32;
489        let frame: Vec<u8> = (0u8..=255).cycle().take((w * h) as usize).collect();
490        let out = HistogramEqualizer::clahe(&frame, w, h, 4.0, 8);
491        assert_eq!(out.len(), (w * h) as usize);
492    }
493
494    #[test]
495    fn test_clahe_tile_size_larger_than_image_falls_back() {
496        let w = 4u32;
497        let h = 4u32;
498        let frame: Vec<u8> = (0u8..16).collect();
499        // tile_size = 32 > image width → falls back to global equalization
500        let out_clahe = HistogramEqualizer::clahe(&frame, w, h, 2.0, 32);
501        let out_global = HistogramEqualizer::equalize_luma(&frame, w, h);
502        assert_eq!(out_clahe, out_global);
503    }
504
505    #[test]
506    fn test_clahe_tile_size_zero_returns_unchanged() {
507        let frame = vec![100u8; 64];
508        let out = HistogramEqualizer::clahe(&frame, 8, 8, 2.0, 0);
509        assert_eq!(out, frame);
510    }
511
512    #[test]
513    fn test_clahe_empty_frame() {
514        let out = HistogramEqualizer::clahe(&[], 0, 0, 2.0, 8);
515        assert!(out.is_empty());
516    }
517
518    // ── clahe_instance ────────────────────────────────────────────────────────
519
520    #[test]
521    fn test_clahe_instance_method() {
522        let eq = HistogramEqualizer::new();
523        let w = 16u32;
524        let h = 16u32;
525        let frame: Vec<u8> = (0u8..=255).cycle().take((w * h) as usize).collect();
526        let out = eq.clahe_instance(&frame, w, h, 2.0, 8);
527        assert_eq!(out.len(), (w * h) as usize);
528    }
529
530    // ── EqualizationStats ─────────────────────────────────────────────────────
531
532    #[test]
533    fn test_equalization_stats_compute() {
534        let original: Vec<u8> = vec![0, 0, 255, 255];
535        let equalized: Vec<u8> = vec![0, 85, 170, 255];
536        let stats = EqualizationStats::compute(&original, &equalized);
537        assert!((stats.original_mean - 127.5).abs() < 1.0);
538        assert!(stats.equalized_mean > 0.0);
539        assert!(stats.original_std_dev > 0.0);
540        assert!(stats.equalized_std_dev >= 0.0);
541    }
542
543    #[test]
544    fn test_equalization_stats_empty() {
545        let stats = EqualizationStats::compute(&[], &[]);
546        assert_eq!(stats.original_mean, 0.0);
547        assert_eq!(stats.equalized_mean, 0.0);
548    }
549
550    // ── ClaheConfig ───────────────────────────────────────────────────────────
551
552    #[test]
553    fn test_clahe_config_defaults() {
554        let cfg = ClaheConfig::default();
555        assert!((cfg.clip_limit - 2.0).abs() < 1e-6);
556        assert_eq!(cfg.tile_size, 8);
557        assert!(cfg.use_parallel);
558    }
559}