use crate::core::PixelBound;
use crate::core::{ColourModel, Image, ImageBase};
use crate::processing::*;
use ndarray::{prelude::*, Data};
use ndarray_stats::histogram::{Bins, Edges, Grid};
use ndarray_stats::HistogramExt;
use ndarray_stats::QuantileExt;
use num_traits::cast::FromPrimitive;
use num_traits::cast::ToPrimitive;
use num_traits::{Num, NumAssignOps};
use std::marker::PhantomData;
#[cfg(test)]
use assert_approx_eq::assert_approx_eq;
#[cfg(test)]
use noisy_float::types::n64;
pub trait ThresholdOtsuExt<T> {
type Output;
fn threshold_otsu(&self) -> Result<Self::Output, Error>;
}
pub trait ThresholdMeanExt<T> {
type Output;
fn threshold_mean(&self) -> Result<Self::Output, Error>;
}
impl<T, U, C> ThresholdOtsuExt<T> for ImageBase<U, C>
where
U: Data<Elem = T>,
Image<U, C>: Clone,
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
C: ColourModel,
{
type Output = Image<bool, C>;
fn threshold_otsu(&self) -> Result<Self::Output, Error> {
let data = self.data.threshold_otsu()?;
Ok(Self::Output {
data,
model: PhantomData,
})
}
}
impl<T, U> ThresholdOtsuExt<T> for ArrayBase<U, Ix3>
where
U: Data<Elem = T>,
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
{
type Output = Array3<bool>;
fn threshold_otsu(&self) -> Result<Self::Output, Error> {
if self.shape()[2] > 1 {
Err(Error::ChannelDimensionMismatch)
} else {
let value = calculate_threshold_otsu(self)?;
let mask = apply_threshold(self, value);
Ok(mask)
}
}
}
fn calculate_threshold_otsu<T, U>(mat: &ArrayBase<U, Ix3>) -> Result<f64, Error>
where
U: Data<Elem = T>,
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
{
let mut threshold = 0.0;
let n_bins = 255;
for c in mat.axis_iter(Axis(2)) {
let scale_factor = (n_bins) as f64 / (c.max().unwrap().to_f64().unwrap());
let edges_vec: Vec<u8> = (0..n_bins).collect();
let grid = Grid::from(vec![Bins::new(Edges::from(edges_vec))]);
let flat = Array::from_iter(c.iter()).insert_axis(Axis(1));
let flat2 = flat.mapv(|x| ((*x).to_f64().unwrap() * scale_factor).to_u8().unwrap());
let hist = flat2.histogram(grid);
let counts = hist.counts();
let total = counts.sum().to_f64().unwrap();
let counts = Array::from_iter(counts.iter());
let mut sum_b = 0.0;
let mut weight_b = 0.0;
let mut maximum = 0.0;
let mut level = 0.0;
let mut sum_intensity = 0.0;
for (index, count) in counts.indexed_iter() {
sum_intensity += (index as f64) * (*count).to_f64().unwrap();
}
for (index, count) in counts.indexed_iter() {
weight_b = weight_b + count.to_f64().unwrap();
sum_b = sum_b + (index as f64) * count.to_f64().unwrap();
let weight_f = total - weight_b;
if (weight_b > 0.0) && (weight_f > 0.0) {
let mean_f = (sum_intensity - sum_b) / weight_f;
let val = weight_b
* weight_f
* ((sum_b / weight_b) - mean_f)
* ((sum_b / weight_b) - mean_f);
if val > maximum {
level = 1.0 + (index as f64);
maximum = val;
}
}
}
threshold = level as f64 / scale_factor;
}
Ok(threshold)
}
impl<T, U, C> ThresholdMeanExt<T> for ImageBase<U, C>
where
U: Data<Elem = T>,
Image<U, C>: Clone,
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive + PixelBound,
C: ColourModel,
{
type Output = Image<bool, C>;
fn threshold_mean(&self) -> Result<Self::Output, Error> {
let data = self.data.threshold_mean()?;
Ok(Self::Output {
data,
model: PhantomData,
})
}
}
impl<T, U> ThresholdMeanExt<T> for ArrayBase<U, Ix3>
where
U: Data<Elem = T>,
T: Copy + Clone + Ord + Num + NumAssignOps + ToPrimitive + FromPrimitive,
{
type Output = Array3<bool>;
fn threshold_mean(&self) -> Result<Self::Output, Error> {
if self.shape()[2] > 1 {
Err(Error::ChannelDimensionMismatch)
} else {
let value = calculate_threshold_mean(self)?;
let mask = apply_threshold(self, value);
Ok(mask)
}
}
}
fn calculate_threshold_mean<T, U>(array: &ArrayBase<U, Ix3>) -> Result<f64, Error>
where
U: Data<Elem = T>,
T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
{
Ok(array.sum().to_f64().unwrap() / array.len() as f64)
}
fn apply_threshold<T, U>(data: &ArrayBase<U, Ix3>, threshold: f64) -> Array3<bool>
where
U: Data<Elem = T>,
T: Copy + Clone + Num + NumAssignOps + ToPrimitive + FromPrimitive,
{
let result = data.mapv(|x| x.to_f64().unwrap() >= threshold);
result
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::arr3;
#[test]
fn threshold_apply_threshold() {
let data = arr3(&[
[[0.2], [0.4], [0.0]],
[[0.7], [0.5], [0.8]],
[[0.1], [0.6], [0.0]],
]);
let expected = arr3(&[
[[false], [false], [false]],
[[true], [true], [true]],
[[false], [true], [false]],
]);
let result = apply_threshold(&data, 0.5);
assert_eq!(result, expected);
}
#[test]
fn threshold_calculate_threshold_otsu_ints() {
let data = arr3(&[[[2], [4], [0]], [[7], [5], [8]], [[1], [6], [0]]]);
let result = calculate_threshold_otsu(&data).unwrap();
println!("Done {}", result);
let expected = 2.0;
assert_approx_eq!(result, expected, 5e-1);
}
#[test]
fn threshold_calculate_threshold_otsu_floats() {
let data = arr3(&[
[[n64(2.0)], [n64(4.0)], [n64(0.0)]],
[[n64(7.0)], [n64(5.0)], [n64(8.0)]],
[[n64(1.0)], [n64(6.0)], [n64(0.0)]],
]);
let result = calculate_threshold_otsu(&data).unwrap();
let expected = 2.0156;
assert_approx_eq!(result, expected, 5e-1);
}
#[test]
fn threshold_calculate_threshold_mean_ints() {
let data = arr3(&[[[4], [4], [4]], [[5], [5], [5]], [[6], [6], [6]]]);
let result = calculate_threshold_mean(&data).unwrap();
let expected = 5.0;
assert_approx_eq!(result, expected, 1e-16);
}
#[test]
fn threshold_calculate_threshold_mean_floats() {
let data = arr3(&[
[[4.0], [4.0], [4.0]],
[[5.0], [5.0], [5.0]],
[[6.0], [6.0], [6.0]],
]);
let result = calculate_threshold_mean(&data).unwrap();
let expected = 5.0;
assert_approx_eq!(result, expected, 1e-16);
}
}