use std::sync::Arc;
use tokio::sync::{RwLock, mpsc, oneshot};
use crate::error::{AudioError, AudioResult};
use crate::pipeline::handle::PipelineHandle;
use crate::pipeline::types::{PipelineInput, PipelineMetrics, PipelineOutput};
use crate::pipeline::voice_agent::{validate_voice_agent_config, voice_agent_loop};
use crate::traits::{
AudioProcessor, FxChain, MusicProvider, SttProvider, TtsProvider, TtsRequest, VadProcessor,
};
pub struct AudioPipelineBuilder {
tts: Option<Arc<dyn TtsProvider>>,
stt: Option<Arc<dyn SttProvider>>,
music: Option<Arc<dyn MusicProvider>>,
vad: Option<Arc<dyn VadProcessor>>,
pre_fx: Option<FxChain>,
post_fx: Option<FxChain>,
agent: Option<Arc<dyn adk_core::Agent>>,
buffer_size: usize,
}
impl AudioPipelineBuilder {
pub fn new() -> Self {
Self {
tts: None,
stt: None,
music: None,
vad: None,
pre_fx: None,
post_fx: None,
agent: None,
buffer_size: 32,
}
}
pub fn tts(mut self, tts: Arc<dyn TtsProvider>) -> Self {
self.tts = Some(tts);
self
}
pub fn stt(mut self, stt: Arc<dyn SttProvider>) -> Self {
self.stt = Some(stt);
self
}
pub fn music(mut self, music: Arc<dyn MusicProvider>) -> Self {
self.music = Some(music);
self
}
pub fn vad(mut self, vad: Arc<dyn VadProcessor>) -> Self {
self.vad = Some(vad);
self
}
pub fn pre_fx(mut self, fx: FxChain) -> Self {
self.pre_fx = Some(fx);
self
}
pub fn post_fx(mut self, fx: FxChain) -> Self {
self.post_fx = Some(fx);
self
}
pub fn agent(mut self, agent: Arc<dyn adk_core::Agent>) -> Self {
self.agent = Some(agent);
self
}
pub fn buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
pub fn build_tts(self) -> AudioResult<PipelineHandle> {
let tts = self.tts.ok_or_else(|| {
AudioError::PipelineClosed("TTS pipeline requires a TtsProvider".into())
})?;
let (input_tx, mut input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let m = metrics.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut shutdown_rx => break,
input = input_rx.recv() => {
let Some(PipelineInput::Text(text)) = input else {
if input.is_none() { break; }
continue;
};
let request = TtsRequest { text, ..Default::default() };
if let Ok(frame) = tts.synthesize(&request).await {
let mut metrics = m.write().await;
metrics.total_audio_ms += frame.duration_ms as u64;
let _ = output_tx.send(PipelineOutput::Audio(frame)).await;
}
}
}
}
});
Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
}
pub fn build_stt(self) -> AudioResult<PipelineHandle> {
let stt = self.stt.ok_or_else(|| {
AudioError::PipelineClosed("STT pipeline requires an SttProvider".into())
})?;
let (input_tx, mut input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let m = metrics.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut shutdown_rx => break,
input = input_rx.recv() => {
let Some(PipelineInput::Audio(frame)) = input else {
if input.is_none() { break; }
continue;
};
let opts = crate::traits::SttOptions::default();
if let Ok(transcript) = stt.transcribe(&frame, &opts).await {
let mut metrics = m.write().await;
metrics.total_audio_ms += frame.duration_ms as u64;
let _ = output_tx.send(PipelineOutput::Transcript(transcript)).await;
}
}
}
}
});
Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
}
pub fn build_voice_agent(self) -> AudioResult<PipelineHandle> {
validate_voice_agent_config(
self.tts.is_some(),
self.stt.is_some(),
self.vad.is_some(),
self.agent.is_some(),
)?;
let tts = self.tts.unwrap();
let stt = self.stt.unwrap();
let vad = self.vad.unwrap();
let agent = self.agent.unwrap();
let (input_tx, input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
let (shutdown_tx, shutdown_rx) = oneshot::channel();
let m = metrics.clone();
tokio::spawn(voice_agent_loop(
input_rx,
output_tx,
stt,
tts,
vad,
agent,
self.pre_fx,
self.post_fx,
m,
shutdown_rx,
));
Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
}
pub fn build_transform(self) -> AudioResult<PipelineHandle> {
let pre_fx = self.pre_fx.unwrap_or_default();
let (input_tx, mut input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let m = metrics.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut shutdown_rx => break,
input = input_rx.recv() => {
let Some(PipelineInput::Audio(frame)) = input else {
if input.is_none() { break; }
continue;
};
if let Ok(processed) = pre_fx.process(&frame).await {
let mut metrics = m.write().await;
metrics.total_audio_ms += processed.duration_ms as u64;
let _ = output_tx.send(PipelineOutput::Audio(processed)).await;
}
}
}
}
});
Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
}
pub fn build_music(self) -> AudioResult<PipelineHandle> {
let music = self.music.ok_or_else(|| {
AudioError::PipelineClosed("Music pipeline requires a MusicProvider".into())
})?;
let (input_tx, mut input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
let m = metrics.clone();
tokio::spawn(async move {
loop {
tokio::select! {
_ = &mut shutdown_rx => break,
input = input_rx.recv() => {
let Some(PipelineInput::Text(prompt)) = input else {
if input.is_none() { break; }
continue;
};
let request = crate::traits::MusicRequest {
prompt,
..Default::default()
};
if let Ok(frame) = music.generate(&request).await {
let mut metrics = m.write().await;
metrics.total_audio_ms += frame.duration_ms as u64;
let _ = output_tx.send(PipelineOutput::Audio(frame)).await;
}
}
}
}
});
Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
}
}
impl Default for AudioPipelineBuilder {
fn default() -> Self {
Self::new()
}
}