Skip to main content

pitch_core_py/
lib.rs

1//! Python bindings for `pitch-core` (PyO3).
2//!
3//! DSP backends (`swipe`, `pyin`, `praat_ac`) are always available.
4//! Neural backends (`swiftf0`, `crepe`, `pesto`) are gated behind the
5//! `onnx` cargo feature; `pesto` requires both `onnx` and `pesto`.
6
7use numpy::{IntoPyArray, PyReadonlyArray1};
8use pitch_core::{EstimatorError, PitchEstimator, PitchTracker as CorePitchTracker};
9use pyo3::exceptions::{PyRuntimeError, PyValueError};
10use pyo3::prelude::*;
11use pyo3::types::PyDict;
12
13fn map_err(e: EstimatorError) -> PyErr {
14    match e {
15        EstimatorError::InvalidInput(m) => PyValueError::new_err(m),
16        other => PyRuntimeError::new_err(other.to_string()),
17    }
18}
19
20#[cfg_attr(not(feature = "onnx"), allow(unused_variables))]
21fn build_estimator(
22    algorithm: &str,
23    model_path: &str,
24    mode: &str,
25    markov_step: bool,
26    swipe_max_window: usize,
27) -> PyResult<Box<dyn PitchEstimator>> {
28    match algorithm {
29        "swipe" => {
30            let est = pitch_core::SwipeEstimator::with_max_window(swipe_max_window)
31                .map_err(map_err)?;
32            Ok(Box::new(est))
33        }
34        "pyin" => {
35            let est = pitch_core::PyinEstimator::new().map_err(map_err)?;
36            Ok(Box::new(est))
37        }
38        "praat_ac" => {
39            let est = pitch_core::PraatAcEstimator::new(markov_step).map_err(map_err)?;
40            Ok(Box::new(est))
41        }
42
43        #[cfg(feature = "onnx")]
44        "swiftf0" => {
45            let m = pitch_core_onnx::Mode::parse(mode).map_err(PyValueError::new_err)?;
46            let est = pitch_core_onnx::SwiftF0Estimator::new(model_path, m).map_err(map_err)?;
47            Ok(Box::new(est))
48        }
49        #[cfg(feature = "onnx")]
50        "crepe" => {
51            let est = pitch_core_onnx::CrepeEstimator::new(model_path).map_err(map_err)?;
52            Ok(Box::new(est))
53        }
54        #[cfg(feature = "onnx")]
55        "rmvpe" => {
56            let est = pitch_core_onnx::RmvpeEstimator::new(model_path).map_err(map_err)?;
57            Ok(Box::new(est))
58        }
59        #[cfg(feature = "onnx")]
60        "fcpe" => {
61            let est = pitch_core_onnx::FcpeEstimator::new(model_path).map_err(map_err)?;
62            Ok(Box::new(est))
63        }
64        #[cfg(all(feature = "onnx", feature = "pesto"))]
65        "pesto" => {
66            let est = pitch_core_onnx::PestoEstimator::new(model_path).map_err(map_err)?;
67            Ok(Box::new(est))
68        }
69
70        other => Err(PyValueError::new_err(format!(
71            "unknown or disabled algorithm: {other}; available in this build: {}",
72            available_backends().join(", ")
73        ))),
74    }
75}
76
77fn available_backends() -> Vec<&'static str> {
78    #[cfg_attr(not(feature = "onnx"), allow(unused_mut))]
79    let mut v = vec!["swipe", "pyin", "praat_ac"];
80    #[cfg(feature = "onnx")]
81    {
82        v.push("swiftf0");
83        v.push("crepe");
84        v.push("rmvpe");
85        v.push("fcpe");
86    }
87    #[cfg(all(feature = "onnx", feature = "pesto"))]
88    v.push("pesto");
89    v
90}
91
92#[pyclass(unsendable)]
93struct PitchTracker {
94    inner: CorePitchTracker,
95}
96
97#[pymethods]
98impl PitchTracker {
99    #[new]
100    #[pyo3(signature = (
101        algorithm,
102        model_path = "",
103        input_sample_rate = 48000,
104        mode = "balanced",
105        resample_chunk = 1024,
106        markov_step = false,
107        swipe_max_window = 8192,
108    ))]
109    fn new(
110        algorithm: &str,
111        model_path: &str,
112        input_sample_rate: u32,
113        mode: &str,
114        resample_chunk: usize,
115        markov_step: bool,
116        swipe_max_window: usize,
117    ) -> PyResult<Self> {
118        let est = build_estimator(algorithm, model_path, mode, markov_step, swipe_max_window)?;
119        let inner = CorePitchTracker::from_boxed(est, input_sample_rate, resample_chunk)
120            .map_err(map_err)?;
121        Ok(Self { inner })
122    }
123
124    #[getter]
125    fn algorithm(&self) -> &str {
126        self.inner.algorithm()
127    }
128    #[getter]
129    fn input_sample_rate(&self) -> u32 {
130        self.inner.input_sample_rate()
131    }
132    #[getter]
133    fn target_sample_rate(&self) -> u32 {
134        self.inner.target_sample_rate()
135    }
136
137    fn reset(&mut self) {
138        self.inner.reset();
139    }
140
141    /// Process a chunk of audio at `input_sample_rate` (mono float32).
142    /// Returns a dict with numpy arrays: `pitch_hz`, `confidence`,
143    /// `timestamps_s`, `frame_indices`, `is_preliminary`.
144    fn process<'py>(
145        &mut self,
146        py: Python<'py>,
147        audio: PyReadonlyArray1<'py, f32>,
148    ) -> PyResult<Bound<'py, PyDict>> {
149        let audio_slice = audio.as_slice()?;
150        let frames = self.inner.process(audio_slice).map_err(map_err)?;
151
152        let mut pitch = Vec::with_capacity(frames.len());
153        let mut conf = Vec::with_capacity(frames.len());
154        let mut times = Vec::with_capacity(frames.len());
155        let mut indices = Vec::with_capacity(frames.len());
156        let mut prelim = Vec::with_capacity(frames.len());
157        for f in &frames {
158            pitch.push(f.pitch_hz);
159            conf.push(f.confidence);
160            times.push(f.time_s);
161            indices.push(f.frame_index as i64);
162            prelim.push(f.is_preliminary);
163        }
164
165        let dict = PyDict::new(py);
166        dict.set_item("pitch_hz", pitch.into_pyarray(py))?;
167        dict.set_item("confidence", conf.into_pyarray(py))?;
168        dict.set_item("timestamps_s", times.into_pyarray(py))?;
169        dict.set_item("frame_indices", indices.into_pyarray(py))?;
170        dict.set_item("is_preliminary", prelim)?;
171        Ok(dict)
172    }
173}
174
175/// Module registration. The function name must match `[lib].name` in
176/// Cargo.toml so `import pitch_core_py` resolves the cdylib directly.
177#[pymodule]
178fn pitch_core_py(m: &Bound<'_, PyModule>) -> PyResult<()> {
179    m.add_class::<PitchTracker>()?;
180    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
181    m.add("available_backends", available_backends())?;
182    Ok(())
183}