use anyhow::{Result, anyhow};
use whisper_rs::{WhisperVadContext, WhisperVadParams, WhisperVadSegments};
pub fn to_speech_only_with_policy(
ctx: &mut WhisperVadContext,
sample_rate_hz: u32,
samples: &mut [f32],
policy: VadPolicy,
) -> Result<bool> {
let mut vad_params = WhisperVadParams::default();
vad_params.set_max_speech_duration(15.0);
vad_params.set_threshold(policy.threshold);
vad_params.set_min_speech_duration(policy.min_speech_ms as i32);
let segments = ctx.segments_from_samples(vad_params, samples)?;
let Some(ranges) = speech_ranges_with_policy(sample_rate_hz, &segments, samples, policy)?
else {
return Ok(false);
};
apply_non_speech_gain_in_place(samples, &ranges, policy.non_speech_gain);
Ok(true)
}
fn speech_ranges_with_policy(
sample_rate_hz: u32,
segments: &WhisperVadSegments,
samples: &[f32],
policy: VadPolicy,
) -> Result<Option<Vec<(usize, usize)>>> {
let n = segments.num_segments();
if n == 0 {
return Ok(None);
}
let sample_rate = sample_rate_hz as f32;
let pre_pad_samples = ms_to_samples(policy.pre_pad_ms, sample_rate);
let post_pad_samples = ms_to_samples(policy.post_pad_ms, sample_rate);
let min_speech_samples = ms_to_samples(policy.min_speech_ms, sample_rate);
let gap_merge_samples = ms_to_samples(policy.gap_merge_ms, sample_rate);
let mut ranges: Vec<(usize, usize)> = Vec::new();
for i in 0..n {
let (mut start_idx, mut end_idx) =
segment_sample_indexes(segments, i, sample_rate, samples.len())?;
let dur = end_idx.saturating_sub(start_idx);
if dur < min_speech_samples {
continue;
}
start_idx = start_idx.saturating_sub(pre_pad_samples);
end_idx = (end_idx + post_pad_samples).min(samples.len());
if start_idx >= end_idx {
continue;
}
if let Some((_, prev_end)) = ranges.last_mut() {
let gap = start_idx.saturating_sub(*prev_end);
if start_idx <= *prev_end || gap <= gap_merge_samples {
*prev_end = (*prev_end).max(end_idx);
continue;
}
}
ranges.push((start_idx, end_idx));
}
if ranges.is_empty() {
return Ok(None);
}
Ok(Some(ranges))
}
fn apply_non_speech_gain_in_place(samples: &mut [f32], ranges: &[(usize, usize)], gain: f32) {
let gain = gain.clamp(0.0, 1.0);
if (gain - 1.0).abs() < f32::EPSILON {
return;
}
let mut cursor = 0usize;
for &(s, e) in ranges {
let s = s.min(samples.len());
let e = e.min(samples.len());
if s > cursor {
scale_samples(&mut samples[cursor..s], gain);
}
cursor = cursor.max(e);
}
if cursor < samples.len() {
scale_samples(&mut samples[cursor..], gain);
}
}
fn scale_samples(buf: &mut [f32], gain: f32) {
if gain == 0.0 {
buf.fill(0.0);
return;
}
for s in buf.iter_mut() {
*s *= gain;
}
}
fn ms_to_samples(ms: u32, sample_rate: f32) -> usize {
((ms as f32 / 1000.0) * sample_rate).round() as usize
}
fn segment_sample_indexes(
segments: &WhisperVadSegments,
i: i32,
sample_rate: f32,
samples_len: usize,
) -> Result<(usize, usize)> {
let start_cs = segments
.get_segment_start_timestamp(i)
.ok_or_else(|| anyhow!("missing start timestamp for VAD segment {i}"))?;
let end_cs = segments
.get_segment_end_timestamp(i)
.ok_or_else(|| anyhow!("missing end timestamp for VAD segment {i}"))?;
let start_sec = start_cs / 100.0;
let end_sec = end_cs / 100.0;
let mut start_idx = (start_sec * sample_rate).floor() as usize;
let mut end_idx = (end_sec * sample_rate).ceil() as usize;
start_idx = start_idx.min(samples_len);
end_idx = end_idx.min(samples_len);
if end_idx < start_idx {
end_idx = start_idx;
}
Ok((start_idx, end_idx))
}
#[derive(Debug, Clone, Copy)]
pub struct VadPolicy {
pub threshold: f32,
pub pre_pad_ms: u32,
pub post_pad_ms: u32,
pub min_speech_ms: u32,
pub gap_merge_ms: u32,
pub non_speech_gain: f32,
}
pub const DEFAULT_VAD_POLICY: VadPolicy = VadPolicy {
threshold: 0.5,
pre_pad_ms: 250,
post_pad_ms: 250,
min_speech_ms: 250,
gap_merge_ms: 300,
non_speech_gain: 0.0,
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn apply_non_speech_gain_in_place_mutes_only_non_speech_ranges() {
let mut samples = vec![1.0, 1.0, 1.0, 1.0, 1.0];
let ranges = vec![(1usize, 3usize)];
apply_non_speech_gain_in_place(&mut samples, &ranges, 0.0);
assert_eq!(samples, vec![0.0, 1.0, 1.0, 0.0, 0.0]);
}
#[test]
fn apply_non_speech_gain_in_place_clamps_gain_to_one() {
let mut samples = vec![0.25, 0.5, 1.0];
let ranges = vec![(0usize, 3usize)];
apply_non_speech_gain_in_place(&mut samples, &ranges, 2.0);
assert_eq!(samples, vec![0.25, 0.5, 1.0]);
}
#[test]
fn scale_samples_fast_path_for_zero_gain() {
let mut buf = vec![0.25, -0.5];
scale_samples(&mut buf, 0.0);
assert_eq!(buf, vec![0.0, 0.0]);
}
#[test]
fn ms_to_samples_rounds_to_nearest_sample() {
assert_eq!(ms_to_samples(0, 16_000.0), 0);
assert_eq!(ms_to_samples(1, 16_000.0), 16);
assert_eq!(ms_to_samples(33, 16_000.0), 528);
}
}