use anyhow::Result;
use base64::{engine::general_purpose, Engine as _};
use futures::{
pin_mut,
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use http::Request;
use rand::distributions::Alphanumeric;
use rand::{thread_rng, Rng};
use serde_json::from_slice;
use std::boxed::Box;
use tokio::{
io::AsyncReadExt,
join,
net::TcpStream,
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
};
use tokio_tungstenite::MaybeTlsStream;
use tokio_tungstenite::{connect_async, tungstenite::Message, WebSocketStream};
use url::Url;
#[cfg(test)]
use std::{println as error, println as debug, println as info, println as warn};
#[cfg(not(test))]
use log::{debug, error, info, warn};
#[allow(missing_docs)]
pub mod models;
pub const DEFAULT_RT_URL: &str = "wss://neu.rt.speechmatics.com/v2/en";
pub const DEFAULT_LANGUAGE: &str = "en";
const VERSION: &str = env!("CARGO_PKG_VERSION");
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ReadMessage {
RecognitionStarted(models::RecognitionStarted),
Info(models::Info),
Warning(models::Warning),
Error(models::Error),
AddPartialTranscript(models::AddPartialTranscript),
AddTranscript(models::AddTranscript),
AddPartialTranslation(models::AddPartialTranslation),
AddTranslation(models::AddTranslation),
AudioAdded(models::AudioAdded),
EndOfTranscript(models::EndOfTranscript),
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct SessionConfig {
pub transcription_config: models::TranscriptionConfig,
pub translation_config: Option<models::TranslationConfig>,
pub audio_format: Option<models::AudioFormat>,
}
impl SessionConfig {
pub fn new(
transcription_config: Option<models::TranscriptionConfig>,
translation_config: Option<models::TranslationConfig>,
audio_format: Option<models::AudioFormat>,
) -> Self {
let mut transc_conf = models::TranscriptionConfig::default();
transc_conf.language = "en".to_string();
if let Some(t_conf) = transcription_config {
transc_conf = t_conf
};
Self {
transcription_config: transc_conf,
translation_config,
audio_format,
}
}
}
impl Default for SessionConfig {
fn default() -> Self {
let mut transcription_config: models::TranscriptionConfig = Default::default();
transcription_config.language = DEFAULT_LANGUAGE.to_owned();
let translation_config: models::TranslationConfig = Default::default();
let audio_format: models::AudioFormat = Default::default();
Self {
transcription_config,
translation_config: Some(translation_config),
audio_format: Some(audio_format),
}
}
}
type SplitStreamAlias = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;
pub struct RealtimeSession {
auth_token: String,
rt_url: String,
internal_message_sender: UnboundedSender<ReadMessage>,
}
impl RealtimeSession {
pub fn new(
auth_token: String,
rt_url: Option<String>,
) -> Result<(Self, UnboundedReceiver<ReadMessage>)> {
let (channel_sender, channel_receiver) = unbounded_channel::<ReadMessage>();
let mut url = DEFAULT_RT_URL.to_owned();
if let Some(temp_url) = rt_url {
url = temp_url
}
let formatted_url = format!("{}?sm-sdk=rust-{}", url, VERSION);
let sesh = Self {
auth_token,
rt_url: formatted_url,
internal_message_sender: channel_sender,
};
Ok((sesh, channel_receiver))
}
async fn connect(&mut self) -> Result<(SenderWrapper, SplitStreamAlias)> {
let sec_key: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(16)
.map(char::from)
.collect();
let b64 = general_purpose::STANDARD.encode(sec_key);
let uri = Url::parse(&self.rt_url)?;
let authority = uri.authority();
let host = authority
.find('@')
.map(|idx| authority.split_at(idx + 1).1)
.unwrap_or_else(|| authority);
if host.is_empty() {
return Err(anyhow::Error::from(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"uri host was empty",
)));
}
let auth_header = format!("Bearer {}", self.auth_token.clone());
let req = Request::builder()
.method("GET")
.header("Host", host)
.header("Connection", "keep-alive, Upgrade")
.header("Upgrade", "websocket")
.header("Sec-WebSocket-Version", "13")
.header("Sec-WebSocket-Key", b64)
.header("Authorization", auth_header)
.uri(&self.rt_url)
.body(())?;
let (stream, res) = connect_async(req).await?;
if let Some(resp) = res.body() {
error!("failed to connect {:?}", resp);
}
let (writer, reader) = stream.split();
let sender = SenderWrapper::new(writer);
Ok((sender, reader))
}
async fn wait_for_start(
&mut self,
receiver: &mut SplitStreamAlias,
channel_sender: &tokio::sync::mpsc::UnboundedSender<ReadMessage>,
) -> Result<()> {
let mut retries = 0;
let max_retries = 5;
let mut success = false;
while !success {
let value = receiver.next().await;
if let Some(val) = value {
let message = match val {
Ok(v) => v,
Err(err) => {
warn!("Failed to get data from stream, {:?}", err);
retries += 1;
if retries > max_retries {
return Err(Into::into(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"Recognition failed to start on the server",
)));
}
continue;
}
};
debug!("{:?}", message);
let bin_data = message.into_data();
match serde_json::from_slice::<models::RecognitionStarted>(&bin_data) {
Ok(mess) => {
success = true;
channel_sender.send(ReadMessage::RecognitionStarted(mess))?;
}
Err(err) => {
warn!(
"Could not read value of message into RecognitionStarted struct, {:?}",
err
);
match serde_json::from_slice::<models::Error>(&bin_data) {
Ok(mess) => {
return Err(Into::into(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
format!("Received error from server {}", mess.reason),
)));
}
Err(_) => {
retries += 1;
if retries > max_retries {
return Err(Into::into(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"Recognition failed to start on the server",
)));
}
continue;
}
}
}
};
} else {
return Err(Into::into(std::io::Error::new(
std::io::ErrorKind::TimedOut,
"Failed to receive message from the server",
)));
}
}
Ok(())
}
pub async fn run<R: AsyncReadExt + std::marker::Send + std::marker::Unpin + 'static>(
&mut self,
config: SessionConfig,
reader: R,
) -> Result<(), anyhow::Error> {
let (mut sock_sender, mut sock_receiver) = self.connect().await?;
sock_sender.start_recognition(config).await?;
self.wait_for_start(&mut sock_receiver, &self.internal_message_sender.clone())
.await?;
let sender = &self.internal_message_sender.clone();
let process_messages = { RealtimeSession::process_messages(sock_receiver, sender) };
let send_audio = { sock_sender.send_audio(reader) };
pin_mut!(process_messages, send_audio);
let (messages_res, audio_res) = join!(process_messages, send_audio);
match audio_res {
Ok(_) => debug!("No issues in audio processing task"),
Err(err) => return Err(err),
};
match messages_res {
Ok(_) => debug!("No issues detected whilst processing server-sent messages"),
Err(err) => {
error!("{:?}", err);
return Err(err);
}
};
Ok(())
}
async fn process_messages(
mut receiver: SplitStreamAlias,
channel_sender: &tokio::sync::mpsc::UnboundedSender<ReadMessage>,
) -> Result<()> {
let mut running = true;
while running {
let result = receiver.next().await;
if let Some(val) = result {
let mess = val?;
debug!("{}", mess);
let data = mess.into_data();
let value = from_slice::<ReadMessage>(&data)?;
match value {
ReadMessage::EndOfTranscript(mess) => {
debug!("detected EndOfTranscript message, quitting");
running = false;
channel_sender.send(ReadMessage::EndOfTranscript(mess))?;
}
ReadMessage::Error(mess) => {
channel_sender.send(ReadMessage::Error(mess.clone()))?;
error!("Received error from server {}", mess.reason);
return Err(Into::into(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
format!("Received error from server {}", mess.reason),
)));
}
mess => channel_sender.send(mess)?,
}
} else {
return Err(Into::into(std::io::Error::new(
std::io::ErrorKind::ConnectionAborted,
"Did not receive a message".to_string(),
)));
}
}
debug!("Exited message processing loop");
Ok(())
}
}
struct SenderWrapper {
pub socket: SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
last_seq_no: i32,
}
impl SenderWrapper {
fn new(
socket: SplitSink<
WebSocketStream<MaybeTlsStream<TcpStream>>,
tokio_tungstenite::tungstenite::Message,
>,
) -> Self {
Self {
socket,
last_seq_no: 0,
}
}
async fn send_audio<R: AsyncReadExt + std::marker::Send + std::marker::Unpin + 'static>(
&mut self,
mut reader: R,
) -> Result<()> {
let mut buffer = vec![0u8; 8192];
loop {
debug!("reading audio data");
match reader.read(&mut buffer).await {
Ok(no) => {
if no == 0 {
info!("Reader was empty, closing stream");
self.send_close(self.last_seq_no).await?;
return Ok(());
} else {
debug!("Sending audio length {no}");
let tu_message = Message::from(&buffer[..no]);
self.send_message(tu_message).await?;
self.last_seq_no += 1;
}
}
Err(_) => {
info!("encountered an error reading audio data, closing the stream");
self.send_close(self.last_seq_no).await?;
}
};
}
}
async fn send_message(&mut self, message: Message) -> Result<()> {
let mut retries = 0;
let max_retries = 5;
let mut success = false;
while !success {
match self.socket.send(message.clone()).await {
Ok(()) => (),
Err(err) => {
retries += 1;
if retries >= max_retries {
error!("{:?}", err);
self.socket.send(message).await?;
panic!("arg too many attempts to send")
}
std::thread::sleep(std::time::Duration::from_millis(100));
continue;
}
};
success = true
}
Ok(())
}
async fn start_recognition(&mut self, config: SessionConfig) -> Result<()> {
let mut message: models::StartRecognition = Default::default();
if let Some(aud) = config.audio_format {
message.audio_format = Box::new(aud);
}
message.transcription_config = Box::new(config.transcription_config);
if let Some(transl) = config.translation_config {
message.translation_config = Some(Box::new(transl));
}
let serialised_msg = serde_json::to_string(&message)?;
let ws_message = Message::from(serialised_msg);
debug!("sending StartRecognition message {:?}", ws_message);
self.send_message(ws_message).await
}
async fn send_close(&mut self, last_seq_no: i32) -> Result<()> {
let message =
models::EndOfStream::new(last_seq_no, models::end_of_stream::Message::EndOfStream);
let serialised_msg = serde_json::to_string(&message)?;
let tungstenite_msg = Message::from(serialised_msg);
self.send_message(tungstenite_msg).await
}
}
#[cfg(test)]
mod tests {
use crate::realtime::*;
use std::{
path::PathBuf,
sync::{Arc, Mutex},
};
use tokio::{self, fs::File, try_join};
struct MockStore {
transcript: String,
}
impl MockStore {
pub fn new() -> Self {
Self {
transcript: "".to_owned(),
}
}
pub fn append(&mut self, transcript: String) {
self.transcript = format!("{} {}", self.transcript, transcript);
}
pub fn print(&self) {
print!("{}", self.transcript)
}
}
#[tokio::test]
async fn test_basic_flow() {
let api_key: String = std::env::var("API_KEY").unwrap();
let (mut rt_session, mut receive_channel) = RealtimeSession::new(api_key, None).unwrap();
let test_file_path = PathBuf::new()
.join(".")
.join("tests")
.join("data")
.join("example.wav");
let file = File::open(test_file_path).await.unwrap();
let mut config: SessionConfig = Default::default();
let audio_config = models::AudioFormat::new(models::audio_format::Type::File);
config.audio_format = Some(audio_config);
let mock_store = Arc::new(Mutex::new(MockStore::new()));
let mock_store_clone = mock_store.clone();
let message_task = tokio::spawn(async move {
while let Some(message) = receive_channel.recv().await {
match message {
ReadMessage::AddTranscript(mess) => {
mock_store_clone
.lock()
.unwrap()
.append(mess.metadata.transcript);
}
ReadMessage::EndOfTranscript(_) => return,
_ => {}
}
}
});
let run_task = { rt_session.run(config, file) };
try_join!(
async move { message_task.await.map_err(anyhow::Error::from) },
run_task
)
.unwrap();
mock_store.lock().unwrap().print();
}
}