use crate::cluster::SpeakerCluster;
use crate::types::ClusterConfig;
use crate::embedding::{EmbeddingError, EmbeddingExtractor};
use crate::types::{DiarizationConfig, SpeakerTurn, TimeRange};
use crate::vad::{VadError, VoiceActivityDetector, VadStateMachine, VadEvent};
use crate::window::WindowBuffer;
use crate::VadConfig;
#[derive(Debug, thiserror::Error)]
pub enum StreamingError {
#[error("VAD error: {0}")]
Vad(#[from] VadError),
#[error("embedding error: {0}")]
Embedding(#[from] EmbeddingError),
}
pub struct StreamingPipeline<V, E> {
vad: V,
extractor: E,
cluster: SpeakerCluster,
config: DiarizationConfig,
frame_size: usize,
sample_rate: u32,
vad_buffer: Vec<f32>,
vad_state: VadStateMachine,
window_buffer: WindowBuffer,
turns: Vec<SpeakerTurn>,
total_frames: usize,
}
impl<V, E> StreamingPipeline<V, E>
where
V: VoiceActivityDetector,
E: EmbeddingExtractor,
{
pub fn new(
vad: V,
extractor: E,
config: DiarizationConfig,
vad_config: VadConfig,
) -> Result<Self, StreamingError> {
let frame_size = vad_config.frame_size;
if frame_size == 0 {
return Err(VadError::InvalidChunkSize {
expected: 1,
got: 0,
}
.into());
}
let sample_rate = config.window.sample_rate.get();
let sr_f = sample_rate as f32;
let ms_per_frame = (frame_size as f32 / sr_f) * 1000.0;
let min_silence_frames = (vad_config.min_silence_ms / ms_per_frame).ceil() as usize;
let min_speech_frames =
((config.speech_filter.min_speech_secs * 1000.0) / ms_per_frame).ceil() as usize;
let cluster = SpeakerCluster::new(ClusterConfig {
threshold: config.cluster.threshold,
max_speakers: config.cluster.max_speakers,
});
let vad_state = VadStateMachine::new(vad_config.threshold, min_silence_frames, min_speech_frames);
Ok(Self {
vad,
extractor,
cluster,
config,
frame_size,
sample_rate,
vad_buffer: Vec::new(),
vad_state,
window_buffer: WindowBuffer::new(
config.window_samples(),
config.hop_samples(),
),
turns: Vec::new(),
total_frames: 0,
})
}
pub fn feed(&mut self, samples: &[f32]) -> Result<Vec<SpeakerTurn>, StreamingError> {
let mut new_turns = Vec::new();
self.vad_buffer.extend_from_slice(samples);
let frame_size = self.frame_size;
while self.vad_buffer.len() >= frame_size {
let frame: Vec<f32> = self.vad_buffer.drain(..frame_size).collect();
let probs = self.vad.process(&frame)?;
for &prob in &probs {
let current_frame = self.total_frames;
self.total_frames += 1;
if let Some(event) = self.vad_state.advance(prob, current_frame) {
match event {
VadEvent::SpeechStart { start_frame } => {
self.window_buffer.clear();
self.window_buffer.set_next_start(start_frame * frame_size);
}
VadEvent::SpeechEnd { start_frame, end_frame } => {
let seg_end_sample = end_frame * frame_size;
let duration_frames = end_frame - start_frame;
if duration_frames >= self.vad_state.min_speech_frames() {
new_turns.extend(self.flush_window_buffer(seg_end_sample)?);
} else {
self.window_buffer.clear();
}
}
}
}
if self.vad_state.in_speech() {
self.window_buffer.extend(&frame);
new_turns.extend(self.try_extract_windows()?);
}
}
}
Ok(new_turns)
}
pub fn flush(&mut self) -> Result<Vec<SpeakerTurn>, StreamingError> {
let mut new_turns = Vec::new();
self.vad_buffer.clear();
if let Some(VadEvent::SpeechEnd { start_frame, end_frame }) = self.vad_state.flush(self.total_frames) {
let duration_frames = end_frame - start_frame;
if duration_frames >= self.vad_state.min_speech_frames() {
let seg_end_sample = end_frame * self.frame_size;
new_turns.extend(self.flush_window_buffer(seg_end_sample)?);
} else {
self.window_buffer.clear();
}
}
Ok(new_turns)
}
pub fn num_speakers(&self) -> usize {
self.cluster.num_speakers()
}
pub fn turns(&self) -> &[SpeakerTurn] {
&self.turns
}
fn try_extract_windows(&mut self) -> Result<Vec<SpeakerTurn>, StreamingError> {
let mut turns = Vec::new();
let sr_f = self.sample_rate as f64;
while let Some((start, chunk)) = self.window_buffer.try_pop() {
let embedding = self.extractor.extract(&chunk, &self.config)?;
let (speaker, _conf) = self.cluster.assign(&embedding);
let end = start + chunk.len();
turns.push(SpeakerTurn {
speaker,
time: TimeRange {
start: start as f64 / sr_f,
end: end as f64 / sr_f,
},
text: None,
});
}
Ok(turns)
}
fn flush_window_buffer(
&mut self,
seg_end_sample: usize,
) -> Result<Vec<SpeakerTurn>, StreamingError> {
let mut turns = Vec::new();
let sr_f = self.sample_rate as f64;
if let Some((start, padded)) = self.window_buffer.flush() {
let embedding = self.extractor.extract(&padded, &self.config)?;
let (speaker, _conf) = self.cluster.assign(&embedding);
let end = seg_end_sample.min(start + padded.len());
turns.push(SpeakerTurn {
speaker,
time: TimeRange {
start: start as f64 / sr_f,
end: end as f64 / sr_f,
},
text: None,
});
}
Ok(turns)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding::DummyExtractor;
use crate::{EnergyVad, VadConfig};
fn default_config() -> DiarizationConfig {
DiarizationConfig::default()
}
fn default_vad_config() -> VadConfig {
VadConfig::default()
}
fn pipeline() -> StreamingPipeline<EnergyVad, DummyExtractor> {
let vad = EnergyVad::new(-40.0, 16000, 512);
let extractor = DummyExtractor::new(256);
StreamingPipeline::new(vad, extractor, default_config(), default_vad_config()).unwrap()
}
fn loud_samples(seconds: f32) -> Vec<f32> {
let n = (seconds * 16000.0) as usize;
vec![0.5f32; n]
}
fn silent_samples(seconds: f32) -> Vec<f32> {
let n = (seconds * 16000.0) as usize;
vec![0.0f32; n]
}
#[test]
fn streaming_pipeline_new_is_empty() {
let p = pipeline();
assert_eq!(p.num_speakers(), 0);
assert!(p.turns().is_empty());
}
#[test]
fn feed_silence_returns_no_turns() {
let mut p = pipeline();
let turns = p.feed(&silent_samples(2.0)).unwrap();
assert!(turns.is_empty());
assert!(p.turns().is_empty());
}
#[test]
fn feed_loud_audio_returns_at_least_one_turn() {
let mut p = pipeline();
let turns = p.feed(&loud_samples(5.0)).unwrap();
assert!(!turns.is_empty(), "expected at least one turn for 5 s of speech");
}
#[test]
fn flush_after_speech_emits_remaining_turn() {
let mut p = pipeline();
let _ = p.feed(&loud_samples(1.0)).unwrap();
let turns = p.flush().unwrap();
assert!(!turns.is_empty(), "flush should emit the trailing partial window");
}
#[test]
fn turns_are_monotonically_ordered() {
let mut p = pipeline();
let _ = p.feed(&loud_samples(5.0)).unwrap();
let _ = p.flush().unwrap();
let turns = p.turns();
for i in 1..turns.len() {
assert!(
turns[i].time.start >= turns[i - 1].time.start,
"turns must be monotonically ordered"
);
}
}
}