use std::num::NonZeroUsize;
use non_empty_slice::NonEmptySlice;
use numpy::{PyArray1, PyArray2, PyArrayMethods};
use pyo3::prelude::*;
use crate::binaural::{
ILDSpectrogramParams, ILRSpectrogramParams, IPDSpectrogramParams,
ITDSpectrogramParams,
compute_ild_spectrogram, compute_ilr_spectrogram, compute_ilr_spectrogram_diff,
compute_ipd_spectrogram, compute_itd_spectrogram, compute_itd_spectrogram_diff,
};
use crate::{StftPlan, python::PySpectrogramParams};
#[pyclass(name = "ITDSpectrogramParams", from_py_object)]
#[derive(Debug, Clone)]
pub struct PyITDSpectrogramParams {
pub(crate) inner: ITDSpectrogramParams,
}
#[pymethods]
impl PyITDSpectrogramParams {
#[new]
#[pyo3(signature = (spectrogram_params: "SpectrogramParams", start_freq: "float" = 50.0, end_freq: "float" = 620.0, magphase_power: "Optional[int]" = 1), text_signature = "(spectrogram_params: SpectrogramParams, start_freq: float = 50.0, end_freq: float = 620.0, magphase_power: Optional[int] = 1) -> ITDSpectrogramParams")]
fn new(
spectrogram_params: PySpectrogramParams,
start_freq: Option<f64>,
end_freq: Option<f64>,
magphase_power: Option<usize>,
) -> Self {
let inner = ITDSpectrogramParams {
spectrogram_params: spectrogram_params.into(),
start_freq: start_freq.unwrap_or(50.0),
end_freq: end_freq.unwrap_or(620.0),
magphase_power: magphase_power
.and_then(|p| NonZeroUsize::new(p))
.unwrap_or_else(|| crate::nzu!(1)),
};
Self { inner }
}
#[getter]
fn spectrogram_params(&self) -> PySpectrogramParams {
PySpectrogramParams::from(self.inner.spectrogram_params.clone())
}
#[getter]
fn start_freq(&self) -> f64 {
self.inner.start_freq
}
#[getter]
fn end_freq(&self) -> f64 {
self.inner.end_freq
}
#[getter]
fn magphase_power(&self) -> NonZeroUsize {
self.inner.magphase_power
}
}
impl From<ITDSpectrogramParams> for PyITDSpectrogramParams {
#[inline]
fn from(inner: ITDSpectrogramParams) -> Self {
Self { inner }
}
}
impl From<PyITDSpectrogramParams> for ITDSpectrogramParams {
#[inline]
fn from(val: PyITDSpectrogramParams) -> Self {
val.inner
}
}
#[pyfunction(name = "compute_itd_spectrogram")]
#[pyo3(signature = (audio: "list[numpy.typing.NDArray[numpy.float64]]", params: "ITDSpectrogramParams"), text_signature = "(audio: list[numpy.typing.NDArray[numpy.float64]], params: ITDSpectrogramParams) -> numpy.typing.NDArray[numpy.float64]")]
fn py_compute_itd_spectrogram<'py>(
py: Python<'py>,
audio: [Bound<'py, PyArray1<f64>>; 2],
params: &'py PyITDSpectrogramParams,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let mut plan: StftPlan = StftPlan::new(¶ms.inner.spectrogram_params).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Failed to create STFT plan: {}",
e
))
})?;
let left_slice = unsafe {
NonEmptySlice::new_unchecked(audio[0].as_slice().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Left audio array must be contiguous and of type float64.",
)
})?)
};
let right_slice = unsafe {
NonEmptySlice::new_unchecked(audio[1].as_slice().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(
"Right audio array must be contiguous and of type float64.",
)
})?)
};
let audio_slices = [left_slice, right_slice];
let itd_spectrogram =
compute_itd_spectrogram(audio_slices, ¶ms.inner, &mut plan).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!(
"Failed to compute ITD spectrogram: {}",
e
))
})?;
let py_array = PyArray2::from_owned_array(py, itd_spectrogram.data);
Ok(py_array.into())
}
#[pyclass(name = "IPDSpectrogramParams", from_py_object)]
#[derive(Debug, Clone)]
pub struct PyIPDSpectrogramParams {
pub(crate) inner: IPDSpectrogramParams,
}
#[pymethods]
impl PyIPDSpectrogramParams {
#[new]
#[pyo3(signature = (spectrogram_params, start_freq = 50.0, end_freq = 620.0, wrapped = false), text_signature = "(spectrogram_params: SpectrogramParams, start_freq: float = 50.0, end_freq: float = 620.0, wrapped: bool = False) -> IPDSpectrogramParams")]
fn new(
spectrogram_params: PySpectrogramParams,
start_freq: Option<f64>,
end_freq: Option<f64>,
wrapped: Option<bool>,
) -> PyResult<Self> {
let inner = IPDSpectrogramParams::new(
spectrogram_params.into(),
start_freq.unwrap_or(50.0),
end_freq.unwrap_or(620.0),
wrapped.unwrap_or(false),
).map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?;
Ok(Self { inner })
}
#[getter]
fn spectrogram_params(&self) -> PySpectrogramParams {
PySpectrogramParams::from(self.inner.spectrogram_params.clone())
}
#[getter]
fn start_freq(&self) -> f64 {
self.inner.start_freq
}
#[getter]
fn end_freq(&self) -> f64 {
self.inner.end_freq
}
#[getter]
fn wrapped(&self) -> bool {
self.inner.wrapped
}
}
#[pyfunction(name = "compute_ipd_spectrogram")]
#[pyo3(signature = (audio: "list[numpy.typing.NDArray[numpy.float64]]", params: "IPDSpectrogramParams"), text_signature = "(audio: list[numpy.typing.NDArray[numpy.float64]], params: IPDSpectrogramParams) -> numpy.typing.NDArray[numpy.float64]")]
fn py_compute_ipd_spectrogram<'py>(
py: Python<'py>,
audio: [Bound<'py, PyArray1<f64>>; 2],
params: &'py PyIPDSpectrogramParams,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let mut plan = StftPlan::new(¶ms.inner.spectrogram_params).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to create STFT plan: {}", e))
})?;
let left_slice = unsafe {
NonEmptySlice::new_unchecked(audio[0].as_slice().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>("Left audio array must be contiguous and of type float64.")
})?)
};
let right_slice = unsafe {
NonEmptySlice::new_unchecked(audio[1].as_slice().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>("Right audio array must be contiguous and of type float64.")
})?)
};
let ipd_spectrogram = compute_ipd_spectrogram([left_slice, right_slice], ¶ms.inner, &mut plan)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to compute IPD spectrogram: {}", e)))?;
Ok(PyArray2::from_owned_array(py, ipd_spectrogram.data).into())
}
#[pyclass(name = "ILDSpectrogramParams", from_py_object)]
#[derive(Debug, Clone)]
pub struct PyILDSpectrogramParams {
pub(crate) inner: ILDSpectrogramParams,
}
#[pymethods]
impl PyILDSpectrogramParams {
#[new]
#[pyo3(signature = (spectrogram_params: "SpectrogramParams", start_freq: "float" = 1700.0, end_freq: "float" = 4600.0), text_signature = "(spectrogram_params: SpectrogramParams, start_freq: float = 1700.0, end_freq: float = 4600.0) -> ILDSpectrogramParams")]
fn new(
spectrogram_params: PySpectrogramParams,
start_freq: Option<f64>,
end_freq: Option<f64>,
) -> PyResult<Self> {
let inner = ILDSpectrogramParams::new(
spectrogram_params.into(),
start_freq.unwrap_or(1700.0),
end_freq.unwrap_or(4600.0),
).map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?;
Ok(Self { inner })
}
#[getter]
fn spectrogram_params(&self) -> PySpectrogramParams {
PySpectrogramParams::from(self.inner.spectrogram_params.clone())
}
#[getter]
fn start_freq(&self) -> f64 {
self.inner.start_freq
}
#[getter]
fn end_freq(&self) -> f64 {
self.inner.end_freq
}
}
#[pyfunction(name = "compute_ild_spectrogram")]
#[pyo3(signature = (audio: "list[numpy.typing.NDArray[numpy.float64]]", params: "ILDSpectrogramParams"), text_signature = "(audio: list[numpy.typing.NDArray[numpy.float64]], params: ILDSpectrogramParams) -> numpy.typing.NDArray[numpy.float64]")]
fn py_compute_ild_spectrogram<'py>(
py: Python<'py>,
audio: [Bound<'py, PyArray1<f64>>; 2],
params: &'py PyILDSpectrogramParams,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let mut plan = StftPlan::new(¶ms.inner.spectrogram_params).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to create STFT plan: {}", e))
})?;
let left_slice = unsafe {
NonEmptySlice::new_unchecked(audio[0].as_slice().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>("Left audio array must be contiguous and of type float64.")
})?)
};
let right_slice = unsafe {
NonEmptySlice::new_unchecked(audio[1].as_slice().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>("Right audio array must be contiguous and of type float64.")
})?)
};
let ild_spectrogram = compute_ild_spectrogram([left_slice, right_slice], ¶ms.inner, &mut plan)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to compute ILD spectrogram: {}", e)))?;
Ok(PyArray2::from_owned_array(py, ild_spectrogram.data).into())
}
#[pyclass(name = "ILRSpectrogramParams", from_py_object)]
#[derive(Debug, Clone)]
pub struct PyILRSpectrogramParams {
pub(crate) inner: ILRSpectrogramParams,
}
#[pymethods]
impl PyILRSpectrogramParams {
#[new]
#[pyo3(signature = (spectrogram_params: "SpectrogramParams", start_freq: "float" = 1700.0, end_freq: "float" = 4600.0), text_signature = "(spectrogram_params: SpectrogramParams, start_freq: float = 1700.0, end_freq: float = 4600.0) -> ILRSpectrogramParams")]
fn new(
spectrogram_params: PySpectrogramParams,
start_freq: Option<f64>,
end_freq: Option<f64>,
) -> PyResult<Self> {
let inner = ILRSpectrogramParams::new(
spectrogram_params.into(),
start_freq.unwrap_or(1700.0),
end_freq.unwrap_or(4600.0),
).map_err(|e| PyErr::new::<pyo3::exceptions::PyValueError, _>(format!("{}", e)))?;
Ok(Self { inner })
}
#[getter]
fn spectrogram_params(&self) -> PySpectrogramParams {
PySpectrogramParams::from(self.inner.spectrogram_params.clone())
}
#[getter]
fn start_freq(&self) -> f64 {
self.inner.start_freq
}
#[getter]
fn end_freq(&self) -> f64 {
self.inner.end_freq
}
}
#[pyfunction(name = "compute_ilr_spectrogram")]
#[pyo3(signature = (audio, params), text_signature = "(audio: list[numpy.typing.NDArray[numpy.float64]], params: ILRSpectrogramParams) -> numpy.typing.NDArray[numpy.float64]")]
fn py_compute_ilr_spectrogram<'py>(
py: Python<'py>,
audio: [Bound<'py, PyArray1<f64>>; 2],
params: &'py PyILRSpectrogramParams,
) -> PyResult<Bound<'py, PyArray2<f64>>> {
let mut plan = StftPlan::new(¶ms.inner.spectrogram_params).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to create STFT plan: {}", e))
})?;
let left_slice = unsafe {
NonEmptySlice::new_unchecked(audio[0].as_slice().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>("Left audio array must be contiguous and of type float64.")
})?)
};
let right_slice = unsafe {
NonEmptySlice::new_unchecked(audio[1].as_slice().map_err(|_| {
PyErr::new::<pyo3::exceptions::PyValueError, _>("Right audio array must be contiguous and of type float64.")
})?)
};
let ilr_spectrogram = compute_ilr_spectrogram([left_slice, right_slice], ¶ms.inner, &mut plan)
.map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to compute ILR spectrogram: {}", e)))?;
Ok(PyArray2::from_owned_array(py, ilr_spectrogram.data).into())
}
#[pyfunction(name = "compute_itd_spectrogram_diff")]
#[pyo3(signature = (reference, test, params), text_signature = "(reference: list[numpy.typing.NDArray[numpy.float64]], test: list[numpy.typing.NDArray[numpy.float64]], params: ITDSpectrogramParams) -> tuple[numpy.typing.NDArray[numpy.float64], float, float]")]
fn py_compute_itd_spectrogram_diff<'py>(
py: Python<'py>,
reference: [Bound<'py, PyArray1<f64>>; 2],
test: [Bound<'py, PyArray1<f64>>; 2],
params: &'py PyITDSpectrogramParams,
) -> PyResult<(Bound<'py, PyArray1<f64>>, f64, f64)> {
let mut plan = StftPlan::new(¶ms.inner.spectrogram_params).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to create STFT plan: {}", e))
})?;
let left_ref = unsafe { NonEmptySlice::new_unchecked(reference[0].as_slice().map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Left reference array must be contiguous float64."))?) };
let right_ref = unsafe { NonEmptySlice::new_unchecked(reference[1].as_slice().map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Right reference array must be contiguous float64."))?) };
let left_test = unsafe { NonEmptySlice::new_unchecked(test[0].as_slice().map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Left test array must be contiguous float64."))?) };
let right_test = unsafe { NonEmptySlice::new_unchecked(test[1].as_slice().map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Right test array must be contiguous float64."))?) };
let (time_diff, mean_deg, mean_itd) = compute_itd_spectrogram_diff(
[left_ref, right_ref], [left_test, right_test], ¶ms.inner, &mut plan
).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to compute ITD diff: {}", e)))?;
Ok((PyArray1::from_owned_array(py, time_diff).into(), mean_deg, mean_itd))
}
#[pyfunction(name = "compute_ilr_spectrogram_diff")]
#[pyo3(signature = (reference, test, params), text_signature = "(reference: list[numpy.typing.NDArray[numpy.float64]], test: list[numpy.typing.NDArray[numpy.float64]], params: ILRSpectrogramParams) -> tuple[numpy.typing.NDArray[numpy.float64], float]")]
fn py_compute_ilr_spectrogram_diff<'py>(
py: Python<'py>,
reference: [Bound<'py, PyArray1<f64>>; 2],
test: [Bound<'py, PyArray1<f64>>; 2],
params: &'py PyILRSpectrogramParams,
) -> PyResult<(Bound<'py, PyArray1<f64>>, f64)> {
let mut plan = StftPlan::new(¶ms.inner.spectrogram_params).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to create STFT plan: {}", e))
})?;
let left_ref = unsafe { NonEmptySlice::new_unchecked(reference[0].as_slice().map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Left reference array must be contiguous float64."))?) };
let right_ref = unsafe { NonEmptySlice::new_unchecked(reference[1].as_slice().map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Right reference array must be contiguous float64."))?) };
let left_test = unsafe { NonEmptySlice::new_unchecked(test[0].as_slice().map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Left test array must be contiguous float64."))?) };
let right_test = unsafe { NonEmptySlice::new_unchecked(test[1].as_slice().map_err(|_| PyErr::new::<pyo3::exceptions::PyValueError, _>("Right test array must be contiguous float64."))?) };
let (time_diff, mean_diff) = compute_ilr_spectrogram_diff(
[left_ref, right_ref], [left_test, right_test], ¶ms.inner, &mut plan
).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(format!("Failed to compute ILR diff: {}", e)))?;
Ok((PyArray1::from_owned_array(py, time_diff).into(), mean_diff))
}
pub fn register(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(py_compute_itd_spectrogram, m)?)?;
m.add_function(wrap_pyfunction!(py_compute_itd_spectrogram_diff, m)?)?;
m.add_class::<PyITDSpectrogramParams>()?;
m.add_function(wrap_pyfunction!(py_compute_ipd_spectrogram, m)?)?;
m.add_class::<PyIPDSpectrogramParams>()?;
m.add_function(wrap_pyfunction!(py_compute_ild_spectrogram, m)?)?;
m.add_class::<PyILDSpectrogramParams>()?;
m.add_function(wrap_pyfunction!(py_compute_ilr_spectrogram, m)?)?;
m.add_function(wrap_pyfunction!(py_compute_ilr_spectrogram_diff, m)?)?;
m.add_class::<PyILRSpectrogramParams>()?;
Ok(())
}