Skip to main content

adk_audio/pipeline/
builder.rs

1//! Audio pipeline builder for composing processing topologies.
2
3use std::sync::Arc;
4
5use tokio::sync::{RwLock, mpsc, oneshot};
6
7use crate::error::{AudioError, AudioResult};
8use crate::pipeline::handle::PipelineHandle;
9use crate::pipeline::types::{PipelineInput, PipelineMetrics, PipelineOutput};
10use crate::pipeline::voice_agent::{validate_voice_agent_config, voice_agent_loop};
11use crate::traits::{
12    AudioProcessor, FxChain, MusicProvider, SttProvider, TtsProvider, TtsRequest, VadProcessor,
13};
14
15/// Builder for constructing audio pipelines.
16///
17/// # Example
18///
19/// ```ignore
20/// let handle = AudioPipelineBuilder::new()
21///     .tts(my_tts)
22///     .stt(my_stt)
23///     .vad(my_vad)
24///     .agent(my_agent)
25///     .build_voice_agent()?;
26/// ```
27pub struct AudioPipelineBuilder {
28    tts: Option<Arc<dyn TtsProvider>>,
29    stt: Option<Arc<dyn SttProvider>>,
30    music: Option<Arc<dyn MusicProvider>>,
31    vad: Option<Arc<dyn VadProcessor>>,
32    pre_fx: Option<FxChain>,
33    post_fx: Option<FxChain>,
34    agent: Option<Arc<dyn adk_core::Agent>>,
35    buffer_size: usize,
36    /// Desktop audio capture source (microphone).
37    #[cfg(feature = "desktop-audio")]
38    capture: Option<crate::desktop::capture::AudioCapture>,
39    /// Desktop audio playback sink (speaker).
40    #[cfg(feature = "desktop-audio")]
41    playback: Option<crate::desktop::playback::AudioPlayback>,
42}
43
44impl AudioPipelineBuilder {
45    /// Create a new builder with default settings.
46    pub fn new() -> Self {
47        Self {
48            tts: None,
49            stt: None,
50            music: None,
51            vad: None,
52            pre_fx: None,
53            post_fx: None,
54            agent: None,
55            buffer_size: 32,
56            #[cfg(feature = "desktop-audio")]
57            capture: None,
58            #[cfg(feature = "desktop-audio")]
59            playback: None,
60        }
61    }
62
63    /// Set the TTS provider.
64    pub fn tts(mut self, tts: Arc<dyn TtsProvider>) -> Self {
65        self.tts = Some(tts);
66        self
67    }
68
69    /// Set the STT provider.
70    pub fn stt(mut self, stt: Arc<dyn SttProvider>) -> Self {
71        self.stt = Some(stt);
72        self
73    }
74
75    /// Set the music generation provider.
76    pub fn music(mut self, music: Arc<dyn MusicProvider>) -> Self {
77        self.music = Some(music);
78        self
79    }
80
81    /// Set the VAD processor.
82    pub fn vad(mut self, vad: Arc<dyn VadProcessor>) -> Self {
83        self.vad = Some(vad);
84        self
85    }
86
87    /// Set the pre-processing FX chain (applied before STT/VAD).
88    pub fn pre_fx(mut self, fx: FxChain) -> Self {
89        self.pre_fx = Some(fx);
90        self
91    }
92
93    /// Set the post-processing FX chain (applied after TTS).
94    pub fn post_fx(mut self, fx: FxChain) -> Self {
95        self.post_fx = Some(fx);
96        self
97    }
98
99    /// Set the agent for voice agent pipelines.
100    pub fn agent(mut self, agent: Arc<dyn adk_core::Agent>) -> Self {
101        self.agent = Some(agent);
102        self
103    }
104
105    /// Set the channel buffer size (default: 32).
106    pub fn buffer_size(mut self, size: usize) -> Self {
107        self.buffer_size = size;
108        self
109    }
110
111    /// Set the audio capture source for desktop pipelines.
112    ///
113    /// When both `capture` and `playback` are configured, `build_voice_agent()`
114    /// will store them for the caller to wire into the pipeline's input/output
115    /// channels.
116    ///
117    /// Only available when the `desktop-audio` feature is enabled.
118    #[cfg(feature = "desktop-audio")]
119    pub fn capture(mut self, capture: crate::desktop::capture::AudioCapture) -> Self {
120        self.capture = Some(capture);
121        self
122    }
123
124    /// Set the audio playback sink for desktop pipelines.
125    ///
126    /// When both `capture` and `playback` are configured, `build_voice_agent()`
127    /// will store them for the caller to wire into the pipeline's input/output
128    /// channels.
129    ///
130    /// Only available when the `desktop-audio` feature is enabled.
131    #[cfg(feature = "desktop-audio")]
132    pub fn playback(mut self, playback: crate::desktop::playback::AudioPlayback) -> Self {
133        self.playback = Some(playback);
134        self
135    }
136
137    /// Build a TTS-only pipeline (Text → TTS → Audio).
138    pub fn build_tts(self) -> AudioResult<PipelineHandle> {
139        let tts = self.tts.ok_or_else(|| {
140            AudioError::PipelineClosed("TTS pipeline requires a TtsProvider".into())
141        })?;
142
143        let (input_tx, mut input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
144        let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
145        let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
146        let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
147
148        let m = metrics.clone();
149        tokio::spawn(async move {
150            loop {
151                tokio::select! {
152                    _ = &mut shutdown_rx => break,
153                    input = input_rx.recv() => {
154                        let Some(PipelineInput::Text(text)) = input else {
155                            if input.is_none() { break; }
156                            continue;
157                        };
158                        let request = TtsRequest { text, ..Default::default() };
159                        if let Ok(frame) = tts.synthesize(&request).await {
160                            let mut metrics = m.write().await;
161                            metrics.total_audio_ms += frame.duration_ms as u64;
162                            let _ = output_tx.send(PipelineOutput::Audio(frame)).await;
163                        }
164                    }
165                }
166            }
167        });
168
169        Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
170    }
171
172    /// Build an STT-only pipeline (Audio → STT → Transcript).
173    pub fn build_stt(self) -> AudioResult<PipelineHandle> {
174        let stt = self.stt.ok_or_else(|| {
175            AudioError::PipelineClosed("STT pipeline requires an SttProvider".into())
176        })?;
177
178        let (input_tx, mut input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
179        let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
180        let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
181        let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
182
183        let m = metrics.clone();
184        tokio::spawn(async move {
185            loop {
186                tokio::select! {
187                    _ = &mut shutdown_rx => break,
188                    input = input_rx.recv() => {
189                        let Some(PipelineInput::Audio(frame)) = input else {
190                            if input.is_none() { break; }
191                            continue;
192                        };
193                        let opts = crate::traits::SttOptions::default();
194                        if let Ok(transcript) = stt.transcribe(&frame, &opts).await {
195                            let mut metrics = m.write().await;
196                            metrics.total_audio_ms += frame.duration_ms as u64;
197                            let _ = output_tx.send(PipelineOutput::Transcript(transcript)).await;
198                        }
199                    }
200                }
201            }
202        });
203
204        Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
205    }
206
207    /// Build a voice agent pipeline (Audio → VAD → STT → Agent → TTS → Audio).
208    ///
209    /// Requires `tts`, `stt`, `vad`, and `agent` to be set.
210    ///
211    /// When the `desktop-audio` feature is enabled and both `capture` and `playback`
212    /// are configured, the caller should use the returned [`PipelineHandle`] to wire
213    /// the capture stream into `input_tx` and route `output_rx` audio frames to
214    /// playback. Starting capture requires a device ID and [`CaptureConfig`](crate::desktop::CaptureConfig),
215    /// and playback requires a device ID, so the builder stores the instances and
216    /// the caller completes the wiring at runtime.
217    pub fn build_voice_agent(self) -> AudioResult<PipelineHandle> {
218        validate_voice_agent_config(
219            self.tts.is_some(),
220            self.stt.is_some(),
221            self.vad.is_some(),
222            self.agent.is_some(),
223        )?;
224
225        let tts = self.tts.unwrap();
226        let stt = self.stt.unwrap();
227        let vad = self.vad.unwrap();
228        let agent = self.agent.unwrap();
229
230        let (input_tx, input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
231        let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
232        let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
233        let (shutdown_tx, shutdown_rx) = oneshot::channel();
234
235        let m = metrics.clone();
236        tokio::spawn(voice_agent_loop(
237            input_rx,
238            output_tx,
239            stt,
240            tts,
241            vad,
242            agent,
243            self.pre_fx,
244            self.post_fx,
245            m,
246            shutdown_rx,
247        ));
248
249        // When desktop-audio is enabled and both capture and playback are
250        // configured, wire them into the pipeline's input/output channels.
251        // Starting capture requires a device_id and CaptureConfig, and
252        // playback requires a device_id — these are runtime parameters.
253        // The caller should:
254        //   1. Call `capture.start_capture(device_id, &config)` to get an AudioStream
255        //   2. Spawn a task that reads from the AudioStream and sends
256        //      `PipelineInput::Audio(frame)` into `handle.input_tx`
257        //   3. Spawn a task that reads `PipelineOutput::Audio(frame)` from
258        //      `handle.output_rx` and calls `playback.play(device_id, &frame)`
259        #[cfg(feature = "desktop-audio")]
260        if self.capture.is_some() && self.playback.is_some() {
261            tracing::info!(
262                "desktop audio capture and playback configured — caller must wire \
263                 capture stream to input_tx and output_rx to playback using device \
264                 IDs at runtime"
265            );
266        }
267
268        Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
269    }
270
271    /// Build a transform-only pipeline (Audio → FxChain → Audio).
272    pub fn build_transform(self) -> AudioResult<PipelineHandle> {
273        let pre_fx = self.pre_fx.unwrap_or_default();
274
275        let (input_tx, mut input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
276        let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
277        let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
278        let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
279
280        let m = metrics.clone();
281        tokio::spawn(async move {
282            loop {
283                tokio::select! {
284                    _ = &mut shutdown_rx => break,
285                    input = input_rx.recv() => {
286                        let Some(PipelineInput::Audio(frame)) = input else {
287                            if input.is_none() { break; }
288                            continue;
289                        };
290                        if let Ok(processed) = pre_fx.process(&frame).await {
291                            let mut metrics = m.write().await;
292                            metrics.total_audio_ms += processed.duration_ms as u64;
293                            let _ = output_tx.send(PipelineOutput::Audio(processed)).await;
294                        }
295                    }
296                }
297            }
298        });
299
300        Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
301    }
302
303    /// Build a music generation pipeline (Text → MusicProvider → Audio).
304    pub fn build_music(self) -> AudioResult<PipelineHandle> {
305        let music = self.music.ok_or_else(|| {
306            AudioError::PipelineClosed("Music pipeline requires a MusicProvider".into())
307        })?;
308
309        let (input_tx, mut input_rx) = mpsc::channel::<PipelineInput>(self.buffer_size);
310        let (output_tx, output_rx) = mpsc::channel::<PipelineOutput>(self.buffer_size);
311        let metrics = Arc::new(RwLock::new(PipelineMetrics::default()));
312        let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
313
314        let m = metrics.clone();
315        tokio::spawn(async move {
316            loop {
317                tokio::select! {
318                    _ = &mut shutdown_rx => break,
319                    input = input_rx.recv() => {
320                        let Some(PipelineInput::Text(prompt)) = input else {
321                            if input.is_none() { break; }
322                            continue;
323                        };
324                        let request = crate::traits::MusicRequest {
325                            prompt,
326                            ..Default::default()
327                        };
328                        if let Ok(frame) = music.generate(&request).await {
329                            let mut metrics = m.write().await;
330                            metrics.total_audio_ms += frame.duration_ms as u64;
331                            let _ = output_tx.send(PipelineOutput::Audio(frame)).await;
332                        }
333                    }
334                }
335            }
336        });
337
338        Ok(PipelineHandle::new(input_tx, output_rx, metrics, shutdown_tx))
339    }
340}
341
342impl Default for AudioPipelineBuilder {
343    fn default() -> Self {
344        Self::new()
345    }
346}