1use super::{
2 INTERNAL_SAMPLERATE,
3 asr_processor::AsrProcessor,
4 denoiser::NoiseReducer,
5 processor::Processor,
6 track::{
7 Track, TrackPacketSender,
8 tts::{SynthesisHandle, TtsTrack},
9 },
10 vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13 CallOption, EouOption,
14 event::EventSender,
15 media::TrackId,
16 synthesis::{
17 AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
18 TencentCloudTtsBasicClient, TencentCloudTtsClient,
19 },
20 transcription::{
21 AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22 TranscriptionOption, TranscriptionType,
23 },
24};
25
26#[cfg(feature = "offline")]
27use crate::{
28 synthesis::SupertonicTtsClient,
29 transcription::SensevoiceAsrClientBuilder,
30};
31
32use anyhow::Result;
33use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
34use tokio::sync::mpsc;
35use tokio_util::sync::CancellationToken;
36use tracing::debug;
37
38pub type FnCreateVadProcessor = fn(
39 token: CancellationToken,
40 event_sender: EventSender,
41 option: VADOption,
42) -> Result<Box<dyn Processor>>;
43
44pub type FnCreateEouProcessor = fn(
45 token: CancellationToken,
46 event_sender: EventSender,
47 option: EouOption,
48) -> Result<Box<dyn Processor>>;
49
50pub type FnCreateAsrClient = Box<
51 dyn Fn(
52 TrackId,
53 CancellationToken,
54 TranscriptionOption,
55 EventSender,
56 ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
57 + Send
58 + Sync,
59>;
60pub type FnCreateTtsClient =
61 fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
62
63pub type CreateProcessorsHook = Box<
65 dyn Fn(
66 Arc<StreamEngine>,
67 &dyn Track,
68 CancellationToken,
69 EventSender,
70 TrackPacketSender,
71 CallOption,
72 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
73 + Send
74 + Sync,
75>;
76
77pub struct StreamEngine {
78 vad_creators: HashMap<VadType, FnCreateVadProcessor>,
79 eou_creators: HashMap<String, FnCreateEouProcessor>,
80 asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
81 tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
82 create_processors_hook: Arc<CreateProcessorsHook>,
83}
84
85impl Default for StreamEngine {
86 fn default() -> Self {
87 let mut engine = Self::new();
88 engine.register_vad(VadType::Silero, VadProcessor::create);
89 engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
90
91 engine.register_asr(
92 TranscriptionType::TencentCloud,
93 Box::new(TencentCloudAsrClientBuilder::create),
94 );
95 engine.register_asr(
96 TranscriptionType::Aliyun,
97 Box::new(AliyunAsrClientBuilder::create),
98 );
99
100 #[cfg(feature = "offline")]
101 engine.register_asr(
102 TranscriptionType::Sensevoice,
103 Box::new(SensevoiceAsrClientBuilder::create),
104 );
105
106 engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
107 engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
108 engine.register_tts(
109 SynthesisType::Other("tencent_basic".to_string()),
110 TencentCloudTtsBasicClient::create,
111 );
112 engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
113
114 #[cfg(feature = "offline")]
115 engine.register_tts(SynthesisType::Supertonic, SupertonicTtsClient::create);
116
117 engine
118 }
119}
120
121impl StreamEngine {
122 pub fn new() -> Self {
123 Self {
124 vad_creators: HashMap::new(),
125 asr_creators: HashMap::new(),
126 tts_creators: HashMap::new(),
127 eou_creators: HashMap::new(),
128 create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
129 }
130 }
131
132 pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
133 self.vad_creators.insert(vad_type, creator);
134 self
135 }
136
137 pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
138 self.eou_creators.insert(name, creator);
139 self
140 }
141
142 pub fn register_asr(
143 &mut self,
144 asr_type: TranscriptionType,
145 creator: FnCreateAsrClient,
146 ) -> &mut Self {
147 self.asr_creators.insert(asr_type, creator);
148 self
149 }
150
151 pub fn register_tts(
152 &mut self,
153 tts_type: SynthesisType,
154 creator: FnCreateTtsClient,
155 ) -> &mut Self {
156 self.tts_creators.insert(tts_type, creator);
157 self
158 }
159
160 pub fn create_vad_processor(
161 &self,
162 token: CancellationToken,
163 event_sender: EventSender,
164 option: VADOption,
165 ) -> Result<Box<dyn Processor>> {
166 let creator = self.vad_creators.get(&option.r#type);
167 if let Some(creator) = creator {
168 creator(token, event_sender, option)
169 } else {
170 Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
171 }
172 }
173 pub fn create_eou_processor(
174 &self,
175 token: CancellationToken,
176 event_sender: EventSender,
177 option: EouOption,
178 ) -> Result<Box<dyn Processor>> {
179 let creator = self
180 .eou_creators
181 .get(&option.r#type.clone().unwrap_or_default());
182 if let Some(creator) = creator {
183 creator(token, event_sender, option)
184 } else {
185 Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
186 }
187 }
188
189 pub async fn create_asr_processor(
190 &self,
191 track_id: TrackId,
192 cancel_token: CancellationToken,
193 option: TranscriptionOption,
194 event_sender: EventSender,
195 ) -> Result<Box<dyn Processor>> {
196 let asr_client = match option.provider {
197 Some(ref provider) => {
198 let creator = self.asr_creators.get(&provider);
199 if let Some(creator) = creator {
200 creator(track_id, cancel_token, option, event_sender).await?
201 } else {
202 return Err(anyhow::anyhow!("ASR type not found: {}", provider));
203 }
204 }
205 None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
206 };
207 Ok(Box::new(AsrProcessor { asr_client }))
208 }
209
210 pub async fn create_tts_client(
211 &self,
212 streaming: bool,
213 tts_option: &SynthesisOption,
214 ) -> Result<Box<dyn SynthesisClient>> {
215 match tts_option.provider {
216 Some(ref provider) => {
217 let creator = self.tts_creators.get(&provider);
218 if let Some(creator) = creator {
219 creator(streaming, tts_option)
220 } else {
221 Err(anyhow::anyhow!("TTS type not found: {}", provider))
222 }
223 }
224 None => Err(anyhow::anyhow!(
225 "TTS type not found: {:?}",
226 tts_option.provider
227 )),
228 }
229 }
230
231 pub async fn create_processors(
232 engine: Arc<StreamEngine>,
233 track: &dyn Track,
234 cancel_token: CancellationToken,
235 event_sender: EventSender,
236 packet_sender: TrackPacketSender,
237 option: &CallOption,
238 ) -> Result<Vec<Box<dyn Processor>>> {
239 (engine.clone().create_processors_hook)(
240 engine,
241 track,
242 cancel_token,
243 event_sender,
244 packet_sender,
245 option.clone(),
246 )
247 .await
248 }
249
250 pub async fn create_tts_track(
251 engine: Arc<StreamEngine>,
252 cancel_token: CancellationToken,
253 session_id: String,
254 track_id: TrackId,
255 ssrc: u32,
256 play_id: Option<String>,
257 streaming: bool,
258 tts_option: &SynthesisOption,
259 ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
260 let (tx, rx) = mpsc::unbounded_channel();
261 let new_handle = SynthesisHandle::new(tx, play_id.clone());
262 let tts_client = engine.create_tts_client(streaming, tts_option).await?;
263 let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
264 .with_ssrc(ssrc)
265 .with_cancel_token(cancel_token);
266 Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
267 }
268
269 pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
270 self.create_processors_hook = Arc::new(Box::new(hook_fn));
271 self
272 }
273
274 fn default_create_procesors_hook(
275 engine: Arc<StreamEngine>,
276 track: &dyn Track,
277 cancel_token: CancellationToken,
278 event_sender: EventSender,
279 packet_sender: TrackPacketSender,
280 option: CallOption,
281 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
282 let track_id = track.id().clone();
283 Box::pin(async move {
284 let mut processors = vec![];
285 debug!(track_id = %track_id, "Creating processors for track");
286
287 if let Some(realtime_option) = option.realtime {
288 debug!(track_id = %track_id, "Adding RealtimeProcessor");
289 let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
290 track_id.clone(),
291 cancel_token.child_token(),
292 event_sender.clone(),
293 packet_sender.clone(),
294 realtime_option,
295 )?;
296 processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
297 return Ok(processors);
300 }
301
302 match option.denoise {
303 Some(true) => {
304 debug!(track_id = %track_id, "Adding NoiseReducer processor");
305 let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
306 processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
307 }
308 _ => {}
309 }
310 match option.vad {
311 Some(mut option) => {
312 debug!(track_id = %track_id, "Adding VadProcessor processor type={:?}", option.r#type);
313 option.samplerate = INTERNAL_SAMPLERATE;
314 let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
315 cancel_token.child_token(),
316 event_sender.clone(),
317 option.to_owned(),
318 )?;
319 processors.push(vad_processor);
320 }
321 None => {}
322 }
323 match option.asr {
324 Some(mut option) => {
325 debug!(track_id = %track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
326 option.samplerate = Some(INTERNAL_SAMPLERATE);
327 let asr_processor = engine
328 .create_asr_processor(
329 track_id.clone(),
330 cancel_token.child_token(),
331 option.to_owned(),
332 event_sender.clone(),
333 )
334 .await?;
335 processors.push(asr_processor);
336 }
337 None => {}
338 }
339 match option.eou {
340 Some(ref option) => {
341 let eou_processor = engine.create_eou_processor(
342 cancel_token.child_token(),
343 event_sender.clone(),
344 option.to_owned(),
345 )?;
346 processors.push(eou_processor);
347 }
348 None => {}
349 }
350 match option.inactivity_timeout {
351 Some(timeout_secs) if timeout_secs > 0 => {
352 let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
353 track_id.clone(),
354 std::time::Duration::from_secs(timeout_secs),
355 event_sender.clone(),
356 cancel_token.child_token(),
357 );
358 processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
359 }
360 _ => {}
361 }
362
363 Ok(processors)
364 })
365 }
366}