1use super::{
2 INTERNAL_SAMPLERATE,
3 asr_processor::AsrProcessor,
4 denoiser::NoiseReducer,
5 processor::Processor,
6 track::{
7 Track, TrackPacketSender,
8 tts::{SynthesisHandle, TtsTrack},
9 },
10 vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13 CallOption, EouOption,
14 event::EventSender,
15 media::TrackId,
16 synthesis::{
17 AliyunTtsClient, DeepegramTtsClient, MsEdgeTtsClient, SynthesisClient, SynthesisOption,
18 SynthesisType, TencentCloudTtsBasicClient, TencentCloudTtsClient,
19 },
20 transcription::{
21 AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22 TranscriptionOption, TranscriptionType,
23 },
24};
25
26#[cfg(feature = "offline")]
27use crate::{synthesis::SupertonicTtsClient, transcription::SensevoiceAsrClientBuilder};
28
29use anyhow::Result;
30use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
31use tokio::sync::mpsc;
32use tokio_util::sync::CancellationToken;
33use tracing::debug;
34
35pub type FnCreateVadProcessor = fn(
36 token: CancellationToken,
37 event_sender: EventSender,
38 option: VADOption,
39) -> Result<Box<dyn Processor>>;
40
41pub type FnCreateEouProcessor = fn(
42 token: CancellationToken,
43 event_sender: EventSender,
44 option: EouOption,
45) -> Result<Box<dyn Processor>>;
46
47pub type FnCreateAsrClient = Box<
48 dyn Fn(
49 TrackId,
50 CancellationToken,
51 TranscriptionOption,
52 EventSender,
53 ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
54 + Send
55 + Sync,
56>;
57pub type FnCreateTtsClient =
58 fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
59
60pub type CreateProcessorsHook = Box<
62 dyn Fn(
63 Arc<StreamEngine>,
64 &dyn Track,
65 CancellationToken,
66 EventSender,
67 TrackPacketSender,
68 CallOption,
69 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
70 + Send
71 + Sync,
72>;
73
74pub struct StreamEngine {
75 vad_creators: HashMap<VadType, FnCreateVadProcessor>,
76 eou_creators: HashMap<String, FnCreateEouProcessor>,
77 asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
78 tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
79 create_processors_hook: Arc<CreateProcessorsHook>,
80}
81
82impl Default for StreamEngine {
83 fn default() -> Self {
84 let mut engine = Self::new();
85 engine.register_vad(VadType::Silero, VadProcessor::create);
86 engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
87
88 engine.register_asr(
89 TranscriptionType::TencentCloud,
90 Box::new(TencentCloudAsrClientBuilder::create),
91 );
92 engine.register_asr(
93 TranscriptionType::Aliyun,
94 Box::new(AliyunAsrClientBuilder::create),
95 );
96
97 #[cfg(feature = "offline")]
98 engine.register_asr(
99 TranscriptionType::Sensevoice,
100 Box::new(SensevoiceAsrClientBuilder::create),
101 );
102
103 engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
104 engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
105 engine.register_tts(
106 SynthesisType::Other("tencent_basic".to_string()),
107 TencentCloudTtsBasicClient::create,
108 );
109 engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
110 engine.register_tts(SynthesisType::MsEdge, MsEdgeTtsClient::create);
111
112 #[cfg(feature = "offline")]
113 engine.register_tts(SynthesisType::Supertonic, SupertonicTtsClient::create);
114
115 engine
116 }
117}
118
119impl StreamEngine {
120 pub fn new() -> Self {
121 Self {
122 vad_creators: HashMap::new(),
123 asr_creators: HashMap::new(),
124 tts_creators: HashMap::new(),
125 eou_creators: HashMap::new(),
126 create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
127 }
128 }
129
130 pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
131 self.vad_creators.insert(vad_type, creator);
132 self
133 }
134
135 pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
136 self.eou_creators.insert(name, creator);
137 self
138 }
139
140 pub fn register_asr(
141 &mut self,
142 asr_type: TranscriptionType,
143 creator: FnCreateAsrClient,
144 ) -> &mut Self {
145 self.asr_creators.insert(asr_type, creator);
146 self
147 }
148
149 pub fn register_tts(
150 &mut self,
151 tts_type: SynthesisType,
152 creator: FnCreateTtsClient,
153 ) -> &mut Self {
154 self.tts_creators.insert(tts_type, creator);
155 self
156 }
157
158 pub fn create_vad_processor(
159 &self,
160 token: CancellationToken,
161 event_sender: EventSender,
162 option: VADOption,
163 ) -> Result<Box<dyn Processor>> {
164 let creator = self.vad_creators.get(&option.r#type);
165 if let Some(creator) = creator {
166 creator(token, event_sender, option)
167 } else {
168 Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
169 }
170 }
171 pub fn create_eou_processor(
172 &self,
173 token: CancellationToken,
174 event_sender: EventSender,
175 option: EouOption,
176 ) -> Result<Box<dyn Processor>> {
177 let creator = self
178 .eou_creators
179 .get(&option.r#type.clone().unwrap_or_default());
180 if let Some(creator) = creator {
181 creator(token, event_sender, option)
182 } else {
183 Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
184 }
185 }
186
187 pub async fn create_asr_processor(
188 &self,
189 track_id: TrackId,
190 cancel_token: CancellationToken,
191 option: TranscriptionOption,
192 event_sender: EventSender,
193 ) -> Result<Box<dyn Processor>> {
194 let asr_client = match option.provider {
195 Some(ref provider) => {
196 let creator = self.asr_creators.get(&provider);
197 if let Some(creator) = creator {
198 creator(track_id, cancel_token, option, event_sender).await?
199 } else {
200 return Err(anyhow::anyhow!("ASR type not found: {}", provider));
201 }
202 }
203 None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
204 };
205 Ok(Box::new(AsrProcessor { asr_client }))
206 }
207
208 pub async fn create_tts_client(
209 &self,
210 streaming: bool,
211 tts_option: &SynthesisOption,
212 ) -> Result<Box<dyn SynthesisClient>> {
213 match tts_option.provider {
214 Some(ref provider) => {
215 let creator = self.tts_creators.get(&provider);
216 if let Some(creator) = creator {
217 creator(streaming, tts_option)
218 } else {
219 Err(anyhow::anyhow!("TTS type not found: {}", provider))
220 }
221 }
222 None => Err(anyhow::anyhow!(
223 "TTS type not found: {:?}",
224 tts_option.provider
225 )),
226 }
227 }
228
229 pub async fn create_processors(
230 engine: Arc<StreamEngine>,
231 track: &dyn Track,
232 cancel_token: CancellationToken,
233 event_sender: EventSender,
234 packet_sender: TrackPacketSender,
235 option: &CallOption,
236 ) -> Result<Vec<Box<dyn Processor>>> {
237 (engine.clone().create_processors_hook)(
238 engine,
239 track,
240 cancel_token,
241 event_sender,
242 packet_sender,
243 option.clone(),
244 )
245 .await
246 }
247
248 pub async fn create_tts_track(
249 engine: Arc<StreamEngine>,
250 cancel_token: CancellationToken,
251 session_id: String,
252 track_id: TrackId,
253 ssrc: u32,
254 play_id: Option<String>,
255 streaming: bool,
256 tts_option: &SynthesisOption,
257 ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
258 let (tx, rx) = mpsc::unbounded_channel();
259 let new_handle = SynthesisHandle::new(tx, play_id.clone(), ssrc);
260 let tts_client = engine.create_tts_client(streaming, tts_option).await?;
261 let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
262 .with_ssrc(ssrc)
263 .with_cancel_token(cancel_token);
264 Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
265 }
266
267 pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
268 self.create_processors_hook = Arc::new(Box::new(hook_fn));
269 self
270 }
271
272 fn default_create_procesors_hook(
273 engine: Arc<StreamEngine>,
274 track: &dyn Track,
275 cancel_token: CancellationToken,
276 event_sender: EventSender,
277 packet_sender: TrackPacketSender,
278 option: CallOption,
279 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
280 let track_id = track.id().clone();
281 Box::pin(async move {
282 let mut processors = vec![];
283 debug!(track_id = %track_id, "Creating processors for track");
284
285 if let Some(realtime_option) = option.realtime {
286 debug!(track_id = %track_id, "Adding RealtimeProcessor");
287 let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
288 track_id.clone(),
289 cancel_token.child_token(),
290 event_sender.clone(),
291 packet_sender.clone(),
292 realtime_option,
293 )?;
294 processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
295 return Ok(processors);
298 }
299
300 match option.denoise {
301 Some(true) => {
302 debug!(track_id = %track_id, "Adding NoiseReducer processor");
303 let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
304 processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
305 }
306 _ => {}
307 }
308 match option.vad {
309 Some(mut option) => {
310 debug!(track_id = %track_id, "Adding VadProcessor processor type={:?}", option.r#type);
311 option.samplerate = INTERNAL_SAMPLERATE;
312 let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
313 cancel_token.child_token(),
314 event_sender.clone(),
315 option.to_owned(),
316 )?;
317 processors.push(vad_processor);
318 }
319 None => {}
320 }
321 match option.asr {
322 Some(mut option) => {
323 debug!(track_id = %track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
324 option.samplerate = Some(INTERNAL_SAMPLERATE);
325 let asr_processor = engine
326 .create_asr_processor(
327 track_id.clone(),
328 cancel_token.child_token(),
329 option.to_owned(),
330 event_sender.clone(),
331 )
332 .await?;
333 processors.push(asr_processor);
334 }
335 None => {}
336 }
337 match option.eou {
338 Some(ref option) => {
339 let eou_processor = engine.create_eou_processor(
340 cancel_token.child_token(),
341 event_sender.clone(),
342 option.to_owned(),
343 )?;
344 processors.push(eou_processor);
345 }
346 None => {}
347 }
348 match option.inactivity_timeout {
349 Some(timeout_secs) if timeout_secs > 0 => {
350 let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
351 track_id.clone(),
352 std::time::Duration::from_secs(timeout_secs),
353 event_sender.clone(),
354 cancel_token.child_token(),
355 );
356 processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
357 }
358 _ => {}
359 }
360
361 Ok(processors)
362 })
363 }
364}