armature_websocket/
client.rs1use crate::error::{WebSocketError, WebSocketResult};
4use crate::message::Message;
5use futures_util::{SinkExt, StreamExt};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::mpsc;
9use tokio_tungstenite::{connect_async, tungstenite::protocol::Message as TungsteniteMessage};
10use url::Url;
11
12#[derive(Debug, Clone)]
14pub struct WebSocketClientBuilder {
15 url: Option<String>,
16 connect_timeout: Duration,
17 max_message_size: Option<usize>,
18}
19
20impl Default for WebSocketClientBuilder {
21 fn default() -> Self {
22 Self {
23 url: None,
24 connect_timeout: Duration::from_secs(30),
25 max_message_size: None,
26 }
27 }
28}
29
30impl WebSocketClientBuilder {
31 pub fn new() -> Self {
33 Self::default()
34 }
35
36 pub fn url<S: Into<String>>(mut self, url: S) -> Self {
38 self.url = Some(url.into());
39 self
40 }
41
42 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
44 self.connect_timeout = timeout;
45 self
46 }
47
48 pub fn max_message_size(mut self, size: usize) -> Self {
50 self.max_message_size = Some(size);
51 self
52 }
53
54 pub async fn connect(self) -> WebSocketResult<WebSocketClient> {
56 let url = self
57 .url
58 .ok_or_else(|| WebSocketError::InvalidUrl("URL not provided".to_string()))?;
59
60 WebSocketClient::connect_with_timeout(&url, self.connect_timeout).await
61 }
62}
63
64pub struct WebSocketClient {
66 tx: mpsc::UnboundedSender<Message>,
67 rx: mpsc::UnboundedReceiver<Message>,
68 closed: AtomicBool,
71}
72
73impl WebSocketClient {
74 pub fn builder() -> WebSocketClientBuilder {
76 WebSocketClientBuilder::new()
77 }
78
79 pub async fn connect(url: &str) -> WebSocketResult<Self> {
81 Self::connect_with_timeout(url, Duration::from_secs(30)).await
82 }
83
84 pub async fn connect_with_timeout(url: &str, timeout: Duration) -> WebSocketResult<Self> {
86 let url = Url::parse(url).map_err(|e| WebSocketError::InvalidUrl(e.to_string()))?;
87
88 let connect_future = connect_async(url.as_str());
89
90 let (ws_stream, _response) = tokio::time::timeout(timeout, connect_future)
91 .await
92 .map_err(|_| WebSocketError::Timeout)?
93 .map_err(WebSocketError::Protocol)?;
94
95 let (write, read) = ws_stream.split();
96
97 let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel::<Message>();
99 let (incoming_tx, incoming_rx) = mpsc::unbounded_channel::<Message>();
100
101 tokio::spawn(Self::writer_task(write, outgoing_rx));
103
104 tokio::spawn(Self::reader_task(read, incoming_tx));
106
107 Ok(Self {
108 tx: outgoing_tx,
109 rx: incoming_rx,
110 closed: AtomicBool::new(false),
111 })
112 }
113
114 async fn writer_task(
116 mut write: futures_util::stream::SplitSink<
117 tokio_tungstenite::WebSocketStream<
118 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
119 >,
120 TungsteniteMessage,
121 >,
122 mut rx: mpsc::UnboundedReceiver<Message>,
123 ) {
124 while let Some(message) = rx.recv().await {
125 let is_close = message.is_close();
126 let raw_message: TungsteniteMessage = message.into();
127
128 if write.send(raw_message).await.is_err() {
129 break;
130 }
131
132 if is_close {
133 break;
134 }
135 }
136
137 let _ = write.close().await;
138 }
139
140 async fn reader_task(
142 mut read: futures_util::stream::SplitStream<
143 tokio_tungstenite::WebSocketStream<
144 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
145 >,
146 >,
147 tx: mpsc::UnboundedSender<Message>,
148 ) {
149 while let Some(result) = read.next().await {
150 match result {
151 Ok(msg) => {
152 if msg.is_close() {
153 let _ = tx.send(Message::close());
154 break;
155 }
156
157 let message: Message = msg.into();
158 if tx.send(message).is_err() {
159 break;
160 }
161 }
162 Err(_) => {
163 break;
164 }
165 }
166 }
167 }
168
169 pub fn send(&self, message: Message) -> WebSocketResult<()> {
171 if self.closed.load(Ordering::Acquire) {
172 return Err(WebSocketError::ConnectionClosed);
173 }
174 self.tx
175 .send(message)
176 .map_err(|e| WebSocketError::Send(e.to_string()))
177 }
178
179 pub fn send_text<S: Into<String>>(&self, text: S) -> WebSocketResult<()> {
181 self.send(Message::text(text))
182 }
183
184 pub fn send_binary<B: Into<bytes::Bytes>>(&self, data: B) -> WebSocketResult<()> {
186 self.send(Message::binary(data))
187 }
188
189 pub fn send_json<T: serde::Serialize>(&self, value: &T) -> WebSocketResult<()> {
191 let message = Message::json(value)?;
192 self.send(message)
193 }
194
195 pub async fn recv(&mut self) -> Option<Message> {
197 self.rx.recv().await
198 }
199
200 pub fn try_recv(&mut self) -> Option<Message> {
202 self.rx.try_recv().ok()
203 }
204
205 pub fn close(&self) {
210 if self
212 .closed
213 .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
214 .is_ok()
215 {
216 let _ = self.tx.send(Message::close());
217 }
218 }
219
220 pub fn is_closed(&self) -> bool {
222 self.closed.load(Ordering::Acquire)
223 }
224}
225
226impl Drop for WebSocketClient {
227 fn drop(&mut self) {
228 self.close();
230 }
231}