1use anyhow::{anyhow, Result};
13use async_trait::async_trait;
14use futures::{SinkExt, Stream, StreamExt};
15use std::collections::VecDeque;
16use std::pin::Pin;
17use std::sync::atomic::{AtomicBool, Ordering};
18use std::sync::Arc;
19use std::task::{Context, Poll};
20use std::time::Duration;
21use tokio::sync::{mpsc, Mutex};
22use tokio_tungstenite::{
23 connect_async, tungstenite::Message as WsMessage, MaybeTlsStream, WebSocketStream,
24};
25
26#[derive(Debug, Clone)]
28pub struct WebSocketConfig {
29 pub connect_timeout_ms: u64,
31 pub heartbeat_interval_ms: u64,
33 pub max_reconnect_attempts: u32,
35 pub reconnect_interval_ms: u64,
37 pub receive_buffer_size: usize,
39}
40
41impl Default for WebSocketConfig {
42 fn default() -> Self {
43 Self {
44 connect_timeout_ms: 10000,
45 heartbeat_interval_ms: 30000,
46 max_reconnect_attempts: 3,
47 reconnect_interval_ms: 1000,
48 receive_buffer_size: 100,
49 }
50 }
51}
52
53#[derive(Debug, Clone)]
55pub enum WebSocketMessage {
56 Text(String),
58 Binary(Vec<u8>),
60 Ping(Vec<u8>),
62 Pong(Vec<u8>),
64 Close(Option<String>),
66}
67
68impl From<WsMessage> for WebSocketMessage {
69 fn from(msg: WsMessage) -> Self {
70 match msg {
71 WsMessage::Text(t) => WebSocketMessage::Text(t.to_string()),
72 WsMessage::Binary(b) => WebSocketMessage::Binary(b.to_vec()),
73 WsMessage::Ping(p) => WebSocketMessage::Ping(p.to_vec()),
74 WsMessage::Pong(p) => WebSocketMessage::Pong(p.to_vec()),
75 WsMessage::Close(_) => WebSocketMessage::Close(None),
76 WsMessage::Frame(_) => WebSocketMessage::Text(String::new()),
77 }
78 }
79}
80
81impl From<WebSocketMessage> for WsMessage {
82 fn from(msg: WebSocketMessage) -> Self {
83 match msg {
84 WebSocketMessage::Text(t) => WsMessage::Text(t.into()),
85 WebSocketMessage::Binary(b) => WsMessage::Binary(b.into()),
86 WebSocketMessage::Ping(p) => WsMessage::Ping(p.into()),
87 WebSocketMessage::Pong(p) => WsMessage::Pong(p.into()),
88 WebSocketMessage::Close(_) => WsMessage::Close(None),
89 }
90 }
91}
92
93#[derive(Debug, Clone, Copy, PartialEq)]
95pub enum ConnectionState {
96 Disconnected,
98 Connecting,
100 Connected,
102 Reconnecting,
104 Closed,
106}
107
108pub struct WebSocketAdapter {
115 config: WebSocketConfig,
117 url: String,
119 state: Arc<Mutex<ConnectionState>>,
121 sender: mpsc::Sender<WebSocketMessage>,
123 abort_flag: Arc<AtomicBool>,
125}
126
127impl WebSocketAdapter {
128 pub fn new(url: impl Into<String>) -> Self {
132 Self::with_config(url, WebSocketConfig::default())
133 }
134
135 pub fn with_config(url: impl Into<String>, config: WebSocketConfig) -> Self {
137 let (sender, _) = mpsc::channel(config.receive_buffer_size);
138 Self {
139 config,
140 url: url.into(),
141 state: Arc::new(Mutex::new(ConnectionState::Disconnected)),
142 sender,
143 abort_flag: Arc::new(AtomicBool::new(false)),
144 }
145 }
146
147 pub async fn state(&self) -> ConnectionState {
149 *self.state.lock().await
150 }
151
152 pub fn abort_flag(&self) -> Arc<AtomicBool> {
154 Arc::clone(&self.abort_flag)
155 }
156
157 pub fn abort(&self) {
159 self.abort_flag.store(true, Ordering::Relaxed);
160 }
161
162 pub async fn connect(&self) -> Result<WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>> {
166 {
167 let mut state = self.state.lock().await;
168 if *state == ConnectionState::Connected {
169 return Err(anyhow!("Already connected"));
170 }
171 *state = ConnectionState::Connecting;
172 }
173
174 let url = self.url.clone();
175 let timeout = Duration::from_millis(self.config.connect_timeout_ms);
176
177 let connect_future = async { connect_async(&url).await };
178
179 let result = tokio::time::timeout(timeout, connect_future).await;
180
181 match result {
182 Ok(Ok((stream, _))) => {
183 let mut state = self.state.lock().await;
184 *state = ConnectionState::Connected;
185 tracing::info!("WebSocket connected to {}", self.url);
186 Ok(stream)
187 }
188 Ok(Err(e)) => {
189 let mut state = self.state.lock().await;
190 *state = ConnectionState::Disconnected;
191 Err(anyhow!("WebSocket connection failed: {}", e))
192 }
193 Err(_) => {
194 let mut state = self.state.lock().await;
195 *state = ConnectionState::Disconnected;
196 Err(anyhow!("WebSocket connection timeout"))
197 }
198 }
199 }
200
201 pub async fn send(&self, message: WebSocketMessage) -> Result<()> {
203 self.sender.send(message).await?;
204 Ok(())
205 }
206
207 pub async fn create_stream(&self) -> Result<WebSocketMessageStream> {
211 let stream = self.connect().await?;
212 Ok(WebSocketMessageStream::new(stream, self.abort_flag.clone()))
213 }
214
215 pub async fn close(&self) -> Result<()> {
217 let mut state = self.state.lock().await;
218 *state = ConnectionState::Closed;
219 tracing::info!("WebSocket closed");
220 Ok(())
221 }
222}
223
224pub struct WebSocketMessageStream {
228 inner: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
229 abort_flag: Arc<AtomicBool>,
230 pending: VecDeque<WebSocketMessage>,
231}
232
233impl WebSocketMessageStream {
234 fn new(
235 inner: WebSocketStream<MaybeTlsStream<tokio::net::TcpStream>>,
236 abort_flag: Arc<AtomicBool>,
237 ) -> Self {
238 Self {
239 inner,
240 abort_flag,
241 pending: VecDeque::new(),
242 }
243 }
244
245 pub async fn next_message(&mut self) -> Result<Option<WebSocketMessage>> {
247 if self.abort_flag.load(Ordering::Relaxed) {
248 return Ok(None);
249 }
250
251 loop {
252 if let Some(msg) = self.pending.pop_front() {
253 return Ok(Some(msg));
254 }
255
256 match self.inner.next().await {
257 Some(Ok(ws_msg)) => {
258 let msg: WebSocketMessage = ws_msg.into();
259 match msg {
260 WebSocketMessage::Ping(p) => {
261 let _ = self.inner.send(WsMessage::Pong(p.into())).await;
263 }
264 WebSocketMessage::Close(_) => {
265 return Ok(None);
266 }
267 other => {
268 self.pending.push_back(other);
269 }
270 }
271 }
272 Some(Err(e)) => {
273 tracing::error!("WebSocket error: {}", e);
274 return Err(anyhow!("WebSocket error: {}", e));
275 }
276 None => return Ok(None),
277 }
278 }
279 }
280
281 pub async fn send(&mut self, message: WebSocketMessage) -> Result<()> {
283 let ws_msg: WsMessage = message.into();
284 self.inner.send(ws_msg).await?;
285 Ok(())
286 }
287
288 pub async fn collect_text(&mut self) -> Result<String> {
290 let mut result = String::new();
291 while let Some(msg) = self.next_message().await? {
292 if let WebSocketMessage::Text(t) = msg {
293 result.push_str(&t);
294 }
295 }
296 Ok(result)
297 }
298}
299
300pub struct WebSocketReceiver {
304 stream: WebSocketMessageStream,
305}
306
307impl WebSocketReceiver {
308 pub fn new(stream: WebSocketMessageStream) -> Self {
310 Self { stream }
311 }
312}
313
314impl Stream for WebSocketReceiver {
315 type Item = Result<WebSocketMessage>;
316
317 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
318 let abort_flag = self.stream.abort_flag.clone();
320 if abort_flag.load(Ordering::Relaxed) {
321 return Poll::Ready(None);
322 }
323
324 Pin::new(&mut self.stream.inner).poll_next(cx).map(|opt| {
326 opt.map(|result| {
327 result
328 .map(WebSocketMessage::from)
329 .map_err(|e| anyhow::anyhow!("WebSocket error: {}", e))
330 })
331 })
332 }
333}
334
335#[async_trait]
339pub trait WebSocketAdapterTrait: Send + Sync {
340 async fn connect(&self) -> Result<()>;
342
343 async fn send(&self, message: &str) -> Result<()>;
345
346 async fn receive(&self) -> Result<Option<String>>;
348
349 async fn close(&self) -> Result<()>;
351
352 async fn is_connected(&self) -> bool;
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn test_websocket_config_default() {
362 let config = WebSocketConfig::default();
363 assert_eq!(config.connect_timeout_ms, 10000);
364 assert_eq!(config.heartbeat_interval_ms, 30000);
365 assert_eq!(config.max_reconnect_attempts, 3);
366 }
367
368 #[test]
369 fn test_websocket_message_conversion() {
370 let ws_msg = WsMessage::Text("hello".into());
371 let msg: WebSocketMessage = ws_msg.into();
372 assert!(matches!(msg, WebSocketMessage::Text(t) if t == "hello"));
373 }
374
375 #[test]
376 fn test_websocket_message_to_ws_message() {
377 let msg = WebSocketMessage::Binary(vec![1, 2, 3]);
378 let ws_msg: WsMessage = msg.into();
379 assert!(matches!(ws_msg, WsMessage::Binary(b) if b == vec![1, 2, 3]));
380 }
381
382 #[tokio::test]
383 async fn test_websocket_adapter_creation() {
384 let adapter = WebSocketAdapter::new("ws://localhost:8080");
385 assert_eq!(adapter.state().await, ConnectionState::Disconnected);
386 }
387
388 #[tokio::test]
389 async fn test_websocket_adapter_abort() {
390 let adapter = WebSocketAdapter::new("ws://localhost:8080");
391 assert!(!adapter.abort_flag().load(Ordering::Relaxed));
392 adapter.abort();
393 assert!(adapter.abort_flag().load(Ordering::Relaxed));
394 }
395
396 #[test]
397 fn test_connection_state() {
398 assert_eq!(ConnectionState::Disconnected, ConnectionState::Disconnected);
399 assert_ne!(ConnectionState::Disconnected, ConnectionState::Connected);
400 }
401}