use std::num::NonZeroUsize;
use ndarray::Array2;
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray2};
use pyo3::prelude::*;
use pyo3::types::PyType;
use super::params::PyWindowType;
use non_empty_slice::NonEmptySlice;
use crate::mdct::{MdctParams, imdct, mdct};
#[pyclass(name = "MdctParams", from_py_object)]
#[derive(Debug, Clone)]
pub struct PyMdctParams {
pub(crate) inner: MdctParams,
}
#[pymethods]
impl PyMdctParams {
#[new]
#[pyo3(signature = (window_size: "int", hop_size: "int", window: "WindowType"), text_signature = "(window_size: int, hop_size: int, window: WindowType) -> MdctParams")]
fn new(window_size: usize, hop_size: usize, window: PyWindowType) -> PyResult<Self> {
let ws = NonZeroUsize::new(window_size)
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("window_size must be > 0"))?;
let hs = NonZeroUsize::new(hop_size)
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("hop_size must be > 0"))?;
let inner = MdctParams::new(ws, hs, window.into_inner())
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
Ok(Self { inner })
}
#[classmethod]
#[pyo3(signature = (window_size: "int"), text_signature = "(window_size: int) -> MdctParams")]
fn sine_window(_cls: &Bound<'_, PyType>, window_size: usize) -> PyResult<Self> {
let ws = NonZeroUsize::new(window_size)
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("window_size must be > 0"))?;
let inner = MdctParams::sine_window(ws)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
Ok(Self { inner })
}
#[getter]
const fn window_size(&self) -> usize {
self.inner.window_size.get()
}
#[getter]
const fn hop_size(&self) -> usize {
self.inner.hop_size.get()
}
#[getter]
const fn n_coefficients(&self) -> usize {
self.inner.n_coefficients()
}
fn __repr__(&self) -> String {
format!(
"MdctParams(window_size={}, hop_size={}, n_coefficients={})",
self.inner.window_size.get(),
self.inner.hop_size.get(),
self.inner.n_coefficients()
)
}
}
#[pyfunction(name = "mdct")]
#[pyo3(signature = (samples, params), text_signature = "(samples: numpy.typing.NDArray[numpy.float64], params: MdctParams) -> numpy.typing.NDArray[numpy.float64]")]
pub fn py_compute_mdct<'py>(
py: Python<'py>,
samples: &Bound<'py, PyAny>,
params: &PyMdctParams,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let np = py.import("numpy")?;
let arr = np.call_method1("ascontiguousarray", (samples, "float64"))?;
let arr = arr.cast::<numpy::PyArray1<f64>>()?;
let ro = arr.try_readonly()?;
let slice = ro.as_slice()?;
let samples_ne = NonEmptySlice::new(slice)
.ok_or_else(|| pyo3::exceptions::PyValueError::new_err("samples must not be empty"))?;
let result = py
.detach(|| mdct(samples_ne, ¶ms.inner))
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
Ok(PyArray2::from_owned_array(py, result))
}
#[pyfunction(name = "imdct")]
#[pyo3(signature = (coefficients, params, original_length=None), text_signature = "(coefficients: numpy.typing.NDArray[numpy.float64], params: MdctParams, original_length: int | None = None) -> numpy.typing.NDArray[numpy.float64]")]
pub fn py_compute_imdct<'py>(
py: Python<'py>,
coefficients: PyReadonlyArray2<f64>,
params: &PyMdctParams,
original_length: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<f64>>> {
let coeffs_arr: Array2<f64> = coefficients.as_array().to_owned();
let result = py
.detach(|| imdct(&coeffs_arr, ¶ms.inner, original_length))
.map_err(|e| pyo3::exceptions::PyValueError::new_err(format!("{e}")))?;
Ok(PyArray1::from_vec(py, result))
}
pub fn register(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyMdctParams>()?;
m.add_function(wrap_pyfunction!(py_compute_mdct, m)?)?;
m.add_function(wrap_pyfunction!(py_compute_imdct, m)?)?;
Ok(())
}