audio_samples/operations/
vad.rs

1//! Voice Activity Detection (VAD) operations.
2//!
3//! This module provides fast, frame-based VAD suitable for segmentation and
4//! pre-processing. It is designed to avoid allocations beyond the output mask,
5//! and prefers contiguous slice access paths for performance.
6
7use crate::operations::traits::AudioVoiceActivityDetection;
8use crate::operations::types::{VadChannelPolicy, VadConfig, VadMethod};
9use crate::{
10    AudioSample, AudioSampleError, AudioSampleResult, AudioSamples, FeatureError, LayoutError,
11    ParameterError, RealFloat, to_precision,
12};
13
14/// Internal helper describing a frame iteration plan.
15#[derive(Debug, Clone, Copy)]
16struct FramePlan {
17    frame_size: usize,
18    hop_size: usize,
19    pad_end: bool,
20}
21
22impl FramePlan {
23    fn frame_starts(self, total_len: usize) -> impl Iterator<Item = usize> {
24        let FramePlan {
25            frame_size,
26            hop_size,
27            pad_end,
28        } = self;
29
30        let mut start = 0usize;
31        std::iter::from_fn(move || {
32            if pad_end {
33                if start >= total_len {
34                    return None;
35                }
36            } else if start + frame_size > total_len {
37                return None;
38            }
39
40            let current = start;
41            start = start.saturating_add(hop_size);
42            Some(current)
43        })
44    }
45}
46
47impl<'a, T: AudioSample> AudioVoiceActivityDetection<'a, T> for AudioSamples<'a, T> {
48    fn voice_activity_mask<F: RealFloat>(
49        &self,
50        config: &VadConfig<F>,
51    ) -> AudioSampleResult<Vec<bool>> {
52        config.validate()?;
53
54        let plan = FramePlan {
55            frame_size: config.frame_size,
56            hop_size: config.hop_size,
57            pad_end: config.pad_end,
58        };
59
60        let total_len = self.samples_per_channel();
61        if total_len == 0 {
62            return Ok(Vec::new());
63        }
64
65        let energy_threshold_db_f64: f64 = to_precision(config.energy_threshold_db);
66        let energy_threshold_rms: f64 = 10f64.powf(energy_threshold_db_f64 / 20.0);
67
68        // Raw decisions per frame.
69        let mut mask = Vec::new();
70
71        match self.as_mono() {
72            Some(mono) => {
73                let slice = mono.as_slice().ok_or(AudioSampleError::Layout(
74                    LayoutError::NonContiguous {
75                        operation: "voice activity detection".to_string(),
76                        layout_type: "non-contiguous mono array".to_string(),
77                    },
78                ))?;
79
80                for start in plan.frame_starts(total_len) {
81                    let decision = vad_decision_for_slice(
82                        slice,
83                        start,
84                        total_len,
85                        config,
86                        energy_threshold_rms,
87                    )?;
88                    mask.push(decision);
89                }
90            }
91            None => {
92                let multi = self.as_multi_channel().ok_or(AudioSampleError::Parameter(
93                    ParameterError::invalid_value("audio_data", "Audio must be multi-channel"),
94                ))?;
95
96                let multi_view = multi.as_view();
97                let channels = multi_view.nrows();
98                let samples = multi_view.ncols();
99
100                if samples != total_len {
101                    return Err(AudioSampleError::Layout(LayoutError::DimensionMismatch {
102                        expected_dims: format!("samples_per_channel={total_len}"),
103                        actual_dims: format!("ncols={samples}"),
104                        operation: "voice activity detection".to_string(),
105                    }));
106                }
107
108                for start in plan.frame_starts(total_len) {
109                    let decision = vad_decision_for_multi(
110                        multi_view,
111                        channels,
112                        samples,
113                        start,
114                        config,
115                        energy_threshold_rms,
116                    )?;
117                    mask.push(decision);
118                }
119            }
120        }
121
122        // Post-processing: smoothing + hangover + min region duration.
123        let mut mask = majority_smooth(mask, config.smooth_frames);
124        apply_hangover(&mut mask, config.hangover_frames);
125        remove_short_runs(&mut mask, true, config.min_speech_frames);
126        remove_short_runs(&mut mask, false, config.min_silence_frames);
127
128        Ok(mask)
129    }
130
131    fn speech_regions<F: RealFloat>(
132        &self,
133        config: &VadConfig<F>,
134    ) -> AudioSampleResult<Vec<(usize, usize)>> {
135        let mask = self.voice_activity_mask(config)?;
136        if mask.is_empty() {
137            return Ok(Vec::new());
138        }
139
140        let plan = FramePlan {
141            frame_size: config.frame_size,
142            hop_size: config.hop_size,
143            pad_end: config.pad_end,
144        };
145
146        let total_len = self.samples_per_channel();
147        let mut regions = Vec::new();
148
149        let mut in_region = false;
150        let mut region_start_sample = 0usize;
151        let mut last_frame_end_sample = 0usize;
152
153        for (frame_idx, start) in plan.frame_starts(total_len).enumerate() {
154            let is_speech = mask.get(frame_idx).copied().unwrap_or(false);
155
156            let frame_end = start.saturating_add(config.frame_size).min(total_len);
157            last_frame_end_sample = frame_end;
158
159            if is_speech {
160                if !in_region {
161                    in_region = true;
162                    region_start_sample = start;
163                }
164            } else if in_region {
165                regions.push((region_start_sample, frame_end));
166                in_region = false;
167            }
168        }
169
170        if in_region {
171            regions.push((region_start_sample, last_frame_end_sample));
172        }
173
174        Ok(merge_overlapping_regions(regions))
175    }
176}
177
178fn vad_decision_for_slice<T: AudioSample, F: RealFloat>(
179    samples: &[T],
180    frame_start: usize,
181    total_len: usize,
182    config: &VadConfig<F>,
183    energy_threshold_rms: f64,
184) -> AudioSampleResult<bool> {
185    let (rms, zcr) = frame_rms_and_zcr(
186        samples,
187        frame_start,
188        total_len,
189        config.frame_size,
190        config.pad_end,
191    );
192    classify(config, rms, zcr, energy_threshold_rms)
193}
194
195fn vad_decision_for_multi<T: AudioSample, F: RealFloat>(
196    multi: ndarray::ArrayView2<'_, T>,
197    channels: usize,
198    samples: usize,
199    frame_start: usize,
200    config: &VadConfig<F>,
201    energy_threshold_rms: f64,
202) -> AudioSampleResult<bool> {
203    let plan_frame_size = config.frame_size;
204
205    match &config.channel_policy {
206        VadChannelPolicy::Channel(ch) => {
207            if *ch >= channels {
208                return Err(AudioSampleError::Parameter(ParameterError::out_of_range(
209                    "channel_policy",
210                    ch.to_string(),
211                    "0".to_string(),
212                    (channels.saturating_sub(1)).to_string(),
213                    "channel index out of range".to_string(),
214                )));
215            }
216            let row = multi.row(*ch);
217            let slice =
218                row.as_slice()
219                    .ok_or(AudioSampleError::Layout(LayoutError::NonContiguous {
220                        operation: "voice activity detection".to_string(),
221                        layout_type: "non-contiguous channel row".to_string(),
222                    }))?;
223            let (rms, zcr) =
224                frame_rms_and_zcr(slice, frame_start, samples, plan_frame_size, config.pad_end);
225            classify(config, rms, zcr, energy_threshold_rms)
226        }
227        VadChannelPolicy::AverageToMono => {
228            // Fast path: contiguous ArrayView2.
229            if let Some(flat) = multi.as_slice() {
230                let (rms, zcr) = frame_rms_and_zcr_avg(
231                    flat,
232                    channels,
233                    samples,
234                    frame_start,
235                    plan_frame_size,
236                    config.pad_end,
237                );
238                classify(config, rms, zcr, energy_threshold_rms)
239            } else {
240                let (rms, zcr) = frame_rms_and_zcr_avg_indexed(
241                    &multi,
242                    channels,
243                    samples,
244                    frame_start,
245                    plan_frame_size,
246                    config.pad_end,
247                );
248                classify(config, rms, zcr, energy_threshold_rms)
249            }
250        }
251        VadChannelPolicy::AnyChannel | VadChannelPolicy::AllChannels => {
252            let want_any = matches!(config.channel_policy, VadChannelPolicy::AnyChannel);
253            let mut any_active = false;
254            let mut all_active = true;
255
256            for ch in 0..channels {
257                let row = multi.row(ch);
258                let (rms, zcr) = if let Some(slice) = row.as_slice() {
259                    frame_rms_and_zcr(slice, frame_start, samples, plan_frame_size, config.pad_end)
260                } else {
261                    frame_rms_and_zcr_indexed(
262                        &multi,
263                        ch,
264                        samples,
265                        frame_start,
266                        plan_frame_size,
267                        config.pad_end,
268                    )
269                };
270
271                let active = classify(config, rms, zcr, energy_threshold_rms)?;
272                any_active |= active;
273                all_active &= active;
274
275                if want_any && any_active {
276                    return Ok(true);
277                }
278                if !want_any && !all_active {
279                    return Ok(false);
280                }
281            }
282
283            Ok(if want_any { any_active } else { all_active })
284        }
285    }
286}
287
288fn classify<F: RealFloat>(
289    config: &VadConfig<F>,
290    rms: f64,
291    zcr: f64,
292    energy_threshold_rms: f64,
293) -> AudioSampleResult<bool> {
294    match config.method {
295        VadMethod::Energy => Ok(rms >= energy_threshold_rms),
296        VadMethod::ZeroCrossing => Ok(zcr >= to_precision::<f64, _>(config.zcr_min)
297            && zcr <= to_precision::<f64, _>(config.zcr_max)),
298        VadMethod::Combined => {
299            let zcr_ok = zcr >= to_precision::<f64, _>(config.zcr_min)
300                && zcr <= to_precision::<f64, _>(config.zcr_max);
301            Ok(rms >= energy_threshold_rms && zcr_ok)
302        }
303        VadMethod::Spectral => Err(AudioSampleError::Feature(FeatureError::not_enabled(
304            "spectral-analysis",
305            "VAD spectral mode",
306        ))),
307    }
308}
309
310fn frame_rms_and_zcr<T: AudioSample>(
311    samples: &[T],
312    frame_start: usize,
313    total_len: usize,
314    frame_size: usize,
315    pad_end: bool,
316) -> (f64, f64) {
317    let available = total_len.saturating_sub(frame_start);
318    if available == 0 || frame_size == 0 {
319        return (0.0, 0.0);
320    }
321
322    let frame_len = available.min(frame_size);
323    let denom_len = if pad_end { frame_size } else { frame_len };
324
325    let mut sum_sq = 0.0f64;
326    let mut zc = 0usize;
327
328    let mut prev_sign = 0i8;
329
330    for i in 0..frame_len {
331        let x: f64 = samples[frame_start + i].convert_to();
332        sum_sq += x * x;
333
334        let sign = if x > 0.0 {
335            1
336        } else if x < 0.0 {
337            -1
338        } else {
339            0
340        };
341        if i > 0 && sign != 0 && prev_sign != 0 && sign != prev_sign {
342            zc += 1;
343        }
344        if sign != 0 {
345            prev_sign = sign;
346        }
347    }
348
349    let rms = if denom_len == 0 {
350        0.0
351    } else {
352        (sum_sq / denom_len as f64).sqrt()
353    };
354
355    let zcr_denom = (frame_size.saturating_sub(1)).max(1) as f64;
356    let zcr = zc as f64 / zcr_denom;
357
358    (rms, zcr)
359}
360
361fn frame_rms_and_zcr_avg<T: AudioSample>(
362    flat: &[T],
363    channels: usize,
364    samples: usize,
365    frame_start: usize,
366    frame_size: usize,
367    pad_end: bool,
368) -> (f64, f64) {
369    let available = samples.saturating_sub(frame_start);
370    if available == 0 || frame_size == 0 || channels == 0 {
371        return (0.0, 0.0);
372    }
373
374    let frame_len = available.min(frame_size);
375    let denom_len = if pad_end { frame_size } else { frame_len };
376
377    let mut sum_sq = 0.0f64;
378    let mut zc = 0usize;
379    let mut prev_sign = 0i8;
380
381    for i in 0..frame_len {
382        let col = frame_start + i;
383        let mut acc = 0.0f64;
384        for ch in 0..channels {
385            let idx = ch * samples + col;
386            let v: f64 = flat[idx].convert_to();
387            acc += v;
388        }
389        let x = acc / channels as f64;
390        sum_sq += x * x;
391
392        let sign = if x > 0.0 {
393            1
394        } else if x < 0.0 {
395            -1
396        } else {
397            0
398        };
399        if i > 0 && sign != 0 && prev_sign != 0 && sign != prev_sign {
400            zc += 1;
401        }
402        if sign != 0 {
403            prev_sign = sign;
404        }
405    }
406
407    let rms = if denom_len == 0 {
408        0.0
409    } else {
410        (sum_sq / denom_len as f64).sqrt()
411    };
412
413    let zcr_denom = (frame_size.saturating_sub(1)).max(1) as f64;
414    let zcr = zc as f64 / zcr_denom;
415
416    (rms, zcr)
417}
418
419fn frame_rms_and_zcr_avg_indexed<T: AudioSample>(
420    multi: &ndarray::ArrayView2<'_, T>,
421    channels: usize,
422    samples: usize,
423    frame_start: usize,
424    frame_size: usize,
425    pad_end: bool,
426) -> (f64, f64) {
427    let available = samples.saturating_sub(frame_start);
428    if available == 0 || frame_size == 0 || channels == 0 {
429        return (0.0, 0.0);
430    }
431
432    let frame_len = available.min(frame_size);
433    let denom_len = if pad_end { frame_size } else { frame_len };
434
435    let mut sum_sq = 0.0f64;
436    let mut zc = 0usize;
437    let mut prev_sign = 0i8;
438
439    for i in 0..frame_len {
440        let col = frame_start + i;
441        let mut acc = 0.0f64;
442        for ch in 0..channels {
443            let v: f64 = multi[(ch, col)].convert_to();
444            acc += v;
445        }
446        let x = acc / channels as f64;
447        sum_sq += x * x;
448
449        let sign = if x > 0.0 {
450            1
451        } else if x < 0.0 {
452            -1
453        } else {
454            0
455        };
456        if i > 0 && sign != 0 && prev_sign != 0 && sign != prev_sign {
457            zc += 1;
458        }
459        if sign != 0 {
460            prev_sign = sign;
461        }
462    }
463
464    let rms = if denom_len == 0 {
465        0.0
466    } else {
467        (sum_sq / denom_len as f64).sqrt()
468    };
469
470    let zcr_denom = (frame_size.saturating_sub(1)).max(1) as f64;
471    let zcr = zc as f64 / zcr_denom;
472
473    (rms, zcr)
474}
475
476fn frame_rms_and_zcr_indexed<T: AudioSample>(
477    multi: &ndarray::ArrayView2<'_, T>,
478    ch: usize,
479    samples: usize,
480    frame_start: usize,
481    frame_size: usize,
482    pad_end: bool,
483) -> (f64, f64) {
484    let available = samples.saturating_sub(frame_start);
485    if available == 0 || frame_size == 0 {
486        return (0.0, 0.0);
487    }
488
489    let frame_len = available.min(frame_size);
490    let denom_len = if pad_end { frame_size } else { frame_len };
491
492    let mut sum_sq = 0.0f64;
493    let mut zc = 0usize;
494    let mut prev_sign = 0i8;
495
496    for i in 0..frame_len {
497        let col = frame_start + i;
498        let x: f64 = multi[(ch, col)].convert_to();
499        sum_sq += x * x;
500
501        let sign = if x > 0.0 {
502            1
503        } else if x < 0.0 {
504            -1
505        } else {
506            0
507        };
508        if i > 0 && sign != 0 && prev_sign != 0 && sign != prev_sign {
509            zc += 1;
510        }
511        if sign != 0 {
512            prev_sign = sign;
513        }
514    }
515
516    let rms = if denom_len == 0 {
517        0.0
518    } else {
519        (sum_sq / denom_len as f64).sqrt()
520    };
521
522    let zcr_denom = (frame_size.saturating_sub(1)).max(1) as f64;
523    let zcr = zc as f64 / zcr_denom;
524
525    (rms, zcr)
526}
527
528fn majority_smooth(mask: Vec<bool>, window: usize) -> Vec<bool> {
529    if window <= 1 || mask.is_empty() {
530        return mask;
531    }
532
533    let w = window;
534    let mut out = vec![false; mask.len()];
535
536    let mut sum = 0i32;
537    let mut ring = vec![false; w];
538
539    for i in 0..mask.len() {
540        let incoming = mask[i];
541        let idx = i % w;
542        let outgoing = ring[idx];
543
544        ring[idx] = incoming;
545        sum += if incoming { 1 } else { 0 };
546        sum -= if outgoing { 1 } else { 0 };
547
548        // For the first (w-1) entries, only use the available prefix.
549        let denom = (i + 1).min(w) as i32;
550        out[i] = sum * 2 >= denom;
551    }
552
553    out
554}
555
556fn apply_hangover(mask: &mut [bool], hangover_frames: usize) {
557    if hangover_frames == 0 || mask.is_empty() {
558        return;
559    }
560
561    let mut hold = 0usize;
562    for v in mask.iter_mut() {
563        if *v {
564            hold = hangover_frames;
565        } else if hold > 0 {
566            *v = true;
567            hold -= 1;
568        }
569    }
570}
571
572fn remove_short_runs(mask: &mut [bool], value: bool, min_len: usize) {
573    if min_len == 0 || mask.is_empty() {
574        return;
575    }
576
577    let mut i = 0usize;
578    while i < mask.len() {
579        if mask[i] != value {
580            i += 1;
581            continue;
582        }
583
584        let start = i;
585        while i < mask.len() && mask[i] == value {
586            i += 1;
587        }
588        let end = i;
589
590        if end - start < min_len {
591            for v in &mut mask[start..end] {
592                *v = !value;
593            }
594        }
595    }
596}
597
598fn merge_overlapping_regions(mut regions: Vec<(usize, usize)>) -> Vec<(usize, usize)> {
599    if regions.len() <= 1 {
600        return regions;
601    }
602
603    regions.sort_by_key(|(s, _)| *s);
604
605    let mut out = Vec::with_capacity(regions.len());
606    let mut current = regions[0];
607
608    for (s, e) in regions.into_iter().skip(1) {
609        if s <= current.1 {
610            current.1 = current.1.max(e);
611        } else {
612            out.push(current);
613            current = (s, e);
614        }
615    }
616
617    out.push(current);
618    out
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624    use ndarray::Array1;
625
626    #[test]
627    fn vad_silence_is_inactive() {
628        let audio = AudioSamples::new_mono(Array1::<f32>::zeros(4096), crate::sample_rate!(44100));
629        let cfg = VadConfig::<f32> {
630            energy_threshold_db: to_precision(-50.0),
631            ..VadConfig::energy_only()
632        };
633
634        let mask = audio.voice_activity_mask(&cfg).unwrap();
635        assert!(!mask.is_empty());
636        assert!(mask.iter().all(|&v| !v));
637    }
638
639    #[test]
640    fn vad_tone_is_active() {
641        let n = 4096;
642        let sr = 44100.0;
643        let freq = 440.0;
644        let data: Vec<f32> = (0..n)
645            .map(|i| {
646                let t = i as f32 / sr as f32;
647                (2.0 * std::f32::consts::PI * freq as f32 * t).sin() * 0.5
648            })
649            .collect();
650        let audio = AudioSamples::new_mono(Array1::from(data), crate::sample_rate!(44100));
651
652        let cfg = VadConfig::<f32> {
653            energy_threshold_db: to_precision(-35.0),
654            ..VadConfig::energy_only()
655        };
656
657        let mask = audio.voice_activity_mask(&cfg).unwrap();
658        assert!(!mask.is_empty());
659        assert!(mask.iter().any(|&v| v));
660    }
661
662    #[test]
663    fn speech_regions_are_non_overlapping() {
664        let mut data = vec![0.0f32; 4096];
665        for i in 2048..4096 {
666            data[i] = 0.2;
667        }
668        let audio = AudioSamples::new_mono(Array1::from(data), crate::sample_rate!(44100));
669
670        let cfg = VadConfig::<f32> {
671            frame_size: 512,
672            hop_size: 256,
673            energy_threshold_db: to_precision(-45.0),
674            ..VadConfig::energy_only()
675        };
676
677        let regions = audio.speech_regions(&cfg).unwrap();
678        for w in regions.windows(2) {
679            assert!(w[0].1 <= w[1].0);
680        }
681    }
682}