ndarray_vision/processing/
threshold.rs

1use crate::core::PixelBound;
2use crate::core::{ColourModel, Image, ImageBase};
3use crate::processing::*;
4use ndarray::{prelude::*, Data};
5use ndarray_stats::histogram::{Bins, Edges, Grid};
6use ndarray_stats::HistogramExt;
7use ndarray_stats::QuantileExt;
8use num_traits::cast::FromPrimitive;
9use num_traits::cast::ToPrimitive;
10use num_traits::{Num, NumAssignOps};
11use std::marker::PhantomData;
12
13/// Runs the Otsu thresholding algorithm on a type `T`.
14pub trait ThresholdOtsuExt<T> {
15    /// The Otsu thresholding output is a binary image.
16    type Output;
17
18    /// Run the Otsu threshold algorithm.
19    ///
20    /// Due to Otsu threshold algorithm specifying a greyscale image, all
21    /// current implementations assume a single channel image; otherwise, an
22    /// error is returned.
23    ///
24    /// # Errors
25    ///
26    /// Returns a `ChannelDimensionMismatch` error if more than one channel
27    /// exists.
28    fn threshold_otsu(&self) -> Result<Self::Output, Error>;
29}
30
31/// Runs the Mean thresholding algorithm on a type `T`.
32pub trait ThresholdMeanExt<T> {
33    /// The Mean thresholding output is a binary image.
34    type Output;
35
36    /// Run the Mean threshold algorithm.
37    ///
38    /// This assumes the image is a single channel image, i.e., a greyscale
39    /// image; otherwise, an error is returned.
40    ///
41    /// # Errors
42    ///
43    /// Returns a `ChannelDimensionMismatch` error if more than one channel
44    /// exists.
45    fn threshold_mean(&self) -> Result<Self::Output, Error>;
46}
47
48/// Applies an upper and lower limit threshold on a type `T`.
49pub trait ThresholdApplyExt<T> {
50    /// The output is a binary image.
51    type Output;
52
53    /// Apply the threshold with the given limits.
54    ///
55    /// An image is segmented into background and foreground
56    /// elements, where any pixel value within the limits are considered
57    /// foreground elements and any pixels with a value outside the limits are
58    /// considered part of the background. The upper and lower limits are
59    /// inclusive.
60    ///
61    /// If only a lower limit threshold is to be applied, the `f64::INFINITY`
62    /// value can be used for the upper limit.
63    ///
64    /// # Errors
65    ///
66    /// The current implementation assumes a single channel image, i.e.,
67    /// greyscale image. Thus, if more than one channel is present, then
68    /// a `ChannelDimensionMismatch` error occurs.
69    ///
70    /// An `InvalidParameter` error occurs if the `lower` limit is greater than
71    /// the `upper` limit.
72    fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error>;
73}
74
75impl<T, U, C> ThresholdOtsuExt<T> for ImageBase<U, C>
76where
77    U: Data<Elem = T>,
78    Image<U, C>: Clone,
79    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
80    C: ColourModel,
81{
82    type Output = Image<bool, C>;
83
84    fn threshold_otsu(&self) -> Result<Self::Output, Error> {
85        let data = self.data.threshold_otsu()?;
86        Ok(Self::Output {
87            data,
88            model: PhantomData,
89        })
90    }
91}
92
93impl<T, U> ThresholdOtsuExt<T> for ArrayBase<U, Ix3>
94where
95    U: Data<Elem = T>,
96    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
97{
98    type Output = Array3<bool>;
99
100    fn threshold_otsu(&self) -> Result<Self::Output, Error> {
101        if self.shape()[2] > 1 {
102            Err(Error::ChannelDimensionMismatch)
103        } else {
104            let value = calculate_threshold_otsu(self)?;
105            self.threshold_apply(value, f64::INFINITY)
106        }
107    }
108}
109
110/// Calculates Otsu's threshold.
111///
112/// Works per channel, but currently assumes greyscale.
113///
114/// See the Errors section for the `ThresholdOtsuExt` trait if the number of
115/// channels is greater than one (1), i.e., single channel; otherwise, we would
116/// need to output all three threshold values.
117///
118/// TODO: Add optional nbins
119fn calculate_threshold_otsu<T, U>(mat: &ArrayBase<U, Ix3>) -> Result<f64, Error>
120where
121    U: Data<Elem = T>,
122    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
123{
124    let mut threshold = 0.0;
125    let n_bins = 255;
126    for c in mat.axis_iter(Axis(2)) {
127        let scale_factor = (n_bins) as f64 / (c.max().unwrap().to_f64().unwrap());
128        let edges_vec: Vec<u8> = (0..n_bins).collect();
129        let grid = Grid::from(vec![Bins::new(Edges::from(edges_vec))]);
130
131        // get the histogram
132        let flat = Array::from_iter(c.iter()).insert_axis(Axis(1));
133        let flat2 = flat.mapv(|x| ((*x).to_f64().unwrap() * scale_factor).to_u8().unwrap());
134        let hist = flat2.histogram(grid);
135        // Straight out of wikipedia:
136        let counts = hist.counts();
137        let total = counts.sum().to_f64().unwrap();
138        let counts = Array::from_iter(counts.iter());
139        // NOTE: Could use the cdf generation for skimage-esque implementation
140        // which entails a cumulative sum of the standard histogram
141        let mut sum_b = 0.0;
142        let mut weight_b = 0.0;
143        let mut maximum = 0.0;
144        let mut level = 0.0;
145        let mut sum_intensity = 0.0;
146        for (index, count) in counts.indexed_iter() {
147            sum_intensity += (index as f64) * (*count).to_f64().unwrap();
148        }
149        for (index, count) in counts.indexed_iter() {
150            weight_b += count.to_f64().unwrap();
151            sum_b += (index as f64) * count.to_f64().unwrap();
152            let weight_f = total - weight_b;
153            if (weight_b > 0.0) && (weight_f > 0.0) {
154                let mean_f = (sum_intensity - sum_b) / weight_f;
155                let val = weight_b
156                    * weight_f
157                    * ((sum_b / weight_b) - mean_f)
158                    * ((sum_b / weight_b) - mean_f);
159                if val > maximum {
160                    level = 1.0 + (index as f64);
161                    maximum = val;
162                }
163            }
164        }
165        threshold = level / scale_factor;
166    }
167    Ok(threshold)
168}
169
170impl<T, U, C> ThresholdMeanExt<T> for ImageBase<U, C>
171where
172    U: Data<Elem = T>,
173    Image<U, C>: Clone,
174    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
175    C: ColourModel,
176{
177    type Output = Image<bool, C>;
178
179    fn threshold_mean(&self) -> Result<Self::Output, Error> {
180        let data = self.data.threshold_mean()?;
181        Ok(Self::Output {
182            data,
183            model: PhantomData,
184        })
185    }
186}
187
188impl<T, U> ThresholdMeanExt<T> for ArrayBase<U, Ix3>
189where
190    U: Data<Elem = T>,
191    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
192{
193    type Output = Array3<bool>;
194
195    fn threshold_mean(&self) -> Result<Self::Output, Error> {
196        if self.shape()[2] > 1 {
197            Err(Error::ChannelDimensionMismatch)
198        } else {
199            let value = calculate_threshold_mean(self)?;
200            self.threshold_apply(value, f64::INFINITY)
201        }
202    }
203}
204
205fn calculate_threshold_mean<T, U>(array: &ArrayBase<U, Ix3>) -> Result<f64, Error>
206where
207    U: Data<Elem = T>,
208    T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
209{
210    Ok(array.sum().to_f64().unwrap() / array.len() as f64)
211}
212
213impl<T, U, C> ThresholdApplyExt<T> for ImageBase<U, C>
214where
215    U: Data<Elem = T>,
216    Image<U, C>: Clone,
217    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
218    C: ColourModel,
219{
220    type Output = Image<bool, C>;
221
222    fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error> {
223        let data = self.data.threshold_apply(lower, upper)?;
224        Ok(Self::Output {
225            data,
226            model: PhantomData,
227        })
228    }
229}
230
231impl<T, U> ThresholdApplyExt<T> for ArrayBase<U, Ix3>
232where
233    U: Data<Elem = T>,
234    T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
235{
236    type Output = Array3<bool>;
237
238    fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error> {
239        if self.shape()[2] > 1 {
240            Err(Error::ChannelDimensionMismatch)
241        } else if lower > upper {
242            Err(Error::InvalidParameter)
243        } else {
244            Ok(apply_threshold(self, lower, upper))
245        }
246    }
247}
248
249fn apply_threshold<T, U>(data: &ArrayBase<U, Ix3>, lower: f64, upper: f64) -> Array3<bool>
250where
251    U: Data<Elem = T>,
252    T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
253{
254    data.mapv(|x| x.to_f64().unwrap() >= lower && x.to_f64().unwrap() <= upper)
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260    use assert_approx_eq::assert_approx_eq;
261    use ndarray::arr3;
262    use noisy_float::types::n64;
263
264    #[test]
265    fn threshold_apply_threshold() {
266        let data = arr3(&[
267            [[0.2], [0.4], [0.0]],
268            [[0.7], [0.5], [0.8]],
269            [[0.1], [0.6], [0.0]],
270        ]);
271
272        let expected = arr3(&[
273            [[false], [false], [false]],
274            [[true], [true], [true]],
275            [[false], [true], [false]],
276        ]);
277
278        let result = apply_threshold(&data, 0.5, f64::INFINITY);
279
280        assert_eq!(result, expected);
281    }
282
283    #[test]
284    fn threshold_apply_threshold_range() {
285        let data = arr3(&[
286            [[0.2], [0.4], [0.0]],
287            [[0.7], [0.5], [0.8]],
288            [[0.1], [0.6], [0.0]],
289        ]);
290        let expected = arr3(&[
291            [[false], [true], [false]],
292            [[true], [true], [false]],
293            [[false], [true], [false]],
294        ]);
295
296        let result = apply_threshold(&data, 0.25, 0.75);
297
298        assert_eq!(result, expected);
299    }
300
301    #[test]
302    fn threshold_calculate_threshold_otsu_ints() {
303        let data = arr3(&[[[2], [4], [0]], [[7], [5], [8]], [[1], [6], [0]]]);
304        let result = calculate_threshold_otsu(&data).unwrap();
305        println!("Done {}", result);
306
307        // Calculated using Python's skimage.filters.threshold_otsu
308        // on int input array. Float array returns 2.0156...
309        let expected = 2.0;
310
311        assert_approx_eq!(result, expected, 5e-1);
312    }
313
314    #[test]
315    fn threshold_calculate_threshold_otsu_floats() {
316        let data = arr3(&[
317            [[n64(2.0)], [n64(4.0)], [n64(0.0)]],
318            [[n64(7.0)], [n64(5.0)], [n64(8.0)]],
319            [[n64(1.0)], [n64(6.0)], [n64(0.0)]],
320        ]);
321
322        let result = calculate_threshold_otsu(&data).unwrap();
323
324        // Calculated using Python's skimage.filters.threshold_otsu
325        // on int input array. Float array returns 2.0156...
326        let expected = 2.0156;
327
328        assert_approx_eq!(result, expected, 5e-1);
329    }
330
331    #[test]
332    fn threshold_calculate_threshold_mean_ints() {
333        let data = arr3(&[[[4], [4], [4]], [[5], [5], [5]], [[6], [6], [6]]]);
334
335        let result = calculate_threshold_mean(&data).unwrap();
336        let expected = 5.0;
337
338        assert_approx_eq!(result, expected, 1e-16);
339    }
340
341    #[test]
342    fn threshold_calculate_threshold_mean_floats() {
343        let data = arr3(&[
344            [[4.0], [4.0], [4.0]],
345            [[5.0], [5.0], [5.0]],
346            [[6.0], [6.0], [6.0]],
347        ]);
348
349        let result = calculate_threshold_mean(&data).unwrap();
350        let expected = 5.0;
351
352        assert_approx_eq!(result, expected, 1e-16);
353    }
354}