use crate::vad::{WhisperVadProcessor, VadParams};
use crate::error::Result;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct EnhancedVadParams {
pub base: VadParams,
pub max_segment_duration_s: f32,
pub merge_segments: bool,
pub min_gap_ms: i32,
}
impl Default for EnhancedVadParams {
fn default() -> Self {
Self {
base: VadParams::default(),
max_segment_duration_s: 30.0,
merge_segments: true,
min_gap_ms: 100,
}
}
}
pub struct EnhancedWhisperVadProcessor {
inner: WhisperVadProcessor,
}
impl EnhancedWhisperVadProcessor {
pub fn new<P: AsRef<Path>>(model_path: P) -> Result<Self> {
Ok(Self {
inner: WhisperVadProcessor::new(model_path)?,
})
}
pub fn process_with_aggregation(
&mut self,
audio: &[f32],
params: &EnhancedVadParams,
) -> Result<Vec<AudioChunk>> {
let segments = self.inner.segments_from_samples(audio, ¶ms.base)?;
let raw_segments = segments.get_all_segments();
let aggregated = self.aggregate_segments(
raw_segments,
params.max_segment_duration_s,
params.min_gap_ms,
params.merge_segments,
);
let chunks = self.extract_audio_chunks(audio, aggregated, 16000.0);
Ok(chunks)
}
#[doc(hidden)]
pub fn aggregate_segments(
&self,
segments: Vec<(f32, f32)>,
max_duration: f32,
min_gap_ms: i32,
merge: bool,
) -> Vec<(f32, f32)> {
if segments.is_empty() {
return Vec::new();
}
let mut aggregated = Vec::new();
let min_gap = min_gap_ms as f32 / 1000.0;
let mut current_start = segments[0].0;
let mut current_end = segments[0].1;
for (start, end) in segments.iter().skip(1) {
let gap = start - current_end;
let combined_duration = end - current_start;
if merge && gap < min_gap && combined_duration <= max_duration {
current_end = *end;
} else {
aggregated.push((current_start, current_end));
current_start = *start;
current_end = *end;
}
}
aggregated.push((current_start, current_end));
aggregated
}
fn extract_audio_chunks(
&self,
audio: &[f32],
segments: Vec<(f32, f32)>,
sample_rate: f32,
) -> Vec<AudioChunk> {
segments
.into_iter()
.map(|(start, end)| {
let start_sample = (start * sample_rate) as usize;
let end_sample = ((end * sample_rate) as usize).min(audio.len());
AudioChunk {
audio: audio[start_sample..end_sample].to_vec(),
offset_seconds: start,
duration_seconds: end - start,
metadata: ChunkMetadata {
original_start: start,
original_end: end,
sample_offset: start_sample,
},
}
})
.collect()
}
}
#[derive(Debug, Clone)]
pub struct AudioChunk {
pub audio: Vec<f32>,
pub offset_seconds: f32,
pub duration_seconds: f32,
pub metadata: ChunkMetadata,
}
#[derive(Debug, Clone)]
pub struct ChunkMetadata {
pub original_start: f32,
pub original_end: f32,
pub sample_offset: usize,
}
pub struct EnhancedVadParamsBuilder {
params: EnhancedVadParams,
}
impl EnhancedVadParamsBuilder {
pub fn new() -> Self {
Self {
params: EnhancedVadParams::default(),
}
}
pub fn threshold(mut self, threshold: f32) -> Self {
self.params.base.threshold = threshold;
self
}
pub fn max_segment_duration(mut self, seconds: f32) -> Self {
self.params.max_segment_duration_s = seconds;
self
}
pub fn merge_segments(mut self, merge: bool) -> Self {
self.params.merge_segments = merge;
self
}
pub fn min_gap_ms(mut self, ms: i32) -> Self {
self.params.min_gap_ms = ms;
self
}
pub fn speech_pad_ms(mut self, ms: i32) -> Self {
self.params.base.speech_pad_ms = ms;
self
}
pub fn build(self) -> EnhancedVadParams {
self.params
}
}
impl Default for EnhancedVadParamsBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_segment_aggregation() {
let processor = EnhancedWhisperVadProcessor {
inner: unsafe { std::mem::zeroed() }, };
let segments = vec![
(0.0, 2.0),
(2.1, 4.0), (4.5, 6.0), (10.0, 12.0), ];
let aggregated = processor.aggregate_segments(segments, 30.0, 100, true);
assert_eq!(aggregated.len(), 3);
assert_eq!(aggregated[0], (0.0, 4.0)); assert_eq!(aggregated[1], (4.5, 6.0));
assert_eq!(aggregated[2], (10.0, 12.0));
}
#[test]
fn test_max_duration_split() {
let processor = EnhancedWhisperVadProcessor {
inner: unsafe { std::mem::zeroed() },
};
let segments = vec![
(0.0, 20.0),
(20.1, 40.0), ];
let aggregated = processor.aggregate_segments(segments, 30.0, 100, true);
assert_eq!(aggregated.len(), 2); }
#[test]
fn test_enhanced_vad_params_builder() {
let params = EnhancedVadParamsBuilder::new()
.threshold(0.6)
.max_segment_duration(25.0)
.merge_segments(false)
.min_gap_ms(200)
.build();
assert_eq!(params.base.threshold, 0.6);
assert_eq!(params.max_segment_duration_s, 25.0);
assert!(!params.merge_segments);
assert_eq!(params.min_gap_ms, 200);
}
}