use crate::error::{NdimageError, NdimageResult};
use crate::utils::safe_f64_to_float;
use scirs2_core::ndarray::{Array, Dimension, Ix2};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
#[allow(dead_code)]
fn safe_usize_to_float<T: Float + FromPrimitive>(value: usize) -> NdimageResult<T> {
T::from_usize(value).ok_or_else(|| {
NdimageError::ComputationError(format!("Failed to convert usize {} to float type", value))
})
}
#[allow(dead_code)]
pub fn threshold_binary<T, D>(image: &Array<T, D>, threshold: T) -> NdimageResult<Array<T, D>>
where
T: Float + NumAssign + std::fmt::Debug + std::ops::DivAssign + 'static,
D: Dimension + 'static,
{
let result = image.mapv(|val| if val > threshold { T::one() } else { T::zero() });
Ok(result)
}
#[allow(dead_code)]
pub fn otsu_threshold<T, D>(image: &Array<T, D>, bins: usize) -> NdimageResult<(Array<T, D>, T)>
where
T: Float + NumAssign + std::fmt::Debug + std::ops::DivAssign + FromPrimitive + 'static,
D: Dimension + 'static,
{
let nbins = bins;
let mut min_val = Float::infinity();
let mut max_val = Float::neg_infinity();
for &val in image.iter() {
if val < min_val {
min_val = val;
}
if val > max_val {
max_val = val;
}
}
if min_val == max_val {
let binary = threshold_binary(image, min_val)?;
return Ok((binary, min_val));
}
let mut hist = vec![0; nbins];
let bin_width = (max_val - min_val) / safe_usize_to_float(nbins)?;
for &val in image.iter() {
let bin = ((val - min_val) / bin_width).to_usize().unwrap_or(0);
let bin_index = std::cmp::min(bin, nbins - 1);
hist[bin_index] += 1;
}
let total_pixels = image.len();
let mut cum_sum = vec![0; nbins];
cum_sum[0] = hist[0];
for i in 1..nbins {
cum_sum[i] = cum_sum[i - 1] + hist[i];
}
let mut cum_val = vec![T::zero(); nbins];
for i in 0..nbins {
if i > 0 {
cum_val[i] = cum_val[i - 1] + safe_usize_to_float(i * hist[i])?;
} else {
cum_val[i] = safe_usize_to_float(i * hist[i])?
}
}
let mut max_var = T::zero();
let mut threshold_idx = 0;
for i in 0..(nbins - 1) {
let bg_pixels = cum_sum[i];
let fg_pixels = total_pixels - bg_pixels;
if bg_pixels == 0 || fg_pixels == 0 {
continue;
}
let bg_mean = cum_val[i] / safe_usize_to_float::<T>(bg_pixels)?;
let fg_mean = (cum_val[nbins - 1] - cum_val[i]) / safe_usize_to_float::<T>(fg_pixels)?;
let variance = safe_usize_to_float::<T>(bg_pixels * fg_pixels)?
* (bg_mean - fg_mean)
* (bg_mean - fg_mean);
if variance > max_var {
max_var = variance;
threshold_idx = i;
}
}
let threshold = min_val + safe_usize_to_float::<T>(threshold_idx)? * bin_width;
let binary = threshold_binary(image, threshold)?;
Ok((binary, threshold))
}
#[derive(Debug, Clone, Copy)]
pub enum AdaptiveMethod {
Mean,
Gaussian,
}
#[allow(dead_code)]
pub fn adaptive_threshold<T>(
image: &Array<T, Ix2>,
block_size: usize,
method: AdaptiveMethod,
c: T,
) -> NdimageResult<Array<bool, Ix2>>
where
T: Float + NumAssign + std::fmt::Debug + FromPrimitive,
{
if block_size % 2 == 0 || block_size < 3 {
return Err(NdimageError::InvalidInput(
"block_size must be odd and at least 3".to_string(),
));
}
let shape = image.raw_dim();
let (rows, cols) = (shape[0], shape[1]);
let mut result = Array::from_elem(shape, false);
let radius = block_size / 2;
for i in 0..rows {
for j in 0..cols {
let start_row = i.saturating_sub(radius);
let end_row = std::cmp::min(i + radius + 1, rows);
let start_col = j.saturating_sub(radius);
let end_col = std::cmp::min(j + radius + 1, cols);
let neighborhood = image.slice(scirs2_core::ndarray::s![
start_row..end_row,
start_col..end_col
]);
let threshold = match method {
AdaptiveMethod::Mean => {
let sum = neighborhood.iter().fold(T::zero(), |acc, &x| acc + x);
sum / safe_usize_to_float(neighborhood.len())? - c
}
AdaptiveMethod::Gaussian => {
let center_row = i - start_row;
let center_col = j - start_col;
let mut weighted_sum = T::zero();
let mut weight_sum = T::zero();
for (idx, &val) in neighborhood.indexed_iter() {
let dist_sq = (idx.0 as isize - center_row as isize).pow(2)
+ (idx.1 as isize - center_col as isize).pow(2);
let dist = safe_usize_to_float::<T>(dist_sq as usize)?.sqrt();
let sigma =
safe_usize_to_float::<T>(radius)? / safe_f64_to_float::<T>(2.0)?;
let weight =
(-dist * dist / (safe_f64_to_float::<T>(2.0)? * sigma * sigma)).exp();
weighted_sum += val * weight;
weight_sum += weight;
}
weighted_sum / weight_sum - c
}
};
result[(i, j)] = image[(i, j)] > threshold;
}
}
Ok(result)
}