1use crate::core::PixelBound;
2use crate::core::{ColourModel, Image, ImageBase};
3use crate::processing::*;
4use ndarray::{prelude::*, Data};
5use ndarray_stats::histogram::{Bins, Edges, Grid};
6use ndarray_stats::HistogramExt;
7use ndarray_stats::QuantileExt;
8use num_traits::cast::FromPrimitive;
9use num_traits::cast::ToPrimitive;
10use num_traits::{Num, NumAssignOps};
11use std::marker::PhantomData;
12
13pub trait ThresholdOtsuExt<T> {
15 type Output;
17
18 fn threshold_otsu(&self) -> Result<Self::Output, Error>;
29}
30
31pub trait ThresholdMeanExt<T> {
33 type Output;
35
36 fn threshold_mean(&self) -> Result<Self::Output, Error>;
46}
47
48pub trait ThresholdApplyExt<T> {
50 type Output;
52
53 fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error>;
73}
74
75impl<T, U, C> ThresholdOtsuExt<T> for ImageBase<U, C>
76where
77 U: Data<Elem = T>,
78 Image<U, C>: Clone,
79 T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
80 C: ColourModel,
81{
82 type Output = Image<bool, C>;
83
84 fn threshold_otsu(&self) -> Result<Self::Output, Error> {
85 let data = self.data.threshold_otsu()?;
86 Ok(Self::Output {
87 data,
88 model: PhantomData,
89 })
90 }
91}
92
93impl<T, U> ThresholdOtsuExt<T> for ArrayBase<U, Ix3>
94where
95 U: Data<Elem = T>,
96 T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
97{
98 type Output = Array3<bool>;
99
100 fn threshold_otsu(&self) -> Result<Self::Output, Error> {
101 if self.shape()[2] > 1 {
102 Err(Error::ChannelDimensionMismatch)
103 } else {
104 let value = calculate_threshold_otsu(self)?;
105 self.threshold_apply(value, f64::INFINITY)
106 }
107 }
108}
109
110fn calculate_threshold_otsu<T, U>(mat: &ArrayBase<U, Ix3>) -> Result<f64, Error>
120where
121 U: Data<Elem = T>,
122 T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
123{
124 let mut threshold = 0.0;
125 let n_bins = 255;
126 for c in mat.axis_iter(Axis(2)) {
127 let scale_factor = (n_bins) as f64 / (c.max().unwrap().to_f64().unwrap());
128 let edges_vec: Vec<u8> = (0..n_bins).collect();
129 let grid = Grid::from(vec![Bins::new(Edges::from(edges_vec))]);
130
131 let flat = Array::from_iter(c.iter()).insert_axis(Axis(1));
133 let flat2 = flat.mapv(|x| ((*x).to_f64().unwrap() * scale_factor).to_u8().unwrap());
134 let hist = flat2.histogram(grid);
135 let counts = hist.counts();
137 let total = counts.sum().to_f64().unwrap();
138 let counts = Array::from_iter(counts.iter());
139 let mut sum_b = 0.0;
142 let mut weight_b = 0.0;
143 let mut maximum = 0.0;
144 let mut level = 0.0;
145 let mut sum_intensity = 0.0;
146 for (index, count) in counts.indexed_iter() {
147 sum_intensity += (index as f64) * (*count).to_f64().unwrap();
148 }
149 for (index, count) in counts.indexed_iter() {
150 weight_b += count.to_f64().unwrap();
151 sum_b += (index as f64) * count.to_f64().unwrap();
152 let weight_f = total - weight_b;
153 if (weight_b > 0.0) && (weight_f > 0.0) {
154 let mean_f = (sum_intensity - sum_b) / weight_f;
155 let val = weight_b
156 * weight_f
157 * ((sum_b / weight_b) - mean_f)
158 * ((sum_b / weight_b) - mean_f);
159 if val > maximum {
160 level = 1.0 + (index as f64);
161 maximum = val;
162 }
163 }
164 }
165 threshold = level / scale_factor;
166 }
167 Ok(threshold)
168}
169
170impl<T, U, C> ThresholdMeanExt<T> for ImageBase<U, C>
171where
172 U: Data<Elem = T>,
173 Image<U, C>: Clone,
174 T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
175 C: ColourModel,
176{
177 type Output = Image<bool, C>;
178
179 fn threshold_mean(&self) -> Result<Self::Output, Error> {
180 let data = self.data.threshold_mean()?;
181 Ok(Self::Output {
182 data,
183 model: PhantomData,
184 })
185 }
186}
187
188impl<T, U> ThresholdMeanExt<T> for ArrayBase<U, Ix3>
189where
190 U: Data<Elem = T>,
191 T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
192{
193 type Output = Array3<bool>;
194
195 fn threshold_mean(&self) -> Result<Self::Output, Error> {
196 if self.shape()[2] > 1 {
197 Err(Error::ChannelDimensionMismatch)
198 } else {
199 let value = calculate_threshold_mean(self)?;
200 self.threshold_apply(value, f64::INFINITY)
201 }
202 }
203}
204
205fn calculate_threshold_mean<T, U>(array: &ArrayBase<U, Ix3>) -> Result<f64, Error>
206where
207 U: Data<Elem = T>,
208 T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
209{
210 Ok(array.sum().to_f64().unwrap() / array.len() as f64)
211}
212
213impl<T, U, C> ThresholdApplyExt<T> for ImageBase<U, C>
214where
215 U: Data<Elem = T>,
216 Image<U, C>: Clone,
217 T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
218 C: ColourModel,
219{
220 type Output = Image<bool, C>;
221
222 fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error> {
223 let data = self.data.threshold_apply(lower, upper)?;
224 Ok(Self::Output {
225 data,
226 model: PhantomData,
227 })
228 }
229}
230
231impl<T, U> ThresholdApplyExt<T> for ArrayBase<U, Ix3>
232where
233 U: Data<Elem = T>,
234 T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
235{
236 type Output = Array3<bool>;
237
238 fn threshold_apply(&self, lower: f64, upper: f64) -> Result<Self::Output, Error> {
239 if self.shape()[2] > 1 {
240 Err(Error::ChannelDimensionMismatch)
241 } else if lower > upper {
242 Err(Error::InvalidParameter)
243 } else {
244 Ok(apply_threshold(self, lower, upper))
245 }
246 }
247}
248
249fn apply_threshold<T, U>(data: &ArrayBase<U, Ix3>, lower: f64, upper: f64) -> Array3<bool>
250where
251 U: Data<Elem = T>,
252 T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
253{
254 data.mapv(|x| x.to_f64().unwrap() >= lower && x.to_f64().unwrap() <= upper)
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260 use assert_approx_eq::assert_approx_eq;
261 use ndarray::arr3;
262 use noisy_float::types::n64;
263
264 #[test]
265 fn threshold_apply_threshold() {
266 let data = arr3(&[
267 [[0.2], [0.4], [0.0]],
268 [[0.7], [0.5], [0.8]],
269 [[0.1], [0.6], [0.0]],
270 ]);
271
272 let expected = arr3(&[
273 [[false], [false], [false]],
274 [[true], [true], [true]],
275 [[false], [true], [false]],
276 ]);
277
278 let result = apply_threshold(&data, 0.5, f64::INFINITY);
279
280 assert_eq!(result, expected);
281 }
282
283 #[test]
284 fn threshold_apply_threshold_range() {
285 let data = arr3(&[
286 [[0.2], [0.4], [0.0]],
287 [[0.7], [0.5], [0.8]],
288 [[0.1], [0.6], [0.0]],
289 ]);
290 let expected = arr3(&[
291 [[false], [true], [false]],
292 [[true], [true], [false]],
293 [[false], [true], [false]],
294 ]);
295
296 let result = apply_threshold(&data, 0.25, 0.75);
297
298 assert_eq!(result, expected);
299 }
300
301 #[test]
302 fn threshold_calculate_threshold_otsu_ints() {
303 let data = arr3(&[[[2], [4], [0]], [[7], [5], [8]], [[1], [6], [0]]]);
304 let result = calculate_threshold_otsu(&data).unwrap();
305 println!("Done {}", result);
306
307 let expected = 2.0;
310
311 assert_approx_eq!(result, expected, 5e-1);
312 }
313
314 #[test]
315 fn threshold_calculate_threshold_otsu_floats() {
316 let data = arr3(&[
317 [[n64(2.0)], [n64(4.0)], [n64(0.0)]],
318 [[n64(7.0)], [n64(5.0)], [n64(8.0)]],
319 [[n64(1.0)], [n64(6.0)], [n64(0.0)]],
320 ]);
321
322 let result = calculate_threshold_otsu(&data).unwrap();
323
324 let expected = 2.0156;
327
328 assert_approx_eq!(result, expected, 5e-1);
329 }
330
331 #[test]
332 fn threshold_calculate_threshold_mean_ints() {
333 let data = arr3(&[[[4], [4], [4]], [[5], [5], [5]], [[6], [6], [6]]]);
334
335 let result = calculate_threshold_mean(&data).unwrap();
336 let expected = 5.0;
337
338 assert_approx_eq!(result, expected, 1e-16);
339 }
340
341 #[test]
342 fn threshold_calculate_threshold_mean_floats() {
343 let data = arr3(&[
344 [[4.0], [4.0], [4.0]],
345 [[5.0], [5.0], [5.0]],
346 [[6.0], [6.0], [6.0]],
347 ]);
348
349 let result = calculate_threshold_mean(&data).unwrap();
350 let expected = 5.0;
351
352 assert_approx_eq!(result, expected, 1e-16);
353 }
354}