use crate::stt::SttProvider;
use crate::Result;
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct StreamingConfig {
pub chunk_increment_ms: u32,
pub max_window_ms: u32,
}
impl Default for StreamingConfig {
fn default() -> Self {
Self {
chunk_increment_ms: 1500,
max_window_ms: 30_000,
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct StreamingPartial {
pub text: String,
pub stable_prefix: String,
pub duration_ms: u64,
}
pub struct ChunkOverlapStreamer {
provider: Arc<dyn SttProvider>,
config: StreamingConfig,
sample_rate: u32,
buffer: Vec<f32>,
last_partial: Option<String>,
last_decode_len: usize,
}
impl ChunkOverlapStreamer {
pub fn new(
provider: Arc<dyn SttProvider>,
sample_rate: u32,
config: StreamingConfig,
) -> Self {
Self {
provider,
config,
sample_rate,
buffer: Vec::new(),
last_partial: None,
last_decode_len: 0,
}
}
pub async fn feed(&mut self, samples: &[f32]) -> Result<Option<StreamingPartial>> {
self.buffer.extend_from_slice(samples);
self.trim_to_max_window();
let increment_samples =
(self.config.chunk_increment_ms as u64 * self.sample_rate as u64 / 1000) as usize;
let new_samples = self.buffer.len().saturating_sub(self.last_decode_len);
if new_samples < increment_samples {
return Ok(None);
}
self.last_decode_len = self.buffer.len();
let text = self
.provider
.transcribe(&self.buffer, self.sample_rate)
.await?;
let stable_prefix = match &self.last_partial {
Some(prev) => longest_common_prefix(prev, &text),
None => String::new(),
};
let duration_ms = self.buffer_duration_ms();
self.last_partial = Some(text.clone());
Ok(Some(StreamingPartial {
text,
stable_prefix,
duration_ms,
}))
}
pub async fn finalize(&mut self) -> Result<String> {
let text = if self.buffer.is_empty() {
String::new()
} else {
self.provider
.transcribe(&self.buffer, self.sample_rate)
.await?
};
self.buffer.clear();
self.last_partial = None;
self.last_decode_len = 0;
Ok(text)
}
pub fn buffer_duration_ms(&self) -> u64 {
if self.sample_rate == 0 {
return 0;
}
(self.buffer.len() as u64 * 1000) / self.sample_rate as u64
}
fn trim_to_max_window(&mut self) {
let max_samples =
(self.config.max_window_ms as u64 * self.sample_rate as u64 / 1000) as usize;
if self.buffer.len() > max_samples {
let drop = self.buffer.len() - max_samples;
self.buffer.drain(..drop);
self.last_decode_len = self.last_decode_len.saturating_sub(drop);
}
}
}
fn longest_common_prefix(a: &str, b: &str) -> String {
let mut out = String::new();
for (ca, cb) in a.chars().zip(b.chars()) {
if ca != cb {
break;
}
out.push(ca);
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use std::sync::Mutex;
struct ScriptedProvider {
outputs: Mutex<Vec<String>>,
}
impl ScriptedProvider {
fn new(outputs: Vec<&str>) -> Self {
Self {
outputs: Mutex::new(outputs.into_iter().rev().map(str::to_string).collect()),
}
}
}
#[async_trait]
impl SttProvider for ScriptedProvider {
async fn transcribe(&self, _samples: &[f32], _sr: u32) -> Result<String> {
let mut outs = self.outputs.lock().unwrap();
Ok(outs.pop().unwrap_or_default())
}
}
#[test]
fn lcp_handles_disjoint_prefix() {
assert_eq!(longest_common_prefix("hello", "world"), "");
assert_eq!(longest_common_prefix("hello", "hellp"), "hell");
assert_eq!(longest_common_prefix("hello", "hello!"), "hello");
assert_eq!(longest_common_prefix("", "anything"), "");
}
#[test]
fn lcp_is_char_aware_for_multibyte() {
assert_eq!(longest_common_prefix("héllo", "héllo!"), "héllo");
assert_eq!(longest_common_prefix("hé", "hè"), "h");
}
#[tokio::test]
async fn feed_returns_none_below_threshold() {
let provider: Arc<dyn SttProvider> = Arc::new(ScriptedProvider::new(vec!["never"]));
let config = StreamingConfig {
chunk_increment_ms: 500,
..Default::default()
};
let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, config);
let result = streamer.feed(&vec![0.0f32; 1600]).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn feed_emits_partial_above_threshold_with_stable_prefix() {
let provider: Arc<dyn SttProvider> = Arc::new(ScriptedProvider::new(vec![
"hello",
"hello world",
"hello world how",
]));
let config = StreamingConfig {
chunk_increment_ms: 500,
..Default::default()
};
let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, config);
let p1 = streamer.feed(&vec![0.0f32; 9600]).await.unwrap().unwrap();
assert_eq!(p1.text, "hello");
assert_eq!(p1.stable_prefix, "");
let p2 = streamer.feed(&vec![0.0f32; 9600]).await.unwrap().unwrap();
assert_eq!(p2.text, "hello world");
assert_eq!(p2.stable_prefix, "hello");
let p3 = streamer.feed(&vec![0.0f32; 9600]).await.unwrap().unwrap();
assert_eq!(p3.text, "hello world how");
assert_eq!(p3.stable_prefix, "hello world");
}
#[tokio::test]
async fn finalize_returns_full_transcript_and_resets() {
let provider: Arc<dyn SttProvider> =
Arc::new(ScriptedProvider::new(vec!["hello world how are you"]));
let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, StreamingConfig::default());
streamer.buffer.extend(vec![0.0f32; 16_000]);
let final_text = streamer.finalize().await.unwrap();
assert_eq!(final_text, "hello world how are you");
assert!(streamer.buffer.is_empty());
assert!(streamer.last_partial.is_none());
assert_eq!(streamer.last_decode_len, 0);
}
#[tokio::test]
async fn finalize_on_empty_buffer_returns_empty_string() {
let provider: Arc<dyn SttProvider> = Arc::new(ScriptedProvider::new(vec![]));
let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, StreamingConfig::default());
let final_text = streamer.finalize().await.unwrap();
assert_eq!(final_text, "");
}
#[tokio::test]
async fn buffer_trims_to_max_window() {
let provider: Arc<dyn SttProvider> = Arc::new(ScriptedProvider::new(vec!["x"; 100]));
let config = StreamingConfig {
chunk_increment_ms: 100,
max_window_ms: 1000,
};
let mut streamer = ChunkOverlapStreamer::new(provider, 16_000, config);
for _ in 0..30 {
let _ = streamer.feed(&vec![0.0f32; 1600]).await.unwrap();
}
assert!(streamer.buffer.len() <= 16_000);
}
}