Skip to main content

fast_vad/
lib.rs

1#![warn(missing_docs)]
2//! Fast voice activity detection for 8 kHz and 16 kHz mono audio.
3//!
4//! This crate provides three main entry points:
5//! - [`VAD`] for batch detection over a full buffer.
6//! - [`VadStateful`] for streaming detection one frame at a time.
7//! - [`FilterBank`] for direct access to the underlying 8-band log-energy features.
8//!
9//! The detector operates on fixed 32 ms frames:
10//! - 512 samples at 16 kHz
11//! - 256 samples at 8 kHz
12//!
13//! # Example
14//!
15//! ```rust
16//! use fast_vad::{VAD, VADModes};
17//!
18//! let audio = vec![0.0f32; 16_000];
19//! let vad = VAD::with_mode(16_000, VADModes::Normal)?;
20//! let sample_labels = vad.detect(&audio);
21//! assert_eq!(sample_labels.len(), audio.len());
22//! # Ok::<(), fast_vad::VadError>(())
23//! ```
24
25use ndarray::Array2;
26use numpy::{
27    PyArray1, PyArray2, PyArrayDescrMethods, PyArrayMethods, PyReadonlyArray1,
28    PyUntypedArrayMethods,
29};
30use pyo3::exceptions::{PyTypeError, PyValueError};
31use pyo3::prelude::*;
32use pyo3::types::{PyModule, PyType};
33use realfft::num_complex::Complex32;
34
35/// Core Rust implementation modules for detection and feature extraction.
36pub mod vad;
37
38pub use vad::VadError;
39pub use vad::detector::{VAD, VADModes, VadConfig, VadStateful};
40pub use vad::filterbank::FilterBank;
41
42fn as_f32_array1<'py>(obj: &Bound<'py, PyAny>) -> PyResult<PyReadonlyArray1<'py, f32>> {
43    if let Ok(untyped) = obj.cast::<numpy::PyUntypedArray>() {
44        let dtype = untyped.dtype();
45        if !dtype.is_equiv_to(&numpy::dtype::<f32>(obj.py())) {
46            return Err(PyErr::new::<PyTypeError, _>(format!(
47                "expected a numpy array with dtype float32, but got dtype {}",
48                dtype
49            )));
50        }
51    }
52    obj.cast::<numpy::PyArray1<f32>>()
53        .map_err(PyErr::from)
54        .and_then(|arr| arr.try_readonly().map_err(PyErr::from))
55}
56
57fn map_vad_error(err: vad::VadError) -> PyErr {
58    match err {
59        vad::VadError::UnsupportedSampleRate(rate) => PyErr::new::<PyValueError, _>(format!(
60            "Unsupported sample rate: {rate} Hz. Only 8000 and 16000 Hz are supported."
61        )),
62        vad::VadError::InvalidFrameLength { expected, got } => PyErr::new::<PyValueError, _>(
63            format!("Invalid frame length: expected {expected} samples, got {got}"),
64        ),
65    }
66}
67
68fn parse_mode(mode: i32) -> PyResult<vad::detector::VADModes> {
69    vad::detector::VADModes::from_index(mode).ok_or_else(|| {
70        PyErr::new::<PyValueError, _>(format!(
71            "Unsupported mode value: {mode}. Use fast_vad.mode.permissive, fast_vad.mode.normal, or fast_vad.mode.aggressive."
72        ))
73    })
74}
75
76fn parse_vad_config(
77    threshold_probability: f32,
78    min_speech_ms: usize,
79    min_silence_ms: usize,
80    hangover_ms: usize,
81) -> vad::detector::VadConfig {
82    vad::detector::VadConfig {
83        threshold_probability,
84        min_speech_ms,
85        min_silence_ms,
86        hangover_ms,
87    }
88}
89
90fn add_mode_namespace(m: &Bound<'_, PyModule>) -> PyResult<()> {
91    let mode_module = PyModule::new(m.py(), "mode")?;
92    mode_module.add("permissive", vad::detector::VADModes::Permissive.as_index())?;
93    mode_module.add("normal", vad::detector::VADModes::Normal.as_index())?;
94    mode_module.add("aggressive", vad::detector::VADModes::Aggressive.as_index())?;
95    m.add_submodule(&mode_module)?;
96    Ok(())
97}
98
99fn segments_to_array<'py>(py: Python<'py>, segments: Vec<[usize; 2]>) -> Bound<'py, PyArray2<u64>> {
100    let mut arr = Array2::<u64>::zeros((segments.len(), 2));
101    for (i, [start, end]) in segments.iter().enumerate() {
102        arr[[i, 0]] = *start as u64;
103        arr[[i, 1]] = *end as u64;
104    }
105    PyArray2::from_owned_array(py, arr)
106}
107
108/// Computes log-filterbank features from raw audio.
109#[pyclass]
110struct FeatureExtractor {
111    feature_extractor: vad::filterbank::FilterBank,
112    window_buf: Vec<f32>,
113    fft_output: Vec<Complex32>,
114    fft_scratch: Vec<Complex32>,
115}
116
117#[pymethods]
118impl FeatureExtractor {
119    /// Creates a `FeatureExtractor` for the given sample rate (8000 or 16000 Hz).
120    #[new]
121    fn new(sample_rate: usize) -> PyResult<Self> {
122        let fe = vad::filterbank::FilterBank::new(sample_rate).map_err(map_vad_error)?;
123        let window_buf = vec![0.0f32; fe.frame_size()];
124        let fft_output = fe.make_output_vec();
125        let fft_scratch = fe.make_scratch_vec();
126        Ok(Self {
127            feature_extractor: fe,
128            window_buf,
129            fft_output,
130            fft_scratch,
131        })
132    }
133
134    /// Number of samples per analysis frame.
135    #[getter]
136    fn frame_size(&self) -> usize {
137        self.feature_extractor.frame_size()
138    }
139
140    /// Hann window applied to each frame before FFT.
141    #[getter]
142    fn hann_window<'py>(&self, py: Python<'py>) -> Bound<'py, PyArray1<f32>> {
143        PyArray1::from_slice(py, self.feature_extractor.hann_window())
144    }
145
146    /// Extracts filterbank features from a single frame of exactly `frame_size` samples.
147    ///
148    /// Returns a float32 array of shape `(8,)`.
149    ///
150    /// Raises `ValueError` if `len(frame) != frame_size`.
151    fn extract_features_frame<'py>(
152        &mut self,
153        py: Python<'py>,
154        frame: Bound<'py, PyAny>,
155    ) -> PyResult<Bound<'py, PyArray1<f32>>> {
156        let frame = as_f32_array1(&frame)?;
157        let frame = frame.as_slice()?;
158        let energies = py
159            .detach(|| {
160                self.feature_extractor.process_single_frame(
161                    frame,
162                    &mut self.window_buf,
163                    &mut self.fft_output,
164                    &mut self.fft_scratch,
165                )
166            })
167            .map_err(map_vad_error)?;
168        Ok(PyArray1::from_slice(py, &energies.to_array()))
169    }
170
171    /// Extracts filterbank features from `audio`.
172    ///
173    /// Returns a `(num_frames, 8)` float32 array. Trailing samples that do not
174    /// fill a complete frame are discarded.
175    fn extract_features<'py>(
176        &self,
177        py: Python<'py>,
178        audio: Bound<'py, PyAny>,
179    ) -> PyResult<Bound<'py, PyArray2<f32>>> {
180        let audio = as_f32_array1(&audio)?;
181        let audio = audio.as_slice()?;
182        let features = py.detach(|| self.feature_extractor.compute_filterbank(audio));
183        let num_frames = features.len();
184
185        let mut arr = Array2::<f32>::zeros((num_frames, vad::constants::NUM_BANDS));
186        for (i, frame) in features.iter().enumerate() {
187            arr.row_mut(i).assign(&ndarray::ArrayView1::from(
188                &frame.to_array() as &[f32; vad::constants::NUM_BANDS]
189            ));
190        }
191        Ok(PyArray2::from_owned_array(py, arr))
192    }
193
194    /// Computes 24-dimensional features for each frame in `audio`.
195    ///
196    /// Each row contains 8 log-energy values, 8 first-order deltas, and 8 second-order deltas.
197    /// Returns a `(num_frames, 24)` float32 array. Trailing samples that do not
198    /// fill a complete frame are discarded.
199    fn feature_engineer<'py>(
200        &self,
201        py: Python<'py>,
202        audio: Bound<'py, PyAny>,
203    ) -> PyResult<Bound<'py, PyArray2<f32>>> {
204        let audio = as_f32_array1(&audio)?;
205        let audio = audio.as_slice()?;
206        let features = py.detach(|| self.feature_extractor.feature_engineer(audio));
207        let num_frames = features.len();
208        let mut arr = Array2::<f32>::zeros((num_frames, 24));
209        for (i, frame) in features.iter().enumerate() {
210            arr.row_mut(i)
211                .assign(&ndarray::ArrayView1::from(frame as &[f32; 24]));
212        }
213        Ok(PyArray2::from_owned_array(py, arr))
214    }
215
216    fn __repr__(&self) -> String {
217        let sample_rate = match self.feature_extractor.frame_size() {
218            512 => 16000,
219            256 => 8000,
220            _ => 0,
221        };
222        format!("FeatureExtractor(sample_rate={})", sample_rate)
223    }
224
225    fn __str__(&self) -> String {
226        format!("{}", self.feature_extractor)
227    }
228}
229
230/// Batch voice activity detector.
231///
232/// Config is fixed at construction time.
233/// Use [`VAD.with_mode`] or [`VAD.with_config`] to control detection behaviour.
234#[pyclass(name = "VAD")]
235struct PyVAD {
236    vad: vad::detector::VAD,
237}
238
239#[pymethods]
240impl PyVAD {
241    /// Creates a `VAD` with the default Normal mode.
242    ///
243    /// Args:
244    ///     sample_rate: Audio sample rate in Hz. Supported values: 8000, 16000.
245    #[new]
246    fn new(sample_rate: usize) -> PyResult<Self> {
247        Ok(Self {
248            vad: vad::detector::VAD::new(sample_rate).map_err(map_vad_error)?,
249        })
250    }
251
252    /// Creates a `VAD` with an explicit detection mode.
253    #[classmethod]
254    fn with_mode(_cls: &Bound<'_, PyType>, sample_rate: usize, mode: i32) -> PyResult<Self> {
255        let mode = parse_mode(mode)?;
256        Ok(Self {
257            vad: vad::detector::VAD::with_mode(sample_rate, mode).map_err(map_vad_error)?,
258        })
259    }
260
261    /// Creates a `VAD` with custom detection parameters.
262    #[classmethod]
263    fn with_config(
264        _cls: &Bound<'_, PyType>,
265        sample_rate: usize,
266        threshold_probability: f32,
267        min_speech_ms: usize,
268        min_silence_ms: usize,
269        hangover_ms: usize,
270    ) -> PyResult<Self> {
271        let config = parse_vad_config(
272            threshold_probability,
273            min_speech_ms,
274            min_silence_ms,
275            hangover_ms,
276        );
277        Ok(Self {
278            vad: vad::detector::VAD::with_config(sample_rate, config).map_err(map_vad_error)?,
279        })
280    }
281
282    /// Returns one `bool` per sample indicating speech presence.
283    fn detect<'py>(
284        &self,
285        py: Python<'py>,
286        audio: Bound<'py, PyAny>,
287    ) -> PyResult<Bound<'py, PyArray1<bool>>> {
288        let audio = as_f32_array1(&audio)?;
289        let audio = audio.as_slice()?;
290        let labels = py.detach(|| self.vad.detect(audio));
291        Ok(PyArray1::from_vec(py, labels))
292    }
293
294    /// Returns one `bool` per frame indicating speech presence.
295    fn detect_frames<'py>(
296        &self,
297        py: Python<'py>,
298        audio: Bound<'py, PyAny>,
299    ) -> PyResult<Bound<'py, PyArray1<bool>>> {
300        let audio = as_f32_array1(&audio)?;
301        let audio = audio.as_slice()?;
302        let labels = py.detach(|| self.vad.detect_frames(audio));
303        Ok(PyArray1::from_vec(py, labels))
304    }
305
306    /// Returns a `(N, 2)` uint64 array of `[start, end]` sample indices for each speech segment.
307    fn detect_segments<'py>(
308        &self,
309        py: Python<'py>,
310        audio: Bound<'py, PyAny>,
311    ) -> PyResult<Bound<'py, PyArray2<u64>>> {
312        let audio = as_f32_array1(&audio)?;
313        let audio = audio.as_slice()?;
314        let segments = py.detach(|| self.vad.detect_segments(audio));
315        Ok(segments_to_array(py, segments))
316    }
317
318    fn __repr__(&self) -> String {
319        format!(
320            "VAD(sample_rate={}, threshold_probability={:.2}, min_speech_ms={}, min_silence_ms={}, hangover_ms={})",
321            self.vad.sample_rate(),
322            self.vad.threshold_probability(),
323            self.vad.min_speech_ms(),
324            self.vad.min_silence_ms(),
325            self.vad.hangover_ms(),
326        )
327    }
328
329    fn __str__(&self) -> String {
330        format!("{}", self.vad)
331    }
332}
333
334/// Streaming voice activity detector that processes one frame at a time.
335///
336/// Config is fixed at construction time.
337/// Use [`VadStateful.with_mode`] or [`VadStateful.with_config`] to control detection behaviour.
338#[pyclass(name = "VadStateful")]
339struct PyVadStateful {
340    vad: Box<vad::detector::VadStateful>,
341}
342
343#[pymethods]
344impl PyVadStateful {
345    /// Creates a `VadStateful` with the default Normal mode.
346    ///
347    /// Args:
348    ///     sample_rate: Audio sample rate in Hz. Supported values: 8000, 16000.
349    #[new]
350    fn new(sample_rate: usize) -> PyResult<Self> {
351        Ok(Self {
352            vad: Box::new(vad::detector::VadStateful::new(sample_rate).map_err(map_vad_error)?),
353        })
354    }
355
356    /// Creates a `VadStateful` with an explicit detection mode.
357    #[classmethod]
358    fn with_mode(_cls: &Bound<'_, PyType>, sample_rate: usize, mode: i32) -> PyResult<Self> {
359        let mode = parse_mode(mode)?;
360        Ok(Self {
361            vad: Box::new(
362                vad::detector::VadStateful::with_mode(sample_rate, mode).map_err(map_vad_error)?,
363            ),
364        })
365    }
366
367    /// Creates a `VadStateful` with custom detection parameters.
368    #[classmethod]
369    fn with_config(
370        _cls: &Bound<'_, PyType>,
371        sample_rate: usize,
372        threshold_probability: f32,
373        min_speech_ms: usize,
374        min_silence_ms: usize,
375        hangover_ms: usize,
376    ) -> PyResult<Self> {
377        let config = parse_vad_config(
378            threshold_probability,
379            min_speech_ms,
380            min_silence_ms,
381            hangover_ms,
382        );
383        Ok(Self {
384            vad: Box::new(
385                vad::detector::VadStateful::with_config(sample_rate, config)
386                    .map_err(map_vad_error)?,
387            ),
388        })
389    }
390
391    /// Number of samples per frame expected by `detect_frame`.
392    #[getter]
393    fn frame_size(&self) -> usize {
394        self.vad.frame_size()
395    }
396
397    /// Processes one frame and returns whether speech is active.
398    ///
399    /// `frame` must contain exactly `frame_size` samples.
400    fn detect_frame<'py>(&mut self, py: Python<'py>, frame: Bound<'py, PyAny>) -> PyResult<bool> {
401        let frame = as_f32_array1(&frame)?;
402        let frame = frame.as_slice()?;
403        py.detach(|| self.vad.detect_frame(frame))
404            .map_err(map_vad_error)
405    }
406
407    /// Resets internal state so the detector can be reused for a new stream.
408    fn reset_state(&mut self) {
409        self.vad.reset_state();
410    }
411
412    fn __repr__(&self) -> String {
413        format!(
414            "VadStateful(sample_rate={}, threshold_probability={:.2}, min_speech_ms={}, min_silence_ms={}, hangover_ms={})",
415            self.vad.sample_rate(),
416            self.vad.threshold_probability(),
417            self.vad.min_speech_ms(),
418            self.vad.min_silence_ms(),
419            self.vad.hangover_ms(),
420        )
421    }
422
423    fn __str__(&self) -> String {
424        format!("{}", self.vad)
425    }
426}
427
428#[pymodule]
429fn fast_vad(m: &Bound<'_, PyModule>) -> PyResult<()> {
430    m.add("__version__", env!("CARGO_PKG_VERSION"))?;
431    m.add_class::<FeatureExtractor>()?;
432    m.add_class::<PyVAD>()?;
433    m.add_class::<PyVadStateful>()?;
434    add_mode_namespace(m)?;
435    Ok(())
436}