mrrs 0.1.1

Multi rate filtering tools
Documentation
use crate::get_hb_filter;
use crate::naive::triband::triband_cascade_noalloc;
use ndarray::{Array2};
use num_complex::Complex;
use numpy::{IntoPyArray, PyArray2, PyReadonlyArray2, PyUntypedArrayMethods};
use pyo3::prelude::*;
use pyo3::types::PyAny;
//use num_traits::{Float, NumAssignOps};
//use pyo3::PyAnyMethods;

#[pyfunction]
fn triband_cascade<'py>(
    py: Python<'py>,
    py_inp: Bound<'py, PyAny>,
) -> PyResult<Py<PyArray2<Complex<f32>>>> {
    let dtype_obj = py_inp.getattr("dtype")?;
    let name_obj = dtype_obj.getattr("name")?;
    let dtype_str: &str = name_obj.extract()?;
    match dtype_str {
        "complex64" => {
            let flt = get_hb_filter::<f32>(31);
            let inp: PyReadonlyArray2<Complex<f32>> = py_inp.extract()?;
            //let out = PyArray2::<Complex<f32>>::zeros((inp.shape()[0] - flt.len() + 1, inp.shape()[1]*2 + 1));
            //Ok(out.into_py(py)) // return as NumPy array
            let mut out = Array2::<Complex<f32>>::zeros((
                inp.shape()[0] - flt.len() + 1,
                inp.shape()[1] * 2 + 1,
            ));
            triband_cascade_noalloc(inp.as_array(), out.view_mut(), &flt, 0);
            Ok(out.into_pyarray(py).into())
        }
        "complex128" => Err(pyo3::exceptions::PyTypeError::new_err(
            "Input must be a 2D NumPy array of complex64 complex128 is on the todo list.",
        )),
        _ => Err(pyo3::exceptions::PyTypeError::new_err(
            "Input must be a 2D NumPy array of complex64 or complex128.",
        )),
    }
}

/// A Python module implemented in Rust.
#[pymodule]
fn mrrs(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(triband_cascade, m)?)?;
    Ok(())
}