use crate::{
audio::{playback::AudioOutputStream, sts::pipeline::voice_pipeline::VoicePipeline},
error::{InvariantViolationPayload, Result},
};
use super::{
barge_in::BargeInDetector,
chunker::{AudioChunker, PreRollBuffer},
config::VoicePipelineConfig,
turn_taking::TurnTakingPolicy,
};
pub trait VadFrameAdapter {
fn is_speech(&mut self, frame: &[f32]) -> Result<bool>;
}
pub trait SttTurnAdapter {
fn transcribe_turn(&mut self, turn_audio: &[f32]) -> Result<String>;
}
pub trait LlmResponderAdapter {
fn respond(&mut self, user_text: &str) -> Result<String>;
}
pub trait TtsStreamAdapter {
fn synthesize_stream<'a>(
&'a mut self,
text: &str,
) -> Result<Box<dyn Iterator<Item = Result<Vec<f32>>> + 'a>>;
fn sample_rate(&self) -> u32;
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TurnEvent {
chunks_consumed: usize,
user_text: String,
assistant_text: String,
barge_in_observed: bool,
}
impl TurnEvent {
#[must_use]
pub fn new(
chunks_consumed: usize,
user_text: String,
assistant_text: String,
barge_in_observed: bool,
) -> Self {
Self {
chunks_consumed,
user_text,
assistant_text,
barge_in_observed,
}
}
#[inline(always)]
#[must_use]
pub fn chunks_consumed(&self) -> usize {
self.chunks_consumed
}
#[inline(always)]
#[must_use]
pub fn user_text(&self) -> &str {
&self.user_text
}
#[inline(always)]
#[must_use]
pub fn assistant_text(&self) -> &str {
&self.assistant_text
}
#[inline(always)]
#[must_use]
pub fn barge_in_observed(&self) -> bool {
self.barge_in_observed
}
}
pub struct VoiceSession<V, S, L, T, C, B, P> {
config: VoicePipelineConfig,
vad: V,
stt: S,
llm: L,
tts: T,
chunker: C,
barge_in: B,
turn_policy: P,
preroll: PreRollBuffer,
events: Vec<TurnEvent>,
in_progress_audio: Vec<f32>,
in_speech: bool,
silence_ms_accum: u32,
current_turn_barge_in: bool,
total_chunks_consumed: usize,
}
impl<V, S, L, T, C, B, P> VoiceSession<V, S, L, T, C, B, P>
where
V: VadFrameAdapter,
S: SttTurnAdapter,
L: LlmResponderAdapter,
T: TtsStreamAdapter,
C: AudioChunker,
B: BargeInDetector,
P: TurnTakingPolicy,
{
#[allow(clippy::too_many_arguments)]
pub fn new(
config: VoicePipelineConfig,
vad: V,
stt: S,
llm: L,
tts: T,
chunker: C,
barge_in: B,
turn_policy: P,
) -> Result<Self> {
if config.input_sample_rate() == 0 {
return Err(crate::error::Error::InvariantViolation(
InvariantViolationPayload::new(
"VoiceSession::new: VoicePipelineConfig::input_sample_rate",
"must be > 0; the orchestrator's per-chunk silence-ms accounting divides by the sample rate",
),
));
}
let preroll_samples =
(config.input_sample_rate() as usize) * (config.preroll_ms() as usize) / 1_000;
Ok(Self {
config,
vad,
stt,
llm,
tts,
chunker,
barge_in,
turn_policy,
preroll: PreRollBuffer::new(preroll_samples),
events: Vec::new(),
in_progress_audio: Vec::new(),
in_speech: false,
silence_ms_accum: 0,
current_turn_barge_in: false,
total_chunks_consumed: 0,
})
}
#[must_use]
pub fn turn_events(&self) -> &[TurnEvent] {
&self.events
}
#[must_use]
pub fn total_chunks_consumed(&self) -> usize {
self.total_chunks_consumed
}
pub fn vad(&self) -> &V {
&self.vad
}
pub fn stt(&self) -> &S {
&self.stt
}
pub fn llm(&self) -> &L {
&self.llm
}
pub fn tts(&self) -> &T {
&self.tts
}
pub fn step<O: AudioOutputStream>(
&mut self,
frame: &[f32],
output: &mut O,
tts_playing: bool,
) -> Result<usize> {
let mut turns_finalized = 0;
let chunks = self.chunker.push_samples(frame)?;
let sample_rate = self.config.input_sample_rate() as u64;
for chunk in chunks {
self.total_chunks_consumed += 1;
let chunk_ms = ((chunk.len() as u64) * 1_000 / sample_rate) as u32;
let is_speech = self.vad.is_speech(&chunk)?;
if is_speech {
if !self.in_speech {
self.current_turn_barge_in = false;
let preroll_snapshot = self.preroll.snapshot();
self.in_progress_audio.extend_from_slice(&preroll_snapshot);
self.in_progress_audio.extend_from_slice(&chunk);
self.preroll.clear();
self.in_speech = true;
} else {
self.in_progress_audio.extend_from_slice(&chunk);
}
self.silence_ms_accum = 0;
} else if self.in_speech {
self.in_progress_audio.extend_from_slice(&chunk);
self.silence_ms_accum = self.silence_ms_accum.saturating_add(chunk_ms);
if self
.turn_policy
.user_finished(&self.in_progress_audio, self.silence_ms_accum)
{
self.finalize_turn(output)?;
turns_finalized += 1;
}
} else {
self.preroll.append(&chunk);
}
if self.config.barge_in()
&& is_speech
&& self.in_speech
&& self.barge_in.detect(&chunk, tts_playing)
{
self.current_turn_barge_in = true;
}
}
Ok(turns_finalized)
}
pub fn flush_in_progress_turn<O: AudioOutputStream>(&mut self, output: &mut O) -> Result<bool> {
let residual = self.chunker.drain_residual();
if self.in_speech && !residual.is_empty() {
self.in_progress_audio.extend_from_slice(&residual);
}
if !self.in_speech || self.in_progress_audio.is_empty() {
return Ok(false);
}
self.finalize_turn(output)?;
Ok(true)
}
fn finalize_turn<O: AudioOutputStream>(&mut self, output: &mut O) -> Result<()> {
let turn_audio = std::mem::take(&mut self.in_progress_audio);
self.in_speech = false;
self.silence_ms_accum = 0;
self.preroll.clear();
self.chunker.reset();
let barge_in_observed = std::mem::replace(&mut self.current_turn_barge_in, false);
let user_text = self.stt.transcribe_turn(&turn_audio)?;
let user_text_for_event = user_text.clone();
let assistant_text = self.llm.respond(&user_text)?;
if self.config.play_audio() {
let stream = self.tts.synthesize_stream(&assistant_text)?;
for chunk in stream {
let samples = chunk?;
let mut written = 0;
while written < samples.len() {
let n = output.write_samples(&samples[written..])?;
if n == 0 {
return Err(crate::error::Error::InvariantViolation(
InvariantViolationPayload::new(
"VoiceSession: audio sink",
"rejected TTS chunk (write_samples returned 0)",
),
));
}
written += n;
}
}
}
self.events.push(TurnEvent::new(
self.total_chunks_consumed,
user_text_for_event,
assistant_text,
barge_in_observed,
));
Ok(())
}
}
impl<V, S, L, T, C, B, P> VoicePipeline for VoiceSession<V, S, L, T, C, B, P>
where
V: VadFrameAdapter,
S: SttTurnAdapter,
L: LlmResponderAdapter,
T: TtsStreamAdapter,
C: AudioChunker,
B: BargeInDetector,
P: TurnTakingPolicy,
{
fn config(&self) -> &VoicePipelineConfig {
&self.config
}
fn run<I, O>(&mut self, mic_input: I, mut output: O) -> Result<()>
where
I: Iterator<Item = Vec<f32>>,
O: AudioOutputStream,
{
self.events.clear();
for frame in mic_input {
let tts_playing = output.is_running();
let _ = self.step(&frame, &mut output, tts_playing)?;
}
self.flush_in_progress_turn(&mut output)?;
output.flush()?;
Ok(())
}
}
#[cfg(test)]
mod tests;