azure_speech/connector/
client.rs

1use futures_util::SinkExt;
2use std::time::Duration;
3use tokio::sync::{broadcast, mpsc, oneshot};
4use tokio_stream::wrappers::BroadcastStream;
5use tokio_stream::{Stream, StreamExt};
6use tokio_websockets::{self, ClientBuilder, MaybeTlsStream, WebSocketStream};
7
8#[async_trait::async_trait]
9trait Connector {
10    async fn connect_stream(&self) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, tokio_websockets::Error>;
11}
12
13#[async_trait::async_trait]
14impl Connector for ClientBuilder<'static> {
15    async fn connect_stream(&self) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, tokio_websockets::Error> {
16        Ok(self.connect().await?.0)
17    }
18}
19
20async fn reconnect_with_attempts<C: Connector>(
21    client: &C,
22    attempts: usize,
23) -> crate::Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>> {
24    let mut last_error = None;
25    for i in 0..attempts {
26        tracing::debug!("Reconnecting ({}/{})", i + 1, attempts);
27        match client.connect_stream().await {
28            Ok(stream) => return Ok(stream),
29            Err(e) => {
30                tracing::error!("Failed to reconnect ({}/{}): {}", i + 1, attempts, e);
31                last_error.replace(e);
32            }
33        }
34    }
35
36    Err(crate::Error::ConnectionError(
37        last_error
38            .map(|e| e.to_string())
39            .unwrap_or_else(|| "reconnect failed".to_string()),
40    ))
41}
42
43enum InternalMessage {
44    SendMessage(tokio_websockets::Message),
45    Subscribe(
46        oneshot::Sender<
47            crate::Result<broadcast::Receiver<crate::Result<tokio_websockets::Message>>>,
48        >,
49    ),
50    Disconnect,
51}
52
53#[derive(Clone)]
54pub struct Client {
55    channel: mpsc::Sender<InternalMessage>,
56}
57
58impl Client {
59    /// Create a new client.
60    fn new(channel: mpsc::Sender<InternalMessage>) -> Self {
61        Self { channel }
62    }
63}
64
65impl Client {
66    pub async fn send(&self, message: tokio_websockets::Message) -> crate::Result<()> {
67        self.channel
68            .send(InternalMessage::SendMessage(message))
69            .await?;
70        Ok(())
71    }
72
73    /// Send a text message to the server.
74    pub async fn send_text(&self, text: impl Into<String>) -> crate::Result<()> {
75        self.channel
76            .send(InternalMessage::SendMessage(
77                tokio_websockets::Message::text(text.into()),
78            ))
79            .await?;
80        Ok(())
81    }
82
83    /// Send a binary message to the server.
84    pub async fn send_binary(&self, bytes: impl Into<Vec<u8>>) -> crate::Result<()> {
85        self.channel
86            .send(InternalMessage::SendMessage(
87                tokio_websockets::Message::binary(bytes.into()),
88            ))
89            .await?;
90        Ok(())
91    }
92
93    /// Stream messages from the server.
94    pub async fn stream(&self) -> crate::Result<impl Stream<Item = crate::Result<crate::Message>>> {
95        let (sender, receiver) = oneshot::channel();
96        self.channel
97            .send(InternalMessage::Subscribe(sender))
98            .await?;
99
100        let br = BroadcastStream::new(receiver.await.map_err(|_| {
101            crate::Error::InternalError("Failed to subscribe to messages".to_string())
102        })??)
103        .timeout(Duration::from_secs(30));
104
105        let br = Box::pin(br);
106
107        let br = br
108            .map(move |m| {
109                tracing::trace!("Downstream message: {:?}", m);
110                m
111            })
112            .filter_map(move |message| match message {
113                Ok(message) => message.ok(),
114                Err(_e) => Some(Err(crate::Error::Timeout)),
115            })
116            .map(move |message| {
117                message.and_then(|msg| {
118                    crate::Message::try_from(msg)
119                        .map_err(|e| crate::Error::InternalError(e.to_string()))
120                })
121            })
122            .map(move |m| m);
123
124        Ok(br)
125    }
126}
127
128impl Client {
129    pub async fn connect(client: ClientBuilder<'static>) -> crate::Result<Self> {
130        let (mut stream, _res) = client.connect().await?;
131        let (sender, mut receiver) = mpsc::channel(16);
132        tokio::spawn(async move {
133            let (broadcaster, _) = broadcast::channel(32);
134            let mut connected = true;
135            loop {
136                tokio::select! {
137                    msg = receiver.recv() => {
138                        let Some(msg) = msg else {
139                            // Receiving `None` here means the client has been dropped, so the task should stop as well.
140                            break;
141                        };
142                        match msg {
143                            InternalMessage::SendMessage(msg) => {
144                                tracing::trace!("Upstream message: {:?}", msg.as_text());
145                                let _ = stream.send(msg).await;
146                            },
147                            InternalMessage::Subscribe(c) => {
148                                if !connected {
149                                    match reconnect_with_attempts(&client, 3).await {
150                                        Ok(new_stream) => {
151                                            connected = true;
152                                            stream = new_stream;
153                                        }
154                                        Err(err) => {
155                                            let _ = c.send(Err(err));
156                                            continue;
157                                        }
158                                    }
159                                }
160
161                                let _ = c.send(Ok(broadcaster.subscribe()));
162                            },
163                            InternalMessage::Disconnect => {
164                                let _ = stream.close().await;
165                                break;
166                            }
167                        }
168                    }
169                    msg = stream.next(), if connected => {
170                        let Some(msg) = msg else {
171                            // Receiving `None` here means the socket has been disconnected and can no longer receive messages.
172                            // We set `connected` to false just to make sure that the stream isn't polled again until we're reconnected.
173                            connected = false;
174                            continue;
175                        };
176                        match msg {
177                            Ok(msg) => {
178
179                                if msg.is_text() || msg.is_binary() {
180                                    let _ = broadcaster.send(Ok(msg.clone()));
181                                } else if msg.is_close() {
182                                    connected = false;
183
184                                    let close = msg.as_close().unwrap();
185                                    let _ = broadcaster.send(Err(crate::Error::ServerDisconnect(format!("{:?}", close))));
186                                    tracing::warn!(reason = ?close.0, msg = close.1, "disconnected from server");
187                                }
188                            },
189                            Err(e) => {
190                                tracing::warn!(?e, "connection errored");
191                                let _ = broadcaster.send(Err(e.into()));
192                                connected = false;
193                            }
194                        }
195                    }
196                }
197            }
198        });
199        Ok(Client::new(sender))
200    }
201
202    /// Disconnect the client.
203pub(crate) async fn disconnect(&self) -> crate::Result<()> {
204        self.channel.send(InternalMessage::Disconnect).await?;
205        // await the client to disconnect.
206        self.channel.closed().await;
207        Ok(())
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use std::sync::atomic::{AtomicUsize, Ordering};
215
216    struct MockConnector {
217        fail_times: usize,
218        calls: AtomicUsize,
219    }
220
221    #[async_trait::async_trait]
222    impl Connector for MockConnector {
223        async fn connect_stream(
224            &self,
225        ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, tokio_websockets::Error>
226        {
227            let attempt = self.calls.fetch_add(1, Ordering::SeqCst);
228            if attempt < self.fail_times {
229                Err(tokio_websockets::Error::Io(std::io::Error::new(
230                    std::io::ErrorKind::Other,
231                    "fail",
232                )))
233            } else {
234                let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
235                let addr = listener.local_addr().unwrap();
236                tokio::spawn(async move { let _ = listener.accept().await; });
237                let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
238                Ok(ClientBuilder::new().take_over(MaybeTlsStream::Plain(stream)))
239            }
240        }
241    }
242
243    #[tokio::test]
244    async fn reconnect_helper_succeeds_after_retries() {
245        let builder = MockConnector { fail_times: 2, calls: AtomicUsize::new(0) };
246        let _ = reconnect_with_attempts(&builder, 3).await.expect("should connect");
247        assert_eq!(builder.calls.load(Ordering::SeqCst), 3);
248    }
249
250    #[tokio::test]
251    async fn reconnect_helper_fails_after_max_attempts() {
252        let builder = MockConnector { fail_times: 5, calls: AtomicUsize::new(0) };
253        let res = reconnect_with_attempts(&builder, 3).await;
254        assert!(res.is_err());
255        assert_eq!(builder.calls.load(Ordering::SeqCst), 3);
256    }
257}