use std::collections::BTreeMap;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures::Stream;
use tokio::sync::{Semaphore, mpsc};
use crate::error::TTSResult;
use crate::provider::TTSSpeechProvider;
use crate::types::{AudioChunk, SpeechRequest};
use super::chunker::{ChunkerConfig, SentenceChunker};
pub struct StreamingTtsPipeline<T: TTSSpeechProvider + Send + Sync + 'static> {
tts: Arc<T>,
config: ChunkerConfig,
}
impl<T: TTSSpeechProvider + Send + Sync + 'static> StreamingTtsPipeline<T> {
pub fn new(tts: Arc<T>) -> Self {
Self {
tts,
config: ChunkerConfig::default(),
}
}
pub fn with_config(tts: Arc<T>, config: ChunkerConfig) -> Self {
Self { tts, config }
}
pub fn run<S>(&self, token_stream: S, base_request: SpeechRequest) -> OrderedAudioStream
where
S: Stream<Item = String> + Send + 'static,
{
let (result_tx, result_rx) = mpsc::channel::<(usize, TTSResult<Vec<AudioChunk>>)>(32);
let tts = Arc::clone(&self.tts);
let config = self.config.clone();
let base = base_request;
tokio::spawn(async move {
Self::producer_task(token_stream, config, tts, base, result_tx).await;
});
OrderedAudioStream::new(result_rx)
}
async fn producer_task<S>(
token_stream: S,
config: ChunkerConfig,
tts: Arc<T>,
base_request: SpeechRequest,
result_tx: mpsc::Sender<(usize, TTSResult<Vec<AudioChunk>>)>,
) where
S: Stream<Item = String> + Send + 'static,
{
use futures::StreamExt;
let mut chunker = SentenceChunker::with_config(config);
let mut seq_idx: usize = 0;
let tts_semaphore = Arc::new(Semaphore::new(1));
let mut token_stream = std::pin::pin!(token_stream);
while let Some(token) = token_stream.next().await {
for sentence in chunker.push_token(&token) {
let idx = seq_idx;
seq_idx += 1;
Self::spawn_tts_task(
idx,
sentence,
Arc::clone(&tts),
base_request.clone(),
result_tx.clone(),
Arc::clone(&tts_semaphore),
);
}
}
if let Some(sentence) = chunker.force_flush() {
let idx = seq_idx;
Self::spawn_tts_task(idx, sentence, tts, base_request, result_tx, tts_semaphore);
}
}
fn spawn_tts_task(
seq_idx: usize,
sentence: String,
tts: Arc<T>,
base_request: SpeechRequest,
result_tx: mpsc::Sender<(usize, TTSResult<Vec<AudioChunk>>)>,
semaphore: Arc<Semaphore>,
) {
tokio::spawn(async move {
let _permit = match semaphore.acquire().await {
Ok(permit) => permit,
Err(_) => {
return;
}
};
let request = SpeechRequest {
text: sentence,
voice: base_request.voice,
format: base_request.format,
sample_rate: base_request.sample_rate,
};
let result = match tts.generate_speech(request).await {
Ok(response) => {
let chunk = AudioChunk {
samples: response.audio.samples,
sample_rate: response.audio.sample_rate,
is_final: false, };
Ok(vec![chunk])
}
Err(e) => Err(e),
};
let _ = result_tx.send((seq_idx, result)).await;
});
}
}
pub struct OrderedAudioStream {
result_rx: mpsc::Receiver<(usize, TTSResult<Vec<AudioChunk>>)>,
buffer: BTreeMap<usize, TTSResult<Vec<AudioChunk>>>,
next_seq: usize,
pending_chunks: Vec<AudioChunk>,
channel_closed: bool,
done: bool,
max_seq_seen: Option<usize>,
}
impl OrderedAudioStream {
fn new(result_rx: mpsc::Receiver<(usize, TTSResult<Vec<AudioChunk>>)>) -> Self {
Self {
result_rx,
buffer: BTreeMap::new(),
next_seq: 0,
pending_chunks: Vec::new(),
channel_closed: false,
done: false,
max_seq_seen: None,
}
}
fn is_seq_missing(&self) -> bool {
if !self.channel_closed {
return false;
}
if self.buffer.is_empty() {
return false;
}
if self.buffer.contains_key(&self.next_seq) {
return false;
}
true
}
fn skip_to_next_available(&mut self) {
if let Some(&min_key) = self.buffer.keys().next() {
self.next_seq = min_key;
}
}
fn try_drain_buffered(&mut self) -> Option<TTSResult<AudioChunk>> {
if let Some(chunk) = self.pending_chunks.pop() {
return Some(Ok(chunk));
}
if let Some(result) = self.buffer.remove(&self.next_seq) {
self.next_seq += 1;
match result {
Ok(mut chunks) => {
if chunks.is_empty() {
return None;
}
chunks.reverse();
self.pending_chunks = chunks;
self.pending_chunks.pop().map(Ok)
}
Err(e) => Some(Err(e)),
}
} else {
None
}
}
}
impl Stream for OrderedAudioStream {
type Item = TTSResult<AudioChunk>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.done {
return Poll::Ready(None);
}
loop {
if let Some(item) = this.try_drain_buffered() {
return Poll::Ready(Some(item));
}
if this.channel_closed && this.buffer.is_empty() && this.pending_chunks.is_empty() {
this.done = true;
return Poll::Ready(None);
}
if this.is_seq_missing() {
this.skip_to_next_available();
continue;
}
match this.result_rx.poll_recv(cx) {
Poll::Ready(Some((seq_idx, result))) => {
this.max_seq_seen = Some(this.max_seq_seen.map_or(seq_idx, |m| m.max(seq_idx)));
this.buffer.insert(seq_idx, result);
}
Poll::Ready(None) => {
this.channel_closed = true;
}
Poll::Pending => {
return Poll::Pending;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::TTSError;
use crate::types::{AudioData, AudioFormat, SpeechResponse, VoiceIdentifier};
use async_trait::async_trait;
use futures::StreamExt;
fn test_request() -> SpeechRequest {
SpeechRequest {
text: String::default(),
voice: VoiceIdentifier::new("test"),
format: AudioFormat::Wav,
sample_rate: Some(24000),
}
}
struct MockTtsProvider;
#[async_trait]
impl TTSSpeechProvider for MockTtsProvider {
async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse> {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let len = request.text.len();
Ok(SpeechResponse {
audio: AudioData {
samples: vec![0.5_f32; len],
channels: 1,
sample_rate: 24000,
},
text: request.text,
duration_ms: len as u64,
})
}
}
#[tokio::test]
async fn test_pipeline_single_sentence() {
let tts = Arc::new(MockTtsProvider);
let pipeline = StreamingTtsPipeline::with_config(
tts,
ChunkerConfig {
min_chunk_chars: 1,
max_chunk_chars: 250,
},
);
let tokens = futures::stream::iter(vec!["Hello world.".to_string()]);
let mut stream = pipeline.run(tokens, test_request());
let mut chunks = Vec::new();
while let Some(result) = stream.next().await {
chunks.push(result.unwrap());
}
assert!(!chunks.is_empty());
}
#[tokio::test]
async fn test_pipeline_multiple_sentences() {
let tts = Arc::new(MockTtsProvider);
let pipeline = StreamingTtsPipeline::with_config(
tts,
ChunkerConfig {
min_chunk_chars: 1,
max_chunk_chars: 250,
},
);
let tokens = futures::stream::iter(
"Hello world. How are you today?"
.split_inclusive(' ')
.map(|s| s.to_string())
.collect::<Vec<_>>(),
);
let mut stream = pipeline.run(tokens, test_request());
let mut chunks = Vec::new();
while let Some(result) = stream.next().await {
chunks.push(result.unwrap());
}
assert!(
chunks.len() >= 2,
"Expected >= 2 chunks, got {}",
chunks.len()
);
}
#[tokio::test]
async fn test_pipeline_ordered_output() {
struct SlowShortTts;
#[async_trait]
impl TTSSpeechProvider for SlowShortTts {
async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse> {
let delay_ms = if request.text.len() < 20 { 50 } else { 5 };
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
let marker = request.text.len() as f32;
Ok(SpeechResponse {
audio: AudioData {
samples: vec![marker; 10],
channels: 1,
sample_rate: 24000,
},
text: request.text,
duration_ms: 10,
})
}
}
let tts = Arc::new(SlowShortTts);
let pipeline = StreamingTtsPipeline::with_config(
tts,
ChunkerConfig {
min_chunk_chars: 1,
max_chunk_chars: 250,
},
);
let tokens = futures::stream::iter(vec![
"Hi! ".to_string(),
"This is a much longer second sentence for testing. ".to_string(),
]);
let mut stream = pipeline.run(tokens, test_request());
let mut sample_markers = Vec::new();
while let Some(result) = stream.next().await {
let chunk = result.unwrap();
if !chunk.samples.is_empty() {
sample_markers.push(chunk.samples[0]);
}
}
assert!(sample_markers.len() >= 2);
assert!(
sample_markers[0] < sample_markers[1],
"Chunks should be in original sentence order: {:?}",
sample_markers
);
}
#[tokio::test]
async fn test_pipeline_empty_stream() {
let tts = Arc::new(MockTtsProvider);
let pipeline = StreamingTtsPipeline::new(tts);
let tokens = futures::stream::empty::<String>();
let mut stream = pipeline.run(tokens, test_request());
let mut count = 0;
while let Some(_result) = stream.next().await {
count += 1;
}
assert_eq!(count, 0);
}
#[tokio::test]
async fn test_pipeline_tts_error_propagation() {
struct FailingTts;
#[async_trait]
impl TTSSpeechProvider for FailingTts {
async fn generate_speech(&self, _request: SpeechRequest) -> TTSResult<SpeechResponse> {
Err(TTSError::Other(
"synthesis failed".to_string(),
"test".to_string(),
))
}
}
let tts = Arc::new(FailingTts);
let pipeline = StreamingTtsPipeline::with_config(
tts,
ChunkerConfig {
min_chunk_chars: 1,
max_chunk_chars: 250,
},
);
let tokens = futures::stream::iter(vec!["Hello world.".to_string()]);
let mut stream = pipeline.run(tokens, test_request());
let result = stream.next().await;
assert!(result.is_some());
assert!(result.unwrap().is_err());
}
#[tokio::test]
async fn test_pipeline_no_duplicate_when_tokens_are_clean() {
use std::sync::Mutex;
struct RecordingTts {
calls: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl TTSSpeechProvider for RecordingTts {
async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse> {
self.calls.lock().unwrap().push(request.text.clone());
Ok(SpeechResponse {
audio: AudioData {
samples: vec![0.1; 10],
channels: 1,
sample_rate: 24000,
},
text: request.text,
duration_ms: 10,
})
}
}
let calls = Arc::new(Mutex::new(Vec::new()));
let tts = Arc::new(RecordingTts {
calls: Arc::clone(&calls),
});
let pipeline = StreamingTtsPipeline::with_config(
tts,
ChunkerConfig {
min_chunk_chars: 1,
max_chunk_chars: 250,
},
);
let tokens = futures::stream::iter(vec![
"Hello ".to_string(),
"there. ".to_string(),
"How ".to_string(),
"are ".to_string(),
"you?".to_string(),
]);
let mut stream = pipeline.run(tokens, test_request());
while let Some(result) = stream.next().await {
result.unwrap();
}
let synthesized = calls.lock().unwrap().clone();
assert_eq!(
synthesized.len(),
2,
"Expected exactly 2 TTS calls (one per sentence), got {}: {:?}",
synthesized.len(),
synthesized
);
assert_eq!(synthesized[0].trim(), "Hello there.");
assert_eq!(synthesized[1].trim(), "How are you?");
}
#[tokio::test]
async fn test_pipeline_duplicate_when_final_included() {
use std::sync::Mutex;
struct RecordingTts {
calls: Arc<Mutex<Vec<String>>>,
}
#[async_trait]
impl TTSSpeechProvider for RecordingTts {
async fn generate_speech(&self, request: SpeechRequest) -> TTSResult<SpeechResponse> {
self.calls.lock().unwrap().push(request.text.clone());
Ok(SpeechResponse {
audio: AudioData {
samples: vec![0.1; 10],
channels: 1,
sample_rate: 24000,
},
text: request.text,
duration_ms: 10,
})
}
}
let calls = Arc::new(Mutex::new(Vec::new()));
let tts = Arc::new(RecordingTts {
calls: Arc::clone(&calls),
});
let pipeline = StreamingTtsPipeline::with_config(
tts,
ChunkerConfig {
min_chunk_chars: 1,
max_chunk_chars: 250,
},
);
let tokens = futures::stream::iter(vec![
"Hello ".to_string(),
"there. ".to_string(),
"How ".to_string(),
"are ".to_string(),
"you?".to_string(),
"Hello there. How are you?".to_string(),
]);
let mut stream = pipeline.run(tokens, test_request());
while let Some(result) = stream.next().await {
result.unwrap();
}
let synthesized = calls.lock().unwrap().clone();
assert!(
synthesized.len() > 2,
"With duplicate included, expected >2 TTS calls, got {}: {:?}",
synthesized.len(),
synthesized
);
}
#[tokio::test]
async fn test_audio_chunk_carries_sample_rate() {
let tts = Arc::new(MockTtsProvider);
let pipeline = StreamingTtsPipeline::with_config(
tts,
ChunkerConfig {
min_chunk_chars: 1,
max_chunk_chars: 250,
},
);
let tokens = futures::stream::iter(vec!["Hello world.".to_string()]);
let mut stream = pipeline.run(tokens, test_request());
if let Some(Ok(chunk)) = stream.next().await {
assert_eq!(
chunk.sample_rate, 24000,
"AudioChunk should carry sample_rate from TTS response"
);
} else {
panic!("Expected at least one audio chunk");
}
}
#[tokio::test]
async fn test_missing_sequence_does_not_hang() {
use tokio::time::{Duration, timeout};
let (tx, rx) = mpsc::channel::<(usize, TTSResult<Vec<AudioChunk>>)>(32);
let mut stream = OrderedAudioStream::new(rx);
let chunk = AudioChunk {
samples: vec![0.5; 10],
sample_rate: 24000,
is_final: false,
};
tx.send((1, Ok(vec![chunk.clone()]))).await.unwrap();
drop(tx);
let result = timeout(Duration::from_secs(1), stream.next()).await;
assert!(
result.is_ok(),
"Stream should not hang when sequence is missing"
);
let item = result.unwrap();
assert!(item.is_some(), "Should yield buffered sequence 1");
assert!(item.unwrap().is_ok());
let result = timeout(Duration::from_secs(1), stream.next()).await;
assert!(result.is_ok(), "Stream should terminate cleanly");
assert!(result.unwrap().is_none(), "Stream should be done");
}
#[tokio::test]
async fn test_multiple_missing_sequences() {
use tokio::time::{Duration, timeout};
let (tx, rx) = mpsc::channel::<(usize, TTSResult<Vec<AudioChunk>>)>(32);
let mut stream = OrderedAudioStream::new(rx);
let chunk = AudioChunk {
samples: vec![0.5; 10],
sample_rate: 24000,
is_final: false,
};
tx.send((3, Ok(vec![chunk]))).await.unwrap();
drop(tx);
let result = timeout(Duration::from_secs(1), stream.next()).await;
assert!(result.is_ok(), "Should not hang with multiple missing seqs");
assert!(result.unwrap().is_some());
let result = timeout(Duration::from_secs(1), stream.next()).await;
assert!(result.is_ok());
assert!(result.unwrap().is_none());
}
}