1use super::{
2 asr_processor::AsrProcessor,
3 denoiser::NoiseReducer,
4 processor::Processor,
5 track::{
6 Track,
7 tts::{SynthesisHandle, TtsTrack},
8 },
9 vad::{VADOption, VadProcessor, VadType},
10};
11use crate::{
12 CallOption, EouOption,
13 event::EventSender,
14 media::TrackId,
15 synthesis::{
16 AliyunTtsClient, DeepegramTtsClient, SynthesisClient, SynthesisOption, SynthesisType,
17 TencentCloudTtsBasicClient, TencentCloudTtsClient, VoiceApiTtsClient,
18 },
19 transcription::{
20 AliyunAsrClientBuilder, TencentCloudAsrClientBuilder, TranscriptionClient,
21 TranscriptionOption, TranscriptionType, VoiceApiAsrClientBuilder,
22 },
23};
24use anyhow::Result;
25use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc};
26use tokio::sync::mpsc;
27use tokio_util::sync::CancellationToken;
28
29pub type FnCreateVadProcessor = fn(
30 token: CancellationToken,
31 event_sender: EventSender,
32 option: VADOption,
33) -> Result<Box<dyn Processor>>;
34
35pub type FnCreateEouProcessor = fn(
36 token: CancellationToken,
37 event_sender: EventSender,
38 option: EouOption,
39) -> Result<Box<dyn Processor>>;
40
41pub type FnCreateAsrClient = Box<
42 dyn Fn(
43 TrackId,
44 CancellationToken,
45 TranscriptionOption,
46 EventSender,
47 ) -> Pin<Box<dyn Future<Output = Result<Box<dyn TranscriptionClient>>> + Send>>
48 + Send
49 + Sync,
50>;
51pub type FnCreateTtsClient =
52 fn(streaming: bool, option: &SynthesisOption) -> Result<Box<dyn SynthesisClient>>;
53
54pub type CreateProcessorsHook = Box<
56 dyn Fn(
57 Arc<StreamEngine>,
58 &dyn Track,
59 CancellationToken,
60 EventSender,
61 CallOption,
62 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>>
63 + Send
64 + Sync,
65>;
66
67pub struct StreamEngine {
68 vad_creators: HashMap<VadType, FnCreateVadProcessor>,
69 eou_creators: HashMap<String, FnCreateEouProcessor>,
70 asr_creators: HashMap<TranscriptionType, FnCreateAsrClient>,
71 tts_creators: HashMap<SynthesisType, FnCreateTtsClient>,
72 create_processors_hook: Arc<CreateProcessorsHook>,
73}
74
75impl Default for StreamEngine {
76 fn default() -> Self {
77 let mut engine = Self::new();
78 engine.register_vad(VadType::Silero, VadProcessor::create);
79 engine.register_vad(VadType::Ten, VadProcessor::create);
80 engine.register_vad(VadType::Other("nop".to_string()), VadProcessor::create_nop);
81
82 engine.register_asr(
83 TranscriptionType::TencentCloud,
84 Box::new(TencentCloudAsrClientBuilder::create),
85 );
86 engine.register_asr(
87 TranscriptionType::VoiceApi,
88 Box::new(VoiceApiAsrClientBuilder::create),
89 );
90 engine.register_asr(
91 TranscriptionType::Aliyun,
92 Box::new(AliyunAsrClientBuilder::create),
93 );
94 engine.register_tts(SynthesisType::Aliyun, AliyunTtsClient::create);
95 engine.register_tts(SynthesisType::TencentCloud, TencentCloudTtsClient::create);
96 engine.register_tts(SynthesisType::VoiceApi, VoiceApiTtsClient::create);
97 engine.register_tts(
98 SynthesisType::Other("tencent_basic".to_string()),
99 TencentCloudTtsBasicClient::create,
100 );
101 engine.register_tts(SynthesisType::Deepgram, DeepegramTtsClient::create);
102 engine
103 }
104}
105
106impl StreamEngine {
107 pub fn new() -> Self {
108 Self {
109 vad_creators: HashMap::new(),
110 asr_creators: HashMap::new(),
111 tts_creators: HashMap::new(),
112 eou_creators: HashMap::new(),
113 create_processors_hook: Arc::new(Box::new(Self::default_create_procesors_hook)),
114 }
115 }
116
117 pub fn register_vad(&mut self, vad_type: VadType, creator: FnCreateVadProcessor) -> &mut Self {
118 self.vad_creators.insert(vad_type, creator);
119 self
120 }
121
122 pub fn register_eou(&mut self, name: String, creator: FnCreateEouProcessor) -> &mut Self {
123 self.eou_creators.insert(name, creator);
124 self
125 }
126
127 pub fn register_asr(
128 &mut self,
129 asr_type: TranscriptionType,
130 creator: FnCreateAsrClient,
131 ) -> &mut Self {
132 self.asr_creators.insert(asr_type, creator);
133 self
134 }
135
136 pub fn register_tts(
137 &mut self,
138 tts_type: SynthesisType,
139 creator: FnCreateTtsClient,
140 ) -> &mut Self {
141 self.tts_creators.insert(tts_type, creator);
142 self
143 }
144
145 pub fn create_vad_processor(
146 &self,
147 token: CancellationToken,
148 event_sender: EventSender,
149 option: VADOption,
150 ) -> Result<Box<dyn Processor>> {
151 let creator = self.vad_creators.get(&option.r#type);
152 if let Some(creator) = creator {
153 creator(token, event_sender, option)
154 } else {
155 Err(anyhow::anyhow!("VAD type not found: {}", option.r#type))
156 }
157 }
158 pub fn create_eou_processor(
159 &self,
160 token: CancellationToken,
161 event_sender: EventSender,
162 option: EouOption,
163 ) -> Result<Box<dyn Processor>> {
164 let creator = self
165 .eou_creators
166 .get(&option.r#type.clone().unwrap_or_default());
167 if let Some(creator) = creator {
168 creator(token, event_sender, option)
169 } else {
170 Err(anyhow::anyhow!("EOU type not found: {:?}", option.r#type))
171 }
172 }
173
174 pub async fn create_asr_processor(
175 &self,
176 track_id: TrackId,
177 cancel_token: CancellationToken,
178 option: TranscriptionOption,
179 event_sender: EventSender,
180 ) -> Result<Box<dyn Processor>> {
181 let asr_client = match option.provider {
182 Some(ref provider) => {
183 let creator = self.asr_creators.get(&provider);
184 if let Some(creator) = creator {
185 creator(track_id, cancel_token, option, event_sender).await?
186 } else {
187 return Err(anyhow::anyhow!("ASR type not found: {}", provider));
188 }
189 }
190 None => return Err(anyhow::anyhow!("ASR type not found: {:?}", option.provider)),
191 };
192 Ok(Box::new(AsrProcessor { asr_client }))
193 }
194
195 pub async fn create_tts_client(
196 &self,
197 streaming: bool,
198 tts_option: &SynthesisOption,
199 ) -> Result<Box<dyn SynthesisClient>> {
200 match tts_option.provider {
201 Some(ref provider) => {
202 let creator = self.tts_creators.get(&provider);
203 if let Some(creator) = creator {
204 creator(streaming, tts_option)
205 } else {
206 Err(anyhow::anyhow!("TTS type not found: {}", provider))
207 }
208 }
209 None => Err(anyhow::anyhow!(
210 "TTS type not found: {:?}",
211 tts_option.provider
212 )),
213 }
214 }
215
216 pub async fn create_processors(
217 engine: Arc<StreamEngine>,
218 track: &dyn Track,
219 cancel_token: CancellationToken,
220 event_sender: EventSender,
221 option: &CallOption,
222 ) -> Result<Vec<Box<dyn Processor>>> {
223 (engine.clone().create_processors_hook)(
224 engine,
225 track,
226 cancel_token,
227 event_sender,
228 option.clone(),
229 )
230 .await
231 }
232
233 pub async fn create_tts_track(
234 engine: Arc<StreamEngine>,
235 cancel_token: CancellationToken,
236 session_id: String,
237 track_id: TrackId,
238 ssrc: u32,
239 play_id: Option<String>,
240 streaming: bool,
241 tts_option: &SynthesisOption,
242 ) -> Result<(SynthesisHandle, Box<dyn Track>)> {
243 let (tx, rx) = mpsc::unbounded_channel();
244 let new_handle = SynthesisHandle::new(tx, play_id.clone());
245 let tts_client = engine.create_tts_client(streaming, tts_option).await?;
246 let tts_track = TtsTrack::new(track_id, session_id, streaming, play_id, rx, tts_client)
247 .with_ssrc(ssrc)
248 .with_cancel_token(cancel_token);
249 Ok((new_handle, Box::new(tts_track) as Box<dyn Track>))
250 }
251
252 pub fn with_processor_hook(&mut self, hook_fn: CreateProcessorsHook) -> &mut Self {
253 self.create_processors_hook = Arc::new(Box::new(hook_fn));
254 self
255 }
256
257 fn default_create_procesors_hook(
258 engine: Arc<StreamEngine>,
259 track: &dyn Track,
260 cancel_token: CancellationToken,
261 event_sender: EventSender,
262 option: CallOption,
263 ) -> Pin<Box<dyn Future<Output = Result<Vec<Box<dyn Processor>>>> + Send>> {
264 let track_id = track.id().clone();
265 let samplerate = track.config().samplerate as usize;
266 Box::pin(async move {
267 let mut processors = vec![];
268 match option.denoise {
269 Some(true) => {
270 let noise_reducer = NoiseReducer::new(samplerate);
271 processors.push(Box::new(noise_reducer) as Box<dyn Processor>);
272 }
273 _ => {}
274 }
275 match option.vad {
276 Some(ref option) => {
277 let vad_processor: Box<dyn Processor + 'static> = engine.create_vad_processor(
278 cancel_token.child_token(),
279 event_sender.clone(),
280 option.to_owned(),
281 )?;
282 processors.push(vad_processor);
283 }
284 None => {}
285 }
286 match option.asr {
287 Some(ref option) => {
288 let asr_processor = engine
289 .create_asr_processor(
290 track_id.clone(),
291 cancel_token.child_token(),
292 option.to_owned(),
293 event_sender.clone(),
294 )
295 .await?;
296 processors.push(asr_processor);
297 }
298 None => {}
299 }
300 match option.eou {
301 Some(ref option) => {
302 let eou_processor = engine.create_eou_processor(
303 cancel_token.child_token(),
304 event_sender.clone(),
305 option.to_owned(),
306 )?;
307 processors.push(eou_processor);
308 }
309 None => {}
310 }
311 match option.inactivity_timeout {
312 Some(timeout_secs) if timeout_secs > 0 => {
313 let inactivity_processor =
314 crate::media::inactivity::InactivityProcessor::new(
315 track_id.clone(),
316 std::time::Duration::from_secs(timeout_secs),
317 event_sender.clone(),
318 cancel_token.child_token(),
319 );
320 processors.push(Box::new(inactivity_processor) as Box<dyn Processor>);
321 }
322 _ => {}
323 }
324
325 Ok(processors)
326 })
327 }
328}