use std::num::NonZeroUsize;
use numpy::PyArray2;
use pyo3::prelude::*;
mod binaural;
mod dlpack;
mod error;
mod fft2d;
mod functions;
mod params;
mod planner;
mod spectrogram;
pub use error::*;
pub use params::*;
use crate::Chromagram;
#[pyclass(name = "Chromagram", skip_from_py_object)]
#[derive(Debug)]
pub struct PyChromagram {
pub(crate) inner: Chromagram,
}
impl From<Chromagram> for PyChromagram {
#[inline]
fn from(inner: Chromagram) -> Self {
Self { inner }
}
}
impl From<PyChromagram> for Chromagram {
#[inline]
fn from(val: PyChromagram) -> Self {
val.inner
}
}
#[pymethods]
impl PyChromagram {
#[getter]
fn n_frames(&self) -> NonZeroUsize {
self.inner.n_frames()
}
#[getter]
fn n_bins(&self) -> NonZeroUsize {
self.inner.n_bins()
}
#[getter]
fn params(&self) -> PyChromaParams {
PyChromaParams::from(*self.inner.params())
}
#[classattr]
const fn labels() -> [&'static str; 12] {
Chromagram::labels()
}
fn __array__<'py>(
&self,
py: Python<'py>,
dtype: Option<&Bound<'py, PyAny>>,
) -> PyResult<Py<PyAny>> {
let arr = PyArray2::from_array(py, &self.inner.data);
if let Some(dtype) = dtype {
let casted: Bound<'py, PyAny> = arr.call_method1("astype", (dtype,))?;
Ok(casted.unbind())
} else {
Ok(arr.into_any().unbind())
}
}
#[staticmethod]
const fn __dlpack_device__() -> (i32, i32) {
(1, 0) }
#[pyo3(signature = (*, stream=None, max_version=None, dl_device=None, copy=None))]
fn __dlpack__<'py>(
&self,
py: Python<'py>,
stream: Option<&Bound<'py, PyAny>>,
max_version: Option<(u32, u32)>,
dl_device: Option<(i32, i32)>,
copy: Option<bool>,
) -> PyResult<Bound<'py, pyo3::types::PyCapsule>> {
use crate::python::dlpack::{DLPACK_FLAG_BITMASK_IS_COPIED, create_dlpack_capsule};
if stream.is_some() {
return Err(pyo3::exceptions::PyBufferError::new_err(
"stream must be None for CPU tensors",
));
}
if let Some((major, minor)) = max_version {
if major < 1 {
return Err(pyo3::exceptions::PyBufferError::new_err(format!(
"Unsupported DLPack version: {major}.{minor}"
)));
}
}
if let Some((dev_type, dev_id)) = dl_device {
if dev_type != 1 || dev_id != 0 {
return Err(pyo3::exceptions::PyBufferError::new_err(
"Only CPU device (1, 0) is supported",
));
}
}
let mut flags = 0u64;
if copy == Some(true) {
flags |= DLPACK_FLAG_BITMASK_IS_COPIED;
}
let arr = PyArray2::from_array(py, &self.inner.data);
create_dlpack_capsule(py, &arr, flags)
}
}
pub fn register_module(py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("SpectrogramError", py.get_type::<PySpectrogramError>())?;
m.add("InvalidInputError", py.get_type::<PyInvalidInputError>())?;
m.add(
"DimensionMismatchError",
py.get_type::<PyDimensionMismatchError>(),
)?;
m.add("FFTBackendError", py.get_type::<PyFFTBackendError>())?;
m.add("InternalError", py.get_type::<PyInternalError>())?;
params::register(py, m)?;
spectrogram::register(py, m)?;
planner::register(py, m)?;
functions::register(py, m)?;
fft2d::register(py, m)?;
dlpack::register(py, m)?;
binaural::register(py, m)?;
#[cfg(feature = "realfft")]
{
m.add_function(wrap_pyfunction!(clear_fft_plan_cache, m)?)?;
m.add_function(wrap_pyfunction!(fft_plan_cache_info, m)?)?;
}
Ok(())
}
#[pyfunction]
#[cfg(feature = "realfft")]
fn clear_fft_plan_cache() {
crate::fft_backend::clear_plan_cache();
}
#[pyfunction]
#[cfg(feature = "realfft")]
fn fft_plan_cache_info() -> (usize, usize) {
crate::fft_backend::cache_stats()
}