1use super::{
2 INTERNAL_SAMPLERATE,
3 asr_processor::AsrProcessor,
4 denoiser::NoiseReducer,
5 processor::Processor,
6 track::{
7 Track, TrackPacketSender,
8 tts::{SynthesisHandle, TtsTrack},
9 },
10 vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13 CallOption, EouOption,
14 event::EventSender,
15 media::TrackId,
16 synthesis::{
17 AliyunTtsClient, DeepegramTtsClient, MsEdgeTtsClient, SynthesisClient, SynthesisOption,
18 SynthesisType, TencentCloudTtsBasicClient, TencentCloudTtsClient,
19 },
20 transcription::{
21 AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22 TranscriptionOption, TranscriptionType,
23 },
24};
25
26#[cfg(feature = "offline")]
27use crate::{
28 synthesis::SupertonicTtsClient,
29 transcription::SensevoiceAsrClientBuilder,
30};
31
32use anyhow::Result;
33use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
34use tokio::sync::mpsc;
35use tokio_util::sync::CancellationToken;
36use tracing::debug;
37
38pub type FnCreateVadProcessor = fn(
39 token: CancellationToken,
40 event_sender: EventSender,
41 option: VADOption,
42) -> Result<Box<dyn Processor>>;
43
44pub type FnCreateEouProcessor = fn(
45 token: CancellationToken,
46 event_sender: EventSender,
47 option: EouOption,
48) -> Result<Box<dyn Processor>>;
49
50pub type FnCreateAsrClient = Box<
51 dyn Fn(
52 TrackId,
53 CancellationToken,
54 TranscriptionOption,
55 EventSender,
56 ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
57 + Send
58 + Sync,
59>;
60pub type FnCreateTtsClient =
61 fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
62
63pub type CreateProcessorsHook = Box<
65 dyn Fn(
66 Arc<StreamEngine>,
67 &dyn Track,
68 CancellationToken,
69 EventSender,
70 TrackPacketSender,
71 CallOption,
72 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
73 + Send
74 + Sync,
75>;
76
77pub struct StreamEngine {
78 vad_creators: HashMap<VadType, FnCreateVadProcessor>,
79 eou_creators: HashMap<String, FnCreateEouProcessor>,
80 asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
81 tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
82 create_processors_hook: Arc<CreateProcessorsHook>,
83}
84
85impl Default for StreamEngine {
86 fn default() -> Self {
87 let mut engine = Self::new();
88 engine.register_vad(VadType::Silero, VadProcessor::create);
89 engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
90
91 engine.register_asr(
92 TranscriptionType::TencentCloud,
93 Box::new(TencentCloudAsrClientBuilder::create),
94 );
95 engine.register_asr(
96 TranscriptionType::Aliyun,
97 Box::new(AliyunAsrClientBuilder::create),
98 );
99
100 #[cfg(feature = "offline")]
101 engine.register_asr(
102 TranscriptionType::Sensevoice,
103 Box::new(SensevoiceAsrClientBuilder::create),
104 );
105
106 engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
107 engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
108 engine.register_tts(
109 SynthesisType::Other("tencent_basic".to_string()),
110 TencentCloudTtsBasicClient::create,
111 );
112 engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
113 engine.register_tts(SynthesisType::MsEdge, MsEdgeTtsClient::create);
114
115 #[cfg(feature = "offline")]
116 engine.register_tts(SynthesisType::Supertonic, SupertonicTtsClient::create);
117
118 engine
119 }
120}
121
122impl StreamEngine {
123 pub fn new() -> Self {
124 Self {
125 vad_creators: HashMap::new(),
126 asr_creators: HashMap::new(),
127 tts_creators: HashMap::new(),
128 eou_creators: HashMap::new(),
129 create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
130 }
131 }
132
133 pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
134 self.vad_creators.insert(vad_type, creator);
135 self
136 }
137
138 pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
139 self.eou_creators.insert(name, creator);
140 self
141 }
142
143 pub fn register_asr(
144 &mut self,
145 asr_type: TranscriptionType,
146 creator: FnCreateAsrClient,
147 ) -> &mut Self {
148 self.asr_creators.insert(asr_type, creator);
149 self
150 }
151
152 pub fn register_tts(
153 &mut self,
154 tts_type: SynthesisType,
155 creator: FnCreateTtsClient,
156 ) -> &mut Self {
157 self.tts_creators.insert(tts_type, creator);
158 self
159 }
160
161 pub fn create_vad_processor(
162 &self,
163 token: CancellationToken,
164 event_sender: EventSender,
165 option: VADOption,
166 ) -> Result<Box<dyn Processor>> {
167 let creator = self.vad_creators.get(&option.r#type);
168 if let Some(creator) = creator {
169 creator(token, event_sender, option)
170 } else {
171 Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
172 }
173 }
174 pub fn create_eou_processor(
175 &self,
176 token: CancellationToken,
177 event_sender: EventSender,
178 option: EouOption,
179 ) -> Result<Box<dyn Processor>> {
180 let creator = self
181 .eou_creators
182 .get(&option.r#type.clone().unwrap_or_default());
183 if let Some(creator) = creator {
184 creator(token, event_sender, option)
185 } else {
186 Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
187 }
188 }
189
190 pub async fn create_asr_processor(
191 &self,
192 track_id: TrackId,
193 cancel_token: CancellationToken,
194 option: TranscriptionOption,
195 event_sender: EventSender,
196 ) -> Result<Box<dyn Processor>> {
197 let asr_client = match option.provider {
198 Some(ref provider) => {
199 let creator = self.asr_creators.get(&provider);
200 if let Some(creator) = creator {
201 creator(track_id, cancel_token, option, event_sender).await?
202 } else {
203 return Err(anyhow::anyhow!("ASR type not found: {}", provider));
204 }
205 }
206 None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
207 };
208 Ok(Box::new(AsrProcessor { asr_client }))
209 }
210
211 pub async fn create_tts_client(
212 &self,
213 streaming: bool,
214 tts_option: &SynthesisOption,
215 ) -> Result<Box<dyn SynthesisClient>> {
216 match tts_option.provider {
217 Some(ref provider) => {
218 let creator = self.tts_creators.get(&provider);
219 if let Some(creator) = creator {
220 creator(streaming, tts_option)
221 } else {
222 Err(anyhow::anyhow!("TTS type not found: {}", provider))
223 }
224 }
225 None => Err(anyhow::anyhow!(
226 "TTS type not found: {:?}",
227 tts_option.provider
228 )),
229 }
230 }
231
232 pub async fn create_processors(
233 engine: Arc<StreamEngine>,
234 track: &dyn Track,
235 cancel_token: CancellationToken,
236 event_sender: EventSender,
237 packet_sender: TrackPacketSender,
238 option: &CallOption,
239 ) -> Result<Vec<Box<dyn Processor>>> {
240 (engine.clone().create_processors_hook)(
241 engine,
242 track,
243 cancel_token,
244 event_sender,
245 packet_sender,
246 option.clone(),
247 )
248 .await
249 }
250
251 pub async fn create_tts_track(
252 engine: Arc<StreamEngine>,
253 cancel_token: CancellationToken,
254 session_id: String,
255 track_id: TrackId,
256 ssrc: u32,
257 play_id: Option<String>,
258 streaming: bool,
259 tts_option: &SynthesisOption,
260 ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
261 let (tx, rx) = mpsc::unbounded_channel();
262 let new_handle = SynthesisHandle::new(tx, play_id.clone());
263 let tts_client = engine.create_tts_client(streaming, tts_option).await?;
264 let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
265 .with_ssrc(ssrc)
266 .with_cancel_token(cancel_token);
267 Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
268 }
269
270 pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
271 self.create_processors_hook = Arc::new(Box::new(hook_fn));
272 self
273 }
274
275 fn default_create_procesors_hook(
276 engine: Arc<StreamEngine>,
277 track: &dyn Track,
278 cancel_token: CancellationToken,
279 event_sender: EventSender,
280 packet_sender: TrackPacketSender,
281 option: CallOption,
282 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
283 let track_id = track.id().clone();
284 Box::pin(async move {
285 let mut processors = vec![];
286 debug!(track_id = %track_id, "Creating processors for track");
287
288 if let Some(realtime_option) = option.realtime {
289 debug!(track_id = %track_id, "Adding RealtimeProcessor");
290 let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
291 track_id.clone(),
292 cancel_token.child_token(),
293 event_sender.clone(),
294 packet_sender.clone(),
295 realtime_option,
296 )?;
297 processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
298 return Ok(processors);
301 }
302
303 match option.denoise {
304 Some(true) => {
305 debug!(track_id = %track_id, "Adding NoiseReducer processor");
306 let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
307 processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
308 }
309 _ => {}
310 }
311 match option.vad {
312 Some(mut option) => {
313 debug!(track_id = %track_id, "Adding VadProcessor processor type={:?}", option.r#type);
314 option.samplerate = INTERNAL_SAMPLERATE;
315 let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
316 cancel_token.child_token(),
317 event_sender.clone(),
318 option.to_owned(),
319 )?;
320 processors.push(vad_processor);
321 }
322 None => {}
323 }
324 match option.asr {
325 Some(mut option) => {
326 debug!(track_id = %track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
327 option.samplerate = Some(INTERNAL_SAMPLERATE);
328 let asr_processor = engine
329 .create_asr_processor(
330 track_id.clone(),
331 cancel_token.child_token(),
332 option.to_owned(),
333 event_sender.clone(),
334 )
335 .await?;
336 processors.push(asr_processor);
337 }
338 None => {}
339 }
340 match option.eou {
341 Some(ref option) => {
342 let eou_processor = engine.create_eou_processor(
343 cancel_token.child_token(),
344 event_sender.clone(),
345 option.to_owned(),
346 )?;
347 processors.push(eou_processor);
348 }
349 None => {}
350 }
351 match option.inactivity_timeout {
352 Some(timeout_secs) if timeout_secs > 0 => {
353 let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
354 track_id.clone(),
355 std::time::Duration::from_secs(timeout_secs),
356 event_sender.clone(),
357 cancel_token.child_token(),
358 );
359 processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
360 }
361 _ => {}
362 }
363
364 Ok(processors)
365 })
366 }
367}