1use std::{
2 sync::{Arc, atomic::{AtomicBool, Ordering}},
3 collections::hash_map::{HashMap, Entry},
4 time::Duration,
5 mem,
6};
7use tokio::{
8 sync::{mpsc as tokio_mpsc, Mutex as AsyncMutex, Notify},
9 task::JoinHandle,
10 net::TcpStream,
11 time::{MissedTickBehavior, timeout},
12};
13use tokio_tungstenite::{
14 tungstenite,
15 MaybeTlsStream,
16};
17pub use tungstenite::Error as TungsteniteError;
18use futures_util::{
19 sink::SinkExt,
20 stream::{StreamExt, SplitSink},
21};
22use parking_lot::Mutex as SyncMutex;
23
24type WebSocketStream = tokio_tungstenite::WebSocketStream<MaybeTlsStream<TcpStream>>;
25type WebSocketSplitSink = SplitSink<WebSocketStream, tungstenite::Message>;
26
27#[derive(Debug)]
41#[must_use = "dropping WebSocketConnection closes the connection"]
42pub struct WebSocketConnection<H: WebSocketHandler> {
43 task_reconnect: JoinHandle<()>,
44 sink: Arc<AsyncMutex<WebSocketSplitSink>>,
45 inner: Arc<ConnectionInner<H>>,
46 reconnect_state: ReconnectState,
47}
48
49#[derive(Debug)]
65struct ConnectionInner<H: WebSocketHandler> {
66 url: String,
67 handler: Arc<SyncMutex<H>>,
68 message_tx: tokio_mpsc::UnboundedSender<(bool, FeederMessage)>,
69 next_connection_id: AtomicBool,
70}
71
72enum FeederMessage {
73 Message(tungstenite::Result<tungstenite::Message>),
74 ConnectionClosed,
75 DropConnectionRequest,
76}
77
78impl<H: WebSocketHandler> WebSocketConnection<H> {
79 pub async fn new(url: &str, handler: H) -> Result<Self, TungsteniteError> {
81 let config = handler.websocket_config();
82 let handler = Arc::new(SyncMutex::new(handler));
83 let url = config.url_prefix.clone() + url;
84
85 let (message_tx, message_rx) = tokio_mpsc::unbounded_channel();
86 let reconnect_manager = ReconnectState::new();
87
88 let connection = Arc::new(ConnectionInner {
89 url,
90 handler: Arc::clone(&handler),
91 message_tx,
92 next_connection_id: AtomicBool::new(false),
93 });
94
95 async fn feed_handler(
96 connection: Arc<ConnectionInner<impl WebSocketHandler>>,
97 mut message_rx: tokio_mpsc::UnboundedReceiver<(bool, FeederMessage)>,
98 reconnect_manager: ReconnectState,
99 config: WebSocketConfig,
100 sink: Arc<AsyncMutex<WebSocketSplitSink>>,
101 ) {
102 let mut messages: HashMap<WebSocketMessage, isize> = HashMap::new();
103
104 let timeout_duration = if config.message_timeout.is_zero() {
105 Duration::MAX
106 } else {
107 config.message_timeout
108 };
109
110 loop {
111 match timeout(timeout_duration, message_rx.recv()).await {
112 Ok(Some((id, FeederMessage::Message(Ok(message))))) => {
114 if let Some(message) = WebSocketMessage::from_message(message) {
116 if reconnect_manager.is_reconnecting() {
117 let id_sign: isize = if id {
119 1
120 } else {
121 -1
122 };
123 let entry = messages.entry(message.clone());
124 match entry {
125 Entry::Occupied(mut occupied) => {
126 if config.ignore_duplicate_during_reconnection {
127 log::debug!("Skipping duplicate message.");
128 continue;
129 }
130
131 *occupied.get_mut() += id_sign;
132 if id_sign != occupied.get().signum() {
133 log::debug!("Skipping duplicate message.");
135 continue;
136 }
137 },
139 Entry::Vacant(vacant) => {
140 vacant.insert(id_sign);
142 }
143 }
144 } else {
145 messages.clear();
146 }
147 let messages = connection.handler.lock().handle_message(message);
148 let mut sink_lock = sink.lock().await;
149 for message in messages {
150 if let Err(error) = sink_lock.send(message.into_message()).await {
151 log::error!("Failed to send message because of an error: {}", error);
152 };
153 }
154 if let Err(error) = sink_lock.flush().await {
155 log::error!("An error occurred while flushing WebSocket sink: {error:?}");
156 }
157 }
158 },
159 Ok(Some((_, FeederMessage::Message(Err(error))))) => {
161 log::error!("Failed to receive message because of an error: {error:?}");
162 if reconnect_manager.request_reconnect() {
163 log::info!("Reconnecting WebSocket because there was an error while receiving a message");
164 }
165 },
166 Err(_) => {
168 log::debug!("WebSocket message timeout");
169 if reconnect_manager.request_reconnect() {
170 log::info!("Reconnecting WebSocket because of timeout");
171 }
172 },
173 Ok(Some((id, FeederMessage::ConnectionClosed))) => {
175 let current_id = !connection.next_connection_id.load(Ordering::SeqCst);
176 if id != current_id {
177 continue;
179 }
180 log::debug!("WebSocket connection closed by server");
181 if reconnect_manager.request_reconnect() {
182 log::info!("Reconnecting WebSocket because it was disconnected by the server");
183 }
184 },
185 Ok(Some((_, FeederMessage::DropConnectionRequest))) => {
187 if let Err(error) = sink.lock().await.close().await {
188 log::debug!("Failed to close WebSocket connection: {error:?}");
189 }
190 break;
191 }
192 Ok(None) => unreachable!("message_rx should never be closed"),
194 }
195 }
196 connection.handler.lock().handle_close(false);
197 }
198
199 async fn reconnect<H: WebSocketHandler>(
200 interval: Duration,
201 cooldown: Duration,
202 connection: Arc<ConnectionInner<H>>,
203 sink: Arc<AsyncMutex<WebSocketSplitSink>>,
204 reconnect_manager: ReconnectState,
205 no_duplicate: bool,
206 wait: Duration,
207 ) {
208 let mut cooldown = tokio::time::interval(cooldown);
209 cooldown.set_missed_tick_behavior(MissedTickBehavior::Delay);
210 loop {
211 let timer = if interval.is_zero() {
212 tokio::time::sleep(Duration::MAX)
214 } else {
215 tokio::time::sleep(interval)
216 };
217 tokio::select! {
218 _ = reconnect_manager.inner.reconnect_notify.notified() => {},
219 _ = timer => {},
220 }
221 log::debug!("Reconnection requested");
222 cooldown.tick().await;
223 reconnect_manager.inner.reconnecting.store(true, Ordering::SeqCst);
224
225 reconnect_manager.inner.reconnect_notify.notify_one();
228 reconnect_manager.inner.reconnect_notify.notified().await;
230
231 log::debug!("Starting reconnection process ...");
232 if no_duplicate {
233 tokio::time::sleep(wait).await;
234 }
235
236 match WebSocketConnection::<H>::start_connection(Arc::clone(&connection)).await {
238 Ok(new_sink) => {
239 let mut old_sink = mem::replace(&mut *sink.lock().await, new_sink);
241 log::debug!("New connection established");
242
243 if no_duplicate {
244 tokio::time::sleep(wait).await;
245 }
246
247 if let Err(error) = old_sink.close().await {
248 log::debug!("An error occurred while closing old connection: {}", error);
249 }
250 connection.handler.lock().handle_close(true);
251 log::debug!("Old connection closed");
252 },
253 Err(error) => {
254 log::error!("Failed to reconnect because of an error: {}, trying again ...", error);
256 reconnect_manager.inner.reconnect_notify.notify_one();
257 },
258 }
259
260 if no_duplicate {
261 tokio::time::sleep(wait).await;
262 }
263
264 reconnect_manager.inner.reconnecting.store(false, Ordering::SeqCst);
265 log::debug!("Reconnection process complete");
266 }
267 }
268
269 let sink_inner = Self::start_connection(Arc::clone(&connection)).await?;
270 let sink = Arc::new(AsyncMutex::new(sink_inner));
271
272 tokio::spawn(
273 feed_handler(
274 Arc::clone(&connection),
275 message_rx,
276 reconnect_manager.clone(),
277 config.clone(),
278 Arc::clone(&sink),
279 )
280 );
281
282 let task_reconnect = tokio::spawn(reconnect(
283 config.refresh_after,
284 config.connect_cooldown,
285 Arc::clone(&connection),
286 Arc::clone(&sink),
287 reconnect_manager.clone(),
288 config.ignore_duplicate_during_reconnection,
289 config.reconnection_wait,
290 ));
291
292 Ok(Self {
293 task_reconnect,
294 sink,
295 inner: connection,
296 reconnect_state: reconnect_manager,
297 })
298 }
299
300 async fn start_connection(connection: Arc<ConnectionInner<impl WebSocketHandler>>) -> Result<WebSocketSplitSink, TungsteniteError> {
301 let (websocket_stream, _) = tokio_tungstenite::connect_async(connection.url.clone()).await?;
302 let (mut sink, mut stream) = websocket_stream.split();
303
304 let messages = connection.handler.lock().handle_start();
305 for message in messages {
306 sink.send(message.into_message()).await?;
307 }
308 sink.flush().await?;
309
310 let id = connection.next_connection_id.fetch_xor(true, Ordering::SeqCst);
312
313 tokio::spawn(async move {
315 while let Some(message) = stream.next().await {
316 if connection.message_tx.send((id, FeederMessage::Message(message))).is_err() {
318 log::debug!("WebSocket message receiver is closed; abandon connection");
320 return;
321 }
322 }
323 drop(connection.message_tx.send((id, FeederMessage::ConnectionClosed))); log::debug!("WebSocket stream closed");
327 });
328 Ok(sink)
329 }
330
331 pub async fn send_message(&self, message: WebSocketMessage) -> Result<(), TungsteniteError> {
333 let mut sink_lock = self.sink.lock().await;
334 sink_lock.send(message.into_message()).await?;
335 sink_lock.flush().await
336 }
337
338 pub fn reconnect_state(&self) -> ReconnectState {
342 self.reconnect_state.clone()
343 }
344}
345
346impl<H: WebSocketHandler> Drop for WebSocketConnection<H> {
347 fn drop(&mut self) {
348 self.task_reconnect.abort();
349 let current_id = !self.inner.next_connection_id.load(Ordering::SeqCst);
351 self.inner.message_tx.send((current_id, FeederMessage::DropConnectionRequest)).ok();
352 }
353}
354
355#[derive(Debug, Clone)]
360pub struct ReconnectState {
361 inner: Arc<ReconnectMangerInner>,
362}
363
364#[derive(Debug)]
365struct ReconnectMangerInner {
366 reconnect_notify: Notify,
367 reconnecting: AtomicBool,
368}
369
370impl ReconnectState {
371 fn new() -> Self {
372 Self {
373 inner: Arc::new(ReconnectMangerInner {
374 reconnect_notify: Notify::new(),
375 reconnecting: AtomicBool::new(false),
376 })
377 }
378 }
379
380 pub fn is_reconnecting(&self) -> bool {
382 self.inner.reconnecting.load(Ordering::SeqCst)
383 }
384
385 pub fn request_reconnect(&self) -> bool {
389 if self.is_reconnecting() {
390 false
391 } else {
392 self.inner.reconnect_notify.notify_one();
393 true
394 }
395 }
396}
397
398#[derive(Debug, Eq, PartialEq, Clone, Hash)]
402pub enum WebSocketMessage {
403 Text(String),
405 Binary(Vec<u8>),
407 Ping(Vec<u8>),
409 Pong(Vec<u8>),
411}
412
413impl WebSocketMessage {
414 fn from_message(message: tungstenite::Message) -> Option<Self> {
415 match message {
416 tungstenite::Message::Text(text) => Some(Self::Text(text)),
417 tungstenite::Message::Binary(data) => Some(Self::Binary(data)),
418 tungstenite::Message::Ping(data) => Some(Self::Ping(data)),
419 tungstenite::Message::Pong(data) => Some(Self::Pong(data)),
420 tungstenite::Message::Close(_) | tungstenite::Message::Frame(_) => None,
421 }
422 }
423
424 fn into_message(self) -> tungstenite::Message {
425 match self {
426 WebSocketMessage::Text(text) => tungstenite::Message::Text(text),
427 WebSocketMessage::Binary(data) => tungstenite::Message::Binary(data),
428 WebSocketMessage::Ping(data) => tungstenite::Message::Ping(data),
429 WebSocketMessage::Pong(data) => tungstenite::Message::Pong(data),
430 }
431 }
432}
433
434pub trait WebSocketHandler: Send + 'static {
439 fn websocket_config(&self) -> WebSocketConfig {
441 WebSocketConfig::default()
442 }
443
444 fn handle_start(&mut self) -> Vec<WebSocketMessage> {
448 log::debug!("WebSocket connection started");
449 vec![]
450 }
451
452 fn handle_message(&mut self, message: WebSocketMessage) -> Vec<WebSocketMessage>;
454
455 #[allow(unused_variables)]
461 fn handle_close(&mut self, reconnect: bool) {
462 log::debug!("WebSocket connection closed; reconnect: {}", reconnect);
463 }
464}
465
466#[derive(Debug, Clone)]
470#[non_exhaustive]
471pub struct WebSocketConfig {
472 pub connect_cooldown: Duration,
477 pub refresh_after: Duration,
480 pub url_prefix: String,
484 pub ignore_duplicate_during_reconnection: bool,
492 pub reconnection_wait: Duration,
495 pub message_timeout: Duration,
498}
499
500impl WebSocketConfig {
501 pub fn new() -> Self {
503 Self::default()
504 }
505}
506
507impl Default for WebSocketConfig {
508 fn default() -> Self {
509 Self {
510 connect_cooldown: Duration::from_millis(3000),
511 refresh_after: Duration::ZERO,
512 url_prefix: String::new(),
513 ignore_duplicate_during_reconnection: false,
514 reconnection_wait: Duration::from_millis(300),
515 message_timeout: Duration::ZERO,
516 }
517 }
518}