1use numpy::{IntoPyArray, PyReadonlyArray1};
8use pitch_core::{EstimatorError, PitchEstimator, PitchTracker as CorePitchTracker};
9use pyo3::exceptions::{PyRuntimeError, PyValueError};
10use pyo3::prelude::*;
11use pyo3::types::PyDict;
12
13fn map_err(e: EstimatorError) -> PyErr {
14 match e {
15 EstimatorError::InvalidInput(m) => PyValueError::new_err(m),
16 other => PyRuntimeError::new_err(other.to_string()),
17 }
18}
19
20#[cfg_attr(not(feature = "onnx"), allow(unused_variables))]
21fn build_estimator(
22 algorithm: &str,
23 model_path: &str,
24 mode: &str,
25 markov_step: bool,
26 swipe_max_window: usize,
27) -> PyResult<Box<dyn PitchEstimator>> {
28 match algorithm {
29 "swipe" => {
30 let est = pitch_core::SwipeEstimator::with_max_window(swipe_max_window)
31 .map_err(map_err)?;
32 Ok(Box::new(est))
33 }
34 "pyin" => {
35 let est = pitch_core::PyinEstimator::new().map_err(map_err)?;
36 Ok(Box::new(est))
37 }
38 "praat_ac" => {
39 let est = pitch_core::PraatAcEstimator::new(markov_step).map_err(map_err)?;
40 Ok(Box::new(est))
41 }
42
43 #[cfg(feature = "onnx")]
44 "swiftf0" => {
45 let m = pitch_core_onnx::Mode::parse(mode).map_err(PyValueError::new_err)?;
46 let est = pitch_core_onnx::SwiftF0Estimator::new(model_path, m).map_err(map_err)?;
47 Ok(Box::new(est))
48 }
49 #[cfg(feature = "onnx")]
50 "crepe" => {
51 let est = pitch_core_onnx::CrepeEstimator::new(model_path).map_err(map_err)?;
52 Ok(Box::new(est))
53 }
54 #[cfg(feature = "onnx")]
55 "rmvpe" => {
56 let est = pitch_core_onnx::RmvpeEstimator::new(model_path).map_err(map_err)?;
57 Ok(Box::new(est))
58 }
59 #[cfg(feature = "onnx")]
60 "fcpe" => {
61 let est = pitch_core_onnx::FcpeEstimator::new(model_path).map_err(map_err)?;
62 Ok(Box::new(est))
63 }
64 #[cfg(all(feature = "onnx", feature = "pesto"))]
65 "pesto" => {
66 let est = pitch_core_onnx::PestoEstimator::new(model_path).map_err(map_err)?;
67 Ok(Box::new(est))
68 }
69
70 other => Err(PyValueError::new_err(format!(
71 "unknown or disabled algorithm: {other}; available in this build: {}",
72 available_backends().join(", ")
73 ))),
74 }
75}
76
77fn available_backends() -> Vec<&'static str> {
78 #[cfg_attr(not(feature = "onnx"), allow(unused_mut))]
79 let mut v = vec!["swipe", "pyin", "praat_ac"];
80 #[cfg(feature = "onnx")]
81 {
82 v.push("swiftf0");
83 v.push("crepe");
84 v.push("rmvpe");
85 v.push("fcpe");
86 }
87 #[cfg(all(feature = "onnx", feature = "pesto"))]
88 v.push("pesto");
89 v
90}
91
92#[pyclass(unsendable)]
93struct PitchTracker {
94 inner: CorePitchTracker,
95}
96
97#[pymethods]
98impl PitchTracker {
99 #[new]
100 #[pyo3(signature = (
101 algorithm,
102 model_path = "",
103 input_sample_rate = 48000,
104 mode = "balanced",
105 resample_chunk = 1024,
106 markov_step = false,
107 swipe_max_window = 8192,
108 ))]
109 fn new(
110 algorithm: &str,
111 model_path: &str,
112 input_sample_rate: u32,
113 mode: &str,
114 resample_chunk: usize,
115 markov_step: bool,
116 swipe_max_window: usize,
117 ) -> PyResult<Self> {
118 let est = build_estimator(algorithm, model_path, mode, markov_step, swipe_max_window)?;
119 let inner = CorePitchTracker::from_boxed(est, input_sample_rate, resample_chunk)
120 .map_err(map_err)?;
121 Ok(Self { inner })
122 }
123
124 #[getter]
125 fn algorithm(&self) -> &str {
126 self.inner.algorithm()
127 }
128 #[getter]
129 fn input_sample_rate(&self) -> u32 {
130 self.inner.input_sample_rate()
131 }
132 #[getter]
133 fn target_sample_rate(&self) -> u32 {
134 self.inner.target_sample_rate()
135 }
136
137 fn reset(&mut self) {
138 self.inner.reset();
139 }
140
141 fn process<'py>(
145 &mut self,
146 py: Python<'py>,
147 audio: PyReadonlyArray1<'py, f32>,
148 ) -> PyResult<Bound<'py, PyDict>> {
149 let audio_slice = audio.as_slice()?;
150 let frames = self.inner.process(audio_slice).map_err(map_err)?;
151
152 let mut pitch = Vec::with_capacity(frames.len());
153 let mut conf = Vec::with_capacity(frames.len());
154 let mut times = Vec::with_capacity(frames.len());
155 let mut indices = Vec::with_capacity(frames.len());
156 let mut prelim = Vec::with_capacity(frames.len());
157 for f in &frames {
158 pitch.push(f.pitch_hz);
159 conf.push(f.confidence);
160 times.push(f.time_s);
161 indices.push(f.frame_index as i64);
162 prelim.push(f.is_preliminary);
163 }
164
165 let dict = PyDict::new(py);
166 dict.set_item("pitch_hz", pitch.into_pyarray(py))?;
167 dict.set_item("confidence", conf.into_pyarray(py))?;
168 dict.set_item("timestamps_s", times.into_pyarray(py))?;
169 dict.set_item("frame_indices", indices.into_pyarray(py))?;
170 dict.set_item("is_preliminary", prelim)?;
171 Ok(dict)
172 }
173}
174
175#[pymodule]
178fn pitch_core_py(m: &Bound<'_, PyModule>) -> PyResult<()> {
179 m.add_class::<PitchTracker>()?;
180 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
181 m.add("available_backends", available_backends())?;
182 Ok(())
183}