use crate::get_hb_filter;
use crate::naive::triband::triband_cascade_noalloc;
use ndarray::{Array2};
use num_complex::Complex;
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2, PyUntypedArrayMethods};
use pyo3::prelude::*;
use pyo3::types::PyAny;
#[pyfunction]
fn triband_cascade<'py>(
py: Python<'py>,
py_inp: Bound<'py, PyAny>,
) -> PyResult<Py<PyArray2<Complex<f32>>>> {
let dtype_obj = py_inp.getattr("dtype")?;
let name_obj = dtype_obj.getattr("name")?;
let dtype_str: &str = name_obj.extract()?;
match dtype_str {
"complex64" => {
let flt = get_hb_filter::<f32>(31);
let inp: PyReadonlyArray2<Complex<f32>> = py_inp.extract()?;
let mut out = Array2::<Complex<f32>>::zeros((
inp.shape()[0] - flt.len() + 1,
inp.shape()[1] * 2 + 1,
));
triband_cascade_noalloc(inp.as_array(), out.view_mut(), &flt, 0);
Ok(out.into_pyarray(py).into())
}
"complex128" => Err(pyo3::exceptions::PyTypeError::new_err(
"Input must be a 2D NumPy array of complex64 complex128 is on the todo list.",
)),
_ => Err(pyo3::exceptions::PyTypeError::new_err(
"Input must be a 2D NumPy array of complex64 or complex128.",
)),
}
}
#[pymodule]
fn mrrs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(triband_cascade, m)?)?;
Ok(())
}