use numr::dtype::DType;
use numr::error::{Error, Result};
use numr::ops::{ScalarOps, UnaryOps, UtilityOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn gaussian_kernel_1d<R, C>(
client: &C,
sigma: f64,
order: usize,
truncate: f64,
dtype: DType,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + UnaryOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
if sigma <= 0.0 {
return Err(Error::InvalidArgument {
arg: "sigma",
reason: "Gaussian sigma must be positive".to_string(),
});
}
if order > 3 {
return Err(Error::InvalidArgument {
arg: "order",
reason: "Gaussian derivative order must be 0-3".to_string(),
});
}
let radius = (truncate * sigma + 0.5) as usize;
let size = 2 * radius + 1;
if size == 0 {
return Err(Error::InvalidArgument {
arg: "sigma",
reason: "Gaussian kernel size is zero (sigma too small)".to_string(),
});
}
let positions = client.arange(0.0, size as f64, 1.0, dtype)?;
let positions = client.add_scalar(&positions, -(radius as f64))?;
let x2 = client.mul(&positions, &positions)?;
let neg_half_inv_sigma2 = -0.5 / (sigma * sigma);
let scaled = client.mul_scalar(&x2, neg_half_inv_sigma2)?;
let gaussian = client.exp(&scaled)?;
match order {
0 => {
let sum = client.sum(&gaussian, &[0], false)?;
client.div(&gaussian, &sum)
}
1 => {
let inv_sigma2 = -1.0 / (sigma * sigma);
let factor = client.mul_scalar(&positions, inv_sigma2)?;
let deriv = client.mul(&factor, &gaussian)?;
let weighted = client.mul(&positions, &deriv)?;
let weight_sum = client.sum(&weighted, &[0], false)?;
let neg_sum = client.mul_scalar(&weight_sum, -1.0)?;
client.div(&deriv, &neg_sum)
}
2 => {
let inv_sigma2 = 1.0 / (sigma * sigma);
let inv_sigma4 = inv_sigma2 * inv_sigma2;
let term1 = client.mul_scalar(&x2, inv_sigma4)?;
let term2 = client.add_scalar(&term1, -inv_sigma2)?;
let deriv = client.mul(&term2, &gaussian)?;
let x2_weighted = client.mul(&x2, &deriv)?;
let x2_sum = client.sum(&x2_weighted, &[0], false)?;
client.div(&deriv, &x2_sum)
}
_ => {
let inv_sigma2 = 1.0 / (sigma * sigma);
let inv_sigma4 = inv_sigma2 * inv_sigma2;
let inv_sigma6 = inv_sigma4 * inv_sigma2;
let x3 = client.mul(&x2, &positions)?;
let t1 = client.mul_scalar(&x3, -inv_sigma6)?;
let t2 = client.mul_scalar(&positions, 3.0 * inv_sigma4)?;
let factor = client.add(&t1, &t2)?;
let deriv = client.mul(&factor, &gaussian)?;
Ok(deriv)
}
}
}
pub fn uniform_kernel_1d<R, C>(client: &C, size: usize, dtype: DType) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
if size == 0 {
return Err(Error::InvalidArgument {
arg: "size",
reason: "Uniform kernel size must be positive".to_string(),
});
}
let value = 1.0 / size as f64;
client.fill(&[size], value, dtype)
}
pub fn edge_kernel_1d<R, C>(
client: &C,
kind: &str,
derivative: bool,
dtype: DType,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + UnaryOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
if derivative {
let positions = client.arange(0.0, 3.0, 1.0, dtype)?;
client.add_scalar(&positions, -1.0)
} else {
match kind {
"sobel" => {
let positions = client.arange(0.0, 3.0, 1.0, dtype)?;
let ones = client.fill(&[3], 1.0, dtype)?;
let centered = client.sub(&positions, &ones)?; let abs_val = client.abs(¢ered)?; let scaled = client.mul_scalar(&abs_val, -1.0)?; client.add_scalar(&scaled, 2.0) }
"prewitt" => {
client.fill(&[3], 1.0, dtype)
}
_ => Err(Error::InvalidArgument {
arg: "kind",
reason: format!("Unknown edge kernel kind: {kind}. Use 'sobel' or 'prewitt'"),
}),
}
}
}
pub fn laplace_kernel_1d<R, C>(client: &C, dtype: DType) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: ScalarOps<R> + UnaryOps<R> + UtilityOps<R> + RuntimeClient<R>,
{
let positions = client.arange(0.0, 3.0, 1.0, dtype)?;
let ones = client.fill(&[3], 1.0, dtype)?;
let centered = client.sub(&positions, &ones)?; let abs_centered = client.abs(¢ered)?;
let one_minus_abs = client.sub(&ones, &abs_centered)?; let offset = client.mul_scalar(&one_minus_abs, 3.0)?;
let base = client.fill(&[3], 1.0, dtype)?;
client.sub(&base, &offset)
}
#[cfg(test)]
mod tests {
#[test]
fn test_gaussian_kernel_size_calculation() {
let radius = (4.0 * 1.0 + 0.5) as usize;
assert_eq!(2 * radius + 1, 9);
}
#[test]
fn test_laplace_formula_correctness() {
let base = [1.0, 1.0, 1.0];
let offset = [0.0, 3.0, 0.0];
let result: [f64; 3] = [
base[0] - offset[0],
base[1] - offset[1],
base[2] - offset[2],
];
assert_eq!(result, [1.0, -2.0, 1.0]);
}
#[test]
fn test_sobel_formula_correctness() {
let xs = [0.0_f64, 1.0_f64, 2.0_f64];
let result: Vec<f64> = xs.iter().map(|&x| 2.0 - (x - 1.0).abs()).collect();
assert_eq!(result, vec![1.0, 2.0, 1.0]);
}
}