Skip to main content

oximedia_gpu/ops/
histogram_eq.rs

1//! GPU-accelerated histogram equalization with prefix-sum CDF mapping.
2//!
3//! # Algorithm
4//!
5//! 1. **Compute histogram** – count pixel occurrences for each intensity bin
6//!    (0..255) over the luminance channel.
7//! 2. **Prefix-sum equalization mapping** – derive the CDF and create a
8//!    monotone mapping table `[u8; 256]`.
9//! 3. **Apply tone curve** – map every pixel through the equalization table.
10//!
11//! Both luma-only and per-channel RGBA equalization are supported via
12//! [`HistogramEqualizerConfig`].
13//!
14//! The implementation uses rayon for CPU-parallel execution.  A future GPU
15//! compute-shader path can be dropped in behind the same public API.
16
17use crate::{GpuDevice, Result};
18use rayon::prelude::*;
19
20use super::utils;
21
22// ─────────────────────────────────────────────────────────────────────────────
23// Configuration
24// ─────────────────────────────────────────────────────────────────────────────
25
26/// Controls which channels are equalized and how the mapping is computed.
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum EqualizationMode {
29    /// Equalize the luma (Y from BT.601 conversion) and transfer to RGB.
30    ///
31    /// This preserves colour ratios while improving contrast.
32    LumaOnly,
33    /// Equalize each of R, G, B independently.
34    ///
35    /// This maximises per-channel contrast but can shift colours.
36    PerChannel,
37}
38
39impl Default for EqualizationMode {
40    fn default() -> Self {
41        Self::LumaOnly
42    }
43}
44
45/// Configuration for [`HistogramEqualizer`].
46#[derive(Debug, Clone)]
47pub struct HistogramEqualizerConfig {
48    /// Which channels to equalize.
49    pub mode: EqualizationMode,
50    /// Clip limit for contrast-limited AHE (0.0 = no clipping, 1.0 = full clip).
51    ///
52    /// Values in (0.0, 1.0) implement a simplified CLAHE-like contrast limit.
53    /// A value of `0.0` gives standard histogram equalization.
54    pub clip_limit: f32,
55}
56
57impl Default for HistogramEqualizerConfig {
58    fn default() -> Self {
59        Self {
60            mode: EqualizationMode::default(),
61            clip_limit: 0.0,
62        }
63    }
64}
65
66// ─────────────────────────────────────────────────────────────────────────────
67// HistogramEqualizer
68// ─────────────────────────────────────────────────────────────────────────────
69
70/// GPU-accelerated (CPU-fallback) histogram equalizer.
71pub struct HistogramEqualizer {
72    config: HistogramEqualizerConfig,
73}
74
75impl HistogramEqualizer {
76    /// Create a new equalizer with the given configuration.
77    #[must_use]
78    pub fn new(config: HistogramEqualizerConfig) -> Self {
79        Self { config }
80    }
81
82    /// Create with default configuration (luma-only, no clip limit).
83    #[must_use]
84    pub fn default_config() -> Self {
85        Self::new(HistogramEqualizerConfig::default())
86    }
87
88    /// Equalize the histogram of an RGBA image.
89    ///
90    /// * `input` / `output` – packed RGBA, 4 bytes per pixel.
91    /// * `device` – reserved for GPU dispatch; CPU path is used automatically.
92    ///
93    /// # Errors
94    ///
95    /// Returns an error if dimensions are invalid or buffers are too small.
96    pub fn equalize(
97        &self,
98        _device: &GpuDevice,
99        input: &[u8],
100        output: &mut [u8],
101        width: u32,
102        height: u32,
103    ) -> Result<()> {
104        self.equalize_cpu(input, output, width, height)
105    }
106
107    /// CPU-only variant — usable without a GPU device.
108    ///
109    /// # Errors
110    ///
111    /// Returns an error if dimensions are invalid or buffers are too small.
112    pub fn equalize_cpu(
113        &self,
114        input: &[u8],
115        output: &mut [u8],
116        width: u32,
117        height: u32,
118    ) -> Result<()> {
119        utils::validate_dimensions(width, height)?;
120        utils::validate_buffer_size(input, width, height, 4)?;
121        utils::validate_buffer_size(output, width, height, 4)?;
122
123        match self.config.mode {
124            EqualizationMode::LumaOnly => self.equalize_luma(input, output, width, height),
125            EqualizationMode::PerChannel => self.equalize_per_channel(input, output, width, height),
126        }
127    }
128
129    // ── luma-only equalization ────────────────────────────────────────────────
130
131    fn equalize_luma(
132        &self,
133        input: &[u8],
134        output: &mut [u8],
135        width: u32,
136        height: u32,
137    ) -> Result<()> {
138        let n_pixels = (width * height) as usize;
139
140        // Step 1: build luma histogram.
141        let mut hist = [0u64; 256];
142        for px in input.chunks_exact(4) {
143            let y = luma_bt601(px[0], px[1], px[2]);
144            hist[y as usize] += 1;
145        }
146
147        // Step 2: apply clip limit.
148        let hist = self.apply_clip_limit(hist, n_pixels);
149
150        // Step 3: compute CDF mapping.
151        let lut = build_equalization_lut(&hist, n_pixels);
152
153        // Step 4: apply mapping — scale all channels proportionally.
154        output
155            .par_chunks_exact_mut(4)
156            .zip(input.par_chunks_exact(4))
157            .for_each(|(out, inn)| {
158                let y_orig = luma_bt601(inn[0], inn[1], inn[2]);
159                let y_eq = lut[y_orig as usize];
160
161                if y_orig == 0 {
162                    out[0] = 0;
163                    out[1] = 0;
164                    out[2] = 0;
165                } else {
166                    // Scale RGB proportionally to the new luma.
167                    let scale = f32::from(y_eq) / f32::from(y_orig);
168                    out[0] = (f32::from(inn[0]) * scale).clamp(0.0, 255.0).round() as u8;
169                    out[1] = (f32::from(inn[1]) * scale).clamp(0.0, 255.0).round() as u8;
170                    out[2] = (f32::from(inn[2]) * scale).clamp(0.0, 255.0).round() as u8;
171                }
172                out[3] = inn[3]; // pass-through alpha
173            });
174
175        Ok(())
176    }
177
178    // ── per-channel equalization ──────────────────────────────────────────────
179
180    fn equalize_per_channel(
181        &self,
182        input: &[u8],
183        output: &mut [u8],
184        width: u32,
185        height: u32,
186    ) -> Result<()> {
187        let n_pixels = (width * height) as usize;
188
189        // Build per-channel histograms.
190        let mut hist_r = [0u64; 256];
191        let mut hist_g = [0u64; 256];
192        let mut hist_b = [0u64; 256];
193
194        for px in input.chunks_exact(4) {
195            hist_r[px[0] as usize] += 1;
196            hist_g[px[1] as usize] += 1;
197            hist_b[px[2] as usize] += 1;
198        }
199
200        let hist_r = self.apply_clip_limit(hist_r, n_pixels);
201        let hist_g = self.apply_clip_limit(hist_g, n_pixels);
202        let hist_b = self.apply_clip_limit(hist_b, n_pixels);
203
204        let lut_r = build_equalization_lut(&hist_r, n_pixels);
205        let lut_g = build_equalization_lut(&hist_g, n_pixels);
206        let lut_b = build_equalization_lut(&hist_b, n_pixels);
207
208        output
209            .par_chunks_exact_mut(4)
210            .zip(input.par_chunks_exact(4))
211            .for_each(|(out, inn)| {
212                out[0] = lut_r[inn[0] as usize];
213                out[1] = lut_g[inn[1] as usize];
214                out[2] = lut_b[inn[2] as usize];
215                out[3] = inn[3];
216            });
217
218        Ok(())
219    }
220
221    // ── clip-limit redistribution ─────────────────────────────────────────────
222
223    fn apply_clip_limit(&self, mut hist: [u64; 256], n_pixels: usize) -> [u64; 256] {
224        let clip = self.config.clip_limit;
225        if clip <= 0.0 {
226            return hist;
227        }
228
229        let max_count = (clip.clamp(0.0, 1.0) * n_pixels as f32 / 256.0).round() as u64;
230        if max_count == 0 {
231            return hist;
232        }
233
234        // Accumulate clipped pixels and redistribute uniformly.
235        let mut clipped_total = 0u64;
236        for bin in &mut hist {
237            if *bin > max_count {
238                clipped_total += *bin - max_count;
239                *bin = max_count;
240            }
241        }
242
243        let redistribute = clipped_total / 256;
244        let remainder = (clipped_total % 256) as usize;
245        for bin in &mut hist {
246            *bin += redistribute;
247        }
248        for bin in hist.iter_mut().take(remainder) {
249            *bin += 1;
250        }
251
252        hist
253    }
254}
255
256// ─────────────────────────────────────────────────────────────────────────────
257// Helper functions
258// ─────────────────────────────────────────────────────────────────────────────
259
260/// Compute the BT.601 luma value for an RGB pixel (integer approximation).
261#[inline(always)]
262fn luma_bt601(r: u8, g: u8, b: u8) -> u8 {
263    let y = 0.299 * f32::from(r) + 0.587 * f32::from(g) + 0.114 * f32::from(b);
264    y.clamp(0.0, 255.0).round() as u8
265}
266
267/// Build a 256-entry equalization LUT from a histogram.
268///
269/// Uses the classic CDF-based formula:
270/// `lut[v] = round((cdf[v] - cdf_min) / (n - cdf_min) * 255)`
271fn build_equalization_lut(hist: &[u64; 256], n_pixels: usize) -> [u8; 256] {
272    // Compute CDF.
273    let mut cdf = [0u64; 256];
274    cdf[0] = hist[0];
275    for i in 1..256 {
276        cdf[i] = cdf[i - 1] + hist[i];
277    }
278
279    // Find minimum non-zero CDF value.
280    let cdf_min = cdf.iter().find(|&&v| v > 0).copied().unwrap_or(0);
281    let denom = (n_pixels as u64).saturating_sub(cdf_min);
282
283    let mut lut = [0u8; 256];
284    for (i, lut_v) in lut.iter_mut().enumerate() {
285        if cdf[i] == 0 {
286            // Bin is empty (before the first populated bin).
287            *lut_v = 0;
288        } else if denom == 0 {
289            // All pixels share the same intensity level — map to maximum.
290            *lut_v = 255;
291        } else if cdf[i] <= cdf_min {
292            // This bin holds only pixels at the minimum CDF level.
293            // Standard formula maps it to 0 unless it is also the maximum bin,
294            // in which case it should map to 255 (degenerate single-bin case).
295            if cdf[i] == cdf[255] {
296                *lut_v = 255;
297            } else {
298                *lut_v = 0;
299            }
300        } else {
301            let num = cdf[i] - cdf_min;
302            *lut_v = ((num as f64 / denom as f64) * 255.0)
303                .round()
304                .clamp(0.0, 255.0) as u8;
305        }
306    }
307    lut
308}
309
310// ─────────────────────────────────────────────────────────────────────────────
311// Tests
312// ─────────────────────────────────────────────────────────────────────────────
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    fn gray_rgba(w: u32, h: u32, v: u8) -> Vec<u8> {
319        vec![v, v, v, 255u8].repeat((w * h) as usize)
320    }
321
322    fn gradient_rgba(w: u32, h: u32) -> Vec<u8> {
323        (0..(w * h))
324            .flat_map(|i| {
325                let v = ((i * 255) / (w * h - 1).max(1)) as u8;
326                [v, v, v, 255u8]
327            })
328            .collect()
329    }
330
331    // ── build_equalization_lut ────────────────────────────────────────────────
332
333    #[test]
334    fn test_lut_uniform_histogram() {
335        // Uniform histogram → output should span full [0, 255].
336        let hist = [4u64; 256];
337        let lut = build_equalization_lut(&hist, 1024);
338        assert_eq!(lut[0], 0, "first bin maps to 0");
339        assert_eq!(lut[255], 255, "last bin maps to 255");
340        // Monotone.
341        for i in 1..256 {
342            assert!(lut[i] >= lut[i - 1], "LUT must be monotone at {i}");
343        }
344    }
345
346    #[test]
347    fn test_lut_single_value_histogram() {
348        // All pixels have the same value → after equalisation they all go to 255.
349        let mut hist = [0u64; 256];
350        hist[128] = 100;
351        let lut = build_equalization_lut(&hist, 100);
352        assert_eq!(lut[128], 255, "single-value bin maps to 255");
353    }
354
355    // ── luma_bt601 ────────────────────────────────────────────────────────────
356
357    #[test]
358    fn test_luma_pure_red() {
359        let y = luma_bt601(255, 0, 0);
360        assert_eq!(y, 76, "BT.601 luma of red ≈ 76");
361    }
362
363    #[test]
364    fn test_luma_pure_green() {
365        let y = luma_bt601(0, 255, 0);
366        assert_eq!(y, 150, "BT.601 luma of green ≈ 150");
367    }
368
369    #[test]
370    fn test_luma_white() {
371        let y = luma_bt601(255, 255, 255);
372        assert_eq!(y, 255, "luma of white = 255");
373    }
374
375    #[test]
376    fn test_luma_black() {
377        let y = luma_bt601(0, 0, 0);
378        assert_eq!(y, 0, "luma of black = 0");
379    }
380
381    // ── HistogramEqualizer::equalize_cpu ─────────────────────────────────────
382
383    #[test]
384    fn test_equalize_constant_image_luma() {
385        let w = 8u32;
386        let h = 8u32;
387        let input = gray_rgba(w, h, 100);
388        let mut output = vec![0u8; (w * h * 4) as usize];
389        let eq = HistogramEqualizer::default_config();
390        eq.equalize_cpu(&input, &mut output, w, h)
391            .expect("equalize constant image");
392        // Constant image → all pixels equalise to 255.
393        for i in 0..(w * h) as usize {
394            assert_eq!(output[i * 4 + 3], 255, "alpha must be preserved");
395        }
396    }
397
398    #[test]
399    fn test_equalize_gradient_luma_monotone() {
400        let w = 16u32;
401        let h = 16u32;
402        let input = gradient_rgba(w, h);
403        let mut output = vec![0u8; (w * h * 4) as usize];
404        let eq = HistogramEqualizer::default_config();
405        eq.equalize_cpu(&input, &mut output, w, h)
406            .expect("equalize gradient");
407
408        // Output lumas must be non-decreasing (monotone).
409        let mut prev_y = 0u8;
410        for i in 0..(w * h) as usize {
411            let y = luma_bt601(output[i * 4], output[i * 4 + 1], output[i * 4 + 2]);
412            assert!(
413                y >= prev_y,
414                "output luma must be non-decreasing: prev={prev_y}, cur={y}"
415            );
416            prev_y = y;
417        }
418    }
419
420    #[test]
421    fn test_equalize_per_channel() {
422        let w = 8u32;
423        let h = 8u32;
424        let input = gradient_rgba(w, h);
425        let mut output = vec![0u8; (w * h * 4) as usize];
426        let eq = HistogramEqualizer::new(HistogramEqualizerConfig {
427            mode: EqualizationMode::PerChannel,
428            clip_limit: 0.0,
429        });
430        eq.equalize_cpu(&input, &mut output, w, h)
431            .expect("equalize per channel");
432        // First pixel should be 0 (min of gradient), last should be 255 (max).
433        let n = (w * h) as usize;
434        assert_eq!(output[0], 0, "first pixel red = 0 after per-channel eq");
435        assert_eq!(
436            output[(n - 1) * 4],
437            255,
438            "last pixel red = 255 after per-channel eq"
439        );
440    }
441
442    #[test]
443    fn test_equalize_alpha_passthrough_luma() {
444        let w = 4u32;
445        let h = 4u32;
446        let input: Vec<u8> = (0..w * h * 4)
447            .map(|i| if i % 4 == 3 { 200u8 } else { 128 })
448            .collect();
449        let mut output = vec![0u8; (w * h * 4) as usize];
450        HistogramEqualizer::default_config()
451            .equalize_cpu(&input, &mut output, w, h)
452            .expect("equalize alpha passthrough luma");
453        for i in 0..(w * h) as usize {
454            assert_eq!(output[i * 4 + 3], 200, "alpha must pass through");
455        }
456    }
457
458    #[test]
459    fn test_equalize_alpha_passthrough_per_channel() {
460        let w = 4u32;
461        let h = 4u32;
462        let input: Vec<u8> = (0..w * h * 4)
463            .map(|i| if i % 4 == 3 { 77u8 } else { 100 })
464            .collect();
465        let mut output = vec![0u8; (w * h * 4) as usize];
466        HistogramEqualizer::new(HistogramEqualizerConfig {
467            mode: EqualizationMode::PerChannel,
468            clip_limit: 0.0,
469        })
470        .equalize_cpu(&input, &mut output, w, h)
471        .expect("equalize alpha passthrough per channel");
472        for i in 0..(w * h) as usize {
473            assert_eq!(output[i * 4 + 3], 77, "alpha must pass through");
474        }
475    }
476
477    #[test]
478    fn test_equalize_invalid_dimensions() {
479        let input = vec![0u8; 64];
480        let mut output = vec![0u8; 64];
481        let result = HistogramEqualizer::default_config().equalize_cpu(&input, &mut output, 0, 4);
482        assert!(result.is_err());
483    }
484
485    #[test]
486    fn test_equalize_buffer_too_small() {
487        let input = vec![0u8; 4]; // too small for 4×4
488        let mut output = vec![0u8; 64];
489        let result = HistogramEqualizer::default_config().equalize_cpu(&input, &mut output, 4, 4);
490        assert!(result.is_err());
491    }
492
493    // ── clip limit ────────────────────────────────────────────────────────────
494
495    #[test]
496    fn test_clip_limit_preserves_total() {
497        let eq = HistogramEqualizer::new(HistogramEqualizerConfig {
498            mode: EqualizationMode::LumaOnly,
499            clip_limit: 0.3,
500        });
501        let mut hist = [0u64; 256];
502        hist[100] = 500;
503        hist[150] = 300;
504        let n = 800usize;
505        let clipped = eq.apply_clip_limit(hist, n);
506        let total: u64 = clipped.iter().sum();
507        assert_eq!(total, n as u64, "clip limit must preserve pixel count");
508    }
509
510    #[test]
511    fn test_clip_limit_zero_no_change() {
512        let eq = HistogramEqualizer::new(HistogramEqualizerConfig {
513            mode: EqualizationMode::LumaOnly,
514            clip_limit: 0.0,
515        });
516        let hist = [10u64; 256];
517        let result = eq.apply_clip_limit(hist, 2560);
518        assert_eq!(hist, result, "zero clip limit must not change histogram");
519    }
520}