1use crate::connector::Client as BaseClient;
2use crate::recognizer::audio_format::AudioFormat;
3use crate::recognizer::session::Session;
4use crate::recognizer::utils::{
5 create_audio_header_message, create_audio_message, create_speech_config_message,
6 create_speech_context_message,
7};
8use crate::recognizer::{
9 AudioDevice, Confidence, Config, Event, OutputFormat, PrimaryLanguage, Recognized,
10};
11use crate::utils::get_azure_hostname_from_region;
12use crate::{stream_ext::StreamExt, Auth, Data, Message};
13use std::cmp::min;
14use tokio::io::AsyncReadExt;
15use tokio_stream::wrappers::ReceiverStream;
16use tokio_stream::{Stream, StreamExt as _};
17use tracing::{debug, warn};
18use url::Url;
19
20const BUFFER_SIZE: usize = 4096;
21
22#[derive(Clone)]
23pub struct Client {
24 pub client: BaseClient,
25 pub config: Config,
26}
27
28impl Client {
29 pub fn new(client: BaseClient, config: Config) -> Self {
30 Self { client, config }
31 }
32
33 pub async fn connect(auth: Auth, config: Config) -> crate::Result<Self> {
34 let base_url = format!(
35 "wss://{}.stt.speech{}/speech/recognition/{}/cognitiveservices/v1",
36 auth.region,
37 get_azure_hostname_from_region(&auth.region),
38 config.mode.as_str()
39 );
40 let mut url = Url::parse(&base_url)?;
41
42 let language = config
43 .languages
44 .first()
45 .ok_or_else(|| crate::Error::IOError("No language specified.".to_string()))?;
46 url.query_pairs_mut()
47 .append_pair("language", language.to_string().as_str())
48 .append_pair("format", config.output_format.as_str())
49 .append_pair("profanity", config.profanity.as_str())
50 .append_pair("storeAudio", &config.store_audio.to_string());
51 if config.output_format == OutputFormat::Detailed {
52 url.query_pairs_mut()
53 .append_pair("wordLevelTimestamps", "true");
54 }
55 if config.languages.len() > 1 {
56 url.query_pairs_mut().append_pair("lidEnabled", "true");
57 }
58 if let Some(ref connection_id) = config.connection_id {
59 url.query_pairs_mut()
60 .append_pair("X-ConnectionId", connection_id);
61 }
62
63 let ws_client = tokio_websockets::ClientBuilder::new()
64 .uri(url.as_str())
65 .unwrap()
66 .add_header(
67 "Ocp-Apim-Subscription-Key".try_into().unwrap(),
68 auth.subscription.to_string().as_str().try_into().unwrap(),
69 )?
70 .add_header(
71 "X-ConnectionId".try_into().unwrap(),
72 uuid::Uuid::new_v4().to_string().try_into().unwrap(),
73 )?;
74
75 let client = BaseClient::connect(ws_client).await?;
76 Ok(Self::new(client, config))
77 }
78
79 pub async fn disconnect(&self) -> crate::Result<()> {
80 self.client.disconnect().await
81 }
82
83 pub async fn recognize_file(
84 &self,
85 path: impl Into<std::path::PathBuf>,
86 ) -> crate::Result<impl Stream<Item = crate::Result<Event>>> {
87 let path = path.into();
88 let file = tokio::fs::File::open(&path).await?;
89 let ext = path
90 .extension()
91 .ok_or_else(|| crate::Error::IOError("Missing file extension.".to_string()))?;
92
93 let (tx, rx) = tokio::sync::mpsc::channel(1024);
94
95 tokio::spawn(async move {
96 let mut reader = tokio::io::BufReader::new(file);
97 loop {
98 let mut chunk = vec![0; BUFFER_SIZE];
99 match reader.read(&mut chunk).await {
100 Ok(0) => break,
101 Ok(n) => {
102 chunk.truncate(n);
103 if let Err(e) = tx.send(chunk).await {
104 warn!("Failed to send chunk: {}", e);
105 break;
106 }
107 }
108 Err(e) => {
109 warn!("Failed to read chunk: {}", e);
110 break;
111 }
112 }
113 }
114 });
115
116 self.recognize(
117 ReceiverStream::new(rx),
118 ext.try_into()?,
119 AudioDevice::file(),
120 )
121 .await
122 }
123 pub async fn recognize<A>(
124 &self,
125 mut audio: A,
126 audio_format: AudioFormat,
127 audio_device: AudioDevice,
128 ) -> crate::Result<impl Stream<Item = crate::Result<Event>>>
129 where
130 A: Stream<Item = Vec<u8>> + Sync + Send + Unpin + 'static,
131 {
132 let messages = self.client.stream().await?;
133 let session = Session::new();
134 let config = self.config.clone();
135 let client = self.client.clone();
136 let (restart_tx, mut restart_rx) = tokio::sync::mpsc::channel(1);
137
138 client
140 .send(create_speech_config_message(
141 session.request_id().to_string(),
142 &config,
143 &audio_device,
144 ))
145 .await?;
146
147 client
149 .send(create_speech_context_message(
150 session.request_id().to_string(),
151 &config,
152 ))
153 .await?;
154
155 let (audio_header, extra) = match audio_format {
157 AudioFormat::Wav => {
158 let (header, extra) = extract_header_from_wav(&mut audio).await?;
159 debug!(
160 "Audio WAV header({}): {:?}",
161 header.len(),
162 header[..44].to_vec()
163 );
164 (Some(header), extra)
165 }
166 _ => (None, vec![]),
167 };
168
169 let mut buffer = Vec::with_capacity(BUFFER_SIZE);
171 buffer.extend(extra);
172
173 client
174 .send(create_audio_header_message(
175 session.request_id().to_string(),
176 audio_format.clone(),
177 audio_header.as_deref(),
178 ))
179 .await?;
180
181 let _session = session.clone();
182 tokio::spawn(async move {
183 loop {
184 tokio::select! {
185 _ = restart_rx.recv() => {
187 tracing::info!("Refreshing audio header");
188 _session.refresh();
189
190 if client.send(create_audio_header_message(
191 _session.request_id().to_string(),
192 audio_format.clone(),
193 audio_header.as_deref(),
194 )).await.is_err() {
195 warn!("Failed to refresh audio header");
196 break;
197 }
198 },
199 maybe_chunk = audio.next() => {
201 match maybe_chunk {
202 Some(chunk) => {
203 buffer.extend(chunk);
205 while buffer.len() >= BUFFER_SIZE {
207 let data: Vec<u8> = buffer.drain(..BUFFER_SIZE).collect();
208 if client.send(create_audio_message(_session.request_id().to_string(), Some(&data))).await.is_err() {
209 warn!("Failed to send audio message");
210 break;
211 }
212 }
213 }
214 None => {
215 while !buffer.is_empty() {
217 let data: Vec<u8> = buffer.drain(..min(buffer.len(), BUFFER_SIZE)).collect();
218 if client.send(create_audio_message(_session.request_id().to_string(), Some(&data))).await.is_err() {
219 warn!("Failed to send final audio chunk");
220 break;
221 }
222 }
223 let _ = client.send(create_audio_message(_session.request_id().to_string(), None)).await;
225 _session.set_audio_completed(true);
226 break;
227 }
228 }
229 }
230 }
231 }
232 });
233
234 let session_clone = session.clone();
236 let output_stream = messages
237 .filter(move |msg| match msg {
238 Ok(m) => m.id == session.request_id().to_string(),
239 Err(_) => true,
240 })
241 .filter_map(move |msg| {
242 let session_ref = session_clone.clone();
243 match msg {
244 Ok(m) => convert_message_to_event(m, &session_ref),
245 Err(e) => Some(Err(e)),
246 }
247 })
248 .map(move |event| {
249 if let Ok(Event::SessionEnded(_)) = event {
250 let _ = restart_tx.try_send(());
251 }
252 event
253 })
254 .stop_after(|event| event.is_err());
255
256 Ok(output_stream)
257 }
258}
259
260fn convert_message_to_event(message: Message, session: &Session) -> Option<crate::Result<Event>> {
261 match (message.path.as_str(), message.data, message.headers) {
262 ("turn.start", _, _) => Some(Ok(Event::SessionStarted(session.request_id()))),
263 ("speech.startdetected", Data::Text(Some(data)), _) => {
264 serde_json::from_str::<crate::recognizer::message::SpeechStartDetected>(&data)
265 .map(|v| Event::StartDetected(session.request_id(), v.offset))
266 .map(Ok)
267 .ok()
268 }
269 ("speech.enddetected", Data::Text(Some(data)), _) => {
270 let value =
271 serde_json::from_str::<crate::recognizer::message::SpeechEndDetected>(&data)
272 .unwrap_or_default();
273 Some(Ok(Event::EndDetected(session.request_id(), value.offset)))
274 }
275 ("speech.hypothesis", Data::Text(Some(data)), _)
276 | ("speech.fragment", Data::Text(Some(data)), _) => {
277 match serde_json::from_str::<crate::recognizer::message::SpeechHypothesis>(&data) {
278 Ok(value) => {
279 let offset = value.offset + session.audio_offset();
280 session.on_hypothesis_received(offset);
281 Some(Ok(Event::Recognizing(
282 session.request_id(),
283 Recognized {
284 text: value.text,
285 primary_language: value.primary_language.map(|l| {
286 PrimaryLanguage::new(
287 l.language.into(),
288 l.confidence.map_or(Confidence::Unknown, |c| c.into()),
289 )
290 }),
291 speaker_id: value.speaker_id,
292 },
293 offset,
294 value.duration,
295 data,
296 )))
297 }
298 Err(e) => Some(Err(crate::Error::ParseError(e.to_string()))),
299 }
300 }
301 ("speech.phrase", Data::Text(Some(data)), _) => {
302 match serde_json::from_str::<crate::recognizer::message::SpeechPhrase>(&data) {
303 Ok(value) => {
304 let offset = value.offset.unwrap_or(0) + session.audio_offset();
305 let duration = value.duration.unwrap_or(0);
306 if value.recognition_status.is_end_of_dictation() {
307 return None;
308 }
309 if value.recognition_status.is_no_match() {
310 return Some(Ok(Event::UnMatch(
311 session.request_id(),
312 offset,
313 duration,
314 data,
315 )));
316 }
317 match serde_json::from_str::<crate::recognizer::message::SimpleSpeechPhrase>(
318 &data,
319 ) {
320 Ok(simple) => Some(Ok(Event::Recognized(
321 session.request_id(),
322 Recognized {
323 text: simple.display_text,
324 primary_language: simple.primary_language.map(|l| {
325 PrimaryLanguage::new(
326 l.language.into(),
327 l.confidence.map_or(Confidence::Unknown, |c| c.into()),
328 )
329 }),
330 speaker_id: simple.speaker_id,
331 },
332 offset,
333 duration,
334 data,
335 ))),
336 Err(e) => Some(Err(crate::Error::ParseError(e.to_string()))),
337 }
338 }
339 Err(e) => Some(Err(crate::Error::ParseError(e.to_string()))),
340 }
341 }
342 ("turn.end", _, _) => Some(Ok(Event::SessionEnded(session.request_id()))),
343 _ => None,
344 }
345}
346
347async fn extract_header_from_wav(
348 reader: &mut (impl Stream<Item = Vec<u8>> + Unpin + Send + Sync + 'static),
349) -> Result<(Vec<u8>, Vec<u8>), crate::Error> {
350 let mut header = Vec::new();
351
352 while let Some(chunk) = reader.next().await {
354 header.extend(chunk);
355
356 if header.len() < 12 {
358 continue;
359 }
360
361 if &header[0..4] != b"RIFF" || &header[8..12] != b"WAVE" {
363 return Err(crate::Error::ParseError("Invalid wav header".to_string()));
364 }
365
366 if let Some(pos) = header.windows(4).position(|w| w == b"data") {
368 if header.len() < pos + 8 {
370 continue;
372 }
373 let header_end = pos + 8;
375 let remainder = header.split_off(header_end);
376 return Ok((header, remainder));
377 }
378 }
379
380 Err(crate::Error::ParseError(
381 "Reached end of stream without finding 'data' chunk".to_string(),
382 ))
383}