use crate::DType;
use crate::morphology::traits::binary::StructuringElement;
use crate::signal::traits::nd_filters::{BoundaryMode, NdFilterAlgorithms};
use numr::error::{Error, Result};
use numr::ops::{CompareOps, ConditionalOps, ReduceOps, ScalarOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn binary_erosion_impl<R, C>(
client: &C,
input: &Tensor<R>,
_structure: StructuringElement,
iterations: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: NdFilterAlgorithms<R>
+ ScalarOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
if input.ndim() == 0 {
return Err(Error::InvalidArgument {
arg: "input",
reason: "binary_erosion requires at least 1D input".to_string(),
});
}
let ndim = input.ndim();
let size = vec![3usize; ndim];
let zero = client.fill(input.shape(), 0.0, input.dtype())?;
let one = client.fill(input.shape(), 1.0, input.dtype())?;
let mask = client.ne(input, &zero)?;
let mut result = client.where_cond(&mask, &one, &zero)?;
for _ in 0..iterations {
let filtered = client.minimum_filter(&result, &size, BoundaryMode::Constant(0.0))?;
let thresh = client.ge(&filtered, &one)?;
result = client.where_cond(&thresh, &one, &zero)?;
}
Ok(result)
}
pub fn binary_dilation_impl<R, C>(
client: &C,
input: &Tensor<R>,
_structure: StructuringElement,
iterations: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: NdFilterAlgorithms<R>
+ ScalarOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
if input.ndim() == 0 {
return Err(Error::InvalidArgument {
arg: "input",
reason: "binary_dilation requires at least 1D input".to_string(),
});
}
let ndim = input.ndim();
let size = vec![3usize; ndim];
let zero = client.fill(input.shape(), 0.0, input.dtype())?;
let one = client.fill(input.shape(), 1.0, input.dtype())?;
let mask = client.ne(input, &zero)?;
let mut result = client.where_cond(&mask, &one, &zero)?;
for _ in 0..iterations {
let filtered = client.maximum_filter(&result, &size, BoundaryMode::Constant(0.0))?;
let thresh = client.gt(&filtered, &zero)?;
result = client.where_cond(&thresh, &one, &zero)?;
}
Ok(result)
}
pub fn binary_opening_impl<R, C>(
client: &C,
input: &Tensor<R>,
structure: StructuringElement,
iterations: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: NdFilterAlgorithms<R>
+ ScalarOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let eroded = binary_erosion_impl(client, input, structure, iterations)?;
binary_dilation_impl(client, &eroded, structure, iterations)
}
pub fn binary_closing_impl<R, C>(
client: &C,
input: &Tensor<R>,
structure: StructuringElement,
iterations: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: NdFilterAlgorithms<R>
+ ScalarOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let dilated = binary_dilation_impl(client, input, structure, iterations)?;
binary_erosion_impl(client, &dilated, structure, iterations)
}
pub fn binary_fill_holes_impl<R, C>(client: &C, input: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: NdFilterAlgorithms<R>
+ ScalarOps<R>
+ CompareOps<R>
+ ConditionalOps<R>
+ ReduceOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let ndim = input.ndim();
let size = vec![3usize; ndim];
let zero = client.fill(input.shape(), 0.0, input.dtype())?;
let one = client.fill(input.shape(), 1.0, input.dtype())?;
let mask = client.ne(input, &zero)?;
let binary = client.where_cond(&mask, &one, &zero)?;
let complement = client.sub(&one, &binary)?;
let mut marker = complement.clone();
let max_iter = input.shape().iter().sum::<usize>(); for _ in 0..max_iter {
let dilated = client.maximum_filter(&marker, &size, BoundaryMode::Constant(1.0))?;
let new_marker = client.minimum(&dilated, &complement)?;
let diff = client.sub(&new_marker, &marker)?;
let diff_abs = client.abs(&diff)?;
let diff_sum = client.sum(&diff_abs, &[], false)?;
let diff_val: Vec<f64> = diff_sum.to_vec();
marker = new_marker;
if diff_val[0] < 1e-10 {
break;
}
}
let holes = client.sub(&complement, &marker)?;
client.add(&binary, &holes)
}