1#![warn(missing_docs)]
2use ndarray::Array2;
26use numpy::{
27 PyArray1, PyArray2, PyArrayDescrMethods, PyArrayMethods, PyReadonlyArray1,
28 PyUntypedArrayMethods,
29};
30use pyo3::exceptions::{PyTypeError, PyValueError};
31use pyo3::prelude::*;
32use pyo3::types::{PyModule, PyType};
33use realfft::num_complex::Complex32;
34
35pub mod vad;
37
38pub use vad::VadError;
39pub use vad::detector::{VAD, VADModes, VadConfig, VadStateful};
40pub use vad::filterbank::FilterBank;
41
42fn as_f32_array1<'py>(obj: &Bound<'py, PyAny>) -> PyResult<PyReadonlyArray1<'py, f32>> {
43 if let Ok(untyped) = obj.cast::<numpy::PyUntypedArray>() {
44 let dtype = untyped.dtype();
45 if !dtype.is_equiv_to(&numpy::dtype::<f32>(obj.py())) {
46 return Err(PyErr::new::<PyTypeError, _>(format!(
47 "expected a numpy array with dtype float32, but got dtype {}",
48 dtype
49 )));
50 }
51 }
52 obj.cast::<numpy::PyArray1<f32>>()
53 .map_err(PyErr::from)
54 .and_then(|arr| arr.try_readonly().map_err(PyErr::from))
55}
56
57fn map_vad_error(err: vad::VadError) -> PyErr {
58 match err {
59 vad::VadError::UnsupportedSampleRate(rate) => PyErr::new::<PyValueError, _>(format!(
60 "Unsupported sample rate: {rate} Hz. Only 8000 and 16000 Hz are supported."
61 )),
62 vad::VadError::InvalidFrameLength { expected, got } => PyErr::new::<PyValueError, _>(
63 format!("Invalid frame length: expected {expected} samples, got {got}"),
64 ),
65 }
66}
67
68fn parse_mode(mode: i32) -> PyResult<vad::detector::VADModes> {
69 vad::detector::VADModes::from_index(mode).ok_or_else(|| {
70 PyErr::new::<PyValueError, _>(format!(
71 "Unsupported mode value: {mode}. Use fast_vad.mode.permissive, fast_vad.mode.normal, or fast_vad.mode.aggressive."
72 ))
73 })
74}
75
76fn parse_vad_config(
77 threshold_probability: f32,
78 min_speech_ms: usize,
79 min_silence_ms: usize,
80 hangover_ms: usize,
81) -> vad::detector::VadConfig {
82 vad::detector::VadConfig {
83 threshold_probability,
84 min_speech_ms,
85 min_silence_ms,
86 hangover_ms,
87 }
88}
89
90fn add_mode_namespace(m: &Bound<'_, PyModule>) -> PyResult<()> {
91 let mode_module = PyModule::new(m.py(), "mode")?;
92 mode_module.add("permissive", vad::detector::VADModes::Permissive.as_index())?;
93 mode_module.add("normal", vad::detector::VADModes::Normal.as_index())?;
94 mode_module.add("aggressive", vad::detector::VADModes::Aggressive.as_index())?;
95 m.add_submodule(&mode_module)?;
96 Ok(())
97}
98
99fn segments_to_array<'py>(py: Python<'py>, segments: Vec<[usize; 2]>) -> Bound<'py, PyArray2<u64>> {
100 let mut arr = Array2::<u64>::zeros((segments.len(), 2));
101 for (i, [start, end]) in segments.iter().enumerate() {
102 arr[[i, 0]] = *start as u64;
103 arr[[i, 1]] = *end as u64;
104 }
105 PyArray2::from_owned_array(py, arr)
106}
107
108#[pyclass]
110struct FeatureExtractor {
111 feature_extractor: vad::filterbank::FilterBank,
112 window_buf: Vec<f32>,
113 fft_output: Vec<Complex32>,
114 fft_scratch: Vec<Complex32>,
115}
116
117#[pymethods]
118impl FeatureExtractor {
119 #[new]
121 fn new(sample_rate: usize) -> PyResult<Self> {
122 let fe = vad::filterbank::FilterBank::new(sample_rate).map_err(map_vad_error)?;
123 let window_buf = vec![0.0f32; fe.frame_size()];
124 let fft_output = fe.make_output_vec();
125 let fft_scratch = fe.make_scratch_vec();
126 Ok(Self {
127 feature_extractor: fe,
128 window_buf,
129 fft_output,
130 fft_scratch,
131 })
132 }
133
134 #[getter]
136 fn frame_size(&self) -> usize {
137 self.feature_extractor.frame_size()
138 }
139
140 #[getter]
142 fn hann_window<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f32>> {
143 PyArray1::from_slice(py, self.feature_extractor.hann_window())
144 }
145
146 fn extract_features_frame<'py>(
152 &mut self,
153 py: Python<'py>,
154 frame: Bound<'py, PyAny>,
155 ) -> PyResult<Bound<'py, PyArray1<f32>>> {
156 let frame = as_f32_array1(&frame)?;
157 let frame = frame.as_slice()?;
158 let energies = py
159 .detach(|| {
160 self.feature_extractor.process_single_frame(
161 frame,
162 &mut self.window_buf,
163 &mut self.fft_output,
164 &mut self.fft_scratch,
165 )
166 })
167 .map_err(map_vad_error)?;
168 Ok(PyArray1::from_slice(py, &energies.to_array()))
169 }
170
171 fn extract_features<'py>(
176 &self,
177 py: Python<'py>,
178 audio: Bound<'py, PyAny>,
179 ) -> PyResult<Bound<'py, PyArray2<f32>>> {
180 let audio = as_f32_array1(&audio)?;
181 let audio = audio.as_slice()?;
182 let features = py.detach(|| self.feature_extractor.compute_filterbank(audio));
183 let num_frames = features.len();
184
185 let mut arr = Array2::<f32>::zeros((num_frames, vad::constants::NUM_BANDS));
186 for (i, frame) in features.iter().enumerate() {
187 arr.row_mut(i).assign(&ndarray::ArrayView1::from(
188 &frame.to_array() as &[f32; vad::constants::NUM_BANDS]
189 ));
190 }
191 Ok(PyArray2::from_owned_array(py, arr))
192 }
193
194 fn feature_engineer<'py>(
200 &self,
201 py: Python<'py>,
202 audio: Bound<'py, PyAny>,
203 ) -> PyResult<Bound<'py, PyArray2<f32>>> {
204 let audio = as_f32_array1(&audio)?;
205 let audio = audio.as_slice()?;
206 let features = py.detach(|| self.feature_extractor.feature_engineer(audio));
207 let num_frames = features.len();
208 let mut arr = Array2::<f32>::zeros((num_frames, 24));
209 for (i, frame) in features.iter().enumerate() {
210 arr.row_mut(i)
211 .assign(&ndarray::ArrayView1::from(frame as &[f32; 24]));
212 }
213 Ok(PyArray2::from_owned_array(py, arr))
214 }
215
216 fn __repr__(&self) -> String {
217 let sample_rate = match self.feature_extractor.frame_size() {
218 512 => 16000,
219 256 => 8000,
220 _ => 0,
221 };
222 format!("FeatureExtractor(sample_rate={})", sample_rate)
223 }
224
225 fn __str__(&self) -> String {
226 format!("{}", self.feature_extractor)
227 }
228}
229
230#[pyclass(name = "VAD")]
235struct PyVAD {
236 vad: vad::detector::VAD,
237}
238
239#[pymethods]
240impl PyVAD {
241 #[new]
246 fn new(sample_rate: usize) -> PyResult<Self> {
247 Ok(Self {
248 vad: vad::detector::VAD::new(sample_rate).map_err(map_vad_error)?,
249 })
250 }
251
252 #[classmethod]
254 fn with_mode(_cls: &Bound<'_, PyType>, sample_rate: usize, mode: i32) -> PyResult<Self> {
255 let mode = parse_mode(mode)?;
256 Ok(Self {
257 vad: vad::detector::VAD::with_mode(sample_rate, mode).map_err(map_vad_error)?,
258 })
259 }
260
261 #[classmethod]
263 fn with_config(
264 _cls: &Bound<'_, PyType>,
265 sample_rate: usize,
266 threshold_probability: f32,
267 min_speech_ms: usize,
268 min_silence_ms: usize,
269 hangover_ms: usize,
270 ) -> PyResult<Self> {
271 let config = parse_vad_config(
272 threshold_probability,
273 min_speech_ms,
274 min_silence_ms,
275 hangover_ms,
276 );
277 Ok(Self {
278 vad: vad::detector::VAD::with_config(sample_rate, config).map_err(map_vad_error)?,
279 })
280 }
281
282 fn detect<'py>(
284 &self,
285 py: Python<'py>,
286 audio: Bound<'py, PyAny>,
287 ) -> PyResult<Bound<'py, PyArray1<bool>>> {
288 let audio = as_f32_array1(&audio)?;
289 let audio = audio.as_slice()?;
290 let labels = py.detach(|| self.vad.detect(audio));
291 Ok(PyArray1::from_vec(py, labels))
292 }
293
294 fn detect_frames<'py>(
296 &self,
297 py: Python<'py>,
298 audio: Bound<'py, PyAny>,
299 ) -> PyResult<Bound<'py, PyArray1<bool>>> {
300 let audio = as_f32_array1(&audio)?;
301 let audio = audio.as_slice()?;
302 let labels = py.detach(|| self.vad.detect_frames(audio));
303 Ok(PyArray1::from_vec(py, labels))
304 }
305
306 fn detect_segments<'py>(
308 &self,
309 py: Python<'py>,
310 audio: Bound<'py, PyAny>,
311 ) -> PyResult<Bound<'py, PyArray2<u64>>> {
312 let audio = as_f32_array1(&audio)?;
313 let audio = audio.as_slice()?;
314 let segments = py.detach(|| self.vad.detect_segments(audio));
315 Ok(segments_to_array(py, segments))
316 }
317
318 fn __repr__(&self) -> String {
319 format!(
320 "VAD(sample_rate={}, threshold_probability={:.2}, min_speech_ms={}, min_silence_ms={}, hangover_ms={})",
321 self.vad.sample_rate(),
322 self.vad.threshold_probability(),
323 self.vad.min_speech_ms(),
324 self.vad.min_silence_ms(),
325 self.vad.hangover_ms(),
326 )
327 }
328
329 fn __str__(&self) -> String {
330 format!("{}", self.vad)
331 }
332}
333
334#[pyclass(name = "VadStateful")]
339struct PyVadStateful {
340 vad: Box<vad::detector::VadStateful>,
341}
342
343#[pymethods]
344impl PyVadStateful {
345 #[new]
350 fn new(sample_rate: usize) -> PyResult<Self> {
351 Ok(Self {
352 vad: Box::new(vad::detector::VadStateful::new(sample_rate).map_err(map_vad_error)?),
353 })
354 }
355
356 #[classmethod]
358 fn with_mode(_cls: &Bound<'_, PyType>, sample_rate: usize, mode: i32) -> PyResult<Self> {
359 let mode = parse_mode(mode)?;
360 Ok(Self {
361 vad: Box::new(
362 vad::detector::VadStateful::with_mode(sample_rate, mode).map_err(map_vad_error)?,
363 ),
364 })
365 }
366
367 #[classmethod]
369 fn with_config(
370 _cls: &Bound<'_, PyType>,
371 sample_rate: usize,
372 threshold_probability: f32,
373 min_speech_ms: usize,
374 min_silence_ms: usize,
375 hangover_ms: usize,
376 ) -> PyResult<Self> {
377 let config = parse_vad_config(
378 threshold_probability,
379 min_speech_ms,
380 min_silence_ms,
381 hangover_ms,
382 );
383 Ok(Self {
384 vad: Box::new(
385 vad::detector::VadStateful::with_config(sample_rate, config)
386 .map_err(map_vad_error)?,
387 ),
388 })
389 }
390
391 #[getter]
393 fn frame_size(&self) -> usize {
394 self.vad.frame_size()
395 }
396
397 fn detect_frame<'py>(&mut self, py: Python<'py>, frame: Bound<'py, PyAny>) -> PyResult<bool> {
401 let frame = as_f32_array1(&frame)?;
402 let frame = frame.as_slice()?;
403 py.detach(|| self.vad.detect_frame(frame))
404 .map_err(map_vad_error)
405 }
406
407 fn reset_state(&mut self) {
409 self.vad.reset_state();
410 }
411
412 fn __repr__(&self) -> String {
413 format!(
414 "VadStateful(sample_rate={}, threshold_probability={:.2}, min_speech_ms={}, min_silence_ms={}, hangover_ms={})",
415 self.vad.sample_rate(),
416 self.vad.threshold_probability(),
417 self.vad.min_speech_ms(),
418 self.vad.min_silence_ms(),
419 self.vad.hangover_ms(),
420 )
421 }
422
423 fn __str__(&self) -> String {
424 format!("{}", self.vad)
425 }
426}
427
428#[pymodule]
429fn fast_vad(m: &Bound<'_, PyModule>) -> PyResult<()> {
430 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
431 m.add_class::<FeatureExtractor>()?;
432 m.add_class::<PyVAD>()?;
433 m.add_class::<PyVadStateful>()?;
434 add_mode_namespace(m)?;
435 Ok(())
436}