active_call/media/
engine.rs

1use super::{
2    INTERNAL_SAMPLERATE,
3    asr_processor::AsrProcessor,
4    denoiser::NoiseReducer,
5    processor::Processor,
6    track::{
7        Track, TrackPacketSender,
8        tts::{SynthesisHandle, TtsTrack},
9    },
10    vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13    CallOption, EouOption,
14    event::EventSender,
15    media::TrackId,
16    synthesis::{
17        AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
18        TencentCloudTtsBasicClient, TencentCloudTtsClient, VoiceApiTtsClient,
19    },
20    transcription::{
21        AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22        TranscriptionOption, TranscriptionType, VoiceApiAsrClientBuilder,
23    },
24};
25use anyhow::Result;
26use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
27use tokio::sync::mpsc;
28use tokio_util::sync::CancellationToken;
29use tracing::debug;
30
31pub type FnCreateVadProcessor = fn(
32    token: CancellationToken,
33    event_sender: EventSender,
34    option: VADOption,
35) -> Result<Box<dyn Processor>>;
36
37pub type FnCreateEouProcessor = fn(
38    token: CancellationToken,
39    event_sender: EventSender,
40    option: EouOption,
41) -> Result<Box<dyn Processor>>;
42
43pub type FnCreateAsrClient = Box<
44    dyn Fn(
45            TrackId,
46            CancellationToken,
47            TranscriptionOption,
48            EventSender,
49        ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
50        + Send
51        + Sync,
52>;
53pub type FnCreateTtsClient =
54    fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
55
56// Define hook types
57pub type CreateProcessorsHook = Box<
58    dyn Fn(
59            Arc<StreamEngine>,
60            &dyn Track,
61            CancellationToken,
62            EventSender,
63            TrackPacketSender,
64            CallOption,
65        ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
66        + Send
67        + Sync,
68>;
69
70pub struct StreamEngine {
71    vad_creators: HashMap<VadType, FnCreateVadProcessor>,
72    eou_creators: HashMap<String, FnCreateEouProcessor>,
73    asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
74    tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
75    create_processors_hook: Arc<CreateProcessorsHook>,
76}
77
78impl Default for StreamEngine {
79    fn default() -> Self {
80        let mut engine = Self::new();
81        engine.register_vad(VadType::Silero, VadProcessor::create);
82        engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
83
84        engine.register_asr(
85            TranscriptionType::TencentCloud,
86            Box::new(TencentCloudAsrClientBuilder::create),
87        );
88        engine.register_asr(
89            TranscriptionType::VoiceApi,
90            Box::new(VoiceApiAsrClientBuilder::create),
91        );
92        engine.register_asr(
93            TranscriptionType::Aliyun,
94            Box::new(AliyunAsrClientBuilder::create),
95        );
96        engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
97        engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
98        engine.register_tts(SynthesisType::VoiceApi, VoiceApiTtsClient::create);
99        engine.register_tts(
100            SynthesisType::Other("tencent_basic".to_string()),
101            TencentCloudTtsBasicClient::create,
102        );
103        engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
104        engine
105    }
106}
107
108impl StreamEngine {
109    pub fn new() -> Self {
110        Self {
111            vad_creators: HashMap::new(),
112            asr_creators: HashMap::new(),
113            tts_creators: HashMap::new(),
114            eou_creators: HashMap::new(),
115            create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
116        }
117    }
118
119    pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
120        self.vad_creators.insert(vad_type, creator);
121        self
122    }
123
124    pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
125        self.eou_creators.insert(name, creator);
126        self
127    }
128
129    pub fn register_asr(
130        &mut self,
131        asr_type: TranscriptionType,
132        creator: FnCreateAsrClient,
133    ) -> &mut Self {
134        self.asr_creators.insert(asr_type, creator);
135        self
136    }
137
138    pub fn register_tts(
139        &mut self,
140        tts_type: SynthesisType,
141        creator: FnCreateTtsClient,
142    ) -> &mut Self {
143        self.tts_creators.insert(tts_type, creator);
144        self
145    }
146
147    pub fn create_vad_processor(
148        &self,
149        token: CancellationToken,
150        event_sender: EventSender,
151        option: VADOption,
152    ) -> Result<Box<dyn Processor>> {
153        let creator = self.vad_creators.get(&option.r#type);
154        if let Some(creator) = creator {
155            creator(token, event_sender, option)
156        } else {
157            Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
158        }
159    }
160    pub fn create_eou_processor(
161        &self,
162        token: CancellationToken,
163        event_sender: EventSender,
164        option: EouOption,
165    ) -> Result<Box<dyn Processor>> {
166        let creator = self
167            .eou_creators
168            .get(&option.r#type.clone().unwrap_or_default());
169        if let Some(creator) = creator {
170            creator(token, event_sender, option)
171        } else {
172            Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
173        }
174    }
175
176    pub async fn create_asr_processor(
177        &self,
178        track_id: TrackId,
179        cancel_token: CancellationToken,
180        option: TranscriptionOption,
181        event_sender: EventSender,
182    ) -> Result<Box<dyn Processor>> {
183        let asr_client = match option.provider {
184            Some(ref provider) => {
185                let creator = self.asr_creators.get(&provider);
186                if let Some(creator) = creator {
187                    creator(track_id, cancel_token, option, event_sender).await?
188                } else {
189                    return Err(anyhow::anyhow!("ASR type not found: {}", provider));
190                }
191            }
192            None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
193        };
194        Ok(Box::new(AsrProcessor { asr_client }))
195    }
196
197    pub async fn create_tts_client(
198        &self,
199        streaming: bool,
200        tts_option: &SynthesisOption,
201    ) -> Result<Box<dyn SynthesisClient>> {
202        match tts_option.provider {
203            Some(ref provider) => {
204                let creator = self.tts_creators.get(&provider);
205                if let Some(creator) = creator {
206                    creator(streaming, tts_option)
207                } else {
208                    Err(anyhow::anyhow!("TTS type not found: {}", provider))
209                }
210            }
211            None => Err(anyhow::anyhow!(
212                "TTS type not found: {:?}",
213                tts_option.provider
214            )),
215        }
216    }
217
218    pub async fn create_processors(
219        engine: Arc<StreamEngine>,
220        track: &dyn Track,
221        cancel_token: CancellationToken,
222        event_sender: EventSender,
223        packet_sender: TrackPacketSender,
224        option: &CallOption,
225    ) -> Result<Vec<Box<dyn Processor>>> {
226        (engine.clone().create_processors_hook)(
227            engine,
228            track,
229            cancel_token,
230            event_sender,
231            packet_sender,
232            option.clone(),
233        )
234        .await
235    }
236
237    pub async fn create_tts_track(
238        engine: Arc<StreamEngine>,
239        cancel_token: CancellationToken,
240        session_id: String,
241        track_id: TrackId,
242        ssrc: u32,
243        play_id: Option<String>,
244        streaming: bool,
245        tts_option: &SynthesisOption,
246    ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
247        let (tx, rx) = mpsc::unbounded_channel();
248        let new_handle = SynthesisHandle::new(tx, play_id.clone());
249        let tts_client = engine.create_tts_client(streaming, tts_option).await?;
250        let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
251            .with_ssrc(ssrc)
252            .with_cancel_token(cancel_token);
253        Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
254    }
255
256    pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
257        self.create_processors_hook = Arc::new(Box::new(hook_fn));
258        self
259    }
260
261    fn default_create_procesors_hook(
262        engine: Arc<StreamEngine>,
263        track: &dyn Track,
264        cancel_token: CancellationToken,
265        event_sender: EventSender,
266        packet_sender: TrackPacketSender,
267        option: CallOption,
268    ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
269        let track_id = track.id().clone();
270        Box::pin(async move {
271            let mut processors = vec![];
272            debug!(track_id = %track_id, "Creating processors for track");
273
274            if let Some(realtime_option) = option.realtime {
275                debug!(track_id = %track_id, "Adding RealtimeProcessor");
276                let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
277                    track_id.clone(),
278                    cancel_token.child_token(),
279                    event_sender.clone(),
280                    packet_sender.clone(),
281                    realtime_option,
282                )?;
283                processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
284                // In realtime mode, we usually don't need separate VAD or ASR processors
285                // as they are handled by the realtime service (OpenAI/Azure)
286                return Ok(processors);
287            }
288
289            match option.denoise {
290                Some(true) => {
291                    debug!(track_id = %track_id, "Adding NoiseReducer processor");
292                    let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
293                    processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
294                }
295                _ => {}
296            }
297            match option.vad {
298                Some(mut option) => {
299                    debug!(track_id = %track_id, "Adding VadProcessor processor type={:?}", option.r#type);
300                    option.samplerate = INTERNAL_SAMPLERATE;
301                    let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
302                        cancel_token.child_token(),
303                        event_sender.clone(),
304                        option.to_owned(),
305                    )?;
306                    processors.push(vad_processor);
307                }
308                None => {}
309            }
310            match option.asr {
311                Some(mut option) => {
312                    debug!(track_id = %track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
313                    option.samplerate = Some(INTERNAL_SAMPLERATE);
314                    let asr_processor = engine
315                        .create_asr_processor(
316                            track_id.clone(),
317                            cancel_token.child_token(),
318                            option.to_owned(),
319                            event_sender.clone(),
320                        )
321                        .await?;
322                    processors.push(asr_processor);
323                }
324                None => {}
325            }
326            match option.eou {
327                Some(ref option) => {
328                    let eou_processor = engine.create_eou_processor(
329                        cancel_token.child_token(),
330                        event_sender.clone(),
331                        option.to_owned(),
332                    )?;
333                    processors.push(eou_processor);
334                }
335                None => {}
336            }
337            match option.inactivity_timeout {
338                Some(timeout_secs) if timeout_secs > 0 => {
339                    let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
340                        track_id.clone(),
341                        std::time::Duration::from_secs(timeout_secs),
342                        event_sender.clone(),
343                        cancel_token.child_token(),
344                    );
345                    processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
346                }
347                _ => {}
348            }
349
350            Ok(processors)
351        })
352    }
353}