use super::audio_buffer::{AudioBuffer, AudioBufferConfig, AudioChunk};
use crate::audio::samples_to_wav;
use crate::audio::vad::{VadConfig, VadSession};
use crate::execution::{ExecutionTemplate, ModelMetadata, TemplateExecutor};
use crate::ir::{Envelope, EnvelopeKind};
use std::path::Path;
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug)]
pub enum StreamError {
ModelLoadError(String),
InferenceError(String),
InvalidState(String),
ConfigError(String),
}
impl std::fmt::Display for StreamError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StreamError::ModelLoadError(msg) => write!(f, "Model load error: {}", msg),
StreamError::InferenceError(msg) => write!(f, "Inference error: {}", msg),
StreamError::InvalidState(msg) => write!(f, "Invalid state: {}", msg),
StreamError::ConfigError(msg) => write!(f, "Config error: {}", msg),
}
}
}
impl std::error::Error for StreamError {}
pub type StreamResult<T> = Result<T, StreamError>;
#[derive(Debug, Clone)]
pub struct VadStreamConfig {
pub enabled: bool,
pub model_dir: Option<String>,
pub threshold: f32,
pub min_silence_frames: usize,
pub padding_frames: usize,
}
impl Default for VadStreamConfig {
fn default() -> Self {
Self {
enabled: false,
model_dir: None,
threshold: 0.5,
min_silence_frames: 8,
padding_frames: 2,
}
}
}
impl VadStreamConfig {
pub fn with_model(model_dir: impl Into<String>) -> Self {
Self {
enabled: true,
model_dir: Some(model_dir.into()),
..Default::default()
}
}
pub fn enabled() -> Self {
Self {
enabled: true,
..Default::default()
}
}
}
#[derive(Debug, Clone)]
pub struct StreamConfig {
pub buffer_config: AudioBufferConfig,
pub min_chunk_secs: f32,
pub enable_partial_results: bool,
pub language: Option<String>,
pub vad: VadStreamConfig,
}
impl Default for StreamConfig {
fn default() -> Self {
Self {
buffer_config: AudioBufferConfig::default(),
min_chunk_secs: 1.0, enable_partial_results: true,
language: Some("en".to_string()),
vad: VadStreamConfig::default(),
}
}
}
impl StreamConfig {
pub fn with_vad() -> Self {
Self {
vad: VadStreamConfig::enabled(),
..Default::default()
}
}
pub fn with_vad_model(model_dir: impl Into<String>) -> Self {
Self {
vad: VadStreamConfig::with_model(model_dir),
..Default::default()
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum StreamState {
Idle,
Streaming,
Finalizing,
Completed,
Error,
}
#[derive(Debug, Clone)]
pub struct PartialResult {
pub text: String,
pub confidence: Option<f32>,
pub is_stable: bool,
pub audio_duration: Duration,
pub chunk_sequence: u64,
}
#[derive(Debug, Default)]
struct TranscriptAccumulator {
segments: Vec<String>,
current_partial: Option<String>,
total_duration: Duration,
}
impl TranscriptAccumulator {
fn new() -> Self {
Self::default()
}
fn add_segment(&mut self, text: String, duration: Duration) {
if !text.trim().is_empty() {
self.segments.push(text.trim().to_string());
}
self.total_duration += duration;
self.current_partial = None;
}
fn set_partial(&mut self, text: String) {
self.current_partial = Some(text);
}
fn get_full_text(&self) -> String {
let mut parts = self.segments.clone();
if let Some(ref partial) = self.current_partial {
if !partial.trim().is_empty() {
parts.push(partial.trim().to_string());
}
}
parts.join(" ")
}
fn get_stable_text(&self) -> String {
self.segments.join(" ")
}
fn reset(&mut self) {
self.segments.clear();
self.current_partial = None;
self.total_duration = Duration::ZERO;
}
}
pub struct StreamSession {
metadata: ModelMetadata,
executor: TemplateExecutor,
config: StreamConfig,
buffer: AudioBuffer,
transcript: TranscriptAccumulator,
state: StreamState,
last_error: Option<String>,
on_partial: Option<Arc<dyn Fn(PartialResult) + Send + Sync>>,
vad: Option<VadSession>,
vad_buffer: Vec<f32>,
vad_speech_start: Option<usize>,
}
impl StreamSession {
pub fn new<P: AsRef<Path>>(model_dir: P, config: StreamConfig) -> StreamResult<Self> {
let model_dir = model_dir.as_ref().to_path_buf();
if !model_dir.exists() {
return Err(StreamError::ConfigError(format!(
"Model directory does not exist: {:?}",
model_dir
)));
}
let metadata_path = model_dir.join("model_metadata.json");
if !metadata_path.exists() {
return Err(StreamError::ConfigError(format!(
"model_metadata.json not found in {:?}",
model_dir
)));
}
let metadata_str = std::fs::read_to_string(&metadata_path)
.map_err(|e| StreamError::ConfigError(format!("Failed to read metadata: {}", e)))?;
let metadata: ModelMetadata = serde_json::from_str(&metadata_str)
.map_err(|e| StreamError::ConfigError(format!("Failed to parse metadata: {}", e)))?;
let executor = TemplateExecutor::with_base_path(model_dir.to_str().unwrap_or("."));
let buffer_config = Self::infer_buffer_config(&metadata, &config);
let buffer = AudioBuffer::with_config(buffer_config);
let vad = if config.vad.enabled {
let vad_model_dir = match &config.vad.model_dir {
Some(dir) => dir.clone(),
None => {
eprintln!("[StreamSession] Warning: VAD enabled but no model_dir specified. VAD disabled.");
return Ok(Self {
metadata,
executor,
config,
buffer,
transcript: TranscriptAccumulator::new(),
state: StreamState::Idle,
last_error: None,
on_partial: None,
vad: None,
vad_buffer: Vec::new(),
vad_speech_start: None,
});
}
};
let vad_config = VadConfig {
threshold: config.vad.threshold,
min_silence_frames: config.vad.min_silence_frames,
padding_frames: config.vad.padding_frames,
..VadConfig::default()
};
match VadSession::new(&vad_model_dir, vad_config) {
Ok(vad) => Some(vad),
Err(e) => {
eprintln!("[StreamSession] Warning: Failed to initialize VAD: {}. Falling back to fixed chunking.", e);
None
}
}
} else {
None
};
Ok(Self {
metadata,
executor,
config,
buffer,
transcript: TranscriptAccumulator::new(),
state: StreamState::Idle,
last_error: None,
on_partial: None,
vad,
vad_buffer: Vec::new(),
vad_speech_start: None,
})
}
fn infer_buffer_config(metadata: &ModelMetadata, config: &StreamConfig) -> AudioBufferConfig {
let is_whisper = match &metadata.execution_template {
ExecutionTemplate::SafeTensors { architecture, .. } => {
architecture.as_deref() == Some("whisper")
}
_ => false,
};
if is_whisper {
AudioBufferConfig {
sample_rate: 16000,
chunk_duration_secs: 5.0,
overlap_secs: 0.5, max_buffer_secs: config.buffer_config.max_buffer_secs,
}
} else {
AudioBufferConfig {
sample_rate: 16000,
chunk_duration_secs: 5.0,
overlap_secs: config.buffer_config.overlap_secs,
max_buffer_secs: config.buffer_config.max_buffer_secs,
}
}
}
pub fn model_id(&self) -> &str {
&self.metadata.model_id
}
pub fn on_partial<F>(&mut self, callback: F)
where
F: Fn(PartialResult) + Send + Sync + 'static,
{
self.on_partial = Some(Arc::new(callback));
}
pub fn feed(&mut self, samples: &[f32]) -> StreamResult<()> {
match self.state {
StreamState::Idle => self.state = StreamState::Streaming,
StreamState::Streaming => {}
StreamState::Finalizing | StreamState::Completed => {
return Err(StreamError::InvalidState(
"Cannot feed after stream ended".to_string(),
));
}
StreamState::Error => {
return Err(StreamError::InvalidState(format!(
"Session in error state: {:?}",
self.last_error
)));
}
}
self.buffer.push(samples);
if self.config.enable_partial_results {
self.process_ready_chunks()?;
}
Ok(())
}
pub fn partial_result(&self) -> Option<PartialResult> {
if !self.config.enable_partial_results {
return None;
}
let text = self.transcript.get_full_text();
if text.is_empty() {
return None;
}
Some(PartialResult {
text,
confidence: None,
is_stable: false,
audio_duration: self.transcript.total_duration,
chunk_sequence: self.buffer.stats().chunks_extracted,
})
}
pub fn flush(&mut self) -> StreamResult<String> {
match self.state {
StreamState::Idle => {
self.state = StreamState::Completed;
return Ok(String::new());
}
StreamState::Streaming => {
self.state = StreamState::Finalizing;
}
StreamState::Finalizing => {}
StreamState::Completed => {
return Ok(self.transcript.get_stable_text());
}
StreamState::Error => {
return Err(StreamError::InvalidState(format!(
"Session in error state: {:?}",
self.last_error
)));
}
}
self.buffer.end_stream();
self.process_all_remaining()?;
self.state = StreamState::Completed;
Ok(self.transcript.get_stable_text())
}
pub fn reset(&mut self) {
self.buffer.reset();
self.transcript.reset();
self.state = StreamState::Idle;
self.last_error = None;
if let Some(ref mut vad) = self.vad {
vad.reset();
}
self.vad_buffer.clear();
self.vad_speech_start = None;
}
pub fn has_vad(&self) -> bool {
self.vad.is_some()
}
pub fn state(&self) -> StreamState {
self.state
}
pub fn stats(&self) -> StreamStats {
let buffer_stats = self.buffer.stats();
StreamStats {
state: self.state,
samples_received: buffer_stats.total_received,
samples_processed: buffer_stats.total_processed,
chunks_processed: buffer_stats.chunks_extracted,
transcript_length: self.transcript.get_full_text().len(),
audio_duration: self.transcript.total_duration,
}
}
fn process_ready_chunks(&mut self) -> StreamResult<()> {
while self.buffer.has_chunk_ready() {
if let Some(chunk) = self.buffer.extract_chunk(false) {
self.process_chunk(chunk)?;
}
}
Ok(())
}
fn process_all_remaining(&mut self) -> StreamResult<()> {
self.process_ready_chunks()?;
while self.buffer.has_audio() {
if let Some(chunk) = self.buffer.extract_chunk(true) {
self.process_chunk(chunk)?;
} else {
break;
}
}
Ok(())
}
fn process_chunk(&mut self, chunk: AudioChunk) -> StreamResult<()> {
let min_samples =
(self.config.min_chunk_secs * self.buffer.config().sample_rate as f32) as usize;
if chunk.samples.len() < min_samples && !chunk.is_final {
return Ok(());
}
let wav_bytes = samples_to_wav(&chunk.samples, self.buffer.config().sample_rate);
let envelope = Envelope::new(EnvelopeKind::Audio(wav_bytes));
let output = self
.executor
.execute(&self.metadata, &envelope, None)
.map_err(|e| StreamError::InferenceError(format!("Execution failed: {}", e)))?;
let text = match output.kind {
EnvelopeKind::Text(text) => text,
_ => {
return Err(StreamError::InferenceError(
"Model did not return text output".to_string(),
));
}
};
self.transcript.add_segment(text.clone(), chunk.duration());
if let Some(ref callback) = self.on_partial {
let result = PartialResult {
text,
confidence: None,
is_stable: chunk.is_final,
audio_duration: chunk.duration(),
chunk_sequence: chunk.sequence,
};
callback(result);
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct StreamStats {
pub state: StreamState,
pub samples_received: u64,
pub samples_processed: u64,
pub chunks_processed: u64,
pub transcript_length: usize,
pub audio_duration: Duration,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stream_config_defaults() {
let config = StreamConfig::default();
assert_eq!(config.buffer_config.sample_rate, 16000);
assert!(config.enable_partial_results);
}
#[test]
fn test_transcript_accumulator() {
let mut acc = TranscriptAccumulator::new();
acc.add_segment("Hello".to_string(), Duration::from_secs(1));
acc.add_segment("world".to_string(), Duration::from_secs(1));
assert_eq!(acc.get_stable_text(), "Hello world");
assert_eq!(acc.get_full_text(), "Hello world");
acc.set_partial("testing".to_string());
assert_eq!(acc.get_full_text(), "Hello world testing");
assert_eq!(acc.get_stable_text(), "Hello world");
}
#[test]
fn test_transcript_accumulator_reset() {
let mut acc = TranscriptAccumulator::new();
acc.add_segment("Hello".to_string(), Duration::from_secs(1));
acc.reset();
assert_eq!(acc.get_stable_text(), "");
assert_eq!(acc.total_duration, Duration::ZERO);
}
}