use std::num::NonZeroUsize;
use num_complex::Complex;
use numpy::{PyArray1, PyArray2, PyReadonlyArray1};
use pyo3::prelude::*;
use pyo3::types::PyType;
use crate::{
ChromaNorm, ChromaParams, CqtParams, ErbParams, LogHzParams, LogParams, MelNorm, MelParams,
MfccParams, SpectrogramParams, StftParams, StftResult, WindowType,
};
#[pyclass(name = "WindowType", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyWindowType {
pub(crate) inner: WindowType,
}
impl PyWindowType { pub fn into_inner(self) -> WindowType { self.inner } pub fn as_inner(&self) -> &WindowType { &self.inner} }
#[pymethods]
impl PyWindowType {
#[classattr]
const fn rectangular() -> Self {
Self {
inner: WindowType::Rectangular,
}
}
#[classattr]
const fn hanning() -> Self {
Self {
inner: WindowType::Hanning,
}
}
#[classattr]
const fn hamming() -> Self {
Self {
inner: WindowType::Hamming,
}
}
#[classattr]
const fn blackman() -> Self {
Self {
inner: WindowType::Blackman,
}
}
#[classmethod]
#[pyo3(signature = (beta: "float"), text_signature = "(beta: float) -> WindowType")]
const fn kaiser(_cls: &Bound<'_, PyType>, beta: f64) -> Self {
Self {
inner: WindowType::Kaiser { beta },
}
}
#[classmethod]
#[pyo3(signature = (std: "float"), text_signature = "(std: float) -> WindowType")]
const fn gaussian(_cls: &Bound<'_, PyType>, std: f64) -> Self {
Self {
inner: WindowType::Gaussian { std },
}
}
#[classmethod]
#[pyo3(signature = (coefficients, normalize=None), text_signature = "(coefficients, normalize=None) -> WindowType")]
fn custom(
_cls: &Bound<'_, PyType>,
coefficients: PyReadonlyArray1<f64>,
normalize: Option<&str>,
) -> PyResult<Self> {
let vec = coefficients.as_slice()?.to_vec();
let inner = WindowType::custom_with_normalization(vec, normalize)?;
Ok(Self { inner })
}
#[staticmethod]
#[pyo3(signature = (n: "int"), text_signature = "(n: int) -> numpy.ndarray")]
fn make_hanning(py: Python<'_>, n: NonZeroUsize) -> Bound<'_, PyArray1<f64>> {
let window_vec = crate::window::hanning_window(n);
PyArray1::from_vec(py, window_vec.into_vec())
}
#[staticmethod]
#[pyo3(signature = (n: "int"), text_signature = "(n: int) -> numpy.ndarray")]
fn make_hamming(py: Python<'_>, n: NonZeroUsize) -> Bound<'_, PyArray1<f64>> {
let window_vec = crate::window::hamming_window(n);
PyArray1::from_vec(py, window_vec.into_vec())
}
#[staticmethod]
#[pyo3(signature = (n: "int"), text_signature = "(n: int) -> numpy.ndarray")]
fn make_blackman(py: Python<'_>, n: NonZeroUsize) -> Bound<'_, PyArray1<f64>> {
let window_vec = crate::window::blackman_window(n);
PyArray1::from_vec(py, window_vec.into_vec())
}
#[staticmethod]
#[pyo3(signature = (n: "int", beta: "float"), text_signature = "(n: int, beta: float) -> numpy.ndarray")]
fn make_kaiser(py: Python<'_>, n: NonZeroUsize, beta: f64) -> Bound<'_, PyArray1<f64>> {
let window_vec = crate::window::kaiser_window(n, beta);
PyArray1::from_vec(py, window_vec.into_vec())
}
#[staticmethod]
#[pyo3(signature = (n: "int", std: "float"), text_signature = "(n: int, std: float) -> numpy.ndarray")]
fn make_gaussian(py: Python<'_>, n: NonZeroUsize, std: f64) -> Bound<'_, PyArray1<f64>> {
let window_vec = crate::window::gaussian_window(n, std);
PyArray1::from_vec(py, window_vec.into_vec())
}
fn __repr__(&self) -> String {
format!("{}", self.inner)
}
}
impl From<WindowType> for PyWindowType {
fn from(wt: WindowType) -> Self {
Self { inner: wt }
}
}
#[pyclass(name = "StftResult", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyStftResult {
pub(crate) inner: StftResult,
}
impl From<StftResult> for PyStftResult {
fn from(inner: StftResult) -> Self {
Self { inner }
}
}
impl From<PyStftResult> for StftResult {
#[inline]
fn from(val: PyStftResult) -> Self {
val.inner
}
}
impl PyStftResult {
pub fn from_inner(inner: StftResult) -> Self {
Self { inner }
}
pub fn into_inner(self) -> StftResult {
self.inner
}
}
#[pymethods]
impl PyStftResult {
#[getter]
fn n_bins(&self) -> usize {
self.inner.n_bins().get()
}
#[getter]
fn n_frames(&self) -> usize {
self.inner.n_frames().get()
}
#[getter]
fn frequency_resolution(&self) -> f64 {
self.inner.frequency_resolution()
}
#[getter]
fn time_resolution(&self) -> f64 {
self.inner.time_resolution()
}
#[getter]
fn params(&self) -> PyStftParams {
PyStftParams {
inner: self.inner.params.clone(),
}
}
#[getter]
fn sample_rate(&self) -> f64 {
self.inner.sample_rate
}
fn norm<'py>(&'py self, py: Python<'py>) -> Bound<'py, PyArray2<f64>> {
PyArray2::from_owned_array(py, self.inner.norm())
}
fn data<'py>(&'py self, py: Python<'py>) -> Bound<'py, PyArray2<Complex<f64>>> {
PyArray2::from_owned_array(py, self.inner.data.clone())
}
}
#[pyclass(name = "StftParams", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyStftParams {
pub inner: StftParams,
}
#[pymethods]
impl PyStftParams {
#[new]
#[pyo3(signature = (
n_fft: "int",
hop_size: "int",
window: "WindowType",
centre: "bool" = true
), text_signature = "(n_fft: int, hop_size: int, window: WindowType, centre: bool = True)")]
fn new(
n_fft: NonZeroUsize,
hop_size: NonZeroUsize,
window: PyWindowType,
centre: bool,
) -> PyResult<Self> {
let inner = StftParams::new(n_fft, hop_size, window.inner, centre)?;
Ok(Self { inner })
}
#[getter]
const fn n_fft(&self) -> NonZeroUsize {
self.inner.n_fft()
}
#[getter]
const fn hop_size(&self) -> NonZeroUsize {
self.inner.hop_size()
}
#[getter]
fn window(&self) -> PyWindowType {
PyWindowType {
inner: self.inner.window(),
}
}
#[getter]
const fn centre(&self) -> bool {
self.inner.centre()
}
fn __repr__(&self) -> String {
format!(
"StftParams(n_fft={}, hop_size={}, window={}, centre={})",
self.n_fft(),
self.hop_size(),
self.window().__repr__(),
self.centre()
)
}
}
impl From<PyStftParams> for StftParams {
#[inline]
fn from(val: PyStftParams) -> Self {
val.inner
}
}
impl From<StftParams> for PyStftParams {
#[inline]
fn from(inner: StftParams) -> Self {
Self { inner }
}
}
#[pyclass(name = "LogParams", from_py_object)]
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct PyLogParams {
pub(crate) inner: LogParams,
}
impl PyLogParams {
#[inline]
#[must_use]
pub fn into_inner(self) -> LogParams {
self.inner
}
#[inline]
#[must_use]
pub fn as_inner(&self) -> &LogParams {
&self.inner
}
}
#[pymethods]
impl PyLogParams {
#[new]
#[pyo3(signature = (floor_db: "float"), text_signature = "(floor_db: float)")]
fn new(floor_db: f64) -> PyResult<Self> {
let inner = LogParams::new(floor_db)?;
Ok(Self { inner })
}
#[getter]
const fn floor_db(&self) -> f64 {
self.inner.floor_db()
}
fn __repr__(&self) -> String {
format!("LogParams(floor_db={})", self.floor_db())
}
}
impl From<PyLogParams> for LogParams {
#[inline]
fn from(val: PyLogParams) -> Self {
val.inner
}
}
impl From<LogParams> for PyLogParams {
#[inline]
fn from(inner: LogParams) -> Self {
Self { inner }
}
}
#[pyclass(name = "SpectrogramParams", from_py_object)]
#[derive(Clone, Debug)]
pub struct PySpectrogramParams {
pub(crate) inner: SpectrogramParams,
}
#[pymethods]
impl PySpectrogramParams {
#[new]
#[pyo3(signature = (
stft: "StftParams",
sample_rate: "float"
), text_signature = "(stft: StftParams, sample_rate: float)")]
fn new(stft: &PyStftParams, sample_rate: f64) -> PyResult<Self> {
let inner = SpectrogramParams::new(stft.inner.clone(), sample_rate)?;
Ok(Self { inner })
}
#[getter]
fn stft(&self) -> PyStftParams {
PyStftParams {
inner: self.inner.stft().clone(),
}
}
#[getter]
const fn sample_rate(&self) -> f64 {
self.inner.sample_rate_hz()
}
#[classmethod]
#[pyo3(signature = (sample_rate: "float"), text_signature = "(sample_rate: float)")]
fn speech_default(_cls: &Bound<'_, PyType>, sample_rate: f64) -> PyResult<Self> {
let inner = SpectrogramParams::speech_default(sample_rate)?;
Ok(Self { inner })
}
#[classmethod]
#[pyo3(signature = (sample_rate: "float"), text_signature = "(sample_rate: float)")]
fn music_default(_cls: &Bound<'_, PyType>, sample_rate: f64) -> PyResult<Self> {
let inner = SpectrogramParams::music_default(sample_rate)?;
Ok(Self { inner })
}
fn __repr__(&self) -> String {
format!(
"SpectrogramParams(sample_rate={}, n_fft={}, hop_size={})",
self.sample_rate(),
self.inner.stft().n_fft(),
self.inner.stft().hop_size()
)
}
}
impl From<SpectrogramParams> for PySpectrogramParams {
fn from(inner: SpectrogramParams) -> Self {
Self { inner }
}
}
impl From<PySpectrogramParams> for SpectrogramParams {
#[inline]
fn from(py_params: PySpectrogramParams) -> Self {
py_params.inner
}
}
#[pyclass(name = "MelNorm", from_py_object)]
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PyMelNorm {
None,
Slaney,
L1,
L2,
}
#[pymethods]
impl PyMelNorm {
#[classattr]
const fn none() -> Self {
Self::None
}
#[classattr]
const fn slaney() -> Self {
Self::Slaney
}
#[classattr]
const fn l1() -> Self {
Self::L1
}
#[classattr]
const fn l2() -> Self {
Self::L2
}
fn __repr__(&self) -> String {
match self {
Self::None => "MelNorm.None".to_string(),
Self::Slaney => "MelNorm.Slaney".to_string(),
Self::L1 => "MelNorm.L1".to_string(),
Self::L2 => "MelNorm.L2".to_string(),
}
}
}
impl From<PyMelNorm> for MelNorm {
#[inline]
fn from(py_norm: PyMelNorm) -> Self {
match py_norm {
PyMelNorm::None => Self::None,
PyMelNorm::Slaney => Self::Slaney,
PyMelNorm::L1 => Self::L1,
PyMelNorm::L2 => Self::L2,
}
}
}
impl From<MelNorm> for PyMelNorm {
#[inline]
fn from(norm: MelNorm) -> Self {
match norm {
MelNorm::None => Self::None,
MelNorm::Slaney => Self::Slaney,
MelNorm::L1 => Self::L1,
MelNorm::L2 => Self::L2,
}
}
}
#[pyclass(name = "MelParams", from_py_object)]
#[derive(Clone, Copy, Debug)]
pub struct PyMelParams {
pub(crate) inner: MelParams,
}
#[pymethods]
impl PyMelParams {
#[new]
#[pyo3(signature = (
n_mels: "int",
f_min: "float",
f_max: "float",
norm: "MelNorm" = None
), text_signature = "(n_mels: int, f_min: float, f_max: float, norm: MelNorm | str | None = None)")]
fn new(
n_mels: NonZeroUsize,
f_min: f64,
f_max: f64,
norm: Option<&pyo3::Bound<'_, pyo3::PyAny>>,
) -> PyResult<Self> {
let norm_val = if let Some(norm_arg) = norm {
if norm_arg.is_none() {
MelNorm::None
} else if let Ok(s) = norm_arg.extract::<String>() {
match s.to_lowercase().as_str() {
"none" => MelNorm::None,
"slaney" => MelNorm::Slaney,
"l1" => MelNorm::L1,
"l2" => MelNorm::L2,
_ => {
return Err(pyo3::exceptions::PyValueError::new_err(format!(
"Invalid norm string: '{s}'. Must be one of: 'none', 'slaney', 'l1', 'l2'"
)));
}
}
} else if let Ok(py_norm) = norm_arg.extract::<PyMelNorm>() {
py_norm.into()
} else {
return Err(pyo3::exceptions::PyTypeError::new_err(
"norm must be a MelNorm enum, a string, or None",
));
}
} else {
MelNorm::None
};
let inner = MelParams::with_norm(n_mels, f_min, f_max, norm_val)?;
Ok(Self { inner })
}
#[getter]
const fn n_mels(&self) -> NonZeroUsize {
self.inner.n_mels()
}
#[getter]
const fn f_min(&self) -> f64 {
self.inner.f_min()
}
#[getter]
const fn f_max(&self) -> f64 {
self.inner.f_max()
}
#[getter]
fn norm(&self) -> PyMelNorm {
self.inner.norm().into()
}
fn __repr__(&self) -> String {
let norm_str = match self.inner.norm() {
MelNorm::None => "None",
MelNorm::Slaney => "slaney",
MelNorm::L1 => "l1",
MelNorm::L2 => "l2",
};
format!(
"MelParams(n_mels={}, f_min={}, f_max={}, norm='{}')",
self.n_mels(),
self.f_min(),
self.f_max(),
norm_str
)
}
}
impl From<PyMelParams> for MelParams {
#[inline]
fn from(val: PyMelParams) -> Self {
val.inner
}
}
impl From<MelParams> for PyMelParams {
#[inline]
fn from(inner: MelParams) -> Self {
Self { inner }
}
}
#[pyclass(name = "ErbParams", from_py_object)]
#[derive(Clone, Copy, Debug)]
pub struct PyErbParams {
pub(crate) inner: ErbParams,
}
impl PyErbParams {
#[inline]
pub fn into_inner(self) -> ErbParams { self.inner }
#[inline] pub fn as_inner(&self) -> &ErbParams { &self.inner }
}
#[pymethods]
impl PyErbParams {
#[new]
#[pyo3(signature = (
n_filters: "int",
f_min: "float",
f_max: "float"
), text_signature = "(n_filters: int, f_min: float, f_max: float)")]
fn new(n_filters: NonZeroUsize, f_min: f64, f_max: f64) -> PyResult<Self> {
let inner = ErbParams::new(n_filters, f_min, f_max)?;
Ok(Self { inner })
}
#[getter]
const fn n_filters(&self) -> NonZeroUsize {
self.inner.n_filters()
}
#[getter]
const fn f_min(&self) -> f64 {
self.inner.f_min()
}
#[getter]
const fn f_max(&self) -> f64 {
self.inner.f_max()
}
fn __repr__(&self) -> String {
format!(
"ErbParams(n_filters={}, f_min={}, f_max={})",
self.n_filters(),
self.f_min(),
self.f_max()
)
}
}
impl From<ErbParams> for PyErbParams { #[inline] fn from(inner: ErbParams) -> Self { Self { inner } } }
#[pyclass(name = "LogHzParams", from_py_object)]
#[derive(Clone, Copy, Debug)]
pub struct PyLogHzParams {
pub(crate) inner: LogHzParams,
}
#[pymethods]
impl PyLogHzParams {
#[new]
#[pyo3(signature = (
n_bins: "int",
f_min: "float",
f_max: "float"
), text_signature = "(n_bins: int, f_min: float, f_max: float)")]
fn new(n_bins: NonZeroUsize, f_min: f64, f_max: f64) -> PyResult<Self> {
let inner = LogHzParams::new(n_bins, f_min, f_max)?;
Ok(Self { inner })
}
#[getter]
const fn n_bins(&self) -> NonZeroUsize {
self.inner.n_bins()
}
#[getter]
const fn f_min(&self) -> f64 {
self.inner.f_min()
}
#[getter]
const fn f_max(&self) -> f64 {
self.inner.f_max()
}
fn __repr__(&self) -> String {
format!(
"LogHzParams(n_bins={}, f_min={}, f_max={})",
self.n_bins(),
self.f_min(),
self.f_max()
)
}
}
#[pyclass(name = "CqtParams", from_py_object)]
#[derive(Clone, Debug)]
pub struct PyCqtParams {
pub(crate) inner: CqtParams,
}
#[pymethods]
impl PyCqtParams {
#[new]
#[pyo3(signature = (
bins_per_octave: "int",
n_octaves: "int",
f_min: "float"
), text_signature = "(bins_per_octave: int, n_octaves: int, f_min: float)")]
fn new(bins_per_octave: NonZeroUsize, n_octaves: NonZeroUsize, f_min: f64) -> PyResult<Self> {
let inner = CqtParams::new(bins_per_octave, n_octaves, f_min)?;
Ok(Self { inner })
}
#[getter]
const fn num_bins(&self) -> NonZeroUsize {
self.inner.num_bins()
}
fn __repr__(&self) -> String {
format!("CqtParams(num_bins={})", self.num_bins())
}
}
impl From<CqtParams> for PyCqtParams {
#[inline]
fn from(inner: CqtParams) -> Self {
Self { inner }
}
}
impl From<PyCqtParams> for CqtParams {
#[inline]
fn from(val: PyCqtParams) -> Self {
val.inner
}
}
#[pyclass(name = "ChromaNorm", from_py_object)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct PyChromaNorm {
pub(crate) inner: ChromaNorm,
}
#[pymethods]
impl PyChromaNorm {
#[classattr]
const fn none() -> Self {
Self {
inner: ChromaNorm::None,
}
}
#[classattr]
const fn l1() -> Self {
Self {
inner: ChromaNorm::L1,
}
}
#[classattr]
const fn l2() -> Self {
Self {
inner: ChromaNorm::L2,
}
}
#[classattr]
const fn max() -> Self {
Self {
inner: ChromaNorm::Max,
}
}
fn __repr__(&self) -> String {
format!("{:?}", self.inner)
}
}
impl From<ChromaNorm> for PyChromaNorm {
#[inline]
fn from(inner: ChromaNorm) -> Self {
Self { inner }
}
}
impl From<PyChromaNorm> for ChromaNorm {
#[inline]
fn from(val: PyChromaNorm) -> Self {
val.inner
}
}
#[pyclass(name = "ChromaParams", from_py_object)]
#[derive(Clone, Copy, Debug)]
pub struct PyChromaParams {
pub(crate) inner: ChromaParams,
}
#[pymethods]
impl PyChromaParams {
#[new]
#[pyo3(signature = (
tuning: "float" = 440.0,
f_min: "float" = 32.7,
f_max: "float" = 4186.0,
norm: "ChromaNorm" = None
), text_signature = "(tuning: float = 440.0, f_min: float = 32.7, f_max: float = 4186.0, norm: ChromaNorm = ChromaNorm.None)")]
fn new(tuning: f64, f_min: f64, f_max: f64, norm: Option<PyChromaNorm>) -> PyResult<Self> {
let norm = norm.unwrap_or_default();
let inner = ChromaParams::new(tuning, f_min, f_max, norm.inner)?;
Ok(Self { inner })
}
#[classmethod]
const fn music_standard(_cls: &Bound<'_, PyType>) -> Self {
let inner = ChromaParams::music_standard();
Self { inner }
}
#[getter]
const fn tuning(&self) -> f64 {
self.inner.tuning()
}
#[getter]
const fn f_min(&self) -> f64 {
self.inner.f_min()
}
#[getter]
const fn f_max(&self) -> f64 {
self.inner.f_max()
}
fn __repr__(&self) -> String {
format!(
"ChromaParams(tuning={}, f_min={}, f_max={}, norm={:?})",
self.tuning(),
self.f_min(),
self.f_max(),
self.inner
)
}
}
impl From<ChromaParams> for PyChromaParams {
#[inline]
fn from(inner: ChromaParams) -> Self {
Self { inner }
}
}
impl From<PyChromaParams> for ChromaParams {
#[inline]
fn from(val: PyChromaParams) -> Self {
val.inner
}
}
#[pyclass(name = "MfccParams", from_py_object)]
#[derive(Clone, Copy, Debug)]
pub struct PyMfccParams {
pub(crate) inner: MfccParams,
}
#[pymethods]
impl PyMfccParams {
#[new]
#[pyo3(signature = (n_mfcc: "int" = 13), text_signature = "(n_mfcc: int = 13)")]
fn new(n_mfcc: usize) -> PyResult<Self> {
let n_mfcc = NonZeroUsize::new(n_mfcc).ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("n_mfcc must be a positive integer")
})?;
let inner = MfccParams::new(n_mfcc);
Ok(Self { inner })
}
#[classmethod]
const fn speech_standard(_cls: &Bound<'_, PyType>) -> Self {
let inner = MfccParams::speech_standard();
Self { inner }
}
#[getter]
const fn n_mfcc(&self) -> NonZeroUsize {
self.inner.n_mfcc()
}
fn __repr__(&self) -> String {
format!("MfccParams(n_mfcc={})", self.n_mfcc())
}
}
impl From<PyMfccParams> for MfccParams {
#[inline]
fn from(val: PyMfccParams) -> Self {
val.inner
}
}
impl From<MfccParams> for PyMfccParams {
#[inline]
fn from(inner: MfccParams) -> Self {
Self { inner }
}
}
pub fn register(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyWindowType>()?;
m.add_class::<PyStftParams>()?;
m.add_class::<PyLogParams>()?;
m.add_class::<PySpectrogramParams>()?;
m.add_class::<PyMelNorm>()?;
m.add_class::<PyMelParams>()?;
m.add_class::<PyErbParams>()?;
m.add_class::<PyLogHzParams>()?;
m.add_class::<PyCqtParams>()?;
m.add_class::<PyChromaParams>()?;
m.add_class::<PyMfccParams>()?;
Ok(())
}