use numpy::{
Complex64, PyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2, ToPyArray,
};
use pyo3::prelude::*;
use crate::fft2d as rust_fft2d;
use crate::fft2d::Fft2dPlanner as RustFft2dPlanner;
use crate::image_ops;
fn extract_array<'py>(
_py: Python<'py>,
obj: &Bound<'py, PyAny>,
) -> PyResult<PyReadonlyArray2<'py, f64>> {
if let Ok(arr) = obj.extract::<PyReadonlyArray2<f64>>() {
return Ok(arr);
}
if obj.hasattr("__array__")? {
let array_result = obj.call_method0("__array__")?;
return array_result
.extract::<PyReadonlyArray2<f64>>()
.map_err(|e| {
pyo3::exceptions::PyTypeError::new_err(format!("Failed to extract array: {e}"))
});
}
Err(pyo3::exceptions::PyTypeError::new_err(
"Object must be a numpy array or implement __array__()",
))
}
fn extract_array_1d<'py>(
_py: Python<'py>,
obj: &Bound<'py, PyAny>,
) -> PyResult<PyReadonlyArray1<'py, f64>> {
if let Ok(arr) = obj.extract::<PyReadonlyArray1<f64>>() {
return Ok(arr);
}
if obj.hasattr("__array__")? {
let array_result = obj.call_method0("__array__")?;
return array_result
.extract::<PyReadonlyArray1<f64>>()
.map_err(|e| {
pyo3::exceptions::PyTypeError::new_err(format!("Failed to extract 1D array: {e}"))
});
}
Err(pyo3::exceptions::PyTypeError::new_err(
"Object must be a 1D numpy array or implement __array__()",
))
}
#[pyfunction]
#[inline]
#[pyo3(signature = (data: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(data: numpy.typing.NDArray[numpy.float64])")]
pub fn fft2d(py: Python, data: &Bound<'_, PyAny>) -> PyResult<Py<PyArray2<Complex64>>> {
let data_arr = extract_array(py, data)?;
let data_view = data_arr.as_array();
let result = py.detach(|| rust_fft2d::fft2d(&data_view))?;
let result_complex64 = result.mapv(|c| Complex64::new(c.re, c.im));
Ok(result_complex64.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (spectrum: "numpy.typing.NDArray[numpy.complex64]", output_ncols: "int"), text_signature = "(spectrum: numpy.typing.NDArray[numpy.complex64], output_ncols: int)")]
pub fn ifft2d(
py: Python,
spectrum: PyReadonlyArray2<Complex64>,
output_ncols: usize,
) -> PyResult<Py<PyArray2<f64>>> {
let spectrum_arr = spectrum.as_array();
let spectrum_f64 = spectrum_arr.mapv(|c| num_complex::Complex::new(c.re as f64, c.im as f64));
let result = py.detach(|| rust_fft2d::ifft2d(&spectrum_f64, output_ncols))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (data: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(data: numpy.typing.NDArray[numpy.float64])")]
pub fn power_spectrum_2d(py: Python, data: &Bound<'_, PyAny>) -> PyResult<Py<PyArray2<f64>>> {
let data_arr = extract_array(py, data)?;
let data_view = data_arr.as_array();
let result = py.detach(|| rust_fft2d::power_spectrum_2d(&data_view))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (data: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(data: numpy.typing.NDArray[numpy.float64])")]
pub fn magnitude_spectrum_2d(py: Python, data: &Bound<'_, PyAny>) -> PyResult<Py<PyArray2<f64>>> {
let data_arr = extract_array(py, data)?;
let data_view = data_arr.as_array();
let result = py.detach(|| rust_fft2d::magnitude_spectrum_2d(&data_view))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (arr: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(arr: numpy.typing.NDArray[numpy.float64])")]
pub fn fftshift(py: Python, arr: &Bound<'_, PyAny>) -> PyResult<Py<PyArray2<f64>>> {
let arr_data = extract_array(py, arr)?;
let arr_owned = arr_data.as_array().to_owned();
let result = rust_fft2d::fftshift(arr_owned);
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (arr: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(arr: numpy.typing.NDArray[numpy.float64])")]
pub fn ifftshift(py: Python, arr: &Bound<'_, PyAny>) -> PyResult<Py<PyArray2<f64>>> {
let arr_data = extract_array(py, arr)?;
let arr_owned = arr_data.as_array().to_owned();
let result = rust_fft2d::ifftshift(arr_owned);
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (arr: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(arr: numpy.typing.NDArray[numpy.float64])")]
pub fn fftshift_1d(py: Python, arr: &Bound<'_, PyAny>) -> PyResult<Py<PyArray1<f64>>> {
let arr_data = extract_array_1d(py, arr)?;
let arr_vec = arr_data.as_slice()?.to_vec();
let result = rust_fft2d::fftshift_1d(arr_vec);
Ok(PyArray1::from_vec(py, result).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (arr: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(arr: numpy.typing.NDArray[numpy.float64])")]
pub fn ifftshift_1d(py: Python, arr: &Bound<'_, PyAny>) -> PyResult<Py<PyArray1<f64>>> {
let arr_data = extract_array_1d(py, arr)?;
let arr_vec = arr_data.as_slice()?.to_vec();
let result = rust_fft2d::ifftshift_1d(arr_vec);
Ok(PyArray1::from_vec(py, result).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (n, d = 1.0), text_signature = "(n: int, d: float = 1.0)")]
pub fn fftfreq(py: Python, n: usize, d: f64) -> Py<PyArray<f64, numpy::Ix1>> {
let freqs = rust_fft2d::fftfreq(n, d);
numpy::PyArray1::from_vec(py, freqs).unbind()
}
#[pyfunction]
#[inline]
#[pyo3(signature = (n, d = 1.0), text_signature = "(n: int, d: float = 1.0)")]
pub fn rfftfreq(py: Python, n: usize, d: f64) -> Py<numpy::PyArray<f64, numpy::Ix1>> {
let freqs = rust_fft2d::rfftfreq(n, d);
numpy::PyArray1::from_vec(py, freqs).unbind()
}
#[pyfunction]
#[inline]
#[pyo3(signature = (size: "int", sigma: "float"), text_signature = "(size: int, sigma: float)")]
pub fn gaussian_kernel_2d(py: Python, size: usize, sigma: f64) -> PyResult<Py<PyArray2<f64>>> {
let size = std::num::NonZeroUsize::new(size).ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("size must be a non-zero odd integer")
})?;
let result = py.detach(|| image_ops::gaussian_kernel_2d(size, sigma))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (image: "numpy.typing.NDArray[numpy.float64]", kernel: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(image: numpy.typing.NDArray[numpy.float64], kernel: numpy.typing.NDArray[numpy.float64])")]
pub fn convolve_fft(
py: Python,
image: &Bound<'_, PyAny>,
kernel: &Bound<'_, PyAny>,
) -> PyResult<Py<PyArray2<f64>>> {
let image_arr = extract_array(py, image)?;
let kernel_arr = extract_array(py, kernel)?;
let image_view = image_arr.as_array();
let kernel_view = kernel_arr.as_array();
let result = py.detach(|| image_ops::convolve_fft(&image_view, &kernel_view))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (image: "numpy.typing.NDArray[numpy.float64]", cutoff_fraction: "float"), text_signature = "(image: numpy.typing.NDArray[numpy.float64], cutoff_fraction: float)")]
pub fn lowpass_filter(
py: Python,
image: &Bound<'_, PyAny>,
cutoff_fraction: f64,
) -> PyResult<Py<PyArray2<f64>>> {
let image_arr = extract_array(py, image)?;
let image_view = image_arr.as_array();
let result = py.detach(|| image_ops::lowpass_filter(&image_view, cutoff_fraction))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (image: "numpy.typing.NDArray[numpy.float64]", cutoff_fraction: "float"), text_signature = "(image: numpy.typing.NDArray[numpy.float64], cutoff_fraction: float)")]
pub fn highpass_filter(
py: Python,
image: &Bound<'_, PyAny>,
cutoff_fraction: f64,
) -> PyResult<Py<PyArray2<f64>>> {
let image_arr = extract_array(py, image)?;
let image_view = image_arr.as_array();
let result = py.detach(|| image_ops::highpass_filter(&image_view, cutoff_fraction))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (image: "numpy.typing.NDArray[numpy.float64]", low_cutoff: "float", high_cutoff: "float"), text_signature = "(image: numpy.typing.NDArray[numpy.float64], low_cutoff: float, high_cutoff: float)")]
pub fn bandpass_filter(
py: Python,
image: &Bound<'_, PyAny>,
low_cutoff: f64,
high_cutoff: f64,
) -> PyResult<Py<PyArray2<f64>>> {
let image_arr = extract_array(py, image)?;
let image_view = image_arr.as_array();
let result = py.detach(|| image_ops::bandpass_filter(&image_view, low_cutoff, high_cutoff))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (image: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(image: numpy.typing.NDArray[numpy.float64])")]
pub fn detect_edges_fft(py: Python, image: &Bound<'_, PyAny>) -> PyResult<Py<PyArray2<f64>>> {
let image_arr = extract_array(py, image)?;
let image_view = image_arr.as_array();
let result = py.detach(|| image_ops::detect_edges_fft(&image_view))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyfunction]
#[inline]
#[pyo3(signature = (image: "numpy.typing.NDArray[numpy.float64]", amount: "float"), text_signature = "(image: numpy.typing.NDArray[numpy.float64], amount: float)")]
pub fn sharpen_fft(
py: Python,
image: &Bound<'_, PyAny>,
amount: f64,
) -> PyResult<Py<PyArray2<f64>>> {
let image_arr = extract_array(py, image)?;
let image_view = image_arr.as_array();
let result = py.detach(|| image_ops::sharpen_fft(&image_view, amount))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyclass(name = "Fft2dPlanner", skip_from_py_object)]
pub struct PyFft2dPlanner {
inner: RustFft2dPlanner,
}
#[pymethods]
impl PyFft2dPlanner {
#[new]
fn new() -> Self {
Self {
inner: RustFft2dPlanner::new(),
}
}
#[pyo3(signature = (data: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(data: numpy.typing.NDArray[numpy.float64])")]
fn fft2d(
&mut self,
py: Python,
data: PyReadonlyArray2<f64>,
) -> PyResult<Py<PyArray2<Complex64>>> {
let data_arr = data.as_array();
let result = py.detach(|| self.inner.fft2d(&data_arr))?;
let result_complex64 = result.mapv(|c| Complex64::new(c.re, c.im));
Ok(result_complex64.to_pyarray(py).unbind())
}
#[pyo3(signature = (spectrum: "numpy.typing.NDArray[numpy.complex64]", output_ncols: "int"), text_signature = "(spectrum: numpy.typing.NDArray[numpy.complex64], output_ncols: int)")]
fn ifft2d(
&mut self,
py: Python,
spectrum: PyReadonlyArray2<Complex64>,
output_ncols: usize,
) -> PyResult<Py<PyArray2<f64>>> {
let spectrum_arr = spectrum.as_array();
let spectrum_f64 =
spectrum_arr.mapv(|c| num_complex::Complex::new(c.re as f64, c.im as f64));
let result = py.detach(|| self.inner.ifft2d(&spectrum_f64.view(), output_ncols))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyo3(signature = (data: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(data: numpy.typing.NDArray[numpy.float64])")]
fn power_spectrum_2d(
&mut self,
py: Python,
data: PyReadonlyArray2<f64>,
) -> PyResult<Py<PyArray2<f64>>> {
let data_arr = data.as_array();
let result = py.detach(|| self.inner.power_spectrum_2d(&data_arr))?;
Ok(result.to_pyarray(py).unbind())
}
#[pyo3(signature = (data: "numpy.typing.NDArray[numpy.float64]"), text_signature = "(data: numpy.typing.NDArray[numpy.float64])")]
fn magnitude_spectrum_2d(
&mut self,
py: Python,
data: PyReadonlyArray2<f64>,
) -> PyResult<Py<PyArray2<f64>>> {
let data_arr = data.as_array();
let result = py.detach(|| self.inner.magnitude_spectrum_2d(&data_arr.view()))?;
Ok(result.to_pyarray(py).unbind())
}
}
pub fn register(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(fft2d, m)?)?;
m.add_function(wrap_pyfunction!(ifft2d, m)?)?;
m.add_function(wrap_pyfunction!(power_spectrum_2d, m)?)?;
m.add_function(wrap_pyfunction!(magnitude_spectrum_2d, m)?)?;
m.add_function(wrap_pyfunction!(fftshift, m)?)?;
m.add_function(wrap_pyfunction!(ifftshift, m)?)?;
m.add_function(wrap_pyfunction!(fftshift_1d, m)?)?;
m.add_function(wrap_pyfunction!(ifftshift_1d, m)?)?;
m.add_function(wrap_pyfunction!(fftfreq, m)?)?;
m.add_function(wrap_pyfunction!(rfftfreq, m)?)?;
m.add_function(wrap_pyfunction!(gaussian_kernel_2d, m)?)?;
m.add_function(wrap_pyfunction!(convolve_fft, m)?)?;
m.add_function(wrap_pyfunction!(lowpass_filter, m)?)?;
m.add_function(wrap_pyfunction!(highpass_filter, m)?)?;
m.add_function(wrap_pyfunction!(bandpass_filter, m)?)?;
m.add_function(wrap_pyfunction!(detect_edges_fft, m)?)?;
m.add_function(wrap_pyfunction!(sharpen_fft, m)?)?;
m.add_class::<PyFft2dPlanner>()?;
Ok(())
}