1use crate::operations::traits::AudioVoiceActivityDetection;
8use crate::operations::types::{VadChannelPolicy, VadConfig, VadMethod};
9use crate::{
10 AudioSample, AudioSampleError, AudioSampleResult, AudioSamples, FeatureError, LayoutError,
11 ParameterError, RealFloat, to_precision,
12};
13
14#[derive(Debug, Clone, Copy)]
16struct FramePlan {
17 frame_size: usize,
18 hop_size: usize,
19 pad_end: bool,
20}
21
22impl FramePlan {
23 fn frame_starts(self, total_len: usize) -> impl Iterator<Item = usize> {
24 let FramePlan {
25 frame_size,
26 hop_size,
27 pad_end,
28 } = self;
29
30 let mut start = 0usize;
31 std::iter::from_fn(move || {
32 if pad_end {
33 if start >= total_len {
34 return None;
35 }
36 } else if start + frame_size > total_len {
37 return None;
38 }
39
40 let current = start;
41 start = start.saturating_add(hop_size);
42 Some(current)
43 })
44 }
45}
46
47impl<'a, T: AudioSample> AudioVoiceActivityDetection<'a, T> for AudioSamples<'a, T> {
48 fn voice_activity_mask<F: RealFloat>(
49 &self,
50 config: &VadConfig<F>,
51 ) -> AudioSampleResult<Vec<bool>> {
52 config.validate()?;
53
54 let plan = FramePlan {
55 frame_size: config.frame_size,
56 hop_size: config.hop_size,
57 pad_end: config.pad_end,
58 };
59
60 let total_len = self.samples_per_channel();
61 if total_len == 0 {
62 return Ok(Vec::new());
63 }
64
65 let energy_threshold_db_f64: f64 = to_precision(config.energy_threshold_db);
66 let energy_threshold_rms: f64 = 10f64.powf(energy_threshold_db_f64 / 20.0);
67
68 let mut mask = Vec::new();
70
71 match self.as_mono() {
72 Some(mono) => {
73 let slice = mono.as_slice().ok_or(AudioSampleError::Layout(
74 LayoutError::NonContiguous {
75 operation: "voice activity detection".to_string(),
76 layout_type: "non-contiguous mono array".to_string(),
77 },
78 ))?;
79
80 for start in plan.frame_starts(total_len) {
81 let decision = vad_decision_for_slice(
82 slice,
83 start,
84 total_len,
85 config,
86 energy_threshold_rms,
87 )?;
88 mask.push(decision);
89 }
90 }
91 None => {
92 let multi = self.as_multi_channel().ok_or(AudioSampleError::Parameter(
93 ParameterError::invalid_value("audio_data", "Audio must be multi-channel"),
94 ))?;
95
96 let multi_view = multi.as_view();
97 let channels = multi_view.nrows();
98 let samples = multi_view.ncols();
99
100 if samples != total_len {
101 return Err(AudioSampleError::Layout(LayoutError::DimensionMismatch {
102 expected_dims: format!("samples_per_channel={total_len}"),
103 actual_dims: format!("ncols={samples}"),
104 operation: "voice activity detection".to_string(),
105 }));
106 }
107
108 for start in plan.frame_starts(total_len) {
109 let decision = vad_decision_for_multi(
110 multi_view,
111 channels,
112 samples,
113 start,
114 config,
115 energy_threshold_rms,
116 )?;
117 mask.push(decision);
118 }
119 }
120 }
121
122 let mut mask = majority_smooth(mask, config.smooth_frames);
124 apply_hangover(&mut mask, config.hangover_frames);
125 remove_short_runs(&mut mask, true, config.min_speech_frames);
126 remove_short_runs(&mut mask, false, config.min_silence_frames);
127
128 Ok(mask)
129 }
130
131 fn speech_regions<F: RealFloat>(
132 &self,
133 config: &VadConfig<F>,
134 ) -> AudioSampleResult<Vec<(usize, usize)>> {
135 let mask = self.voice_activity_mask(config)?;
136 if mask.is_empty() {
137 return Ok(Vec::new());
138 }
139
140 let plan = FramePlan {
141 frame_size: config.frame_size,
142 hop_size: config.hop_size,
143 pad_end: config.pad_end,
144 };
145
146 let total_len = self.samples_per_channel();
147 let mut regions = Vec::new();
148
149 let mut in_region = false;
150 let mut region_start_sample = 0usize;
151 let mut last_frame_end_sample = 0usize;
152
153 for (frame_idx, start) in plan.frame_starts(total_len).enumerate() {
154 let is_speech = mask.get(frame_idx).copied().unwrap_or(false);
155
156 let frame_end = start.saturating_add(config.frame_size).min(total_len);
157 last_frame_end_sample = frame_end;
158
159 if is_speech {
160 if !in_region {
161 in_region = true;
162 region_start_sample = start;
163 }
164 } else if in_region {
165 regions.push((region_start_sample, frame_end));
166 in_region = false;
167 }
168 }
169
170 if in_region {
171 regions.push((region_start_sample, last_frame_end_sample));
172 }
173
174 Ok(merge_overlapping_regions(regions))
175 }
176}
177
178fn vad_decision_for_slice<T: AudioSample, F: RealFloat>(
179 samples: &[T],
180 frame_start: usize,
181 total_len: usize,
182 config: &VadConfig<F>,
183 energy_threshold_rms: f64,
184) -> AudioSampleResult<bool> {
185 let (rms, zcr) = frame_rms_and_zcr(
186 samples,
187 frame_start,
188 total_len,
189 config.frame_size,
190 config.pad_end,
191 );
192 classify(config, rms, zcr, energy_threshold_rms)
193}
194
195fn vad_decision_for_multi<T: AudioSample, F: RealFloat>(
196 multi: ndarray::ArrayView2<'_, T>,
197 channels: usize,
198 samples: usize,
199 frame_start: usize,
200 config: &VadConfig<F>,
201 energy_threshold_rms: f64,
202) -> AudioSampleResult<bool> {
203 let plan_frame_size = config.frame_size;
204
205 match &config.channel_policy {
206 VadChannelPolicy::Channel(ch) => {
207 if *ch >= channels {
208 return Err(AudioSampleError::Parameter(ParameterError::out_of_range(
209 "channel_policy",
210 ch.to_string(),
211 "0".to_string(),
212 (channels.saturating_sub(1)).to_string(),
213 "channel index out of range".to_string(),
214 )));
215 }
216 let row = multi.row(*ch);
217 let slice =
218 row.as_slice()
219 .ok_or(AudioSampleError::Layout(LayoutError::NonContiguous {
220 operation: "voice activity detection".to_string(),
221 layout_type: "non-contiguous channel row".to_string(),
222 }))?;
223 let (rms, zcr) =
224 frame_rms_and_zcr(slice, frame_start, samples, plan_frame_size, config.pad_end);
225 classify(config, rms, zcr, energy_threshold_rms)
226 }
227 VadChannelPolicy::AverageToMono => {
228 if let Some(flat) = multi.as_slice() {
230 let (rms, zcr) = frame_rms_and_zcr_avg(
231 flat,
232 channels,
233 samples,
234 frame_start,
235 plan_frame_size,
236 config.pad_end,
237 );
238 classify(config, rms, zcr, energy_threshold_rms)
239 } else {
240 let (rms, zcr) = frame_rms_and_zcr_avg_indexed(
241 &multi,
242 channels,
243 samples,
244 frame_start,
245 plan_frame_size,
246 config.pad_end,
247 );
248 classify(config, rms, zcr, energy_threshold_rms)
249 }
250 }
251 VadChannelPolicy::AnyChannel | VadChannelPolicy::AllChannels => {
252 let want_any = matches!(config.channel_policy, VadChannelPolicy::AnyChannel);
253 let mut any_active = false;
254 let mut all_active = true;
255
256 for ch in 0..channels {
257 let row = multi.row(ch);
258 let (rms, zcr) = if let Some(slice) = row.as_slice() {
259 frame_rms_and_zcr(slice, frame_start, samples, plan_frame_size, config.pad_end)
260 } else {
261 frame_rms_and_zcr_indexed(
262 &multi,
263 ch,
264 samples,
265 frame_start,
266 plan_frame_size,
267 config.pad_end,
268 )
269 };
270
271 let active = classify(config, rms, zcr, energy_threshold_rms)?;
272 any_active |= active;
273 all_active &= active;
274
275 if want_any && any_active {
276 return Ok(true);
277 }
278 if !want_any && !all_active {
279 return Ok(false);
280 }
281 }
282
283 Ok(if want_any { any_active } else { all_active })
284 }
285 }
286}
287
288fn classify<F: RealFloat>(
289 config: &VadConfig<F>,
290 rms: f64,
291 zcr: f64,
292 energy_threshold_rms: f64,
293) -> AudioSampleResult<bool> {
294 match config.method {
295 VadMethod::Energy => Ok(rms >= energy_threshold_rms),
296 VadMethod::ZeroCrossing => Ok(zcr >= to_precision::<f64, _>(config.zcr_min)
297 && zcr <= to_precision::<f64, _>(config.zcr_max)),
298 VadMethod::Combined => {
299 let zcr_ok = zcr >= to_precision::<f64, _>(config.zcr_min)
300 && zcr <= to_precision::<f64, _>(config.zcr_max);
301 Ok(rms >= energy_threshold_rms && zcr_ok)
302 }
303 VadMethod::Spectral => Err(AudioSampleError::Feature(FeatureError::not_enabled(
304 "spectral-analysis",
305 "VAD spectral mode",
306 ))),
307 }
308}
309
310fn frame_rms_and_zcr<T: AudioSample>(
311 samples: &[T],
312 frame_start: usize,
313 total_len: usize,
314 frame_size: usize,
315 pad_end: bool,
316) -> (f64, f64) {
317 let available = total_len.saturating_sub(frame_start);
318 if available == 0 || frame_size == 0 {
319 return (0.0, 0.0);
320 }
321
322 let frame_len = available.min(frame_size);
323 let denom_len = if pad_end { frame_size } else { frame_len };
324
325 let mut sum_sq = 0.0f64;
326 let mut zc = 0usize;
327
328 let mut prev_sign = 0i8;
329
330 for i in 0..frame_len {
331 let x: f64 = samples[frame_start + i].convert_to();
332 sum_sq += x * x;
333
334 let sign = if x > 0.0 {
335 1
336 } else if x < 0.0 {
337 -1
338 } else {
339 0
340 };
341 if i > 0 && sign != 0 && prev_sign != 0 && sign != prev_sign {
342 zc += 1;
343 }
344 if sign != 0 {
345 prev_sign = sign;
346 }
347 }
348
349 let rms = if denom_len == 0 {
350 0.0
351 } else {
352 (sum_sq / denom_len as f64).sqrt()
353 };
354
355 let zcr_denom = (frame_size.saturating_sub(1)).max(1) as f64;
356 let zcr = zc as f64 / zcr_denom;
357
358 (rms, zcr)
359}
360
361fn frame_rms_and_zcr_avg<T: AudioSample>(
362 flat: &[T],
363 channels: usize,
364 samples: usize,
365 frame_start: usize,
366 frame_size: usize,
367 pad_end: bool,
368) -> (f64, f64) {
369 let available = samples.saturating_sub(frame_start);
370 if available == 0 || frame_size == 0 || channels == 0 {
371 return (0.0, 0.0);
372 }
373
374 let frame_len = available.min(frame_size);
375 let denom_len = if pad_end { frame_size } else { frame_len };
376
377 let mut sum_sq = 0.0f64;
378 let mut zc = 0usize;
379 let mut prev_sign = 0i8;
380
381 for i in 0..frame_len {
382 let col = frame_start + i;
383 let mut acc = 0.0f64;
384 for ch in 0..channels {
385 let idx = ch * samples + col;
386 let v: f64 = flat[idx].convert_to();
387 acc += v;
388 }
389 let x = acc / channels as f64;
390 sum_sq += x * x;
391
392 let sign = if x > 0.0 {
393 1
394 } else if x < 0.0 {
395 -1
396 } else {
397 0
398 };
399 if i > 0 && sign != 0 && prev_sign != 0 && sign != prev_sign {
400 zc += 1;
401 }
402 if sign != 0 {
403 prev_sign = sign;
404 }
405 }
406
407 let rms = if denom_len == 0 {
408 0.0
409 } else {
410 (sum_sq / denom_len as f64).sqrt()
411 };
412
413 let zcr_denom = (frame_size.saturating_sub(1)).max(1) as f64;
414 let zcr = zc as f64 / zcr_denom;
415
416 (rms, zcr)
417}
418
419fn frame_rms_and_zcr_avg_indexed<T: AudioSample>(
420 multi: &ndarray::ArrayView2<'_, T>,
421 channels: usize,
422 samples: usize,
423 frame_start: usize,
424 frame_size: usize,
425 pad_end: bool,
426) -> (f64, f64) {
427 let available = samples.saturating_sub(frame_start);
428 if available == 0 || frame_size == 0 || channels == 0 {
429 return (0.0, 0.0);
430 }
431
432 let frame_len = available.min(frame_size);
433 let denom_len = if pad_end { frame_size } else { frame_len };
434
435 let mut sum_sq = 0.0f64;
436 let mut zc = 0usize;
437 let mut prev_sign = 0i8;
438
439 for i in 0..frame_len {
440 let col = frame_start + i;
441 let mut acc = 0.0f64;
442 for ch in 0..channels {
443 let v: f64 = multi[(ch, col)].convert_to();
444 acc += v;
445 }
446 let x = acc / channels as f64;
447 sum_sq += x * x;
448
449 let sign = if x > 0.0 {
450 1
451 } else if x < 0.0 {
452 -1
453 } else {
454 0
455 };
456 if i > 0 && sign != 0 && prev_sign != 0 && sign != prev_sign {
457 zc += 1;
458 }
459 if sign != 0 {
460 prev_sign = sign;
461 }
462 }
463
464 let rms = if denom_len == 0 {
465 0.0
466 } else {
467 (sum_sq / denom_len as f64).sqrt()
468 };
469
470 let zcr_denom = (frame_size.saturating_sub(1)).max(1) as f64;
471 let zcr = zc as f64 / zcr_denom;
472
473 (rms, zcr)
474}
475
476fn frame_rms_and_zcr_indexed<T: AudioSample>(
477 multi: &ndarray::ArrayView2<'_, T>,
478 ch: usize,
479 samples: usize,
480 frame_start: usize,
481 frame_size: usize,
482 pad_end: bool,
483) -> (f64, f64) {
484 let available = samples.saturating_sub(frame_start);
485 if available == 0 || frame_size == 0 {
486 return (0.0, 0.0);
487 }
488
489 let frame_len = available.min(frame_size);
490 let denom_len = if pad_end { frame_size } else { frame_len };
491
492 let mut sum_sq = 0.0f64;
493 let mut zc = 0usize;
494 let mut prev_sign = 0i8;
495
496 for i in 0..frame_len {
497 let col = frame_start + i;
498 let x: f64 = multi[(ch, col)].convert_to();
499 sum_sq += x * x;
500
501 let sign = if x > 0.0 {
502 1
503 } else if x < 0.0 {
504 -1
505 } else {
506 0
507 };
508 if i > 0 && sign != 0 && prev_sign != 0 && sign != prev_sign {
509 zc += 1;
510 }
511 if sign != 0 {
512 prev_sign = sign;
513 }
514 }
515
516 let rms = if denom_len == 0 {
517 0.0
518 } else {
519 (sum_sq / denom_len as f64).sqrt()
520 };
521
522 let zcr_denom = (frame_size.saturating_sub(1)).max(1) as f64;
523 let zcr = zc as f64 / zcr_denom;
524
525 (rms, zcr)
526}
527
528fn majority_smooth(mask: Vec<bool>, window: usize) -> Vec<bool> {
529 if window <= 1 || mask.is_empty() {
530 return mask;
531 }
532
533 let w = window;
534 let mut out = vec![false; mask.len()];
535
536 let mut sum = 0i32;
537 let mut ring = vec![false; w];
538
539 for i in 0..mask.len() {
540 let incoming = mask[i];
541 let idx = i % w;
542 let outgoing = ring[idx];
543
544 ring[idx] = incoming;
545 sum += if incoming { 1 } else { 0 };
546 sum -= if outgoing { 1 } else { 0 };
547
548 let denom = (i + 1).min(w) as i32;
550 out[i] = sum * 2 >= denom;
551 }
552
553 out
554}
555
556fn apply_hangover(mask: &mut [bool], hangover_frames: usize) {
557 if hangover_frames == 0 || mask.is_empty() {
558 return;
559 }
560
561 let mut hold = 0usize;
562 for v in mask.iter_mut() {
563 if *v {
564 hold = hangover_frames;
565 } else if hold > 0 {
566 *v = true;
567 hold -= 1;
568 }
569 }
570}
571
572fn remove_short_runs(mask: &mut [bool], value: bool, min_len: usize) {
573 if min_len == 0 || mask.is_empty() {
574 return;
575 }
576
577 let mut i = 0usize;
578 while i < mask.len() {
579 if mask[i] != value {
580 i += 1;
581 continue;
582 }
583
584 let start = i;
585 while i < mask.len() && mask[i] == value {
586 i += 1;
587 }
588 let end = i;
589
590 if end - start < min_len {
591 for v in &mut mask[start..end] {
592 *v = !value;
593 }
594 }
595 }
596}
597
598fn merge_overlapping_regions(mut regions: Vec<(usize, usize)>) -> Vec<(usize, usize)> {
599 if regions.len() <= 1 {
600 return regions;
601 }
602
603 regions.sort_by_key(|(s, _)| *s);
604
605 let mut out = Vec::with_capacity(regions.len());
606 let mut current = regions[0];
607
608 for (s, e) in regions.into_iter().skip(1) {
609 if s <= current.1 {
610 current.1 = current.1.max(e);
611 } else {
612 out.push(current);
613 current = (s, e);
614 }
615 }
616
617 out.push(current);
618 out
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624 use ndarray::Array1;
625
626 #[test]
627 fn vad_silence_is_inactive() {
628 let audio = AudioSamples::new_mono(Array1::<f32>::zeros(4096), crate::sample_rate!(44100));
629 let cfg = VadConfig::<f32> {
630 energy_threshold_db: to_precision(-50.0),
631 ..VadConfig::energy_only()
632 };
633
634 let mask = audio.voice_activity_mask(&cfg).unwrap();
635 assert!(!mask.is_empty());
636 assert!(mask.iter().all(|&v| !v));
637 }
638
639 #[test]
640 fn vad_tone_is_active() {
641 let n = 4096;
642 let sr = 44100.0;
643 let freq = 440.0;
644 let data: Vec<f32> = (0..n)
645 .map(|i| {
646 let t = i as f32 / sr as f32;
647 (2.0 * std::f32::consts::PI * freq as f32 * t).sin() * 0.5
648 })
649 .collect();
650 let audio = AudioSamples::new_mono(Array1::from(data), crate::sample_rate!(44100));
651
652 let cfg = VadConfig::<f32> {
653 energy_threshold_db: to_precision(-35.0),
654 ..VadConfig::energy_only()
655 };
656
657 let mask = audio.voice_activity_mask(&cfg).unwrap();
658 assert!(!mask.is_empty());
659 assert!(mask.iter().any(|&v| v));
660 }
661
662 #[test]
663 fn speech_regions_are_non_overlapping() {
664 let mut data = vec![0.0f32; 4096];
665 for i in 2048..4096 {
666 data[i] = 0.2;
667 }
668 let audio = AudioSamples::new_mono(Array1::from(data), crate::sample_rate!(44100));
669
670 let cfg = VadConfig::<f32> {
671 frame_size: 512,
672 hop_size: 256,
673 energy_threshold_db: to_precision(-45.0),
674 ..VadConfig::energy_only()
675 };
676
677 let regions = audio.speech_regions(&cfg).unwrap();
678 for w in regions.windows(2) {
679 assert!(w[0].1 <= w[1].0);
680 }
681 }
682}