use scirs2_core::ndarray::{Array, Array1, Dimension};
use scirs2_core::numeric::{Float, FromPrimitive, NumAssign};
use std::fmt::Debug;
use crate::error::{NdimageError, NdimageResult};
#[allow(dead_code)]
fn safe_usize_to_float<T: Float + FromPrimitive>(value: usize) -> NdimageResult<T> {
T::from_usize(value).ok_or_else(|| {
NdimageError::ComputationError(format!("Failed to convert usize {} to float type", value))
})
}
#[allow(dead_code)]
pub fn sum_labels<T, D>(
input: &Array<T, D>,
labels: &Array<usize, D>,
index: Option<&[usize]>,
) -> NdimageResult<Array1<T>>
where
T: Float + FromPrimitive + Debug + NumAssign + std::ops::DivAssign + 'static,
D: Dimension + 'static,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
if labels.shape() != input.shape() {
return Err(NdimageError::DimensionError(
"Labels array must have same shape as input array".to_string(),
));
}
let unique_labels: std::collections::HashSet<usize> = if let Some(idx) = index {
idx.iter().cloned().collect()
} else {
labels.iter().cloned().collect()
};
let mut sorted_labels: Vec<usize> = unique_labels.into_iter().collect();
sorted_labels.sort();
if sorted_labels.first() == Some(&0) {
sorted_labels.remove(0);
}
if sorted_labels.is_empty() {
return Ok(Array1::<T>::zeros(0));
}
let label_to_idx: std::collections::HashMap<usize, usize> = sorted_labels
.iter()
.enumerate()
.map(|(i, &label)| (label, i))
.collect();
let mut sums = vec![T::zero(); sorted_labels.len()];
for (input_val, label_val) in input.iter().zip(labels.iter()) {
if let Some(&idx) = label_to_idx.get(label_val) {
sums[idx] += *input_val;
}
}
Ok(Array1::from_vec(sums))
}
#[allow(dead_code)]
pub fn mean_labels<T, D>(
input: &Array<T, D>,
labels: &Array<usize, D>,
index: Option<&[usize]>,
) -> NdimageResult<Array1<T>>
where
T: Float + FromPrimitive + Debug + NumAssign + std::ops::DivAssign + 'static,
D: Dimension + 'static,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
if labels.shape() != input.shape() {
return Err(NdimageError::DimensionError(
"Labels array must have same shape as input array".to_string(),
));
}
let sums = sum_labels(input, labels, index)?;
let counts = count_labels(labels, index)?;
if sums.len() != counts.len() {
return Err(NdimageError::InvalidInput(
"Mismatch between sums and counts arrays".into(),
));
}
let means: Vec<T> = sums
.iter()
.zip(counts.iter())
.map(|(&sum, &count)| {
if count > 0 {
sum / safe_usize_to_float(count).unwrap_or(T::one())
} else {
T::zero()
}
})
.collect();
Ok(Array1::from_vec(means))
}
#[allow(dead_code)]
pub fn variance_labels<T, D>(
input: &Array<T, D>,
labels: &Array<usize, D>,
index: Option<&[usize]>,
) -> NdimageResult<Array1<T>>
where
T: Float + FromPrimitive + Debug + NumAssign + std::ops::DivAssign + 'static,
D: Dimension + 'static,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
if labels.shape() != input.shape() {
return Err(NdimageError::DimensionError(
"Labels array must have same shape as input array".to_string(),
));
}
let unique_labels: std::collections::HashSet<usize> = if let Some(idx) = index {
idx.iter().cloned().collect()
} else {
labels.iter().cloned().collect()
};
let mut sorted_labels: Vec<usize> = unique_labels.into_iter().collect();
sorted_labels.sort();
if sorted_labels.first() == Some(&0) {
sorted_labels.remove(0);
}
if sorted_labels.is_empty() {
return Ok(Array1::<T>::zeros(0));
}
let label_to_idx: std::collections::HashMap<usize, usize> = sorted_labels
.iter()
.enumerate()
.map(|(i, &label)| (label, i))
.collect();
let mut sums = vec![T::zero(); sorted_labels.len()];
let mut counts = vec![0usize; sorted_labels.len()];
for (input_val, label_val) in input.iter().zip(labels.iter()) {
if let Some(&idx) = label_to_idx.get(label_val) {
sums[idx] += *input_val;
counts[idx] += 1;
}
}
let means: Vec<T> = sums
.iter()
.zip(&counts)
.map(|(&sum, &count)| {
if count > 0 {
sum / safe_usize_to_float(count).unwrap_or(T::one())
} else {
T::zero()
}
})
.collect();
let mut variance_sums = vec![T::zero(); sorted_labels.len()];
for (input_val, label_val) in input.iter().zip(labels.iter()) {
if let Some(&idx) = label_to_idx.get(label_val) {
let diff = *input_val - means[idx];
variance_sums[idx] += diff * diff;
}
}
let variances: Vec<T> = variance_sums
.iter()
.zip(&counts)
.map(|(&var_sum, &count)| {
if count > 1 {
var_sum / safe_usize_to_float(count - 1).unwrap_or(T::one())
} else {
T::zero() }
})
.collect();
Ok(Array1::from_vec(variances))
}
#[allow(dead_code)]
pub fn count_labels<D>(
labels: &Array<usize, D>,
index: Option<&[usize]>,
) -> NdimageResult<Array1<usize>>
where
D: Dimension,
{
if labels.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Labels array cannot be 0-dimensional".into(),
));
}
let unique_labels: std::collections::HashSet<usize> = if let Some(idx) = index {
idx.iter().cloned().collect()
} else {
labels.iter().cloned().collect()
};
let mut sorted_labels: Vec<usize> = unique_labels.into_iter().collect();
sorted_labels.sort();
if sorted_labels.first() == Some(&0) {
sorted_labels.remove(0);
}
if sorted_labels.is_empty() {
return Ok(Array1::<usize>::zeros(0));
}
let label_to_idx: std::collections::HashMap<usize, usize> = sorted_labels
.iter()
.enumerate()
.map(|(i, &label)| (label, i))
.collect();
let mut counts = vec![0usize; sorted_labels.len()];
for &label_val in labels.iter() {
if let Some(&idx) = label_to_idx.get(&label_val) {
counts[idx] += 1;
}
}
Ok(Array1::from_vec(counts))
}
#[allow(dead_code)]
pub fn histogram<T, D>(
input: &Array<T, D>,
min: T,
max: T,
bins: usize,
labels: Option<&Array<usize, D>>,
_index: Option<&[usize]>,
) -> NdimageResult<(Array1<usize>, Array1<T>)>
where
T: Float + FromPrimitive + Debug + NumAssign + std::ops::DivAssign + 'static,
D: Dimension + 'static,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
if min >= max {
return Err(NdimageError::InvalidInput(format!(
"min must be less than max (got min={:?}, max={:?})",
min, max
)));
}
if bins == 0 {
return Err(NdimageError::InvalidInput(
"bins must be greater than 0".into(),
));
}
if let Some(lab) = labels {
if lab.shape() != input.shape() {
return Err(NdimageError::DimensionError(
"Labels array must have same shape as input array".to_string(),
));
}
}
let bin_width = (max - min) / T::from_usize(bins).expect("Operation failed");
let mut edges = Array1::<T>::zeros(bins + 1);
for i in 0..=bins {
edges[i] = min + T::from_usize(i).expect("Operation failed") * bin_width;
}
let mut hist = Array1::<usize>::zeros(bins);
match labels {
None => {
for &value in input.iter() {
if value >= min && value < max {
let bin_idx = ((value - min) / bin_width).to_usize().unwrap_or(0);
let bin_idx = bin_idx.min(bins - 1); hist[bin_idx] += 1;
} else if value == max {
hist[bins - 1] += 1;
}
}
}
Some(label_array) => {
for (value, &label) in input.iter().zip(label_array.iter()) {
if label > 0 && *value >= min && *value < max {
let bin_idx = ((*value - min) / bin_width).to_usize().unwrap_or(0);
let bin_idx = bin_idx.min(bins - 1);
hist[bin_idx] += 1;
} else if label > 0 && *value == max {
hist[bins - 1] += 1;
}
}
}
}
Ok((hist, edges))
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use scirs2_core::ndarray::{array, Array2, Array3};
#[test]
fn test_sum_labels_basic() {
let input: Array2<f64> = Array2::eye(3);
let labels: Array2<usize> = Array2::from_elem((3, 3), 1);
let result =
sum_labels(&input, &labels, None).expect("sum_labels should succeed for basic test");
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
assert_abs_diff_eq!(result[0], 3.0, epsilon = 1e-10); }
#[test]
fn test_sum_labels_multiple_regions() {
let input = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0]];
let labels = array![[1, 1, 2], [1, 2, 2], [3, 3, 3]];
let sums = sum_labels(&input, &labels, None)
.expect("sum_labels should succeed for multiple regions test");
assert_eq!(sums.len(), 3);
assert_abs_diff_eq!(sums[0], 1.0 + 2.0 + 4.0, epsilon = 1e-10); assert_abs_diff_eq!(sums[1], 3.0 + 5.0 + 6.0, epsilon = 1e-10); assert_abs_diff_eq!(sums[2], 7.0 + 8.0 + 9.0, epsilon = 1e-10); }
#[test]
fn test_sum_labels_with_background() {
let input = array![[1.0, 2.0], [3.0, 4.0]];
let labels = array![[0, 1], [1, 2]];
let sums = sum_labels(&input, &labels, None)
.expect("sum_labels should succeed with background test");
assert_eq!(sums.len(), 2); assert_abs_diff_eq!(sums[0], 2.0 + 3.0, epsilon = 1e-10); assert_abs_diff_eq!(sums[1], 4.0, epsilon = 1e-10); }
#[test]
fn test_sum_labels_selective_index() {
let input = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
let labels = array![[1, 2, 3], [1, 2, 3]];
let sums = sum_labels(&input, &labels, Some(&[1, 3]))
.expect("sum_labels should succeed with selective index test");
assert_eq!(sums.len(), 2); assert_abs_diff_eq!(sums[0], 1.0 + 4.0, epsilon = 1e-10); assert_abs_diff_eq!(sums[1], 3.0 + 6.0, epsilon = 1e-10); }
#[test]
fn test_sum_labels_edge_cases() {
let input = array![[1.0, 2.0]];
let labels = array![[0, 0]]; let sums = sum_labels(&input, &labels, None)
.expect("sum_labels should succeed for empty result test");
assert_eq!(sums.len(), 0);
let input2 = array![[1.0, 2.0, 3.0]];
let labels2 = array![[1, 2, 3]];
let sums2 = sum_labels(&input2, &labels2, None)
.expect("sum_labels should succeed for single pixel test");
assert_eq!(sums2.len(), 3);
assert_abs_diff_eq!(sums2[0], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(sums2[1], 2.0, epsilon = 1e-10);
assert_abs_diff_eq!(sums2[2], 3.0, epsilon = 1e-10);
}
#[test]
fn test_sum_labels_3d() {
let input = Array3::from_shape_fn((2, 2, 2), |(i, j, k)| (i + j + k) as f64);
let labels = Array3::from_shape_fn((2, 2, 2), |(i, j, _k)| if i == j { 1 } else { 2 });
let sums =
sum_labels(&input, &labels, None).expect("sum_labels should succeed for 3D test");
assert_eq!(sums.len(), 2);
assert!(sums[0] > 0.0);
assert!(sums[1] > 0.0);
}
#[test]
fn test_mean_labels_basic() {
let input: Array2<f64> = Array2::eye(3);
let labels: Array2<usize> = Array2::from_elem((3, 3), 1);
let result =
mean_labels(&input, &labels, None).expect("mean_labels should succeed for basic test");
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
assert_abs_diff_eq!(result[0], 3.0 / 9.0, epsilon = 1e-10);
}
#[test]
fn test_mean_labels_multiple_regions() {
let input = array![[2.0, 4.0, 6.0], [8.0, 10.0, 12.0]];
let labels = array![[1, 1, 2], [1, 2, 2]];
let means = mean_labels(&input, &labels, None)
.expect("mean_labels should succeed for multiple regions test");
assert_eq!(means.len(), 2);
assert_abs_diff_eq!(means[0], (2.0 + 4.0 + 8.0) / 3.0, epsilon = 1e-10);
assert_abs_diff_eq!(means[1], (6.0 + 10.0 + 12.0) / 3.0, epsilon = 1e-10);
}
#[test]
fn test_variance_labels_basic() {
let input = array![[1.0, 3.0, 5.0], [2.0, 4.0, 6.0]];
let labels = array![[1, 1, 2], [1, 2, 2]];
let variances = variance_labels(&input, &labels, None)
.expect("variance_labels should succeed for basic test");
assert_eq!(variances.len(), 2);
assert_abs_diff_eq!(variances[0], 1.0, epsilon = 1e-10);
assert_abs_diff_eq!(variances[1], 1.0, epsilon = 1e-10);
}
#[test]
fn test_variance_labels_zero_variance() {
let input = array![[5.0, 5.0, 3.0], [5.0, 3.0, 3.0]];
let labels = array![[1, 1, 2], [1, 2, 2]];
let variances = variance_labels(&input, &labels, None)
.expect("variance_labels should succeed for zero variance test");
assert_eq!(variances.len(), 2);
assert_abs_diff_eq!(variances[0], 0.0, epsilon = 1e-10); assert_abs_diff_eq!(variances[1], 0.0, epsilon = 1e-10); }
#[test]
fn test_variance_labels_single_pixel() {
let input = array![[1.0, 2.0, 3.0]];
let labels = array![[1, 2, 3]];
let variances = variance_labels(&input, &labels, None)
.expect("variance_labels should succeed for single pixel test");
assert_eq!(variances.len(), 3);
assert_abs_diff_eq!(variances[0], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(variances[1], 0.0, epsilon = 1e-10);
assert_abs_diff_eq!(variances[2], 0.0, epsilon = 1e-10);
}
#[test]
fn test_count_labels_basic() {
let labels: Array2<usize> = Array2::from_elem((3, 3), 1);
let result =
count_labels(&labels, None).expect("count_labels should succeed for basic test");
assert!(!result.is_empty());
assert_eq!(result.len(), 1);
assert_eq!(result[0], 9); }
#[test]
fn test_count_labels_multiple_regions() {
let labels = array![[1, 1, 2, 2], [1, 3, 3, 2], [4, 4, 4, 4]];
let counts = count_labels(&labels, None)
.expect("count_labels should succeed for multiple regions test");
assert_eq!(counts.len(), 4);
assert_eq!(counts[0], 3); assert_eq!(counts[1], 3); assert_eq!(counts[2], 2); assert_eq!(counts[3], 4); }
#[test]
fn test_count_labels_with_background() {
let labels = array![[0, 1, 1], [0, 2, 2], [0, 0, 3]];
let counts =
count_labels(&labels, None).expect("count_labels should succeed with background test");
assert_eq!(counts.len(), 3); assert_eq!(counts[0], 2); assert_eq!(counts[1], 2); assert_eq!(counts[2], 1); }
#[test]
fn test_error_handling() {
let input = array![[1.0, 2.0]];
let labels = array![[1], [2]];
assert!(sum_labels(&input, &labels, None).is_err());
assert!(mean_labels(&input, &labels, None).is_err());
assert!(variance_labels(&input, &labels, None).is_err());
let input_0d = scirs2_core::ndarray::arr0(1.0);
let labels_0d = scirs2_core::ndarray::arr0(1);
assert!(sum_labels(&input_0d, &labels_0d, None).is_err());
assert!(mean_labels(&input_0d, &labels_0d, None).is_err());
assert!(variance_labels(&input_0d, &labels_0d, None).is_err());
assert!(count_labels(&labels_0d, None).is_err());
}
#[test]
fn test_high_dimensional_arrays() {
let input = Array::from_shape_fn((2, 2, 2, 2), |(i, j, k, l)| (i + j + k + l) as f64);
let labels = Array::from_shape_fn((2, 2, 2, 2), |(i, j, _k, _l)| i + j + 1);
let sums =
sum_labels(&input, &labels, None).expect("sum_labels should succeed for 4D test");
let means =
mean_labels(&input, &labels, None).expect("mean_labels should succeed for 4D test");
let variances = variance_labels(&input, &labels, None)
.expect("variance_labels should succeed for 4D test");
let counts = count_labels(&labels, None).expect("count_labels should succeed for 4D test");
assert!(!sums.is_empty());
assert!(!means.is_empty());
assert!(!variances.is_empty());
assert!(!counts.is_empty());
assert_eq!(sums.len(), means.len());
assert_eq!(means.len(), variances.len());
assert_eq!(variances.len(), counts.len());
}
}