ndarray_vision/enhancement/
histogram_equalisation.rs

1use crate::core::*;
2use ndarray::{prelude::*, DataMut};
3use ndarray_stats::{histogram::Grid, HistogramExt};
4use num_traits::cast::{FromPrimitive, ToPrimitive};
5use num_traits::{Num, NumAssignOps};
6
7/// Extension trait to implement histogram equalisation on other types
8pub trait HistogramEqExt<A>
9where
10    A: Ord,
11{
12    type Output;
13    /// Equalises an image histogram returning a new image.
14    /// Grids should be for a 1xN image as the image is flattened during processing
15    fn equalise_hist(&self, grid: Grid<A>) -> Self::Output;
16
17    /// Equalises an image histogram inplace
18    /// Grids should be for a 1xN image as the image is flattened during processing
19    fn equalise_hist_inplace(&mut self, grid: Grid<A>);
20}
21
22impl<T, U> HistogramEqExt<T> for ArrayBase<U, Ix3>
23where
24    U: DataMut<Elem = T>,
25    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
26{
27    type Output = Array<T, Ix3>;
28
29    fn equalise_hist(&self, grid: Grid<T>) -> Self::Output {
30        let mut result = self.to_owned();
31        result.equalise_hist_inplace(grid);
32        result
33    }
34
35    fn equalise_hist_inplace(&mut self, grid: Grid<T>) {
36        for mut c in self.axis_iter_mut(Axis(2)) {
37            // get the histogram
38            let flat = Array::from_iter(c.iter()).mapv(|x| *x).insert_axis(Axis(1));
39            let hist = flat.histogram(grid.clone());
40            // get cdf
41            let mut running_total = 0;
42            let mut min = 0.0;
43            let cdf = hist.counts().mapv(|x| {
44                running_total += x;
45                if min == 0.0 && running_total > 0 {
46                    min = running_total as f32;
47                }
48                running_total as f32
49            });
50
51            // Rescale cdf writing back new values
52            let scale = (T::max_pixel() - T::min_pixel())
53                .to_f32()
54                .unwrap_or_default();
55            let denominator = flat.len() as f32 - min;
56            c.mapv_inplace(|x| {
57                let index = match grid.index_of(&arr1(&[x])) {
58                    Some(i) => {
59                        if i.is_empty() {
60                            0
61                        } else {
62                            i[0]
63                        }
64                    }
65                    None => 0,
66                };
67                let mut f_res = ((cdf[index] - min) / denominator) * scale;
68                if T::is_integral() {
69                    f_res = f_res.round();
70                }
71                T::from_f32(f_res).unwrap_or_else(T::zero) + T::min_pixel()
72            });
73        }
74    }
75}
76
77impl<T, U, C> HistogramEqExt<T> for ImageBase<U, C>
78where
79    U: DataMut<Elem = T>,
80    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
81    C: ColourModel,
82{
83    type Output = Image<T, C>;
84
85    fn equalise_hist(&self, grid: Grid<T>) -> Self::Output {
86        let mut result = self.to_owned();
87        result.equalise_hist_inplace(grid);
88        result
89    }
90
91    fn equalise_hist_inplace(&mut self, grid: Grid<T>) {
92        self.data.equalise_hist_inplace(grid);
93    }
94}
95
96#[cfg(test)]
97mod tests {
98    use super::*;
99    use crate::core::Gray;
100    use ndarray_stats::histogram::{Bins, Edges};
101
102    #[test]
103    fn hist_eq_test() {
104        // test data from wikipedia
105        let input_pixels = vec![
106            52, 55, 61, 59, 70, 61, 76, 61, 62, 59, 55, 104, 94, 85, 59, 71, 63, 65, 66, 113, 144,
107            104, 63, 72, 64, 70, 70, 126, 154, 109, 71, 69, 67, 73, 68, 106, 122, 88, 68, 68, 68,
108            79, 60, 79, 77, 66, 58, 75, 69, 85, 64, 58, 55, 61, 65, 83, 70, 87, 69, 68, 65, 73, 78,
109            90,
110        ];
111
112        let output_pixels = vec![
113            0, 12, 53, 32, 146, 53, 174, 53, 57, 32, 12, 227, 219, 202, 32, 154, 65, 85, 93, 239,
114            251, 227, 65, 158, 73, 146, 146, 247, 255, 235, 154, 130, 97, 166, 117, 231, 243, 210,
115            117, 117, 117, 190, 36, 190, 178, 93, 20, 170, 130, 202, 73, 20, 12, 53, 85, 194, 146,
116            206, 130, 117, 85, 166, 182, 215,
117        ];
118
119        let input = Image::<u8, Gray>::from_shape_data(8, 8, input_pixels);
120
121        let expected = Image::<u8, Gray>::from_shape_data(8, 8, output_pixels);
122
123        let edges_vec: Vec<u8> = (0..255).collect();
124        let grid = Grid::from(vec![Bins::new(Edges::from(edges_vec))]);
125
126        let equalised = input.equalise_hist(grid);
127
128        assert_eq!(expected, equalised);
129    }
130}