use ndarray::{Array2, ArrayView2};
use std::f64::consts::PI;
use std::num::NonZeroUsize;
use crate::fft2d::{fft2d, ifft2d};
use crate::{SpectrogramError, SpectrogramResult};
#[inline]
pub fn convolve_fft(
image: &ArrayView2<f64>,
kernel: &ArrayView2<f64>,
) -> SpectrogramResult<Array2<f64>> {
let (img_rows, img_cols) = image.dim();
let (ker_rows, ker_cols) = kernel.dim();
if ker_rows > img_rows || ker_cols > img_cols {
return Err(SpectrogramError::invalid_input(
"kernel dimensions must not exceed image dimensions",
));
}
if ker_rows == 0 || ker_cols == 0 {
return Err(SpectrogramError::invalid_input(
"kernel dimensions must be > 0",
));
}
let padded_kernel = pad_kernel_for_fft(kernel, (img_rows, img_cols));
let img_freq = fft2d(image)?;
let kernel_freq = fft2d(&padded_kernel.view())?;
let result_freq = &img_freq * &kernel_freq;
let result = ifft2d(&result_freq, img_cols)?;
Ok(result)
}
fn pad_kernel_for_fft(kernel: &ArrayView2<f64>, target_shape: (usize, usize)) -> Array2<f64> {
let (target_rows, target_cols) = target_shape;
let (ker_rows, ker_cols) = kernel.dim();
let mut result = Array2::<f64>::zeros(target_shape);
let ker_center_row = ker_rows / 2;
let ker_center_col = ker_cols / 2;
for i in 0..ker_rows {
for j in 0..ker_cols {
let row_offset = i as isize - ker_center_row as isize;
let col_offset = j as isize - ker_center_col as isize;
let target_row = row_offset.rem_euclid(target_rows as isize) as usize;
let target_col = col_offset.rem_euclid(target_cols as isize) as usize;
result[[target_row, target_col]] = kernel[[i, j]];
}
}
result
}
#[inline]
pub fn gaussian_kernel_2d(size: NonZeroUsize, sigma: f64) -> SpectrogramResult<Array2<f64>> {
if size.get().is_multiple_of(2) {
return Err(SpectrogramError::invalid_input(
"kernel size must be odd and > 0",
));
}
if sigma <= 0.0 {
return Err(SpectrogramError::invalid_input("sigma must be > 0"));
}
let size = size.get();
let center = (size / 2) as f64;
let variance = sigma * sigma;
let coeff = 1.0 / (2.0 * PI * variance);
let mut kernel = Array2::<f64>::zeros((size, size));
for i in 0..size {
for j in 0..size {
let x = i as f64 - center;
let y = j as f64 - center;
let exponent = -(x * x + y * y) / (2.0 * variance);
kernel[[i, j]] = coeff * exponent.exp();
}
}
let sum: f64 = kernel.iter().sum();
kernel.mapv_inplace(|v| v / sum);
Ok(kernel)
}
fn create_lowpass_mask(shape: (usize, usize), cutoff_fraction: f64) -> Array2<f64> {
let (nrows, ncols) = shape;
let mut mask = Array2::<f64>::zeros(shape);
let max_freq_row = (nrows / 2) as f64;
let max_freq_col = (ncols / 2) as f64;
let max_radius = (max_freq_row.min(max_freq_col) * cutoff_fraction).powi(2);
for i in 0..nrows {
for j in 0..ncols {
let freq_row = if i <= nrows / 2 {
i as f64
} else {
(i as f64 - nrows as f64).abs()
};
let freq_col = if j <= ncols / 2 {
j as f64
} else {
(j as f64 - ncols as f64).abs()
};
let dist_sq = freq_col.mul_add(freq_col, freq_row.powi(2));
if dist_sq <= max_radius {
mask[[i, j]] = 1.0;
}
}
}
mask
}
#[inline]
pub fn lowpass_filter(
image: &ArrayView2<f64>,
cutoff_fraction: f64,
) -> SpectrogramResult<Array2<f64>> {
if !(0.0..=1.0).contains(&cutoff_fraction) {
return Err(SpectrogramError::invalid_input(
"cutoff_fraction must be between 0.0 and 1.0",
));
}
let spectrum = fft2d(image)?;
let mask = create_lowpass_mask(spectrum.dim(), cutoff_fraction);
let filtered = &spectrum * &mask.mapv(|v| num_complex::Complex::new(v, 0.0));
ifft2d(&filtered, image.ncols())
}
#[inline]
pub fn highpass_filter(
image: &ArrayView2<f64>,
cutoff_fraction: f64,
) -> SpectrogramResult<Array2<f64>> {
if !(0.0..=1.0).contains(&cutoff_fraction) {
return Err(SpectrogramError::invalid_input(
"cutoff_fraction must be between 0.0 and 1.0",
));
}
let spectrum = fft2d(image)?;
let lowpass_mask = create_lowpass_mask(spectrum.dim(), cutoff_fraction);
let highpass_mask = lowpass_mask.mapv(|v| 1.0 - v);
let filtered = &spectrum * &highpass_mask.mapv(|v| num_complex::Complex::new(v, 0.0));
ifft2d(&filtered, image.ncols())
}
#[inline]
pub fn bandpass_filter(
image: &ArrayView2<f64>,
low_cutoff: f64,
high_cutoff: f64,
) -> SpectrogramResult<Array2<f64>> {
if !(0.0..=1.0).contains(&low_cutoff) || !(0.0..=1.0).contains(&high_cutoff) {
return Err(SpectrogramError::invalid_input(
"cutoff fractions must be between 0.0 and 1.0",
));
}
if low_cutoff >= high_cutoff {
return Err(SpectrogramError::invalid_input(
"high_cutoff must be greater than low_cutoff",
));
}
let spectrum = fft2d(image)?;
let low_mask = create_lowpass_mask(spectrum.dim(), low_cutoff);
let high_mask = create_lowpass_mask(spectrum.dim(), high_cutoff);
let bandpass_mask = &high_mask - &low_mask;
let filtered = &spectrum * &bandpass_mask.mapv(|v| num_complex::Complex::new(v, 0.0));
ifft2d(&filtered, image.ncols())
}
#[inline]
pub fn detect_edges_fft(image: &ArrayView2<f64>) -> SpectrogramResult<Array2<f64>> {
highpass_filter(image, 0.1)
}
#[inline]
pub fn sharpen_fft(image: &ArrayView2<f64>, amount: f64) -> SpectrogramResult<Array2<f64>> {
if amount < 0.0 {
return Err(SpectrogramError::invalid_input("amount must be >= 0"));
}
let high_freq = highpass_filter(image, 0.2)?;
Ok(image + &(high_freq * amount))
}
#[cfg(test)]
mod tests {
use crate::nzu;
use super::*;
#[test]
fn test_gaussian_kernel_normalized() {
let kernel = gaussian_kernel_2d(nzu!(5), 1.0).unwrap();
let sum: f64 = kernel.iter().sum();
assert!((sum - 1.0).abs() < 1e-10, "kernel should sum to 1.0");
}
#[test]
fn test_gaussian_kernel_symmetric() {
let kernel = gaussian_kernel_2d(nzu!(5), 1.0).unwrap();
let center = 2;
for i in 0..5 {
for j in 0..5 {
let di = (i as isize - center as isize).unsigned_abs();
let dj = (j as isize - center as isize).unsigned_abs();
let mirrored_i = if i <= center {
center + di
} else {
center - di
};
let mirrored_j = if j <= center {
center + dj
} else {
center - dj
};
if mirrored_i < 5 && mirrored_j < 5 {
assert!(
(kernel[[i, j]] - kernel[[mirrored_i, mirrored_j]]).abs() < 1e-10,
"kernel should be symmetric"
);
}
}
}
}
#[test]
fn test_convolve_fft_identity() {
let image = Array2::<f64>::from_shape_fn((64, 64), |(i, j)| i as f64 + j as f64);
let mut kernel = Array2::<f64>::zeros((3, 3));
kernel[[1, 1]] = 1.0;
let result = convolve_fft(&image.view(), &kernel.view()).unwrap();
for i in 1..63 {
for j in 1..63 {
assert!((result[[i, j]] - image[[i, j]]).abs() < 1e-6);
}
}
}
#[test]
fn test_lowpass_removes_high_freq() {
let image = Array2::<f64>::from_shape_fn((64, 64), |(i, j)| {
((i as f64 * 0.5).sin() + (j as f64 * 0.5).cos()) * 50.0
});
let filtered = lowpass_filter(&image.view(), 0.2).unwrap();
let original_var: f64 = image.iter().map(|&x| x * x).sum::<f64>() / (64.0 * 64.0);
let filtered_var: f64 = filtered.iter().map(|&x| x * x).sum::<f64>() / (64.0 * 64.0);
assert!(
filtered_var < original_var,
"low-pass should reduce variance"
);
}
#[test]
fn test_highpass_emphasizes_edges() {
let constant = Array2::<f64>::from_elem((64, 64), 100.0);
let edges = highpass_filter(&constant.view(), 0.1).unwrap();
let max_val = edges.iter().map(|&x| x.abs()).fold(0.0, f64::max);
assert!(max_val < 1.0, "high-pass of constant should be ~zero");
}
#[test]
fn test_bandpass_bounds() {
let image = Array2::<f64>::from_elem((64, 64), 1.0);
assert!(bandpass_filter(&image.view(), 0.2, 0.8).is_ok());
assert!(bandpass_filter(&image.view(), 0.8, 0.2).is_err());
assert!(bandpass_filter(&image.view(), 0.5, 0.5).is_err());
assert!(bandpass_filter(&image.view(), -0.1, 0.5).is_err());
assert!(bandpass_filter(&image.view(), 0.5, 1.5).is_err());
}
}