active_call/media/
engine.rs

1use super::{
2    INTERNAL_SAMPLERATE,
3    asr_processor::AsrProcessor,
4    denoiser::NoiseReducer,
5    processor::Processor,
6    track::{
7        Track, TrackPacketSender,
8        tts::{SynthesisHandle, TtsTrack},
9    },
10    vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13    CallOption, EouOption,
14    event::EventSender,
15    media::TrackId,
16    synthesis::{
17        AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
18        TencentCloudTtsBasicClient, TencentCloudTtsClient,
19    },
20    transcription::{
21        AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22        TranscriptionOption, TranscriptionType,
23    },
24};
25
26#[cfg(feature = "offline")]
27use crate::{
28    synthesis::SupertonicTtsClient,
29    transcription::SensevoiceAsrClientBuilder,
30};
31
32use anyhow::Result;
33use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
34use tokio::sync::mpsc;
35use tokio_util::sync::CancellationToken;
36use tracing::debug;
37
38pub type FnCreateVadProcessor = fn(
39    token: CancellationToken,
40    event_sender: EventSender,
41    option: VADOption,
42) -> Result<Box<dyn Processor>>;
43
44pub type FnCreateEouProcessor = fn(
45    token: CancellationToken,
46    event_sender: EventSender,
47    option: EouOption,
48) -> Result<Box<dyn Processor>>;
49
50pub type FnCreateAsrClient = Box<
51    dyn Fn(
52            TrackId,
53            CancellationToken,
54            TranscriptionOption,
55            EventSender,
56        ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
57        + Send
58        + Sync,
59>;
60pub type FnCreateTtsClient =
61    fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
62
63// Define hook types
64pub type CreateProcessorsHook = Box<
65    dyn Fn(
66            Arc<StreamEngine>,
67            &dyn Track,
68            CancellationToken,
69            EventSender,
70            TrackPacketSender,
71            CallOption,
72        ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
73        + Send
74        + Sync,
75>;
76
77pub struct StreamEngine {
78    vad_creators: HashMap<VadType, FnCreateVadProcessor>,
79    eou_creators: HashMap<String, FnCreateEouProcessor>,
80    asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
81    tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
82    create_processors_hook: Arc<CreateProcessorsHook>,
83}
84
85impl Default for StreamEngine {
86    fn default() -> Self {
87        let mut engine = Self::new();
88        engine.register_vad(VadType::Silero, VadProcessor::create);
89        engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
90
91        engine.register_asr(
92            TranscriptionType::TencentCloud,
93            Box::new(TencentCloudAsrClientBuilder::create),
94        );
95        engine.register_asr(
96            TranscriptionType::Aliyun,
97            Box::new(AliyunAsrClientBuilder::create),
98        );
99        
100        #[cfg(feature = "offline")]
101        engine.register_asr(
102            TranscriptionType::Sensevoice,
103            Box::new(SensevoiceAsrClientBuilder::create),
104        );
105        
106        engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
107        engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
108        engine.register_tts(
109            SynthesisType::Other("tencent_basic".to_string()),
110            TencentCloudTtsBasicClient::create,
111        );
112        engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
113        
114        #[cfg(feature = "offline")]
115        engine.register_tts(SynthesisType::Supertonic, SupertonicTtsClient::create);
116        
117        engine
118    }
119}
120
121impl StreamEngine {
122    pub fn new() -> Self {
123        Self {
124            vad_creators: HashMap::new(),
125            asr_creators: HashMap::new(),
126            tts_creators: HashMap::new(),
127            eou_creators: HashMap::new(),
128            create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
129        }
130    }
131
132    pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
133        self.vad_creators.insert(vad_type, creator);
134        self
135    }
136
137    pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
138        self.eou_creators.insert(name, creator);
139        self
140    }
141
142    pub fn register_asr(
143        &mut self,
144        asr_type: TranscriptionType,
145        creator: FnCreateAsrClient,
146    ) -> &mut Self {
147        self.asr_creators.insert(asr_type, creator);
148        self
149    }
150
151    pub fn register_tts(
152        &mut self,
153        tts_type: SynthesisType,
154        creator: FnCreateTtsClient,
155    ) -> &mut Self {
156        self.tts_creators.insert(tts_type, creator);
157        self
158    }
159
160    pub fn create_vad_processor(
161        &self,
162        token: CancellationToken,
163        event_sender: EventSender,
164        option: VADOption,
165    ) -> Result<Box<dyn Processor>> {
166        let creator = self.vad_creators.get(&option.r#type);
167        if let Some(creator) = creator {
168            creator(token, event_sender, option)
169        } else {
170            Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
171        }
172    }
173    pub fn create_eou_processor(
174        &self,
175        token: CancellationToken,
176        event_sender: EventSender,
177        option: EouOption,
178    ) -> Result<Box<dyn Processor>> {
179        let creator = self
180            .eou_creators
181            .get(&option.r#type.clone().unwrap_or_default());
182        if let Some(creator) = creator {
183            creator(token, event_sender, option)
184        } else {
185            Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
186        }
187    }
188
189    pub async fn create_asr_processor(
190        &self,
191        track_id: TrackId,
192        cancel_token: CancellationToken,
193        option: TranscriptionOption,
194        event_sender: EventSender,
195    ) -> Result<Box<dyn Processor>> {
196        let asr_client = match option.provider {
197            Some(ref provider) => {
198                let creator = self.asr_creators.get(&provider);
199                if let Some(creator) = creator {
200                    creator(track_id, cancel_token, option, event_sender).await?
201                } else {
202                    return Err(anyhow::anyhow!("ASR type not found: {}", provider));
203                }
204            }
205            None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
206        };
207        Ok(Box::new(AsrProcessor { asr_client }))
208    }
209
210    pub async fn create_tts_client(
211        &self,
212        streaming: bool,
213        tts_option: &SynthesisOption,
214    ) -> Result<Box<dyn SynthesisClient>> {
215        match tts_option.provider {
216            Some(ref provider) => {
217                let creator = self.tts_creators.get(&provider);
218                if let Some(creator) = creator {
219                    creator(streaming, tts_option)
220                } else {
221                    Err(anyhow::anyhow!("TTS type not found: {}", provider))
222                }
223            }
224            None => Err(anyhow::anyhow!(
225                "TTS type not found: {:?}",
226                tts_option.provider
227            )),
228        }
229    }
230
231    pub async fn create_processors(
232        engine: Arc<StreamEngine>,
233        track: &dyn Track,
234        cancel_token: CancellationToken,
235        event_sender: EventSender,
236        packet_sender: TrackPacketSender,
237        option: &CallOption,
238    ) -> Result<Vec<Box<dyn Processor>>> {
239        (engine.clone().create_processors_hook)(
240            engine,
241            track,
242            cancel_token,
243            event_sender,
244            packet_sender,
245            option.clone(),
246        )
247        .await
248    }
249
250    pub async fn create_tts_track(
251        engine: Arc<StreamEngine>,
252        cancel_token: CancellationToken,
253        session_id: String,
254        track_id: TrackId,
255        ssrc: u32,
256        play_id: Option<String>,
257        streaming: bool,
258        tts_option: &SynthesisOption,
259    ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
260        let (tx, rx) = mpsc::unbounded_channel();
261        let new_handle = SynthesisHandle::new(tx, play_id.clone());
262        let tts_client = engine.create_tts_client(streaming, tts_option).await?;
263        let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
264            .with_ssrc(ssrc)
265            .with_cancel_token(cancel_token);
266        Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
267    }
268
269    pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
270        self.create_processors_hook = Arc::new(Box::new(hook_fn));
271        self
272    }
273
274    fn default_create_procesors_hook(
275        engine: Arc<StreamEngine>,
276        track: &dyn Track,
277        cancel_token: CancellationToken,
278        event_sender: EventSender,
279        packet_sender: TrackPacketSender,
280        option: CallOption,
281    ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
282        let track_id = track.id().clone();
283        Box::pin(async move {
284            let mut processors = vec![];
285            debug!(track_id = %track_id, "Creating processors for track");
286
287            if let Some(realtime_option) = option.realtime {
288                debug!(track_id = %track_id, "Adding RealtimeProcessor");
289                let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
290                    track_id.clone(),
291                    cancel_token.child_token(),
292                    event_sender.clone(),
293                    packet_sender.clone(),
294                    realtime_option,
295                )?;
296                processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
297                // In realtime mode, we usually don't need separate VAD or ASR processors
298                // as they are handled by the realtime service (OpenAI/Azure)
299                return Ok(processors);
300            }
301
302            match option.denoise {
303                Some(true) => {
304                    debug!(track_id = %track_id, "Adding NoiseReducer processor");
305                    let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
306                    processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
307                }
308                _ => {}
309            }
310            match option.vad {
311                Some(mut option) => {
312                    debug!(track_id = %track_id, "Adding VadProcessor processor type={:?}", option.r#type);
313                    option.samplerate = INTERNAL_SAMPLERATE;
314                    let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
315                        cancel_token.child_token(),
316                        event_sender.clone(),
317                        option.to_owned(),
318                    )?;
319                    processors.push(vad_processor);
320                }
321                None => {}
322            }
323            match option.asr {
324                Some(mut option) => {
325                    debug!(track_id = %track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
326                    option.samplerate = Some(INTERNAL_SAMPLERATE);
327                    let asr_processor = engine
328                        .create_asr_processor(
329                            track_id.clone(),
330                            cancel_token.child_token(),
331                            option.to_owned(),
332                            event_sender.clone(),
333                        )
334                        .await?;
335                    processors.push(asr_processor);
336                }
337                None => {}
338            }
339            match option.eou {
340                Some(ref option) => {
341                    let eou_processor = engine.create_eou_processor(
342                        cancel_token.child_token(),
343                        event_sender.clone(),
344                        option.to_owned(),
345                    )?;
346                    processors.push(eou_processor);
347                }
348                None => {}
349            }
350            match option.inactivity_timeout {
351                Some(timeout_secs) if timeout_secs > 0 => {
352                    let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
353                        track_id.clone(),
354                        std::time::Duration::from_secs(timeout_secs),
355                        event_sender.clone(),
356                        cancel_token.child_token(),
357                    );
358                    processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
359                }
360                _ => {}
361            }
362
363            Ok(processors)
364        })
365    }
366}