use std::collections::VecDeque;
use tokio::sync::mpsc;
use crate::resample::WHISPER_SAMPLE_RATE;
use super::vad::VadState;
const OVERLAP_SECS: usize = 2;
const OVERLAP_SAMPLES: usize = OVERLAP_SECS * WHISPER_SAMPLE_RATE as usize;
#[derive(Debug, Clone)]
pub struct AudioChunk {
pub index: usize,
pub samples: Vec<f32>,
pub has_leading_overlap: bool,
}
#[derive(Debug, Clone)]
pub struct ChunkerConfig {
pub target_duration_secs: u64,
pub min_duration_secs: u64,
pub max_duration_secs: u64,
pub vad_aware: bool,
}
impl Default for ChunkerConfig {
fn default() -> Self {
Self {
target_duration_secs: 90,
min_duration_secs: 60,
max_duration_secs: 120,
vad_aware: true,
}
}
}
struct ChunkBuffer {
current_chunk: Vec<f32>,
overlap_buffer: VecDeque<f32>,
chunk_index: usize,
}
impl ChunkBuffer {
fn new() -> Self {
Self {
current_chunk: Vec::new(),
overlap_buffer: VecDeque::with_capacity(OVERLAP_SAMPLES + 1024),
chunk_index: 0,
}
}
fn add_samples(&mut self, samples: &[f32]) {
self.current_chunk.extend(samples);
self.overlap_buffer.extend(samples);
while self.overlap_buffer.len() > OVERLAP_SAMPLES {
self.overlap_buffer.pop_front();
}
}
fn duration_secs(&self) -> u64 {
(self.current_chunk.len() as f32 / WHISPER_SAMPLE_RATE as f32) as u64
}
fn create_chunk(&mut self) -> AudioChunk {
let chunk = AudioChunk {
index: self.chunk_index,
samples: std::mem::take(&mut self.current_chunk),
has_leading_overlap: self.chunk_index > 0,
};
self.current_chunk.extend(self.overlap_buffer.iter());
self.chunk_index += 1;
chunk
}
fn create_final_chunk(&mut self) -> Option<AudioChunk> {
if self.current_chunk.is_empty() {
return None;
}
Some(AudioChunk {
index: self.chunk_index,
samples: std::mem::take(&mut self.current_chunk),
has_leading_overlap: self.chunk_index > 0,
})
}
}
pub struct ProgressiveChunker {
config: ChunkerConfig,
buffer: ChunkBuffer,
chunk_tx: mpsc::UnboundedSender<AudioChunk>,
}
impl ProgressiveChunker {
pub fn new(config: ChunkerConfig, chunk_tx: mpsc::UnboundedSender<AudioChunk>) -> Self {
Self {
config,
buffer: ChunkBuffer::new(),
chunk_tx,
}
}
fn should_chunk(&self, vad_state: Option<VadState>) -> bool {
let duration = self.buffer.duration_secs();
if let Some(state) = vad_state
&& self.config.vad_aware
{
if duration >= self.config.min_duration_secs && state.is_silence() {
return true;
}
if duration >= self.config.max_duration_secs {
return true;
}
return false;
}
duration >= self.config.target_duration_secs
}
pub async fn consume_stream(
&mut self,
mut audio_rx: mpsc::UnboundedReceiver<Vec<f32>>,
mut vad_state_rx: Option<mpsc::UnboundedReceiver<VadState>>,
) -> Result<(), String> {
let mut current_vad_state: Option<VadState> = None;
loop {
tokio::select! {
Some(samples) = audio_rx.recv() => {
self.buffer.add_samples(&samples);
if self.should_chunk(current_vad_state) {
let chunk = self.buffer.create_chunk();
crate::verbose!(
"Created chunk {} ({:.1}s)",
chunk.index,
chunk.samples.len() as f32 / WHISPER_SAMPLE_RATE as f32
);
self.chunk_tx.send(chunk).map_err(|e| e.to_string())?;
}
}
Some(state) = async {
match &mut vad_state_rx {
Some(rx) => rx.recv().await,
None => None,
}
} => {
current_vad_state = Some(state);
}
else => {
if let Some(final_chunk) = self.buffer.create_final_chunk() {
crate::verbose!(
"Created final chunk {} ({:.1}s)",
final_chunk.index,
final_chunk.samples.len() as f32 / WHISPER_SAMPLE_RATE as f32
);
self.chunk_tx.send(final_chunk).map_err(|e| e.to_string())?;
}
break;
}
}
}
Ok(())
}
}