use crate::DType;
use crate::signal::impl_generic::boundary::pad_axis_impl;
use crate::signal::impl_generic::kernels::{edge_kernel_1d, laplace_kernel_1d};
use crate::signal::traits::nd_filters::BoundaryMode;
use crate::signal::validate_signal_dtype;
use numr::error::{Error, Result};
use numr::ops::{ConvOps, PaddingMode, ScalarOps, ShapeOps, TensorOps, UnaryOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
fn separable_edge_filter_impl<R, C>(
client: &C,
input: &Tensor<R>,
axis: usize,
kind: &str,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ConvOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ TensorOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
let ndim = input.ndim();
let dtype = input.dtype();
if ndim < 2 {
return Err(Error::InvalidArgument {
arg: "input",
reason: format!("{kind} requires at least 2D input, got {ndim}D"),
});
}
if axis >= ndim {
return Err(Error::InvalidArgument {
arg: "axis",
reason: format!("Axis {axis} out of range for {ndim}D input"),
});
}
let mut result = input.clone();
for ax in 0..ndim {
let kernel = if ax == axis {
edge_kernel_1d(client, kind, true, dtype)?
} else {
edge_kernel_1d(client, kind, false, dtype)?
};
result = convolve_along_axis_simple(client, &result, &kernel, ax)?;
}
Ok(result)
}
fn convolve_along_axis_simple<R, C>(
client: &C,
input: &Tensor<R>,
kernel_1d: &Tensor<R>,
axis: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ConvOps<R> + ScalarOps<R> + ShapeOps<R> + TensorOps<R> + RuntimeClient<R>,
{
let shape = input.shape().to_vec();
let ndim = shape.len();
let kernel_len = kernel_1d.shape()[0];
let pad_size = kernel_len / 2;
let padded = pad_axis_impl(
client,
input,
axis as isize,
pad_size,
pad_size,
BoundaryMode::Reflect,
)?;
let padded_shape = padded.shape().to_vec();
let axis_len = padded_shape[axis];
let batch_size: usize = padded_shape
.iter()
.enumerate()
.filter(|(i, _)| *i != axis)
.map(|(_, &s)| s)
.product();
let mut perm: Vec<usize> = (0..ndim).collect();
perm.remove(axis);
perm.push(axis);
let permuted = padded.permute(&perm)?;
let permuted_contig = permuted.contiguous()?;
let reshaped = permuted_contig.reshape(&[batch_size, 1, axis_len])?;
let kernel_3d = kernel_1d.reshape(&[1, 1, kernel_len])?;
let conv_result = client.conv1d(&reshaped, &kernel_3d, None, 1, PaddingMode::Valid, 1, 1)?;
let output_axis_len = shape[axis];
let mut permuted_shape: Vec<usize> = perm[..ndim - 1].iter().map(|&i| shape[i]).collect();
permuted_shape.push(output_axis_len);
let reshaped_back = conv_result.reshape(&permuted_shape)?;
let mut inv_perm = vec![0usize; ndim];
for (i, &p) in perm.iter().enumerate() {
inv_perm[p] = i;
}
let permuted = reshaped_back.permute(&inv_perm)?;
permuted.contiguous()
}
pub fn sobel_impl<R, C>(client: &C, input: &Tensor<R>, axis: usize) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ConvOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ TensorOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
validate_signal_dtype(input.dtype(), "sobel")?;
separable_edge_filter_impl(client, input, axis, "sobel")
}
pub fn prewitt_impl<R, C>(client: &C, input: &Tensor<R>, axis: usize) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ConvOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ TensorOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
validate_signal_dtype(input.dtype(), "prewitt")?;
separable_edge_filter_impl(client, input, axis, "prewitt")
}
pub fn laplace_impl<R, C>(client: &C, input: &Tensor<R>) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ConvOps<R> + ScalarOps<R> + ShapeOps<R> + TensorOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
validate_signal_dtype(input.dtype(), "laplace")?;
let ndim = input.ndim();
if ndim == 0 {
return Err(Error::InvalidArgument {
arg: "input",
reason: "laplace requires at least 1D input".to_string(),
});
}
let dtype = input.dtype();
let kernel = laplace_kernel_1d(client, dtype)?;
let first = convolve_along_axis_simple(client, input, &kernel, 0)?;
let mut result = first;
for axis in 1..ndim {
let component = convolve_along_axis_simple(client, input, &kernel, axis)?;
result = client.add(&result, &component)?;
}
Ok(result)
}
pub fn gaussian_laplace_impl<R, C>(client: &C, input: &Tensor<R>, sigma: f64) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ConvOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ TensorOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
validate_signal_dtype(input.dtype(), "gaussian_laplace")?;
let ndim = input.ndim();
let sigmas = vec![sigma; ndim];
let orders = vec![0usize; ndim];
use super::nd_filters::gaussian_filter_impl;
let smoothed =
gaussian_filter_impl(client, input, &sigmas, &orders, BoundaryMode::Reflect, 4.0)?;
laplace_impl(client, &smoothed)
}
pub fn gaussian_gradient_magnitude_impl<R, C>(
client: &C,
input: &Tensor<R>,
sigma: f64,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ConvOps<R>
+ ScalarOps<R>
+ ShapeOps<R>
+ TensorOps<R>
+ UnaryOps<R>
+ UtilityOps<R>
+ RuntimeClient<R>,
{
validate_signal_dtype(input.dtype(), "gaussian_gradient_magnitude")?;
let ndim = input.ndim();
if ndim == 0 {
return Err(Error::InvalidArgument {
arg: "input",
reason: "gaussian_gradient_magnitude requires at least 1D input".to_string(),
});
}
use super::nd_filters::gaussian_filter_impl;
let mut sum_sq: Option<Tensor<R>> = None;
for axis in 0..ndim {
let sigmas = vec![sigma; ndim];
let mut orders = vec![0usize; ndim];
orders[axis] = 1;
let grad =
gaussian_filter_impl(client, input, &sigmas, &orders, BoundaryMode::Reflect, 4.0)?;
let grad_sq = client.mul(&grad, &grad)?;
sum_sq = Some(match sum_sq {
Some(acc) => client.add(&acc, &grad_sq)?,
None => grad_sq,
});
}
match sum_sq {
Some(sq) => client.sqrt(&sq),
None => Err(Error::InvalidArgument {
arg: "input",
reason: "gaussian_gradient_magnitude requires at least 1D input".to_string(),
}),
}
}