use crate::morphology::traits::binary::StructuringElement;
use crate::morphology::traits::measurements::RegionProperties;
use numr::error::{Error, Result};
use numr::ops::ScatterReduceOp;
use numr::ops::{
BinaryOps, CompareOps, ConditionalOps, IndexingOps, ReduceOps, ScalarOps, ShapeOps,
TypeConversionOps, UnaryOps, UtilityOps,
};
use numr::prelude::DType;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
fn pad_single_axis<R, C>(
client: &C,
tensor: &Tensor<R>,
axis: usize,
before: usize,
after: usize,
value: f64,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ShapeOps<R> + RuntimeClient<R>,
{
let ndim = tensor.ndim();
let mut padding = vec![0usize; ndim * 2];
let pad_idx = (ndim - 1 - axis) * 2;
padding[pad_idx] = before;
padding[pad_idx + 1] = after;
client.pad(tensor, &padding, value)
}
pub fn label_impl<R, C>(
client: &C,
input: &Tensor<R>,
structure: StructuringElement,
) -> Result<(Tensor<R>, usize)>
where
R: Runtime<DType = DType>,
C: ScalarOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ ShapeOps<R>
+ ReduceOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let ndim = input.ndim();
if ndim == 0 {
return Err(Error::InvalidArgument {
arg: "input",
reason: "label requires at least 1D input".to_string(),
});
}
let shape = input.shape().to_vec();
let total: usize = shape.iter().product();
let device = input.device();
let dtype = input.dtype();
let zero_tensor = Tensor::from_slice(&vec![0.0; total], &shape, device);
let fg_mask = client.ne(input, &zero_tensor)?;
let ids = client.arange(1.0, (total + 1) as f64, 1.0, DType::F64)?;
let ids = ids.reshape(&shape)?;
let inf_val = (total + 1) as f64;
let inf_tensor = Tensor::from_slice(&vec![inf_val; total], &shape, device);
let mut labels = client.where_cond(&fg_mask, &ids, &inf_tensor)?;
let full_connectivity = matches!(structure, StructuringElement::Full);
let max_iter = total;
for _ in 0..max_iter {
let prev = labels.clone();
for (axis, &axis_len) in shape.iter().enumerate() {
if axis_len <= 1 {
continue;
}
let padded = pad_single_axis(client, &labels, axis, 1, 1, inf_val)?;
let left = padded.narrow(axis as isize, 0, axis_len)?;
let right = padded.narrow(axis as isize, 2, axis_len)?;
labels = client.minimum(&labels, &left)?;
labels = client.minimum(&labels, &right)?;
}
if full_connectivity && ndim >= 2 {
for a1 in 0..ndim {
for a2 in (a1 + 1)..ndim {
if shape[a1] <= 1 || shape[a2] <= 1 {
continue;
}
let padded = pad_single_axis(client, &labels, a1, 1, 1, inf_val)?;
let padded = pad_single_axis(client, &padded, a2, 1, 1, inf_val)?;
let len1 = shape[a1];
let len2 = shape[a2];
for (s1, s2) in [(0, 0), (0, 2), (2, 0), (2, 2)] {
let view = padded.narrow(a1 as isize, s1, len1)?;
let view = view.narrow(a2 as isize, s2, len2)?;
labels = client.minimum(&labels, &view)?;
}
}
}
}
labels = client.where_cond(&fg_mask, &labels, &inf_tensor)?;
let diff = client.sub(&labels, &prev)?;
let diff_abs = client.abs(&diff)?;
let diff_sum = client.sum(&diff_abs, &[], false)?;
let val: Vec<f64> = diff_sum.to_vec();
if val[0] < 0.5 {
break;
}
}
let label_data: Vec<f64> = labels.to_vec();
let mut unique_labels: Vec<f64> = label_data
.iter()
.filter(|&&v| v < inf_val)
.copied()
.collect();
unique_labels.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Greater));
unique_labels.dedup();
let num_labels = unique_labels.len();
let mut result = vec![0.0f64; total];
if !unique_labels.is_empty() {
use std::collections::HashMap;
let remap: HashMap<u64, f64> = unique_labels
.iter()
.enumerate()
.map(|(i, &v)| (v.to_bits(), (i + 1) as f64))
.collect();
for (i, &v) in label_data.iter().enumerate() {
if let Some(&new_label) = remap.get(&v.to_bits()) {
result[i] = new_label;
}
}
}
let tensor = Tensor::from_slice(&result, &shape, device);
let tensor = client.cast(&tensor, dtype)?;
Ok((tensor, num_labels))
}
pub fn find_objects_impl<R, C>(
_client: &C,
labels: &Tensor<R>,
num_labels: usize,
) -> Result<Vec<RegionProperties>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
let ndim = labels.ndim();
let shape = labels.shape().to_vec();
let data: Vec<f64> = labels.to_vec();
let mut props: Vec<RegionProperties> = (1..=num_labels)
.map(|label| RegionProperties {
label,
area: 0,
bbox: {
let mut b = vec![usize::MAX; ndim];
b.extend(vec![0usize; ndim]);
b
},
})
.collect();
let total: usize = shape.iter().product();
for (flat_idx, &label_val_raw) in data.iter().enumerate().take(total) {
let label_val = label_val_raw as usize;
if label_val == 0 || label_val > num_labels {
continue;
}
let prop = &mut props[label_val - 1];
prop.area += 1;
let mut remaining = flat_idx;
for d in (0..ndim).rev() {
let coord = remaining % shape[d];
remaining /= shape[d];
prop.bbox[d] = prop.bbox[d].min(coord);
prop.bbox[ndim + d] = prop.bbox[ndim + d].max(coord);
}
}
Ok(props)
}
pub fn sum_labels_impl<R, C>(
client: &C,
input: &Tensor<R>,
labels: &Tensor<R>,
num_labels: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + IndexingOps<R> + TypeConversionOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
let device = input.device();
let dtype = input.dtype();
let total: usize = input.shape().iter().product();
let flat_input = input.reshape(&[total])?;
let flat_labels = labels.reshape(&[total])?;
let indices = client.add_scalar(&flat_labels, -1.0)?;
let indices = client.cast(&indices, DType::I64)?;
let dst = Tensor::from_slice(&vec![0.0; num_labels], &[num_labels], device);
let result =
client.scatter_reduce(&dst, 0, &indices, &flat_input, ScatterReduceOp::Sum, true)?;
client.cast(&result, dtype)
}
pub fn mean_labels_impl<R, C>(
client: &C,
input: &Tensor<R>,
labels: &Tensor<R>,
num_labels: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + IndexingOps<R> + TypeConversionOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
let device = input.device();
let dtype = input.dtype();
let total: usize = input.shape().iter().product();
let flat_input = input.reshape(&[total])?;
let flat_labels = labels.reshape(&[total])?;
let indices = client.add_scalar(&flat_labels, -1.0)?;
let indices = client.cast(&indices, DType::I64)?;
let dst = Tensor::from_slice(&vec![0.0; num_labels], &[num_labels], device);
let result =
client.scatter_reduce(&dst, 0, &indices, &flat_input, ScatterReduceOp::Mean, true)?;
client.cast(&result, dtype)
}
pub fn center_of_mass_impl<R, C>(
client: &C,
input: &Tensor<R>,
labels: &Tensor<R>,
num_labels: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R>
+ BinaryOps<R>
+ UnaryOps<R>
+ IndexingOps<R>
+ TypeConversionOps<R>
+ ConditionalOps<R>
+ CompareOps<R>
+ ShapeOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let ndim = input.ndim();
let shape = input.shape().to_vec();
let device = input.device();
let dtype = input.dtype();
let total: usize = shape.iter().product();
let flat_input = input.reshape(&[total])?;
let flat_labels = labels.reshape(&[total])?;
let indices = client.add_scalar(&flat_labels, -1.0)?;
let indices = client.cast(&indices, DType::I64)?;
let dst_zeros = Tensor::from_slice(&vec![0.0; num_labels], &[num_labels], device);
let total_weights = client.scatter_reduce(
&dst_zeros,
0,
&indices,
&flat_input,
ScatterReduceOp::Sum,
true,
)?;
let mut results = Vec::with_capacity(ndim);
for d in 0..ndim {
let stride: usize = shape[d + 1..].iter().product();
let flat_indices = client.arange(0.0, total as f64, 1.0, DType::F64)?;
let divided = client.mul_scalar(&flat_indices, 1.0 / stride as f64)?;
let floored = client.floor(÷d)?;
let shape_d = shape[d] as f64;
let scaled = client.mul_scalar(&floored, 1.0 / shape_d)?;
let floored2 = client.floor(&scaled)?;
let subtract = client.mul_scalar(&floored2, shape_d)?;
let coords = client.sub(&floored, &subtract)?;
let weighted = client.mul(&flat_input, &coords)?;
let weighted_sum = client.scatter_reduce(
&dst_zeros,
0,
&indices,
&weighted,
ScatterReduceOp::Sum,
true,
)?;
let zero_mask = client.eq(
&total_weights,
&Tensor::from_slice(&vec![0.0; num_labels], &[num_labels], device),
)?;
let center = client.div(&weighted_sum, &total_weights)?;
let zero_tensor = Tensor::from_slice(&vec![0.0; num_labels], &[num_labels], device);
let center = client.where_cond(&zero_mask, &zero_tensor, ¢er)?;
results.push(center);
}
let mut stacked = results[0].reshape(&[num_labels, 1])?;
for r in &results[1..] {
let col = r.reshape(&[num_labels, 1])?;
stacked = client.cat(&[&stacked, &col], 1)?;
}
client.cast(&stacked, dtype)
}