1use super::{
2 INTERNAL_SAMPLERATE,
3 asr_processor::AsrProcessor,
4 denoiser::NoiseReducer,
5 processor::Processor,
6 track::{
7 Track, TrackPacketSender,
8 tts::{SynthesisHandle, TtsTrack},
9 },
10 vad::{VADOption, VadProcessor, VadType},
11};
12use crate::{
13 CallOption, EouOption,
14 event::EventSender,
15 media::TrackId,
16 synthesis::{
17 AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
18 TencentCloudTtsBasicClient, TencentCloudTtsClient,
19 },
20 transcription::{
21 AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
22 TranscriptionOption, TranscriptionType,
23 },
24};
25
26#[cfg(feature = "offline")]
27use crate::{synthesis::SupertonicTtsClient, transcription::SensevoiceAsrClientBuilder};
28
29use anyhow::Result;
30use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
31use tokio::sync::mpsc;
32use tokio_util::sync::CancellationToken;
33use tracing::debug;
34
35pub type FnCreateVadProcessor = fn(
36 token: CancellationToken,
37 event_sender: EventSender,
38 option: VADOption,
39) -> Result<Box<dyn Processor>>;
40
41pub type FnCreateEouProcessor = fn(
42 token: CancellationToken,
43 event_sender: EventSender,
44 option: EouOption,
45) -> Result<Box<dyn Processor>>;
46
47pub type FnCreateAsrClient = Box<
48 dyn Fn(
49 TrackId,
50 CancellationToken,
51 TranscriptionOption,
52 EventSender,
53 ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
54 + Send
55 + Sync,
56>;
57pub type FnCreateTtsClient =
58 fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
59
60pub type CreateProcessorsHook = Box<
62 dyn Fn(
63 Arc<StreamEngine>,
64 TrackId,
65 CancellationToken,
66 EventSender,
67 TrackPacketSender,
68 CallOption,
69 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
70 + Send
71 + Sync,
72>;
73
74pub struct StreamEngine {
75 vad_creators: HashMap<VadType, FnCreateVadProcessor>,
76 eou_creators: HashMap<String, FnCreateEouProcessor>,
77 asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
78 tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
79 create_processors_hook: Arc<CreateProcessorsHook>,
80}
81
82impl Default for StreamEngine {
83 fn default() -> Self {
84 let mut engine = Self::new();
85 engine.register_vad(VadType::Silero, VadProcessor::create);
86 engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
87
88 engine.register_asr(
89 TranscriptionType::TencentCloud,
90 Box::new(TencentCloudAsrClientBuilder::create),
91 );
92 engine.register_asr(
93 TranscriptionType::Aliyun,
94 Box::new(AliyunAsrClientBuilder::create),
95 );
96
97 #[cfg(feature = "offline")]
98 engine.register_asr(
99 TranscriptionType::Sensevoice,
100 Box::new(SensevoiceAsrClientBuilder::create),
101 );
102
103 engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
104 engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
105 engine.register_tts(
106 SynthesisType::Other("tencent_basic".to_string()),
107 TencentCloudTtsBasicClient::create,
108 );
109 engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
110
111 #[cfg(feature = "offline")]
112 engine.register_tts(SynthesisType::Supertonic, SupertonicTtsClient::create);
113
114 engine
115 }
116}
117
118impl StreamEngine {
119 pub fn new() -> Self {
120 Self {
121 vad_creators: HashMap::new(),
122 asr_creators: HashMap::new(),
123 tts_creators: HashMap::new(),
124 eou_creators: HashMap::new(),
125 create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
126 }
127 }
128
129 pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
130 self.vad_creators.insert(vad_type, creator);
131 self
132 }
133
134 pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
135 self.eou_creators.insert(name, creator);
136 self
137 }
138
139 pub fn register_asr(
140 &mut self,
141 asr_type: TranscriptionType,
142 creator: FnCreateAsrClient,
143 ) -> &mut Self {
144 self.asr_creators.insert(asr_type, creator);
145 self
146 }
147
148 pub fn register_tts(
149 &mut self,
150 tts_type: SynthesisType,
151 creator: FnCreateTtsClient,
152 ) -> &mut Self {
153 self.tts_creators.insert(tts_type, creator);
154 self
155 }
156
157 pub fn create_vad_processor(
158 &self,
159 token: CancellationToken,
160 event_sender: EventSender,
161 option: VADOption,
162 ) -> Result<Box<dyn Processor>> {
163 let creator = self.vad_creators.get(&option.r#type);
164 if let Some(creator) = creator {
165 creator(token, event_sender, option)
166 } else {
167 Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
168 }
169 }
170 pub fn create_eou_processor(
171 &self,
172 token: CancellationToken,
173 event_sender: EventSender,
174 option: EouOption,
175 ) -> Result<Box<dyn Processor>> {
176 let creator = self
177 .eou_creators
178 .get(&option.r#type.clone().unwrap_or_default());
179 if let Some(creator) = creator {
180 creator(token, event_sender, option)
181 } else {
182 Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
183 }
184 }
185
186 pub async fn create_asr_processor(
187 &self,
188 track_id: TrackId,
189 cancel_token: CancellationToken,
190 option: TranscriptionOption,
191 event_sender: EventSender,
192 ) -> Result<Box<dyn Processor>> {
193 let asr_client = match option.provider {
194 Some(ref provider) => {
195 let creator = self.asr_creators.get(&provider);
196 if let Some(creator) = creator {
197 creator(track_id, cancel_token, option, event_sender).await?
198 } else {
199 return Err(anyhow::anyhow!("ASR type not found: {}", provider));
200 }
201 }
202 None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
203 };
204 Ok(Box::new(AsrProcessor { asr_client }))
205 }
206
207 pub async fn create_tts_client(
208 &self,
209 streaming: bool,
210 tts_option: &SynthesisOption,
211 ) -> Result<Box<dyn SynthesisClient>> {
212 match tts_option.provider {
213 Some(ref provider) => {
214 let creator = self.tts_creators.get(&provider);
215 if let Some(creator) = creator {
216 creator(streaming, tts_option)
217 } else {
218 Err(anyhow::anyhow!("TTS type not found: {}", provider))
219 }
220 }
221 None => Err(anyhow::anyhow!(
222 "TTS type not found: {:?}",
223 tts_option.provider
224 )),
225 }
226 }
227
228 pub async fn create_processors(
229 engine: Arc<StreamEngine>,
230 track: &dyn Track,
231 cancel_token: CancellationToken,
232 event_sender: EventSender,
233 packet_sender: TrackPacketSender,
234 option: &CallOption,
235 ) -> Result<Vec<Box<dyn Processor>>> {
236 (engine.clone().create_processors_hook)(
237 engine,
238 track.id().clone(),
239 cancel_token,
240 event_sender,
241 packet_sender,
242 option.clone(),
243 )
244 .await
245 }
246
247 pub async fn create_tts_track(
248 engine: Arc<StreamEngine>,
249 cancel_token: CancellationToken,
250 session_id: String,
251 track_id: TrackId,
252 ssrc: u32,
253 play_id: Option<String>,
254 streaming: bool,
255 tts_option: &SynthesisOption,
256 ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
257 let (tx, rx) = mpsc::unbounded_channel();
258 let new_handle = SynthesisHandle::new(tx, play_id.clone(), ssrc);
259 let tts_client = engine.create_tts_client(streaming, tts_option).await?;
260 let sample_rate = tts_option.samplerate.unwrap_or(16000) as u32;
261 let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
262 .with_ssrc(ssrc)
263 .with_sample_rate(sample_rate)
264 .with_cancel_token(cancel_token);
265 Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
266 }
267
268 pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
269 self.create_processors_hook = Arc::new(Box::new(hook_fn));
270 self
271 }
272
273 pub fn default_create_procesors_hook(
274 engine: Arc<StreamEngine>,
275 track_id: TrackId,
276 cancel_token: CancellationToken,
277 event_sender: EventSender,
278 packet_sender: TrackPacketSender,
279 option: CallOption,
280 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
281 Box::pin(async move {
282 let mut processors = vec![];
283 debug!(%track_id, "Creating processors for track");
284
285 if let Some(realtime_option) = option.realtime {
286 debug!(%track_id, "Adding RealtimeProcessor");
287 let realtime_processor = crate::media::realtime_processor::RealtimeProcessor::new(
288 track_id.clone(),
289 cancel_token.child_token(),
290 event_sender.clone(),
291 packet_sender.clone(),
292 realtime_option,
293 )?;
294 processors.push(Box::new(realtime_processor) as Box<dyn Processor>);
295 return Ok(processors);
298 }
299
300 match option.denoise {
301 Some(true) => {
302 debug!(%track_id, "Adding NoiseReducer processor");
303 let noise_reducer = NoiseReducer::new(INTERNAL_SAMPLERATE as usize);
304 processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
305 }
306 _ => {}
307 }
308 match option.vad {
309 Some(mut option) => {
310 debug!(%track_id, "Adding VadProcessor processor type={:?}", option.r#type);
311 option.samplerate = INTERNAL_SAMPLERATE;
312 let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
313 cancel_token.child_token(),
314 event_sender.clone(),
315 option.to_owned(),
316 )?;
317 processors.push(vad_processor);
318 }
319 None => {}
320 }
321 match option.asr {
322 Some(mut option) => {
323 debug!(%track_id, "Adding AsrProcessor processor provider={:?}", option.provider);
324 option.samplerate = Some(INTERNAL_SAMPLERATE);
325 let asr_processor = engine
326 .create_asr_processor(
327 track_id.clone(),
328 cancel_token.child_token(),
329 option.to_owned(),
330 event_sender.clone(),
331 )
332 .await?;
333 processors.push(asr_processor);
334 }
335 None => {}
336 }
337 match option.eou {
338 Some(ref option) => {
339 let eou_processor = engine.create_eou_processor(
340 cancel_token.child_token(),
341 event_sender.clone(),
342 option.to_owned(),
343 )?;
344 processors.push(eou_processor);
345 }
346 None => {}
347 }
348 match option.inactivity_timeout {
349 Some(timeout_secs) if timeout_secs > 0 => {
350 let inactivity_processor = crate::media::inactivity::InactivityProcessor::new(
351 track_id.clone(),
352 std::time::Duration::from_secs(timeout_secs),
353 event_sender.clone(),
354 cancel_token.child_token(),
355 );
356 processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
357 }
358 _ => {}
359 }
360
361 Ok(processors)
362 })
363 }
364}