speechmatics 0.4.0

An async rust SDK for the Speechmatics API
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
//! This module is the main entrypoint for all realtime-related code, including the creation of session structs

use anyhow::Result;
use base64::{engine::general_purpose, Engine as _};
use futures::{
    pin_mut,
    stream::{SplitSink, SplitStream},
    SinkExt, StreamExt,
};
use http::Request;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde_json::from_slice;
use std::boxed::Box;
use tokio::{
    io::AsyncReadExt,
    join,
    net::TcpStream,
    sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
};
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream};
use url::Url;

#[cfg(test)]
use std::{println as error, println as debug, println as info, println as warn};

#[cfg(not(test))]
use log::{debug, error, info, warn};

/// Types for interfacing with the realtime API, autogenerated from the spec with a few extra processing steps
#[allow(missing_docs)]
pub mod models;

/// The default URL for the realtime runtime
///
/// This is the standard URL for self-service customers, and some enterprise customers.
/// Some customers may wish instead to access other European, American or Australian environments.
/// A full list of URLs can be found in our [docs](https://docs.speechmatics.com/introduction/authentication#supported-endpoints).
pub const DEFAULT_RT_URL: &str = "wss://neu.rt.speechmatics.com/v2/en";

/// The default ISO language code, which sets it to English.
pub const DEFAULT_LANGUAGE: &str = "en";
const VERSION: &str = env!("CARGO_PKG_VERSION");

/// Enum of all messages that can be read by an end user. This enum is passed to the receive channel that can be used to read messages
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ReadMessage {
    /// The RecognitionStarted enum variant
    RecognitionStarted(models::RecognitionStarted),
    /// The Info enum variant
    Info(models::Info),
    /// The Warning enum variant
    Warning(models::Warning),
    /// The Error enum variant
    Error(models::Error),
    /// The AddPartialTranscript enum variant
    AddPartialTranscript(models::AddPartialTranscript),
    /// The AddTranscript enum variant
    AddTranscript(models::AddTranscript),
    /// The AddPartialTranslation enum variant
    AddPartialTranslation(models::AddPartialTranslation),
    /// The AddTranslation enum variant
    AddTranslation(models::AddTranslation),
    /// The AudioAdded enum variant
    AudioAdded(models::AudioAdded),
    /// The EndOfTranscript enum variant
    EndOfTranscript(models::EndOfTranscript),
}

/// Struct which is passed into start (and then start_recognition) to configure the realtime session.
/// It implements default, which sets the language as English and otherwise sets everything to the API default.
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct SessionConfig {
    /// Config for the transcription part of the service. This is an optional property and defaults to standard English transcription.
    pub transcription_config: models::TranscriptionConfig,
    /// Config for the translation part of the service. This is an optional property and defaults to None.
    pub translation_config: Option<models::TranslationConfig>,
    /// Config to tell the server what kind of audio to expect. This is an optional property and defaults to raw pcm audio.
    pub audio_format: Option<models::AudioFormat>,
}

impl SessionConfig {
    /// Create a new instance of session config.
    ///
    /// If no transcription_config is provided, then it will set it to run a default English Standard session.
    /// By default, no audio format or translation config will be set, as the server will infer this.
    ///
    pub fn new(
        transcription_config: Option<models::TranscriptionConfig>,
        translation_config: Option<models::TranslationConfig>,
        audio_format: Option<models::AudioFormat>,
    ) -> Self {
        let mut transc_conf = models::TranscriptionConfig::default();
        transc_conf.language = "en".to_string();
        if let Some(t_conf) = transcription_config {
            transc_conf = t_conf
        };
        Self {
            transcription_config: transc_conf,
            translation_config,
            audio_format,
        }
    }
}

impl Default for SessionConfig {
    fn default() -> Self {
        let mut transcription_config: models::TranscriptionConfig = Default::default();
        transcription_config.language = DEFAULT_LANGUAGE.to_owned();
        let translation_config: models::TranslationConfig = Default::default();
        let audio_format: models::AudioFormat = Default::default();
        Self {
            transcription_config,
            translation_config: Some(translation_config),
            audio_format: Some(audio_format),
        }
    }
}

type SplitStreamAlias = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;

/// Struct that contains everything about the session. It includes the two mains functions:
/// - new to instantiate the session.
/// - start to start running the session. Start is an async function that can be joined or selected with other futures
pub struct RealtimeSession {
    auth_token: String,
    rt_url: String,
    internal_message_sender: UnboundedSender<ReadMessage>,
}

impl RealtimeSession {
    /// Instantiates a RealtimeSession struct. The run method can then be called to start transcribing a stream of data.
    ///
    /// # Example
    ///
    /// ```
    /// use speechmatics::realtime::RealtimeSession;
    ///
    /// let rt_session = RealtimeSession::new("YOUR_API_KEY", None).unwrap();
    /// ```
    ///
    /// # Errors
    ///
    /// This function can error if the URL provided is not a valid websocket URL
    pub fn new(
        auth_token: String,
        rt_url: Option<String>,
    ) -> Result<(Self, UnboundedReceiver<ReadMessage>)> {
        let (channel_sender, channel_receiver) = unbounded_channel::<ReadMessage>();
        let mut url = DEFAULT_RT_URL.to_owned();
        if let Some(temp_url) = rt_url {
            url = temp_url
        }
        let formatted_url = format!("{}?sm-sdk=rust-{}", url, VERSION);
        let sesh = Self {
            auth_token,
            rt_url: formatted_url,
            internal_message_sender: channel_sender,
        };
        Ok((sesh, channel_receiver))
    }

    /// connect is an internal function that handles the TCP handshake, TLS handshake and websocket handshake
    /// It ultimately returns the send and receive parts of the websocket.
    async fn connect(&mut self) -> Result<(SenderWrapper, SplitStreamAlias)> {
        let sec_key: String = thread_rng()
            .sample_iter(&Alphanumeric)
            .take(16)
            .map(char::from)
            .collect();
        let b64 = general_purpose::STANDARD.encode(sec_key);

        let uri = Url::parse(&self.rt_url)?;
        let authority = uri.authority();
        let host = authority
            .find('@')
            .map(|idx| authority.split_at(idx + 1).1)
            .unwrap_or_else(|| authority);

        if host.is_empty() {
            return Err(anyhow::Error::from(std::io::Error::new(
                std::io::ErrorKind::InvalidInput,
                "uri host was empty",
            )));
        }
        let auth_header = format!("Bearer {}", self.auth_token.clone());

        let req = Request::builder()
            .method("GET")
            .header("Host", host)
            .header("Connection", "keep-alive, Upgrade")
            .header("Upgrade", "websocket")
            .header("Sec-WebSocket-Version", "13")
            .header("Sec-WebSocket-Key", b64)
            .header("Authorization", auth_header)
            .uri(&self.rt_url)
            .body(())?;

        let (stream, res) = connect_async(req).await?;
        if let Some(resp) = res.body() {
            error!("failed to connect {:?}", resp);
        }

        let (writer, reader) = stream.split();
        let sender = SenderWrapper::new(writer);
        Ok((sender, reader))
    }

    /// Wait for start reads messages in a loop until one of a set of coniditions is met:
    /// 1. We receive RecognitionStarted, at which point the rt session begins in earnest
    /// 2. We receive an error, in which case we exit
    /// 3. We receive some other message, in which case we retry the function a set number of times
    async fn wait_for_start(
        &mut self,
        receiver: &mut SplitStreamAlias,
        channel_sender: &tokio::sync::mpsc::UnboundedSender<ReadMessage>,
    ) -> Result<()> {
        let mut retries = 0;
        let max_retries = 5;
        let mut success = false;
        while !success {
            let value = receiver.next().await;
            if let Some(val) = value {
                let message = match val {
                    Ok(v) => v,
                    Err(err) => {
                        warn!("Failed to get data from stream, {:?}", err);
                        retries += 1;
                        if retries > max_retries {
                            return Err(Into::into(std::io::Error::new(
                                std::io::ErrorKind::ConnectionAborted,
                                "Recognition failed to start on the server",
                            )));
                        }
                        continue;
                    }
                };
                debug!("{:?}", message);
                let bin_data = message.into_data();
                // this deserialise will fail if not the right message type
                match serde_json::from_slice::<models::RecognitionStarted>(&bin_data) {
                    Ok(mess) => {
                        success = true;
                        channel_sender.send(ReadMessage::RecognitionStarted(mess))?;
                    }
                    Err(err) => {
                        warn!(
                            "Could not read value of message into RecognitionStarted struct, {:?}",
                            err
                        );
                        match serde_json::from_slice::<models::Error>(&bin_data) {
                            Ok(mess) => {
                                return Err(Into::into(std::io::Error::new(
                                    std::io::ErrorKind::ConnectionAborted,
                                    format!("Received error from server {}", mess.reason),
                                )));
                            }
                            Err(_) => {
                                retries += 1;
                                if retries > max_retries {
                                    return Err(Into::into(std::io::Error::new(
                                        std::io::ErrorKind::ConnectionAborted,
                                        "Recognition failed to start on the server",
                                    )));
                                }
                                continue;
                            }
                        }
                    }
                };
            } else {
                return Err(Into::into(std::io::Error::new(
                    std::io::ErrorKind::TimedOut,
                    "Failed to receive message from the server",
                )));
            }
        }
        Ok(())
    }

    /// The main function of the RealtimeSession struct. It connects to the WebSocket,
    /// handles auth and sends a StartRecognition message to the websocket. It then waits until the server acknowledges
    /// the start of the session and then concurrently sends audio data and calls the user-registered handler functions.
    ///
    /// The config parameter sets the SessionConfig for the transcriber, including transcription, translation and audio source config.
    /// Although it is not yet, implemented, it is possible to update transcriber config on the fly.
    ///
    /// The reader parameter accepts anything that satisfies Read and Send e.g. a File, a BufReader, a Cursor.
    /// This allows the user to flexibly provide any audio source of their choice.
    ///
    /// # Example
    ///
    /// ```
    /// let api_key: String = std::env::var("API_KEY").unwrap();
    /// let mut rt_session =
    ///     RealtimeSession::new("YOUR_KEY", None).unwrap();
    ///
    /// let file = File::open("SOME_FILE_PATH").unwrap();
    ///
    /// let mut config: SessionConfig = Default::default();
    /// let audio_config = models::AudioFormat::new(models::audio_format::Type::File);
    /// config.audio_format = Some(audio_config);
    ///
    /// rt_session.run(config, file).await.unwrap();
    /// ```
    ///
    /// # Errors
    ///
    /// This function can fail in a number of ways:
    ///     - If the audio  read loop fails, the connection will be closed and the audio failure will be returned
    ///     - If the server sends an error message, this will be returned as an error and the audio read loop will stop
    ///     - If something goes wrong deserialising json or handling the local websocket, the error will be returned
    pub async fn run<R: AsyncReadExt + std::marker::Send + std::marker::Unpin + 'static>(
        &mut self,
        config: SessionConfig,
        reader: R,
    ) -> Result<(), anyhow::Error> {
        let (mut sock_sender, mut sock_receiver) = self.connect().await?;
        sock_sender.start_recognition(config).await?;
        self.wait_for_start(&mut sock_receiver, &self.internal_message_sender.clone())
            .await?;

        let sender = &self.internal_message_sender.clone();
        let process_messages = { RealtimeSession::process_messages(sock_receiver, sender) };
        let send_audio = { sock_sender.send_audio(reader) };

        pin_mut!(process_messages, send_audio);
        let (messages_res, audio_res) = join!(process_messages, send_audio);
        match audio_res {
            Ok(_) => debug!("No issues in audio processing task"),
            Err(err) => return Err(err),
        };
        match messages_res {
            Ok(_) => debug!("No issues detected whilst processing server-sent messages"),
            Err(err) => {
                error!("{:?}", err);
                return Err(err);
            }
        };
        Ok(())
    }

    async fn process_messages(
        mut receiver: SplitStreamAlias,
        channel_sender: &tokio::sync::mpsc::UnboundedSender<ReadMessage>,
    ) -> Result<()> {
        let mut running = true;
        while running {
            let result = receiver.next().await;
            if let Some(val) = result {
                let mess = val?;
                debug!("{}", mess);
                let data = mess.into_data();
                // Parse the string of data into serde_json::Value.
                let value = from_slice::<ReadMessage>(&data)?;
                match value {
                    ReadMessage::EndOfTranscript(mess) => {
                        debug!("detected EndOfTranscript message, quitting");
                        running = false;
                        channel_sender.send(ReadMessage::EndOfTranscript(mess))?;
                    }
                    ReadMessage::Error(mess) => {
                        channel_sender.send(ReadMessage::Error(mess.clone()))?;
                        error!("Received error from server {}", mess.reason);
                        return Err(Into::into(std::io::Error::new(
                            std::io::ErrorKind::ConnectionAborted,
                            format!("Received error from server {}", mess.reason),
                        )));
                    }
                    mess => channel_sender.send(mess)?,
                }
            } else {
                return Err(Into::into(std::io::Error::new(
                    std::io::ErrorKind::ConnectionAborted,
                    "Did not receive a message".to_string(),
                )));
            }
        }
        debug!("Exited message processing loop");
        Ok(())
    }
}

struct SenderWrapper {
    pub socket: SplitSink<
        WebSocketStream<MaybeTlsStream<TcpStream>>,
        tokio_tungstenite::tungstenite::Message,
    >,
    last_seq_no: i32,
}

impl SenderWrapper {
    fn new(
        socket: SplitSink<
            WebSocketStream<MaybeTlsStream<TcpStream>>,
            tokio_tungstenite::tungstenite::Message,
        >,
    ) -> Self {
        Self {
            socket,
            last_seq_no: 0,
        }
    }

    async fn send_audio<R: AsyncReadExt + std::marker::Send + std::marker::Unpin + 'static>(
        &mut self,
        mut reader: R,
    ) -> Result<()> {
        let mut buffer = vec![0u8; 8192];
        loop {
            debug!("reading audio data");
            match reader.read(&mut buffer).await {
                Ok(no) => {
                    if no == 0 {
                        info!("Reader was empty, closing stream");
                        self.send_close(self.last_seq_no).await?;
                        return Ok(());
                    } else {
                        debug!("Sending audio length {no}");
                        let tu_message = Message::from(&buffer[..no]);
                        self.send_message(tu_message).await?;
                        self.last_seq_no += 1;
                    }
                }
                Err(_) => {
                    info!("encountered an error reading audio data, closing the stream");
                    self.send_close(self.last_seq_no).await?;
                }
            };
        }
    }

    async fn send_message(&mut self, message: Message) -> Result<()> {
        let mut retries = 0;
        let max_retries = 5;
        let mut success = false;
        while !success {
            match self.socket.send(message.clone()).await {
                Ok(()) => (),
                Err(err) => {
                    retries += 1;
                    if retries >= max_retries {
                        error!("{:?}", err);
                        self.socket.send(message).await?;
                        panic!("arg too many attempts to send")
                    }
                    std::thread::sleep(std::time::Duration::from_millis(100));
                    continue;
                }
            };
            success = true
        }
        Ok(())
    }

    async fn start_recognition(&mut self, config: SessionConfig) -> Result<()> {
        let mut message: models::StartRecognition = Default::default();
        if let Some(aud) = config.audio_format {
            message.audio_format = Box::new(aud);
        }
        message.transcription_config = Box::new(config.transcription_config);
        if let Some(transl) = config.translation_config {
            message.translation_config = Some(Box::new(transl));
        }
        let serialised_msg = serde_json::to_string(&message)?;
        let ws_message = Message::from(serialised_msg);
        debug!("sending StartRecognition message {:?}", ws_message);
        self.send_message(ws_message).await
    }

    async fn send_close(&mut self, last_seq_no: i32) -> Result<()> {
        let message =
            models::EndOfStream::new(last_seq_no, models::end_of_stream::Message::EndOfStream);
        let serialised_msg = serde_json::to_string(&message)?;
        let tungstenite_msg = Message::from(serialised_msg);
        self.send_message(tungstenite_msg).await
    }
}

#[cfg(test)]
mod tests {
    use crate::realtime::*;
    use std::{
        path::PathBuf,
        sync::{Arc, Mutex},
    };
    use tokio::{self, fs::File, try_join};

    struct MockStore {
        transcript: String,
    }

    impl MockStore {
        pub fn new() -> Self {
            Self {
                transcript: "".to_owned(),
            }
        }

        pub fn append(&mut self, transcript: String) {
            self.transcript = format!("{} {}", self.transcript, transcript);
        }

        pub fn print(&self) {
            print!("{}", self.transcript)
        }
    }

    #[tokio::test]
    async fn test_basic_flow() {
        let api_key: String = std::env::var("API_KEY").unwrap();
        let (mut rt_session, mut receive_channel) = RealtimeSession::new(api_key, None).unwrap();

        let test_file_path = PathBuf::new()
            .join(".")
            .join("tests")
            .join("data")
            .join("example.wav");

        let file = File::open(test_file_path).await.unwrap();

        let mut config: SessionConfig = Default::default();
        let audio_config = models::AudioFormat::new(models::audio_format::Type::File);
        config.audio_format = Some(audio_config);

        let mock_store = Arc::new(Mutex::new(MockStore::new()));
        let mock_store_clone = mock_store.clone();

        let message_task = tokio::spawn(async move {
            while let Some(message) = receive_channel.recv().await {
                match message {
                    ReadMessage::AddTranscript(mess) => {
                        mock_store_clone
                            .lock()
                            .unwrap()
                            .append(mess.metadata.transcript);
                    }
                    ReadMessage::EndOfTranscript(_) => return,
                    _ => {}
                }
            }
        });

        let run_task = { rt_session.run(config, file) };

        try_join!(
            async move { message_task.await.map_err(anyhow::Error::from) },
            run_task
        )
        .unwrap();

        mock_store.lock().unwrap().print();
    }
}