azure_speech/recognizer/
client.rs

1use crate::connector::Client as BaseClient;
2use crate::recognizer::audio_format::AudioFormat;
3use crate::recognizer::session::Session;
4use crate::recognizer::utils::{
5    create_audio_header_message, create_audio_message, create_speech_config_message,
6    create_speech_context_message,
7};
8use crate::recognizer::{
9    AudioDevice, Confidence, Config, Event, OutputFormat, PrimaryLanguage, Recognized,
10};
11use crate::utils::get_azure_hostname_from_region;
12use crate::{stream_ext::StreamExt, Auth, Data, Message};
13use std::cmp::min;
14use tokio::io::AsyncReadExt;
15use tokio_stream::wrappers::ReceiverStream;
16use tokio_stream::{Stream, StreamExt as _};
17use tracing::{debug, warn};
18use url::Url;
19
20const BUFFER_SIZE: usize = 4096;
21
22#[derive(Clone)]
23pub struct Client {
24    pub client: BaseClient,
25    pub config: Config,
26}
27
28impl Client {
29    pub fn new(client: BaseClient, config: Config) -> Self {
30        Self { client, config }
31    }
32
33    pub async fn connect(auth: Auth, config: Config) -> crate::Result<Self> {
34        let base_url = format!(
35            "wss://{}.stt.speech{}/speech/recognition/{}/cognitiveservices/v1",
36            auth.region,
37            get_azure_hostname_from_region(&auth.region),
38            config.mode.as_str()
39        );
40        let mut url = Url::parse(&base_url)?;
41
42        let language = config
43            .languages
44            .first()
45            .ok_or_else(|| crate::Error::IOError("No language specified.".to_string()))?;
46        url.query_pairs_mut()
47            .append_pair("language", language.to_string().as_str())
48            .append_pair("format", config.output_format.as_str())
49            .append_pair("profanity", config.profanity.as_str())
50            .append_pair("storeAudio", &config.store_audio.to_string());
51        if config.output_format == OutputFormat::Detailed {
52            url.query_pairs_mut()
53                .append_pair("wordLevelTimestamps", "true");
54        }
55        if config.languages.len() > 1 {
56            url.query_pairs_mut().append_pair("lidEnabled", "true");
57        }
58        if let Some(ref connection_id) = config.connection_id {
59            url.query_pairs_mut()
60                .append_pair("X-ConnectionId", connection_id);
61        }
62
63        let ws_client = tokio_websockets::ClientBuilder::new()
64            .uri(url.as_str())
65            .unwrap()
66            .add_header(
67                "Ocp-Apim-Subscription-Key".try_into().unwrap(),
68                auth.subscription.to_string().as_str().try_into().unwrap(),
69            )?
70            .add_header(
71                "X-ConnectionId".try_into().unwrap(),
72                uuid::Uuid::new_v4().to_string().try_into().unwrap(),
73            )?;
74
75        let client = BaseClient::connect(ws_client).await?;
76        Ok(Self::new(client, config))
77    }
78
79    pub async fn disconnect(&self) -> crate::Result<()> {
80        self.client.disconnect().await
81    }
82
83    pub async fn recognize_file(
84        &self,
85        path: impl Into<std::path::PathBuf>,
86    ) -> crate::Result<impl Stream<Item = crate::Result<Event>>> {
87        let path = path.into();
88        let file = tokio::fs::File::open(&path).await?;
89        let ext = path
90            .extension()
91            .ok_or_else(|| crate::Error::IOError("Missing file extension.".to_string()))?;
92
93        let (tx, rx) = tokio::sync::mpsc::channel(1024);
94
95        tokio::spawn(async move {
96            let mut reader = tokio::io::BufReader::new(file);
97            loop {
98                let mut chunk = vec![0; BUFFER_SIZE];
99                match reader.read(&mut chunk).await {
100                    Ok(0) => break,
101                    Ok(n) => {
102                        chunk.truncate(n);
103                        if let Err(e) = tx.send(chunk).await {
104                            warn!("Failed to send chunk: {}", e);
105                            break;
106                        }
107                    }
108                    Err(e) => {
109                        warn!("Failed to read chunk: {}", e);
110                        break;
111                    }
112                }
113            }
114        });
115
116        self.recognize(
117            ReceiverStream::new(rx),
118            ext.try_into()?,
119            AudioDevice::file(),
120        )
121        .await
122    }
123    pub async fn recognize<A>(
124        &self,
125        mut audio: A,
126        audio_format: AudioFormat,
127        audio_device: AudioDevice,
128    ) -> crate::Result<impl Stream<Item = crate::Result<Event>>>
129    where
130        A: Stream<Item = Vec<u8>> + Sync + Send + Unpin + 'static,
131    {
132        let messages = self.client.stream().await?;
133        let session = Session::new();
134        let config = self.config.clone();
135        let client = self.client.clone();
136        let (restart_tx, mut restart_rx) = tokio::sync::mpsc::channel(1);
137
138        // Send the initial speech configuration.
139        client
140            .send(create_speech_config_message(
141                session.request_id().to_string(),
142                &config,
143                &audio_device,
144            ))
145            .await?;
146
147        // Send the initial context and audio header messages.
148        client
149            .send(create_speech_context_message(
150                session.request_id().to_string(),
151                &config,
152            ))
153            .await?;
154
155        // For WAV audio, extract the header and extra data.
156        let (audio_header, extra) = match audio_format {
157            AudioFormat::Wav => {
158                let (header, extra) = extract_header_from_wav(&mut audio).await?;
159                debug!(
160                    "Audio WAV header({}): {:?}",
161                    header.len(),
162                    header[..44].to_vec()
163                );
164                (Some(header), extra)
165            }
166            _ => (None, vec![]),
167        };
168
169        // Create the audio data buffer and seed it with any extra bytes.
170        let mut buffer = Vec::with_capacity(BUFFER_SIZE);
171        buffer.extend(extra);
172
173        client
174            .send(create_audio_header_message(
175                session.request_id().to_string(),
176                audio_format.clone(),
177                audio_header.as_deref(),
178            ))
179            .await?;
180
181        let _session = session.clone();
182        tokio::spawn(async move {
183            loop {
184                tokio::select! {
185                    // Handle any restart signal.
186                    _ = restart_rx.recv() => {
187                        tracing::info!("Refreshing audio header");
188                        _session.refresh();
189
190                        if client.send(create_audio_header_message(
191                            _session.request_id().to_string(),
192                            audio_format.clone(),
193                            audio_header.as_deref(),
194                        )).await.is_err() {
195                            warn!("Failed to refresh audio header");
196                            break;
197                        }
198                    },
199                    // Process the next chunk from the audio stream.
200                    maybe_chunk = audio.next() => {
201                        match maybe_chunk {
202                            Some(chunk) => {
203                                // Append the new data to the buffer.
204                                buffer.extend(chunk);
205                                // While there is enough data, send it in fixed-size chunks.
206                                while buffer.len() >= BUFFER_SIZE {
207                                    let data: Vec<u8> = buffer.drain(..BUFFER_SIZE).collect();
208                                    if client.send(create_audio_message(_session.request_id().to_string(), Some(&data))).await.is_err() {
209                                        warn!("Failed to send audio message");
210                                        break;
211                                    }
212                                }
213                            }
214                            None => {
215                                // No more audio: flush remaining bytes in the buffer.
216                                while !buffer.is_empty() {
217                                    let data: Vec<u8> = buffer.drain(..min(buffer.len(), BUFFER_SIZE)).collect();
218                                    if client.send(create_audio_message(_session.request_id().to_string(), Some(&data))).await.is_err() {
219                                        warn!("Failed to send final audio chunk");
220                                        break;
221                                    }
222                                }
223                                // Signal the end of audio.
224                                let _ = client.send(create_audio_message(_session.request_id().to_string(), None)).await;
225                                _session.set_audio_completed(true);
226                                break;
227                            }
228                        }
229                    }
230                }
231            }
232        });
233
234        // Build the output stream that filters and converts messages into events.
235        let session_clone = session.clone();
236        let output_stream = messages
237            .filter(move |msg| match msg {
238                Ok(m) => m.id == session.request_id().to_string(),
239                Err(_) => true,
240            })
241            .filter_map(move |msg| {
242                let session_ref = session_clone.clone();
243                match msg {
244                    Ok(m) => convert_message_to_event(m, &session_ref),
245                    Err(e) => Some(Err(e)),
246                }
247            })
248            .map(move |event| {
249                if let Ok(Event::SessionEnded(_)) = event {
250                    let _ = restart_tx.try_send(());
251                }
252                event
253            })
254            .stop_after(|event| event.is_err());
255
256        Ok(output_stream)
257    }
258}
259
260fn convert_message_to_event(message: Message, session: &Session) -> Option<crate::Result<Event>> {
261    match (message.path.as_str(), message.data, message.headers) {
262        ("turn.start", _, _) => Some(Ok(Event::SessionStarted(session.request_id()))),
263        ("speech.startdetected", Data::Text(Some(data)), _) => {
264            serde_json::from_str::<crate::recognizer::message::SpeechStartDetected>(&data)
265                .map(|v| Event::StartDetected(session.request_id(), v.offset))
266                .map(Ok)
267                .ok()
268        }
269        ("speech.enddetected", Data::Text(Some(data)), _) => {
270            let value =
271                serde_json::from_str::<crate::recognizer::message::SpeechEndDetected>(&data)
272                    .unwrap_or_default();
273            Some(Ok(Event::EndDetected(session.request_id(), value.offset)))
274        }
275        ("speech.hypothesis", Data::Text(Some(data)), _)
276        | ("speech.fragment", Data::Text(Some(data)), _) => {
277            match serde_json::from_str::<crate::recognizer::message::SpeechHypothesis>(&data) {
278                Ok(value) => {
279                    let offset = value.offset + session.audio_offset();
280                    session.on_hypothesis_received(offset);
281                    Some(Ok(Event::Recognizing(
282                        session.request_id(),
283                        Recognized {
284                            text: value.text,
285                            primary_language: value.primary_language.map(|l| {
286                                PrimaryLanguage::new(
287                                    l.language.into(),
288                                    l.confidence.map_or(Confidence::Unknown, |c| c.into()),
289                                )
290                            }),
291                            speaker_id: value.speaker_id,
292                        },
293                        offset,
294                        value.duration,
295                        data,
296                    )))
297                }
298                Err(e) => Some(Err(crate::Error::ParseError(e.to_string()))),
299            }
300        }
301        ("speech.phrase", Data::Text(Some(data)), _) => {
302            match serde_json::from_str::<crate::recognizer::message::SpeechPhrase>(&data) {
303                Ok(value) => {
304                    let offset = value.offset.unwrap_or(0) + session.audio_offset();
305                    let duration = value.duration.unwrap_or(0);
306                    if value.recognition_status.is_end_of_dictation() {
307                        return None;
308                    }
309                    if value.recognition_status.is_no_match() {
310                        return Some(Ok(Event::UnMatch(
311                            session.request_id(),
312                            offset,
313                            duration,
314                            data,
315                        )));
316                    }
317                    match serde_json::from_str::<crate::recognizer::message::SimpleSpeechPhrase>(
318                        &data,
319                    ) {
320                        Ok(simple) => Some(Ok(Event::Recognized(
321                            session.request_id(),
322                            Recognized {
323                                text: simple.display_text,
324                                primary_language: simple.primary_language.map(|l| {
325                                    PrimaryLanguage::new(
326                                        l.language.into(),
327                                        l.confidence.map_or(Confidence::Unknown, |c| c.into()),
328                                    )
329                                }),
330                                speaker_id: simple.speaker_id,
331                            },
332                            offset,
333                            duration,
334                            data,
335                        ))),
336                        Err(e) => Some(Err(crate::Error::ParseError(e.to_string()))),
337                    }
338                }
339                Err(e) => Some(Err(crate::Error::ParseError(e.to_string()))),
340            }
341        }
342        ("turn.end", _, _) => Some(Ok(Event::SessionEnded(session.request_id()))),
343        _ => None,
344    }
345}
346
347async fn extract_header_from_wav(
348    reader: &mut (impl Stream<Item = Vec<u8>> + Unpin + Send + Sync + 'static),
349) -> Result<(Vec<u8>, Vec<u8>), crate::Error> {
350    let mut header = Vec::new();
351
352    // Loop until the stream is exhausted.
353    while let Some(chunk) = reader.next().await {
354        header.extend(chunk);
355
356        // We need at least 12 bytes to check the RIFF and WAVE identifiers.
357        if header.len() < 12 {
358            continue;
359        }
360
361        // Check for a valid WAV header: bytes 0..4 must be "RIFF" and 8..12 must be "WAVE".
362        if &header[0..4] != b"RIFF" || &header[8..12] != b"WAVE" {
363            return Err(crate::Error::ParseError("Invalid wav header".to_string()));
364        }
365
366        // Look for the "data" descriptor.
367        if let Some(pos) = header.windows(4).position(|w| w == b"data") {
368            // Ensure we have read the 4 bytes following "data" (i.e. the length field).
369            if header.len() < pos + 8 {
370                // Not enough bytes yet; continue reading.
371                continue;
372            }
373            // Split the header at the end of the "data" chunk descriptor and its length field.
374            let header_end = pos + 8;
375            let remainder = header.split_off(header_end);
376            return Ok((header, remainder));
377        }
378    }
379
380    Err(crate::Error::ParseError(
381        "Reached end of stream without finding 'data' chunk".to_string(),
382    ))
383}