azure_speech/connector/
client.rs1use futures_util::SinkExt;
2use std::time::Duration;
3use tokio::sync::{broadcast, mpsc, oneshot};
4use tokio_stream::wrappers::BroadcastStream;
5use tokio_stream::{Stream, StreamExt};
6use tokio_websockets::{self, ClientBuilder, MaybeTlsStream, WebSocketStream};
7
8#[async_trait::async_trait]
9trait Connector {
10 async fn connect_stream(&self) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, tokio_websockets::Error>;
11}
12
13#[async_trait::async_trait]
14impl Connector for ClientBuilder<'static> {
15 async fn connect_stream(&self) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, tokio_websockets::Error> {
16 Ok(self.connect().await?.0)
17 }
18}
19
20async fn reconnect_with_attempts<C: Connector>(
21 client: &C,
22 attempts: usize,
23) -> crate::Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>> {
24 let mut last_error = None;
25 for i in 0..attempts {
26 tracing::debug!("Reconnecting ({}/{})", i + 1, attempts);
27 match client.connect_stream().await {
28 Ok(stream) => return Ok(stream),
29 Err(e) => {
30 tracing::error!("Failed to reconnect ({}/{}): {}", i + 1, attempts, e);
31 last_error.replace(e);
32 }
33 }
34 }
35
36 Err(crate::Error::ConnectionError(
37 last_error
38 .map(|e| e.to_string())
39 .unwrap_or_else(|| "reconnect failed".to_string()),
40 ))
41}
42
43enum InternalMessage {
44 SendMessage(tokio_websockets::Message),
45 Subscribe(
46 oneshot::Sender<
47 crate::Result<broadcast::Receiver<crate::Result<tokio_websockets::Message>>>,
48 >,
49 ),
50 Disconnect,
51}
52
53#[derive(Clone)]
54pub struct Client {
55 channel: mpsc::Sender<InternalMessage>,
56}
57
58impl Client {
59 fn new(channel: mpsc::Sender<InternalMessage>) -> Self {
61 Self { channel }
62 }
63}
64
65impl Client {
66 pub async fn send(&self, message: tokio_websockets::Message) -> crate::Result<()> {
67 self.channel
68 .send(InternalMessage::SendMessage(message))
69 .await?;
70 Ok(())
71 }
72
73 pub async fn send_text(&self, text: impl Into<String>) -> crate::Result<()> {
75 self.channel
76 .send(InternalMessage::SendMessage(
77 tokio_websockets::Message::text(text.into()),
78 ))
79 .await?;
80 Ok(())
81 }
82
83 pub async fn send_binary(&self, bytes: impl Into<Vec<u8>>) -> crate::Result<()> {
85 self.channel
86 .send(InternalMessage::SendMessage(
87 tokio_websockets::Message::binary(bytes.into()),
88 ))
89 .await?;
90 Ok(())
91 }
92
93 pub async fn stream(&self) -> crate::Result<impl Stream<Item = crate::Result<crate::Message>>> {
95 let (sender, receiver) = oneshot::channel();
96 self.channel
97 .send(InternalMessage::Subscribe(sender))
98 .await?;
99
100 let br = BroadcastStream::new(receiver.await.map_err(|_| {
101 crate::Error::InternalError("Failed to subscribe to messages".to_string())
102 })??)
103 .timeout(Duration::from_secs(30));
104
105 let br = Box::pin(br);
106
107 let br = br
108 .map(move |m| {
109 tracing::trace!("Downstream message: {:?}", m);
110 m
111 })
112 .filter_map(move |message| match message {
113 Ok(message) => message.ok(),
114 Err(_e) => Some(Err(crate::Error::Timeout)),
115 })
116 .map(move |message| {
117 message.and_then(|msg| {
118 crate::Message::try_from(msg)
119 .map_err(|e| crate::Error::InternalError(e.to_string()))
120 })
121 })
122 .map(move |m| m);
123
124 Ok(br)
125 }
126}
127
128impl Client {
129 pub async fn connect(client: ClientBuilder<'static>) -> crate::Result<Self> {
130 let (mut stream, _res) = client.connect().await?;
131 let (sender, mut receiver) = mpsc::channel(16);
132 tokio::spawn(async move {
133 let (broadcaster, _) = broadcast::channel(32);
134 let mut connected = true;
135 loop {
136 tokio::select! {
137 msg = receiver.recv() => {
138 let Some(msg) = msg else {
139 break;
141 };
142 match msg {
143 InternalMessage::SendMessage(msg) => {
144 tracing::trace!("Upstream message: {:?}", msg.as_text());
145 let _ = stream.send(msg).await;
146 },
147 InternalMessage::Subscribe(c) => {
148 if !connected {
149 match reconnect_with_attempts(&client, 3).await {
150 Ok(new_stream) => {
151 connected = true;
152 stream = new_stream;
153 }
154 Err(err) => {
155 let _ = c.send(Err(err));
156 continue;
157 }
158 }
159 }
160
161 let _ = c.send(Ok(broadcaster.subscribe()));
162 },
163 InternalMessage::Disconnect => {
164 let _ = stream.close().await;
165 break;
166 }
167 }
168 }
169 msg = stream.next(), if connected => {
170 let Some(msg) = msg else {
171 connected = false;
174 continue;
175 };
176 match msg {
177 Ok(msg) => {
178
179 if msg.is_text() || msg.is_binary() {
180 let _ = broadcaster.send(Ok(msg.clone()));
181 } else if msg.is_close() {
182 connected = false;
183
184 let close = msg.as_close().unwrap();
185 let _ = broadcaster.send(Err(crate::Error::ServerDisconnect(format!("{:?}", close))));
186 tracing::warn!(reason = ?close.0, msg = close.1, "disconnected from server");
187 }
188 },
189 Err(e) => {
190 tracing::warn!(?e, "connection errored");
191 let _ = broadcaster.send(Err(e.into()));
192 connected = false;
193 }
194 }
195 }
196 }
197 }
198 });
199 Ok(Client::new(sender))
200 }
201
202 pub(crate) async fn disconnect(&self) -> crate::Result<()> {
204 self.channel.send(InternalMessage::Disconnect).await?;
205 self.channel.closed().await;
207 Ok(())
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use std::sync::atomic::{AtomicUsize, Ordering};
215
216 struct MockConnector {
217 fail_times: usize,
218 calls: AtomicUsize,
219 }
220
221 #[async_trait::async_trait]
222 impl Connector for MockConnector {
223 async fn connect_stream(
224 &self,
225 ) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>, tokio_websockets::Error>
226 {
227 let attempt = self.calls.fetch_add(1, Ordering::SeqCst);
228 if attempt < self.fail_times {
229 Err(tokio_websockets::Error::Io(std::io::Error::new(
230 std::io::ErrorKind::Other,
231 "fail",
232 )))
233 } else {
234 let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
235 let addr = listener.local_addr().unwrap();
236 tokio::spawn(async move { let _ = listener.accept().await; });
237 let stream = tokio::net::TcpStream::connect(addr).await.unwrap();
238 Ok(ClientBuilder::new().take_over(MaybeTlsStream::Plain(stream)))
239 }
240 }
241 }
242
243 #[tokio::test]
244 async fn reconnect_helper_succeeds_after_retries() {
245 let builder = MockConnector { fail_times: 2, calls: AtomicUsize::new(0) };
246 let _ = reconnect_with_attempts(&builder, 3).await.expect("should connect");
247 assert_eq!(builder.calls.load(Ordering::SeqCst), 3);
248 }
249
250 #[tokio::test]
251 async fn reconnect_helper_fails_after_max_attempts() {
252 let builder = MockConnector { fail_times: 5, calls: AtomicUsize::new(0) };
253 let res = reconnect_with_attempts(&builder, 3).await;
254 assert!(res.is_err());
255 assert_eq!(builder.calls.load(Ordering::SeqCst), 3);
256 }
257}