active_call/media/
engine.rs

1use super::{
2    asr_processor::AsrProcessor,
3    denoiser::NoiseReducer,
4    processor::Processor,
5    track::{
6        Track,
7        tts::{SynthesisHandle, TtsTrack},
8    },
9    vad::{VADOption, VadProcessor, VadType},
10};
11use crate::{
12    CallOption, EouOption,
13    event::EventSender,
14    media::TrackId,
15    synthesis::{
16        AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
17        TencentCloudTtsBasicClient, TencentCloudTtsClient, VoiceApiTtsClient,
18    },
19    transcription::{
20        AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
21        TranscriptionOption, TranscriptionType, VoiceApiAsrClientBuilder,
22    },
23};
24use anyhow::Result;
25use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
26use tokio::sync::mpsc;
27use tokio_util::sync::CancellationToken;
28
29pub type FnCreateVadProcessor = fn(
30    token: CancellationToken,
31    event_sender: EventSender,
32    option: VADOption,
33) -> Result<Box<dyn Processor>>;
34
35pub type FnCreateEouProcessor = fn(
36    token: CancellationToken,
37    event_sender: EventSender,
38    option: EouOption,
39) -> Result<Box<dyn Processor>>;
40
41pub type FnCreateAsrClient = Box<
42    dyn Fn(
43            TrackId,
44            CancellationToken,
45            TranscriptionOption,
46            EventSender,
47        ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
48        + Send
49        + Sync,
50>;
51pub type FnCreateTtsClient =
52    fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
53
54// Define hook types
55pub type CreateProcessorsHook = Box<
56    dyn Fn(
57            Arc<StreamEngine>,
58            &dyn Track,
59            CancellationToken,
60            EventSender,
61            CallOption,
62        ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
63        + Send
64        + Sync,
65>;
66
67pub struct StreamEngine {
68    vad_creators: HashMap<VadType, FnCreateVadProcessor>,
69    eou_creators: HashMap<String, FnCreateEouProcessor>,
70    asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
71    tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
72    create_processors_hook: Arc<CreateProcessorsHook>,
73}
74
75impl Default for StreamEngine {
76    fn default() -> Self {
77        let mut engine = Self::new();
78        engine.register_vad(VadType::Silero, VadProcessor::create);
79        engine.register_vad(VadType::Ten, VadProcessor::create);
80        engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
81
82        engine.register_asr(
83            TranscriptionType::TencentCloud,
84            Box::new(TencentCloudAsrClientBuilder::create),
85        );
86        engine.register_asr(
87            TranscriptionType::VoiceApi,
88            Box::new(VoiceApiAsrClientBuilder::create),
89        );
90        engine.register_asr(
91            TranscriptionType::Aliyun,
92            Box::new(AliyunAsrClientBuilder::create),
93        );
94        engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
95        engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
96        engine.register_tts(SynthesisType::VoiceApi, VoiceApiTtsClient::create);
97        engine.register_tts(
98            SynthesisType::Other("tencent_basic".to_string()),
99            TencentCloudTtsBasicClient::create,
100        );
101        engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
102        engine
103    }
104}
105
106impl StreamEngine {
107    pub fn new() -> Self {
108        Self {
109            vad_creators: HashMap::new(),
110            asr_creators: HashMap::new(),
111            tts_creators: HashMap::new(),
112            eou_creators: HashMap::new(),
113            create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
114        }
115    }
116
117    pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
118        self.vad_creators.insert(vad_type, creator);
119        self
120    }
121
122    pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
123        self.eou_creators.insert(name, creator);
124        self
125    }
126
127    pub fn register_asr(
128        &mut self,
129        asr_type: TranscriptionType,
130        creator: FnCreateAsrClient,
131    ) -> &mut Self {
132        self.asr_creators.insert(asr_type, creator);
133        self
134    }
135
136    pub fn register_tts(
137        &mut self,
138        tts_type: SynthesisType,
139        creator: FnCreateTtsClient,
140    ) -> &mut Self {
141        self.tts_creators.insert(tts_type, creator);
142        self
143    }
144
145    pub fn create_vad_processor(
146        &self,
147        token: CancellationToken,
148        event_sender: EventSender,
149        option: VADOption,
150    ) -> Result<Box<dyn Processor>> {
151        let creator = self.vad_creators.get(&option.r#type);
152        if let Some(creator) = creator {
153            creator(token, event_sender, option)
154        } else {
155            Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
156        }
157    }
158    pub fn create_eou_processor(
159        &self,
160        token: CancellationToken,
161        event_sender: EventSender,
162        option: EouOption,
163    ) -> Result<Box<dyn Processor>> {
164        let creator = self
165            .eou_creators
166            .get(&option.r#type.clone().unwrap_or_default());
167        if let Some(creator) = creator {
168            creator(token, event_sender, option)
169        } else {
170            Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
171        }
172    }
173
174    pub async fn create_asr_processor(
175        &self,
176        track_id: TrackId,
177        cancel_token: CancellationToken,
178        option: TranscriptionOption,
179        event_sender: EventSender,
180    ) -> Result<Box<dyn Processor>> {
181        let asr_client = match option.provider {
182            Some(ref provider) => {
183                let creator = self.asr_creators.get(&provider);
184                if let Some(creator) = creator {
185                    creator(track_id, cancel_token, option, event_sender).await?
186                } else {
187                    return Err(anyhow::anyhow!("ASR type not found: {}", provider));
188                }
189            }
190            None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
191        };
192        Ok(Box::new(AsrProcessor { asr_client }))
193    }
194
195    pub async fn create_tts_client(
196        &self,
197        streaming: bool,
198        tts_option: &SynthesisOption,
199    ) -> Result<Box<dyn SynthesisClient>> {
200        match tts_option.provider {
201            Some(ref provider) => {
202                let creator = self.tts_creators.get(&provider);
203                if let Some(creator) = creator {
204                    creator(streaming, tts_option)
205                } else {
206                    Err(anyhow::anyhow!("TTS type not found: {}", provider))
207                }
208            }
209            None => Err(anyhow::anyhow!(
210                "TTS type not found: {:?}",
211                tts_option.provider
212            )),
213        }
214    }
215
216    pub async fn create_processors(
217        engine: Arc<StreamEngine>,
218        track: &dyn Track,
219        cancel_token: CancellationToken,
220        event_sender: EventSender,
221        option: &CallOption,
222    ) -> Result<Vec<Box<dyn Processor>>> {
223        (engine.clone().create_processors_hook)(
224            engine,
225            track,
226            cancel_token,
227            event_sender,
228            option.clone(),
229        )
230        .await
231    }
232
233    pub async fn create_tts_track(
234        engine: Arc<StreamEngine>,
235        cancel_token: CancellationToken,
236        session_id: String,
237        track_id: TrackId,
238        ssrc: u32,
239        play_id: Option<String>,
240        streaming: bool,
241        tts_option: &SynthesisOption,
242    ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
243        let (tx, rx) = mpsc::unbounded_channel();
244        let new_handle = SynthesisHandle::new(tx, play_id.clone());
245        let tts_client = engine.create_tts_client(streaming, tts_option).await?;
246        let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
247            .with_ssrc(ssrc)
248            .with_cancel_token(cancel_token);
249        Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
250    }
251
252    pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
253        self.create_processors_hook = Arc::new(Box::new(hook_fn));
254        self
255    }
256
257    fn default_create_procesors_hook(
258        engine: Arc<StreamEngine>,
259        track: &dyn Track,
260        cancel_token: CancellationToken,
261        event_sender: EventSender,
262        option: CallOption,
263    ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
264        let track_id = track.id().clone();
265        let samplerate = track.config().samplerate as usize;
266        Box::pin(async move {
267            let mut processors = vec![];
268            match option.denoise {
269                Some(true) => {
270                    let noise_reducer = NoiseReducer::new(samplerate);
271                    processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
272                }
273                _ => {}
274            }
275            match option.vad {
276                Some(ref option) => {
277                    let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
278                        cancel_token.child_token(),
279                        event_sender.clone(),
280                        option.to_owned(),
281                    )?;
282                    processors.push(vad_processor);
283                }
284                None => {}
285            }
286            match option.asr {
287                Some(ref option) => {
288                    let asr_processor = engine
289                        .create_asr_processor(
290                            track_id.clone(),
291                            cancel_token.child_token(),
292                            option.to_owned(),
293                            event_sender.clone(),
294                        )
295                        .await?;
296                    processors.push(asr_processor);
297                }
298                None => {}
299            }
300            match option.eou {
301                Some(ref option) => {
302                    let eou_processor = engine.create_eou_processor(
303                        cancel_token.child_token(),
304                        event_sender.clone(),
305                        option.to_owned(),
306                    )?;
307                    processors.push(eou_processor);
308                }
309                None => {}
310            }
311            match option.inactivity_timeout {
312                Some(timeout_secs) if timeout_secs > 0 => {
313                    let inactivity_processor =
314                        crate::media::inactivity::InactivityProcessor::new(
315                            track_id.clone(),
316                            std::time::Duration::from_secs(timeout_secs),
317                            event_sender.clone(),
318                            cancel_token.child_token(),
319                        );
320                    processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
321                }
322                _ => {}
323            }
324
325            Ok(processors)
326        })
327    }
328}