Skip to main content

active_call/media/
engine.rs

1use super::{
2    INTERNAL_SAMPLERATE,
3    asr_processor::AsrProcessor,
4    denoiser::NoiseReducer,
5    processor::Processor,
6    track::{
7        Track, TrackPacketSender,
8        tts::{SynthesisHandle, TtsTrack},
9    },
10    vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13    CallOption, EouOption,
14    event::EventSender,
15    media::TrackId,
16    synthesis::{
17        AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
18        TencentCloudTtsBasicClient, TencentCloudTtsClient,
19    },
20    transcription::{
21        AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22        TranscriptionOption, TranscriptionType,
23    },
24};
25
26#[cfg(feature = "offline")]
27use crate::{synthesis::SupertonicTtsClient, transcription::SensevoiceAsrClientBuilder};
28
29use anyhow::Result;
30use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
31use tokio::sync::mpsc;
32use tokio_util::sync::CancellationToken;
33use tracing::debug;
34
35pub type FnCreateVadProcessor = fn(
36    token: CancellationToken,
37    event_sender: EventSender,
38    option: VADOption,
39) -> Result<Box<dyn Processor>>;
40
41pub type FnCreateEouProcessor = fn(
42    token: CancellationToken,
43    event_sender: EventSender,
44    option: EouOption,
45) -> Result<Box<dyn Processor>>;
46
47pub type FnCreateAsrClient = Box<
48    dyn Fn(
49            TrackId,
50            CancellationToken,
51            TranscriptionOption,
52            EventSender,
53        ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
54        + Send
55        + Sync,
56>;
57pub type FnCreateTtsClient =
58    fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
59
60// Define hook types
61pub type CreateProcessorsHook = Box<
62    dyn Fn(
63            Arc<StreamEngine>,
64            TrackId,
65            CancellationToken,
66            EventSender,
67            TrackPacketSender,
68            CallOption,
69        ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
70        + Send
71        + Sync,
72>;
73
74pub struct StreamEngine {
75    vad_creators: HashMap<VadType, FnCreateVadProcessor>,
76    eou_creators: HashMap<String, FnCreateEouProcessor>,
77    asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
78    tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
79    create_processors_hook: Arc<CreateProcessorsHook>,
80}
81
82impl Default for StreamEngine {
83    fn default() -> Self {
84        let mut engine = Self::new();
85        engine.register_vad(VadType::Silero, VadProcessor::create);
86        engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
87
88        engine.register_asr(
89            TranscriptionType::TencentCloud,
90            Box::new(TencentCloudAsrClientBuilder::create),
91        );
92        engine.register_asr(
93            TranscriptionType::Aliyun,
94            Box::new(AliyunAsrClientBuilder::create),
95        );
96
97        #[cfg(feature = "offline")]
98        engine.register_asr(
99            TranscriptionType::Sensevoice,
100            Box::new(SensevoiceAsrClientBuilder::create),
101        );
102
103        engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
104        engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
105        engine.register_tts(
106            SynthesisType::Other("tencent_basic".to_string()),
107            TencentCloudTtsBasicClient::create,
108        );
109        engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
110
111        #[cfg(feature = "offline")]
112        engine.register_tts(SynthesisType::Supertonic, SupertonicTtsClient::create);
113
114        engine
115    }
116}
117
118impl StreamEngine {
119    pub fn new() -> Self {
120        Self {
121            vad_creators: HashMap::new(),
122            asr_creators: HashMap::new(),
123            tts_creators: HashMap::new(),
124            eou_creators: HashMap::new(),
125            create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
126        }
127    }
128
129    pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
130        self.vad_creators.insert(vad_type, creator);
131        self
132    }
133
134    pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
135        self.eou_creators.insert(name, creator);
136        self
137    }
138
139    pub fn register_asr(
140        &mut self,
141        asr_type: TranscriptionType,
142        creator: FnCreateAsrClient,
143    ) -> &mut Self {
144        self.asr_creators.insert(asr_type, creator);
145        self
146    }
147
148    pub fn register_tts(
149        &mut self,
150        tts_type: SynthesisType,
151        creator: FnCreateTtsClient,
152    ) -> &mut Self {
153        self.tts_creators.insert(tts_type, creator);
154        self
155    }
156
157    pub fn create_vad_processor(
158        &self,
159        token: CancellationToken,
160        event_sender: EventSender,
161        option: VADOption,
162    ) -> Result<Box<dyn Processor>> {
163        let creator = self.vad_creators.get(&option.r#type);
164        if let Some(creator) = creator {
165            creator(token, event_sender, option)
166        } else {
167            Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
168        }
169    }
170    pub fn create_eou_processor(
171        &self,
172        token: CancellationToken,
173        event_sender: EventSender,
174        option: EouOption,
175    ) -> Result<Box<dyn Processor>> {
176        let creator = self
177            .eou_creators
178            .get(&option.r#type.clone().unwrap_or_default());
179        if let Some(creator) = creator {
180            creator(token, event_sender, option)
181        } else {
182            Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
183        }
184    }
185
186    pub async fn create_asr_processor(
187        &self,
188        track_id: TrackId,
189        cancel_token: CancellationToken,
190        option: TranscriptionOption,
191        event_sender: EventSender,
192    ) -> Result<Box<dyn Processor>> {
193        let asr_client = match option.provider {
194            Some(ref provider) => {
195                let creator = self.asr_creators.get(&provider);
196                if let Some(creator) = creator {
197                    creator(track_id, cancel_token, option, event_sender).await?
198                } else {
199                    return Err(anyhow::anyhow!("ASR type not found: {}", provider));
200                }
201            }
202            None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
203        };
204        Ok(Box::new(AsrProcessor { asr_client }))
205    }
206
207    pub async fn create_tts_client(
208        &self,
209        streaming: bool,
210        tts_option: &SynthesisOption,
211    ) -> Result<Box<dyn SynthesisClient>> {
212        match tts_option.provider {
213            Some(ref provider) => {
214                let creator = self.tts_creators.get(&provider);
215                if let Some(creator) = creator {
216                    creator(streaming, tts_option)
217                } else {
218                    Err(anyhow::anyhow!("TTS type not found: {}", provider))
219                }
220            }
221            None => Err(anyhow::anyhow!(
222                "TTS type not found: {:?}",
223                tts_option.provider
224            )),
225        }
226    }
227
228    pub async fn create_processors(
229        engine: Arc<StreamEngine>,
230        track: &dyn Track,
231        cancel_token: CancellationToken,
232        event_sender: EventSender,
233        packet_sender: TrackPacketSender,
234        option: &CallOption,
235    ) -> Result<Vec<Box<dyn Processor>>> {
236        (engine.clone().create_processors_hook)(
237            engine,
238            track.id().clone(),
239            cancel_token,
240            event_sender,
241            packet_sender,
242            option.clone(),
243        )
244        .await
245    }
246
247    pub async fn create_tts_track(
248        engine: Arc<StreamEngine>,
249        cancel_token: CancellationToken,
250        session_id: String,
251        track_id: TrackId,
252        ssrc: u32,
253        play_id: Option<String>,
254        streaming: bool,
255        tts_option: &SynthesisOption,
256    ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
257        let (tx, rx) = mpsc::unbounded_channel();
258        let new_handle = SynthesisHandle::new(tx, play_id.clone(), ssrc);
259        let tts_client = engine.create_tts_client(streaming, tts_option).await?;
260        let sample_rate = tts_option.samplerate.unwrap_or(16000) as u32;
261        let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
262            .with_ssrc(ssrc)
263            .with_sample_rate(sample_rate)
264            .with_cancel_token(cancel_token);
265        Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
266    }
267
268    pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
269        self.create_processors_hook = Arc::new(Box::new(hook_fn));
270        self
271    }
272
273    pub fn default_create_procesors_hook(
274        engine: Arc<StreamEngine>,
275        track_id: TrackId,
276        cancel_token: CancellationToken,
277        event_sender: EventSender,
278        packet_sender: TrackPacketSender,
279        option: CallOption,
280    ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
281        Box::pin(async move {
282            let mut processors = vec![];
283            debug!(%track_id, "Creating processors for track");
284
285            if let Some(realtime_option) = option.realtime {
286                debug!(%track_id, "Adding RealtimeProcessor");
287                let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
288                    track_id.clone(),
289                    cancel_token.child_token(),
290                    event_sender.clone(),
291                    packet_sender.clone(),
292                    realtime_option,
293                )?;
294                processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
295                // In realtime mode, we usually don't need separate VAD or ASR processors
296                // as they are handled by the realtime service (OpenAI/Azure)
297                return Ok(processors);
298            }
299
300            match option.denoise {
301                Some(true) => {
302                    debug!(%track_id, "Adding NoiseReducer processor");
303                    let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
304                    processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
305                }
306                _ => {}
307            }
308            match option.vad {
309                Some(mut option) => {
310                    debug!(%track_id, "Adding VadProcessor processor type={:?}", option.r#type);
311                    option.samplerate = INTERNAL_SAMPLERATE;
312                    let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
313                        cancel_token.child_token(),
314                        event_sender.clone(),
315                        option.to_owned(),
316                    )?;
317                    processors.push(vad_processor);
318                }
319                None => {}
320            }
321            match option.asr {
322                Some(mut option) => {
323                    debug!(%track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
324                    option.samplerate = Some(INTERNAL_SAMPLERATE);
325                    let asr_processor = engine
326                        .create_asr_processor(
327                            track_id.clone(),
328                            cancel_token.child_token(),
329                            option.to_owned(),
330                            event_sender.clone(),
331                        )
332                        .await?;
333                    processors.push(asr_processor);
334                }
335                None => {}
336            }
337            match option.eou {
338                Some(ref option) => {
339                    let eou_processor = engine.create_eou_processor(
340                        cancel_token.child_token(),
341                        event_sender.clone(),
342                        option.to_owned(),
343                    )?;
344                    processors.push(eou_processor);
345                }
346                None => {}
347            }
348            match option.inactivity_timeout {
349                Some(timeout_secs) if timeout_secs > 0 => {
350                    let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
351                        track_id.clone(),
352                        std::time::Duration::from_secs(timeout_secs),
353                        event_sender.clone(),
354                        cancel_token.child_token(),
355                    );
356                    processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
357                }
358                _ => {}
359            }
360
361            Ok(processors)
362        })
363    }
364}