armature_websocket/
client.rs

1//! WebSocket client implementation.
2
3use crate::error::{WebSocketError, WebSocketResult};
4use crate::message::Message;
5use futures_util::{SinkExt, StreamExt};
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::mpsc;
9use tokio_tungstenite::{connect_async, tungstenite::protocol::Message as TungsteniteMessage};
10use url::Url;
11
12/// Builder for WebSocket client.
13#[derive(Debug, Clone)]
14pub struct WebSocketClientBuilder {
15    url: Option<String>,
16    connect_timeout: Duration,
17    max_message_size: Option<usize>,
18}
19
20impl Default for WebSocketClientBuilder {
21    fn default() -> Self {
22        Self {
23            url: None,
24            connect_timeout: Duration::from_secs(30),
25            max_message_size: None,
26        }
27    }
28}
29
30impl WebSocketClientBuilder {
31    /// Create a new client builder.
32    pub fn new() -> Self {
33        Self::default()
34    }
35
36    /// Set the WebSocket URL.
37    pub fn url<S: Into<String>>(mut self, url: S) -> Self {
38        self.url = Some(url.into());
39        self
40    }
41
42    /// Set the connection timeout.
43    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
44        self.connect_timeout = timeout;
45        self
46    }
47
48    /// Set the maximum message size.
49    pub fn max_message_size(mut self, size: usize) -> Self {
50        self.max_message_size = Some(size);
51        self
52    }
53
54    /// Connect to the WebSocket server.
55    pub async fn connect(self) -> WebSocketResult<WebSocketClient> {
56        let url = self
57            .url
58            .ok_or_else(|| WebSocketError::InvalidUrl("URL not provided".to_string()))?;
59
60        WebSocketClient::connect_with_timeout(&url, self.connect_timeout).await
61    }
62}
63
64/// WebSocket client for connecting to WebSocket servers.
65pub struct WebSocketClient {
66    tx: mpsc::UnboundedSender<Message>,
67    rx: mpsc::UnboundedReceiver<Message>,
68    /// Thread-safe closed flag using AtomicBool to prevent data races
69    /// between send() and close() when client is shared across tasks.
70    closed: AtomicBool,
71}
72
73impl WebSocketClient {
74    /// Create a new client builder.
75    pub fn builder() -> WebSocketClientBuilder {
76        WebSocketClientBuilder::new()
77    }
78
79    /// Connect to a WebSocket server.
80    pub async fn connect(url: &str) -> WebSocketResult<Self> {
81        Self::connect_with_timeout(url, Duration::from_secs(30)).await
82    }
83
84    /// Connect to a WebSocket server with a timeout.
85    pub async fn connect_with_timeout(url: &str, timeout: Duration) -> WebSocketResult<Self> {
86        let url = Url::parse(url).map_err(|e| WebSocketError::InvalidUrl(e.to_string()))?;
87
88        let connect_future = connect_async(url.as_str());
89
90        let (ws_stream, _response) = tokio::time::timeout(timeout, connect_future)
91            .await
92            .map_err(|_| WebSocketError::Timeout)?
93            .map_err(WebSocketError::Protocol)?;
94
95        let (write, read) = ws_stream.split();
96
97        // Create channels for sending and receiving messages
98        let (outgoing_tx, outgoing_rx) = mpsc::unbounded_channel::<Message>();
99        let (incoming_tx, incoming_rx) = mpsc::unbounded_channel::<Message>();
100
101        // Spawn writer task
102        tokio::spawn(Self::writer_task(write, outgoing_rx));
103
104        // Spawn reader task
105        tokio::spawn(Self::reader_task(read, incoming_tx));
106
107        Ok(Self {
108            tx: outgoing_tx,
109            rx: incoming_rx,
110            closed: AtomicBool::new(false),
111        })
112    }
113
114    /// Writer task that sends messages to the WebSocket.
115    async fn writer_task(
116        mut write: futures_util::stream::SplitSink<
117            tokio_tungstenite::WebSocketStream<
118                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
119            >,
120            TungsteniteMessage,
121        >,
122        mut rx: mpsc::UnboundedReceiver<Message>,
123    ) {
124        while let Some(message) = rx.recv().await {
125            let is_close = message.is_close();
126            let raw_message: TungsteniteMessage = message.into();
127
128            if write.send(raw_message).await.is_err() {
129                break;
130            }
131
132            if is_close {
133                break;
134            }
135        }
136
137        let _ = write.close().await;
138    }
139
140    /// Reader task that receives messages from the WebSocket.
141    async fn reader_task(
142        mut read: futures_util::stream::SplitStream<
143            tokio_tungstenite::WebSocketStream<
144                tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
145            >,
146        >,
147        tx: mpsc::UnboundedSender<Message>,
148    ) {
149        while let Some(result) = read.next().await {
150            match result {
151                Ok(msg) => {
152                    if msg.is_close() {
153                        let _ = tx.send(Message::close());
154                        break;
155                    }
156
157                    let message: Message = msg.into();
158                    if tx.send(message).is_err() {
159                        break;
160                    }
161                }
162                Err(_) => {
163                    break;
164                }
165            }
166        }
167    }
168
169    /// Send a message to the server.
170    pub fn send(&self, message: Message) -> WebSocketResult<()> {
171        if self.closed.load(Ordering::Acquire) {
172            return Err(WebSocketError::ConnectionClosed);
173        }
174        self.tx
175            .send(message)
176            .map_err(|e| WebSocketError::Send(e.to_string()))
177    }
178
179    /// Send a text message.
180    pub fn send_text<S: Into<String>>(&self, text: S) -> WebSocketResult<()> {
181        self.send(Message::text(text))
182    }
183
184    /// Send a binary message.
185    pub fn send_binary<B: Into<bytes::Bytes>>(&self, data: B) -> WebSocketResult<()> {
186        self.send(Message::binary(data))
187    }
188
189    /// Send a JSON message.
190    pub fn send_json<T: serde::Serialize>(&self, value: &T) -> WebSocketResult<()> {
191        let message = Message::json(value)?;
192        self.send(message)
193    }
194
195    /// Receive the next message from the server.
196    pub async fn recv(&mut self) -> Option<Message> {
197        self.rx.recv().await
198    }
199
200    /// Try to receive a message without blocking.
201    pub fn try_recv(&mut self) -> Option<Message> {
202        self.rx.try_recv().ok()
203    }
204
205    /// Close the connection.
206    ///
207    /// This method uses atomic compare-and-exchange to ensure only one task
208    /// sends the close message, even when called concurrently.
209    pub fn close(&self) {
210        // Atomically set closed from false to true; only proceed if we won the race
211        if self
212            .closed
213            .compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
214            .is_ok()
215        {
216            let _ = self.tx.send(Message::close());
217        }
218    }
219
220    /// Check if the connection is closed.
221    pub fn is_closed(&self) -> bool {
222        self.closed.load(Ordering::Acquire)
223    }
224}
225
226impl Drop for WebSocketClient {
227    fn drop(&mut self) {
228        // close() now takes &self, but we have &mut self which coerces
229        self.close();
230    }
231}