active_call/media/
engine.rs

1use super::{
2    INTERNAL_SAMPLERATE,
3    asr_processor::AsrProcessor,
4    denoiser::NoiseReducer,
5    processor::Processor,
6    track::{
7        Track, TrackPacketSender,
8        tts::{SynthesisHandle, TtsTrack},
9    },
10    vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13    CallOption, EouOption,
14    event::EventSender,
15    media::TrackId,
16    synthesis::{
17        AliyunTtsClient, DeepegramTtsClient, MsEdgeTtsClient, SynthesisClient, SynthesisOption,
18        SynthesisType, TencentCloudTtsBasicClient, TencentCloudTtsClient,
19    },
20    transcription::{
21        AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22        TranscriptionOption, TranscriptionType,
23    },
24};
25
26#[cfg(feature = "offline")]
27use crate::{synthesis::SupertonicTtsClient, transcription::SensevoiceAsrClientBuilder};
28
29use anyhow::Result;
30use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
31use tokio::sync::mpsc;
32use tokio_util::sync::CancellationToken;
33use tracing::debug;
34
35pub type FnCreateVadProcessor = fn(
36    token: CancellationToken,
37    event_sender: EventSender,
38    option: VADOption,
39) -> Result<Box<dyn Processor>>;
40
41pub type FnCreateEouProcessor = fn(
42    token: CancellationToken,
43    event_sender: EventSender,
44    option: EouOption,
45) -> Result<Box<dyn Processor>>;
46
47pub type FnCreateAsrClient = Box<
48    dyn Fn(
49            TrackId,
50            CancellationToken,
51            TranscriptionOption,
52            EventSender,
53        ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
54        + Send
55        + Sync,
56>;
57pub type FnCreateTtsClient =
58    fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
59
60// Define hook types
61pub type CreateProcessorsHook = Box<
62    dyn Fn(
63            Arc<StreamEngine>,
64            &dyn Track,
65            CancellationToken,
66            EventSender,
67            TrackPacketSender,
68            CallOption,
69        ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
70        + Send
71        + Sync,
72>;
73
74pub struct StreamEngine {
75    vad_creators: HashMap<VadType, FnCreateVadProcessor>,
76    eou_creators: HashMap<String, FnCreateEouProcessor>,
77    asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
78    tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
79    create_processors_hook: Arc<CreateProcessorsHook>,
80}
81
82impl Default for StreamEngine {
83    fn default() -> Self {
84        let mut engine = Self::new();
85        engine.register_vad(VadType::Silero, VadProcessor::create);
86        engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
87
88        engine.register_asr(
89            TranscriptionType::TencentCloud,
90            Box::new(TencentCloudAsrClientBuilder::create),
91        );
92        engine.register_asr(
93            TranscriptionType::Aliyun,
94            Box::new(AliyunAsrClientBuilder::create),
95        );
96
97        #[cfg(feature = "offline")]
98        engine.register_asr(
99            TranscriptionType::Sensevoice,
100            Box::new(SensevoiceAsrClientBuilder::create),
101        );
102
103        engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
104        engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
105        engine.register_tts(
106            SynthesisType::Other("tencent_basic".to_string()),
107            TencentCloudTtsBasicClient::create,
108        );
109        engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
110        engine.register_tts(SynthesisType::MsEdge, MsEdgeTtsClient::create);
111
112        #[cfg(feature = "offline")]
113        engine.register_tts(SynthesisType::Supertonic, SupertonicTtsClient::create);
114
115        engine
116    }
117}
118
119impl StreamEngine {
120    pub fn new() -> Self {
121        Self {
122            vad_creators: HashMap::new(),
123            asr_creators: HashMap::new(),
124            tts_creators: HashMap::new(),
125            eou_creators: HashMap::new(),
126            create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
127        }
128    }
129
130    pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
131        self.vad_creators.insert(vad_type, creator);
132        self
133    }
134
135    pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
136        self.eou_creators.insert(name, creator);
137        self
138    }
139
140    pub fn register_asr(
141        &mut self,
142        asr_type: TranscriptionType,
143        creator: FnCreateAsrClient,
144    ) -> &mut Self {
145        self.asr_creators.insert(asr_type, creator);
146        self
147    }
148
149    pub fn register_tts(
150        &mut self,
151        tts_type: SynthesisType,
152        creator: FnCreateTtsClient,
153    ) -> &mut Self {
154        self.tts_creators.insert(tts_type, creator);
155        self
156    }
157
158    pub fn create_vad_processor(
159        &self,
160        token: CancellationToken,
161        event_sender: EventSender,
162        option: VADOption,
163    ) -> Result<Box<dyn Processor>> {
164        let creator = self.vad_creators.get(&option.r#type);
165        if let Some(creator) = creator {
166            creator(token, event_sender, option)
167        } else {
168            Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
169        }
170    }
171    pub fn create_eou_processor(
172        &self,
173        token: CancellationToken,
174        event_sender: EventSender,
175        option: EouOption,
176    ) -> Result<Box<dyn Processor>> {
177        let creator = self
178            .eou_creators
179            .get(&option.r#type.clone().unwrap_or_default());
180        if let Some(creator) = creator {
181            creator(token, event_sender, option)
182        } else {
183            Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
184        }
185    }
186
187    pub async fn create_asr_processor(
188        &self,
189        track_id: TrackId,
190        cancel_token: CancellationToken,
191        option: TranscriptionOption,
192        event_sender: EventSender,
193    ) -> Result<Box<dyn Processor>> {
194        let asr_client = match option.provider {
195            Some(ref provider) => {
196                let creator = self.asr_creators.get(&provider);
197                if let Some(creator) = creator {
198                    creator(track_id, cancel_token, option, event_sender).await?
199                } else {
200                    return Err(anyhow::anyhow!("ASR type not found: {}", provider));
201                }
202            }
203            None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
204        };
205        Ok(Box::new(AsrProcessor { asr_client }))
206    }
207
208    pub async fn create_tts_client(
209        &self,
210        streaming: bool,
211        tts_option: &SynthesisOption,
212    ) -> Result<Box<dyn SynthesisClient>> {
213        match tts_option.provider {
214            Some(ref provider) => {
215                let creator = self.tts_creators.get(&provider);
216                if let Some(creator) = creator {
217                    creator(streaming, tts_option)
218                } else {
219                    Err(anyhow::anyhow!("TTS type not found: {}", provider))
220                }
221            }
222            None => Err(anyhow::anyhow!(
223                "TTS type not found: {:?}",
224                tts_option.provider
225            )),
226        }
227    }
228
229    pub async fn create_processors(
230        engine: Arc<StreamEngine>,
231        track: &dyn Track,
232        cancel_token: CancellationToken,
233        event_sender: EventSender,
234        packet_sender: TrackPacketSender,
235        option: &CallOption,
236    ) -> Result<Vec<Box<dyn Processor>>> {
237        (engine.clone().create_processors_hook)(
238            engine,
239            track,
240            cancel_token,
241            event_sender,
242            packet_sender,
243            option.clone(),
244        )
245        .await
246    }
247
248    pub async fn create_tts_track(
249        engine: Arc<StreamEngine>,
250        cancel_token: CancellationToken,
251        session_id: String,
252        track_id: TrackId,
253        ssrc: u32,
254        play_id: Option<String>,
255        streaming: bool,
256        tts_option: &SynthesisOption,
257    ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
258        let (tx, rx) = mpsc::unbounded_channel();
259        let new_handle = SynthesisHandle::new(tx, play_id.clone(), ssrc);
260        let tts_client = engine.create_tts_client(streaming, tts_option).await?;
261        let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
262            .with_ssrc(ssrc)
263            .with_cancel_token(cancel_token);
264        Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
265    }
266
267    pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
268        self.create_processors_hook = Arc::new(Box::new(hook_fn));
269        self
270    }
271
272    fn default_create_procesors_hook(
273        engine: Arc<StreamEngine>,
274        track: &dyn Track,
275        cancel_token: CancellationToken,
276        event_sender: EventSender,
277        packet_sender: TrackPacketSender,
278        option: CallOption,
279    ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
280        let track_id = track.id().clone();
281        Box::pin(async move {
282            let mut processors = vec![];
283            debug!(track_id = %track_id, "Creating processors for track");
284
285            if let Some(realtime_option) = option.realtime {
286                debug!(track_id = %track_id, "Adding RealtimeProcessor");
287                let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
288                    track_id.clone(),
289                    cancel_token.child_token(),
290                    event_sender.clone(),
291                    packet_sender.clone(),
292                    realtime_option,
293                )?;
294                processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
295                // In realtime mode, we usually don't need separate VAD or ASR processors
296                // as they are handled by the realtime service (OpenAI/Azure)
297                return Ok(processors);
298            }
299
300            match option.denoise {
301                Some(true) => {
302                    debug!(track_id = %track_id, "Adding NoiseReducer processor");
303                    let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
304                    processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
305                }
306                _ => {}
307            }
308            match option.vad {
309                Some(mut option) => {
310                    debug!(track_id = %track_id, "Adding VadProcessor processor type={:?}", option.r#type);
311                    option.samplerate = INTERNAL_SAMPLERATE;
312                    let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
313                        cancel_token.child_token(),
314                        event_sender.clone(),
315                        option.to_owned(),
316                    )?;
317                    processors.push(vad_processor);
318                }
319                None => {}
320            }
321            match option.asr {
322                Some(mut option) => {
323                    debug!(track_id = %track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
324                    option.samplerate = Some(INTERNAL_SAMPLERATE);
325                    let asr_processor = engine
326                        .create_asr_processor(
327                            track_id.clone(),
328                            cancel_token.child_token(),
329                            option.to_owned(),
330                            event_sender.clone(),
331                        )
332                        .await?;
333                    processors.push(asr_processor);
334                }
335                None => {}
336            }
337            match option.eou {
338                Some(ref option) => {
339                    let eou_processor = engine.create_eou_processor(
340                        cancel_token.child_token(),
341                        event_sender.clone(),
342                        option.to_owned(),
343                    )?;
344                    processors.push(eou_processor);
345                }
346                None => {}
347            }
348            match option.inactivity_timeout {
349                Some(timeout_secs) if timeout_secs > 0 => {
350                    let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
351                        track_id.clone(),
352                        std::time::Duration::from_secs(timeout_secs),
353                        event_sender.clone(),
354                        cancel_token.child_token(),
355                    );
356                    processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
357                }
358                _ => {}
359            }
360
361            Ok(processors)
362        })
363    }
364}