active_call/media/
engine.rs

1use super::{
2    INTERNAL_SAMPLERATE,
3    asr_processor::AsrProcessor,
4    denoiser::NoiseReducer,
5    processor::Processor,
6    track::{
7        Track, TrackPacketSender,
8        tts::{SynthesisHandle, TtsTrack},
9    },
10    vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13    CallOption, EouOption,
14    event::EventSender,
15    media::TrackId,
16    synthesis::{
17        AliyunTtsClient, DeepegramTtsClient, MsEdgeTtsClient, SynthesisClient, SynthesisOption,
18        SynthesisType, TencentCloudTtsBasicClient, TencentCloudTtsClient,
19    },
20    transcription::{
21        AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22        TranscriptionOption, TranscriptionType,
23    },
24};
25
26#[cfg(feature = "offline")]
27use crate::{
28    synthesis::SupertonicTtsClient,
29    transcription::SensevoiceAsrClientBuilder,
30};
31
32use anyhow::Result;
33use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
34use tokio::sync::mpsc;
35use tokio_util::sync::CancellationToken;
36use tracing::debug;
37
38pub type FnCreateVadProcessor = fn(
39    token: CancellationToken,
40    event_sender: EventSender,
41    option: VADOption,
42) -> Result<Box<dyn Processor>>;
43
44pub type FnCreateEouProcessor = fn(
45    token: CancellationToken,
46    event_sender: EventSender,
47    option: EouOption,
48) -> Result<Box<dyn Processor>>;
49
50pub type FnCreateAsrClient = Box<
51    dyn Fn(
52            TrackId,
53            CancellationToken,
54            TranscriptionOption,
55            EventSender,
56        ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
57        + Send
58        + Sync,
59>;
60pub type FnCreateTtsClient =
61    fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
62
63// Define hook types
64pub type CreateProcessorsHook = Box<
65    dyn Fn(
66            Arc<StreamEngine>,
67            &dyn Track,
68            CancellationToken,
69            EventSender,
70            TrackPacketSender,
71            CallOption,
72        ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
73        + Send
74        + Sync,
75>;
76
77pub struct StreamEngine {
78    vad_creators: HashMap<VadType, FnCreateVadProcessor>,
79    eou_creators: HashMap<String, FnCreateEouProcessor>,
80    asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
81    tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
82    create_processors_hook: Arc<CreateProcessorsHook>,
83}
84
85impl Default for StreamEngine {
86    fn default() -> Self {
87        let mut engine = Self::new();
88        engine.register_vad(VadType::Silero, VadProcessor::create);
89        engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
90
91        engine.register_asr(
92            TranscriptionType::TencentCloud,
93            Box::new(TencentCloudAsrClientBuilder::create),
94        );
95        engine.register_asr(
96            TranscriptionType::Aliyun,
97            Box::new(AliyunAsrClientBuilder::create),
98        );
99        
100        #[cfg(feature = "offline")]
101        engine.register_asr(
102            TranscriptionType::Sensevoice,
103            Box::new(SensevoiceAsrClientBuilder::create),
104        );
105        
106        engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
107        engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
108        engine.register_tts(
109            SynthesisType::Other("tencent_basic".to_string()),
110            TencentCloudTtsBasicClient::create,
111        );
112        engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
113        engine.register_tts(SynthesisType::MsEdge, MsEdgeTtsClient::create);
114        
115        #[cfg(feature = "offline")]
116        engine.register_tts(SynthesisType::Supertonic, SupertonicTtsClient::create);
117        
118        engine
119    }
120}
121
122impl StreamEngine {
123    pub fn new() -> Self {
124        Self {
125            vad_creators: HashMap::new(),
126            asr_creators: HashMap::new(),
127            tts_creators: HashMap::new(),
128            eou_creators: HashMap::new(),
129            create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
130        }
131    }
132
133    pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
134        self.vad_creators.insert(vad_type, creator);
135        self
136    }
137
138    pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
139        self.eou_creators.insert(name, creator);
140        self
141    }
142
143    pub fn register_asr(
144        &mut self,
145        asr_type: TranscriptionType,
146        creator: FnCreateAsrClient,
147    ) -> &mut Self {
148        self.asr_creators.insert(asr_type, creator);
149        self
150    }
151
152    pub fn register_tts(
153        &mut self,
154        tts_type: SynthesisType,
155        creator: FnCreateTtsClient,
156    ) -> &mut Self {
157        self.tts_creators.insert(tts_type, creator);
158        self
159    }
160
161    pub fn create_vad_processor(
162        &self,
163        token: CancellationToken,
164        event_sender: EventSender,
165        option: VADOption,
166    ) -> Result<Box<dyn Processor>> {
167        let creator = self.vad_creators.get(&option.r#type);
168        if let Some(creator) = creator {
169            creator(token, event_sender, option)
170        } else {
171            Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
172        }
173    }
174    pub fn create_eou_processor(
175        &self,
176        token: CancellationToken,
177        event_sender: EventSender,
178        option: EouOption,
179    ) -> Result<Box<dyn Processor>> {
180        let creator = self
181            .eou_creators
182            .get(&option.r#type.clone().unwrap_or_default());
183        if let Some(creator) = creator {
184            creator(token, event_sender, option)
185        } else {
186            Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
187        }
188    }
189
190    pub async fn create_asr_processor(
191        &self,
192        track_id: TrackId,
193        cancel_token: CancellationToken,
194        option: TranscriptionOption,
195        event_sender: EventSender,
196    ) -> Result<Box<dyn Processor>> {
197        let asr_client = match option.provider {
198            Some(ref provider) => {
199                let creator = self.asr_creators.get(&provider);
200                if let Some(creator) = creator {
201                    creator(track_id, cancel_token, option, event_sender).await?
202                } else {
203                    return Err(anyhow::anyhow!("ASR type not found: {}", provider));
204                }
205            }
206            None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
207        };
208        Ok(Box::new(AsrProcessor { asr_client }))
209    }
210
211    pub async fn create_tts_client(
212        &self,
213        streaming: bool,
214        tts_option: &SynthesisOption,
215    ) -> Result<Box<dyn SynthesisClient>> {
216        match tts_option.provider {
217            Some(ref provider) => {
218                let creator = self.tts_creators.get(&provider);
219                if let Some(creator) = creator {
220                    creator(streaming, tts_option)
221                } else {
222                    Err(anyhow::anyhow!("TTS type not found: {}", provider))
223                }
224            }
225            None => Err(anyhow::anyhow!(
226                "TTS type not found: {:?}",
227                tts_option.provider
228            )),
229        }
230    }
231
232    pub async fn create_processors(
233        engine: Arc<StreamEngine>,
234        track: &dyn Track,
235        cancel_token: CancellationToken,
236        event_sender: EventSender,
237        packet_sender: TrackPacketSender,
238        option: &CallOption,
239    ) -> Result<Vec<Box<dyn Processor>>> {
240        (engine.clone().create_processors_hook)(
241            engine,
242            track,
243            cancel_token,
244            event_sender,
245            packet_sender,
246            option.clone(),
247        )
248        .await
249    }
250
251    pub async fn create_tts_track(
252        engine: Arc<StreamEngine>,
253        cancel_token: CancellationToken,
254        session_id: String,
255        track_id: TrackId,
256        ssrc: u32,
257        play_id: Option<String>,
258        streaming: bool,
259        tts_option: &SynthesisOption,
260    ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
261        let (tx, rx) = mpsc::unbounded_channel();
262        let new_handle = SynthesisHandle::new(tx, play_id.clone());
263        let tts_client = engine.create_tts_client(streaming, tts_option).await?;
264        let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
265            .with_ssrc(ssrc)
266            .with_cancel_token(cancel_token);
267        Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
268    }
269
270    pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
271        self.create_processors_hook = Arc::new(Box::new(hook_fn));
272        self
273    }
274
275    fn default_create_procesors_hook(
276        engine: Arc<StreamEngine>,
277        track: &dyn Track,
278        cancel_token: CancellationToken,
279        event_sender: EventSender,
280        packet_sender: TrackPacketSender,
281        option: CallOption,
282    ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
283        let track_id = track.id().clone();
284        Box::pin(async move {
285            let mut processors = vec![];
286            debug!(track_id = %track_id, "Creating processors for track");
287
288            if let Some(realtime_option) = option.realtime {
289                debug!(track_id = %track_id, "Adding RealtimeProcessor");
290                let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
291                    track_id.clone(),
292                    cancel_token.child_token(),
293                    event_sender.clone(),
294                    packet_sender.clone(),
295                    realtime_option,
296                )?;
297                processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
298                // In realtime mode, we usually don't need separate VAD or ASR processors
299                // as they are handled by the realtime service (OpenAI/Azure)
300                return Ok(processors);
301            }
302
303            match option.denoise {
304                Some(true) => {
305                    debug!(track_id = %track_id, "Adding NoiseReducer processor");
306                    let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
307                    processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
308                }
309                _ => {}
310            }
311            match option.vad {
312                Some(mut option) => {
313                    debug!(track_id = %track_id, "Adding VadProcessor processor type={:?}", option.r#type);
314                    option.samplerate = INTERNAL_SAMPLERATE;
315                    let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
316                        cancel_token.child_token(),
317                        event_sender.clone(),
318                        option.to_owned(),
319                    )?;
320                    processors.push(vad_processor);
321                }
322                None => {}
323            }
324            match option.asr {
325                Some(mut option) => {
326                    debug!(track_id = %track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
327                    option.samplerate = Some(INTERNAL_SAMPLERATE);
328                    let asr_processor = engine
329                        .create_asr_processor(
330                            track_id.clone(),
331                            cancel_token.child_token(),
332                            option.to_owned(),
333                            event_sender.clone(),
334                        )
335                        .await?;
336                    processors.push(asr_processor);
337                }
338                None => {}
339            }
340            match option.eou {
341                Some(ref option) => {
342                    let eou_processor = engine.create_eou_processor(
343                        cancel_token.child_token(),
344                        event_sender.clone(),
345                        option.to_owned(),
346                    )?;
347                    processors.push(eou_processor);
348                }
349                None => {}
350            }
351            match option.inactivity_timeout {
352                Some(timeout_secs) if timeout_secs > 0 => {
353                    let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
354                        track_id.clone(),
355                        std::time::Duration::from_secs(timeout_secs),
356                        event_sender.clone(),
357                        cancel_token.child_token(),
358                    );
359                    processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
360                }
361                _ => {}
362            }
363
364            Ok(processors)
365        })
366    }
367}