1use bytes::Bytes;
2use futures::lock::Mutex;
3use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt};
4use std::marker::PhantomData;
5use std::sync::atomic::{AtomicU8, Ordering};
6use std::sync::Arc;
7use std::time::Duration;
8use tokio_tungstenite_wasm::Error as WSError;
9use tungstenite::Utf8Bytes;
10
11#[cfg(not(target_family = "wasm"))]
12use std::time::{Instant, SystemTime, UNIX_EPOCH};
13
14#[cfg(target_family = "wasm")]
15use wasmtimer::std::{Instant, SystemTime, UNIX_EPOCH};
16
17pub trait SocketHeartbeatPingFn: Fn(Duration) -> RawMessage + Sync + Send {}
19impl<F> SocketHeartbeatPingFn for F where F: Fn(Duration) -> RawMessage + Sync + Send {}
20pub type SocketHeartbeatPingFnT = dyn SocketHeartbeatPingFn<Output = RawMessage>;
21
22impl std::fmt::Debug for SocketHeartbeatPingFnT {
23 fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
24 Ok(())
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct SocketConfig {
31 pub heartbeat: Duration,
33 pub timeout: Duration,
35 pub heartbeat_ping_msg_fn: Arc<dyn SocketHeartbeatPingFn>,
41}
42
43impl Default for SocketConfig {
44 fn default() -> Self {
45 Self {
46 heartbeat: Duration::from_secs(5),
47 timeout: Duration::from_secs(10),
48 heartbeat_ping_msg_fn: Arc::new(|timestamp: Duration| {
49 let timestamp = timestamp.as_millis();
50 let bytes = timestamp.to_be_bytes();
51 RawMessage::Ping(bytes.to_vec().into())
52 }),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
58pub enum CloseCode {
59 Normal,
62 Away,
65 Protocol,
68 Unsupported,
73 Status,
77 Abnormal,
83 Invalid,
88 Policy,
94 Size,
98 Extension,
106 Error,
110 Restart,
113 Again,
117 #[doc(hidden)]
118 Tls,
119 #[doc(hidden)]
120 Reserved(u16),
121 #[doc(hidden)]
122 Iana(u16),
123 #[doc(hidden)]
124 Library(u16),
125 #[doc(hidden)]
126 Bad(u16),
127}
128
129impl From<CloseCode> for u16 {
130 fn from(code: CloseCode) -> u16 {
131 use self::CloseCode::*;
132 match code {
133 Normal => 1000,
134 Away => 1001,
135 Protocol => 1002,
136 Unsupported => 1003,
137 Status => 1005,
138 Abnormal => 1006,
139 Invalid => 1007,
140 Policy => 1008,
141 Size => 1009,
142 Extension => 1010,
143 Error => 1011,
144 Restart => 1012,
145 Again => 1013,
146 Tls => 1015,
147 Reserved(code) => code,
148 Iana(code) => code,
149 Library(code) => code,
150 Bad(code) => code,
151 }
152 }
153}
154
155impl From<u16> for CloseCode {
156 fn from(code: u16) -> Self {
157 use self::CloseCode::*;
158
159 match code {
160 1000 => Normal,
161 1001 => Away,
162 1002 => Protocol,
163 1003 => Unsupported,
164 1005 => Status,
165 1006 => Abnormal,
166 1007 => Invalid,
167 1008 => Policy,
168 1009 => Size,
169 1010 => Extension,
170 1011 => Error,
171 1012 => Restart,
172 1013 => Again,
173 1015 => Tls,
174 1..=999 => Bad(code),
175 1016..=2999 => Reserved(code),
176 3000..=3999 => Iana(code),
177 4000..=4999 => Library(code),
178 _ => Bad(code),
179 }
180 }
181}
182
183#[derive(Debug, Clone)]
184pub struct CloseFrame {
185 pub code: CloseCode,
186 pub reason: Utf8Bytes,
187}
188
189#[derive(Debug, Clone)]
190pub enum Message {
191 Text(Utf8Bytes),
192 Binary(Bytes),
193 Close(Option<CloseFrame>),
194}
195
196#[derive(Debug, Clone)]
197pub enum RawMessage {
198 Text(Utf8Bytes),
199 Binary(Bytes),
200 Ping(Bytes),
201 Pong(Bytes),
202 Close(Option<CloseFrame>),
203}
204
205impl From<Message> for RawMessage {
206 fn from(message: Message) -> Self {
207 match message {
208 Message::Text(text) => Self::Text(text),
209 Message::Binary(bytes) => Self::Binary(bytes),
210 Message::Close(frame) => Self::Close(frame.map(CloseFrame::from)),
211 }
212 }
213}
214
215#[derive(Debug, Copy, Clone, Eq, PartialEq)]
217pub enum MessageStatus {
218 Sending,
220 Sent,
222 Failed,
224}
225
226#[derive(Debug, Clone)]
228pub struct MessageSignal {
229 signal: Arc<AtomicU8>,
230}
231
232impl MessageSignal {
233 pub fn new(status: MessageStatus) -> Self {
237 let signal = Self::default();
238 signal.set(status);
239 signal
240 }
241
242 pub fn status(&self) -> MessageStatus {
244 match self.signal.load(Ordering::Acquire) {
245 0u8 => MessageStatus::Sending,
246 1u8 => MessageStatus::Sent,
247 _ => MessageStatus::Failed,
248 }
249 }
250
251 pub(crate) fn set(&self, status: MessageStatus) {
255 match status {
256 MessageStatus::Sending => self.signal.store(0u8, Ordering::Release),
257 MessageStatus::Sent => self.signal.store(1u8, Ordering::Release),
258 MessageStatus::Failed => self.signal.store(2u8, Ordering::Release),
259 }
260 }
261}
262
263impl Default for MessageSignal {
264 fn default() -> Self {
265 Self {
266 signal: Arc::new(AtomicU8::new(0u8)),
267 }
268 }
269}
270
271#[derive(Debug, Clone)]
273pub struct InRawMessage {
274 message: Option<RawMessage>,
276 signal: Option<MessageSignal>,
277}
278
279impl InRawMessage {
280 pub fn new(message: RawMessage) -> Self {
281 Self {
282 message: Some(message),
283 signal: Some(MessageSignal::default()),
284 }
285 }
286
287 pub(crate) fn take_message(&mut self) -> Option<RawMessage> {
288 self.message.take()
289 }
290
291 pub(crate) fn set_signal(&mut self, state: MessageStatus) {
292 let Some(signal) = &self.signal else {
293 return;
294 };
295 signal.set(state);
296 self.signal = None;
297 }
298}
299
300impl Drop for InRawMessage {
301 fn drop(&mut self) {
302 self.set_signal(MessageStatus::Failed);
304 }
305}
306
307#[derive(Debug, Clone)]
309pub struct InMessage {
310 pub(crate) message: Option<Message>,
312 signal: Option<MessageSignal>,
313}
314
315impl InMessage {
316 pub fn new(message: Message) -> Self {
317 Self {
318 message: Some(message),
319 signal: Some(MessageSignal::default()),
320 }
321 }
322
323 pub fn clone_signal(&self) -> Option<MessageSignal> {
324 self.signal.clone()
325 }
326}
327
328impl From<InMessage> for InRawMessage {
329 fn from(mut inmessage: InMessage) -> Self {
330 Self {
331 message: inmessage.message.take().map(|msg| msg.into()),
332 signal: inmessage.signal.take(),
333 }
334 }
335}
336
337impl Drop for InMessage {
338 fn drop(&mut self) {
339 let Some(signal) = self.signal.take() else {
341 return;
342 };
343 signal.set(MessageStatus::Failed);
344 }
345}
346
347#[derive(Debug)]
348struct SinkActor<M, S>
349where
350 M: From<RawMessage>,
351 S: SinkExt<M, Error = WSError> + Unpin,
352{
353 receiver: async_channel::Receiver<InRawMessage>,
354 abort_receiver: async_channel::Receiver<()>,
355 sink: S,
356 phantom: PhantomData<M>,
357}
358
359impl<M, S> SinkActor<M, S>
360where
361 M: From<RawMessage>,
362 S: SinkExt<M, Error = WSError> + Unpin,
363{
364 async fn run(&mut self) -> Result<(), WSError> {
365 loop {
366 futures::select! {
367 res = self.receiver.recv().fuse() => {
368 let Ok(mut inmessage) = res else {
369 break;
370 };
371 let Some(message) = inmessage.take_message() else {
372 continue;
373 };
374 tracing::trace!("sending message: {:?}", message);
375 match self.sink.send(M::from(message)).await {
376 Ok(()) => inmessage.set_signal(MessageStatus::Sent),
377 Err(err) => {
378 inmessage.set_signal(MessageStatus::Failed);
379 tracing::warn!(?err, "sink send failed");
380 return Err(err);
381 }
382 }
383 },
384 _ = &mut self.abort_receiver.recv().fuse() => {
385 break;
386 },
387 }
388 }
389 Ok(())
390 }
391}
392
393#[derive(Debug, Clone)]
394pub struct Sink {
395 sender: async_channel::Sender<InRawMessage>,
396}
397
398impl Sink {
399 fn new<M, S>(
400 sink: S,
401 abort_receiver: async_channel::Receiver<()>,
402 handle: impl enfync::Handle,
403 ) -> (enfync::PendingResult<Result<(), WSError>>, Self)
404 where
405 M: From<RawMessage> + Send + 'static,
406 S: SinkExt<M, Error = WSError> + Unpin + Send + 'static,
407 {
408 let (sender, receiver) = async_channel::unbounded();
409 let mut actor = SinkActor {
410 receiver,
411 abort_receiver,
412 sink,
413 phantom: Default::default(),
414 };
415 let future = handle.spawn(async move { actor.run().await });
416 (future, Self { sender })
417 }
418
419 pub fn is_closed(&self) -> bool {
420 self.sender.is_closed()
421 }
422
423 pub async fn send(
424 &self,
425 inmessage: InMessage,
426 ) -> Result<(), async_channel::SendError<InRawMessage>> {
427 self.sender.send(inmessage.into()).await
428 }
429
430 pub(crate) async fn send_raw(
431 &self,
432 inmessage: InRawMessage,
433 ) -> Result<(), async_channel::SendError<InRawMessage>> {
434 self.sender.send(inmessage).await
435 }
436}
437
438#[derive(Debug)]
439struct StreamActor<M, S>
440where
441 M: Into<RawMessage>,
442 S: StreamExt<Item = Result<M, WSError>> + Unpin,
443{
444 sender: async_channel::Sender<Result<Message, WSError>>,
445 stream: S,
446 last_alive: Arc<Mutex<Instant>>,
447}
448
449impl<M, S> StreamActor<M, S>
450where
451 M: Into<RawMessage>,
452 S: StreamExt<Item = Result<M, WSError>> + Unpin,
453{
454 async fn run(mut self) {
455 while let Some(result) = self.stream.next().await {
456 let result = result.map(M::into);
457 tracing::trace!("received message: {:?}", result);
458 *self.last_alive.lock().await = Instant::now();
459
460 let mut closing = false;
461 let message = match result {
462 Ok(message) => Ok(match message {
463 RawMessage::Text(text) => Message::Text(text),
464 RawMessage::Binary(bytes) => Message::Binary(bytes),
465 RawMessage::Ping(_bytes) => continue,
466 RawMessage::Pong(bytes) => {
467 if let Ok(bytes) = (*bytes).try_into() {
468 let bytes: [u8; 16] = bytes;
469 let timestamp = u128::from_be_bytes(bytes);
470 let timestamp = Duration::from_millis(timestamp as u64); let latency = SystemTime::now()
472 .duration_since(UNIX_EPOCH + timestamp)
473 .unwrap_or_default();
474 tracing::trace!("latency: {}ms", latency.as_millis());
476 }
477
478 continue;
479 }
480 RawMessage::Close(frame) => {
481 closing = true;
482 Message::Close(frame)
483 }
484 }),
485 Err(err) => Err(err), };
487 if self.sender.send(message).await.is_err() {
488 if closing {
493 tracing::trace!("stream is closed");
494 } else {
495 tracing::warn!("failed to forward message, stream is disconnected");
496 }
497 break;
498 };
499 }
500 }
501}
502
503#[derive(Debug)]
504pub struct Stream {
505 receiver: async_channel::Receiver<Result<Message, WSError>>,
506}
507
508impl Stream {
509 fn new<M, S>(
510 stream: S,
511 last_alive: Arc<Mutex<Instant>>,
512 handle: impl enfync::Handle,
513 ) -> (enfync::PendingResult<()>, Self)
514 where
515 M: Into<RawMessage> + std::fmt::Debug + Send + 'static,
516 S: StreamExt<Item = Result<M, WSError>> + Unpin + Send + 'static,
517 {
518 let (sender, receiver) = async_channel::unbounded();
519 let actor = StreamActor {
520 sender,
521 stream,
522 last_alive,
523 };
524 let future = handle.spawn(actor.run());
525
526 (future, Self { receiver })
527 }
528
529 pub async fn recv(&mut self) -> Option<Result<Message, WSError>> {
530 self.receiver.recv().await.ok()
531 }
532}
533
534#[derive(Debug)]
535pub struct Socket {
536 pub sink: Sink,
537 pub stream: Stream,
538 sink_result_receiver: Option<async_channel::Receiver<Result<(), WSError>>>,
539}
540
541impl Socket {
542 pub fn new<M, E, S>(socket: S, config: SocketConfig, handle: impl enfync::Handle) -> Self
543 where
544 M: Into<RawMessage> + From<RawMessage> + std::fmt::Debug + Send + 'static,
545 E: Into<WSError> + std::error::Error,
546 S: SinkExt<M, Error = E> + Unpin + StreamExt<Item = Result<M, E>> + Unpin + Send + 'static,
547 {
548 let last_alive = Instant::now();
549 let last_alive = Arc::new(Mutex::new(last_alive));
550 let (sink, stream) = socket.sink_err_into().err_into().split();
551 let (sink_abort_sender, sink_abort_receiver) = async_channel::bounded(1usize);
552 let ((mut sink_future, sink), (mut stream_future, stream)) = (
553 Sink::new(sink, sink_abort_receiver, handle.clone()),
554 Stream::new(stream, last_alive.clone(), handle.clone()),
555 );
556 let (hearbeat_abort_sender, hearbeat_abort_receiver) = async_channel::bounded(1usize);
557 let sink_clone = sink.clone();
558 handle.spawn(async move {
559 socket_heartbeat(sink_clone, config, hearbeat_abort_receiver, last_alive).await
560 });
561
562 let (sink_result_sender, sink_result_receiver) = async_channel::bounded(1usize);
563 handle.spawn(async move {
564 let _ = stream_future.extract().await;
565 let _ = sink_abort_sender.send_blocking(());
566 let _ = hearbeat_abort_sender.send_blocking(());
567 let _ = sink_result_sender.send_blocking(
568 sink_future
569 .extract()
570 .await
571 .unwrap_or(Err(WSError::AlreadyClosed)),
572 );
573 });
574
575 Self {
576 sink,
577 stream,
578 sink_result_receiver: Some(sink_result_receiver),
579 }
580 }
581
582 pub async fn send(
583 &self,
584 message: InMessage,
585 ) -> Result<(), async_channel::SendError<InRawMessage>> {
586 self.sink.send(message).await
587 }
588
589 pub async fn send_raw(
590 &self,
591 message: InRawMessage,
592 ) -> Result<(), async_channel::SendError<InRawMessage>> {
593 self.sink.send_raw(message).await
594 }
595
596 pub async fn recv(&mut self) -> Option<Result<Message, WSError>> {
597 self.stream.recv().await
598 }
599
600 pub(crate) async fn await_sink_close(&mut self) -> Result<(), WSError> {
601 let Some(sink_result_receiver) = self.sink_result_receiver.take() else {
602 return Err(WSError::AlreadyClosed);
603 };
604 sink_result_receiver
605 .recv()
606 .await
607 .unwrap_or(Err(WSError::AlreadyClosed))
608 }
609}
610
611#[cfg(not(target_family = "wasm"))]
612async fn socket_heartbeat(
613 sink: Sink,
614 config: SocketConfig,
615 abort_receiver: async_channel::Receiver<()>,
616 last_alive: Arc<Mutex<Instant>>,
617) {
618 let sleep = tokio::time::sleep(config.heartbeat);
619 tokio::pin!(sleep);
620
621 loop {
622 tokio::select! {
623 _ = &mut sleep => {
624 let Some(next_sleep_duration) = handle_heartbeat_sleep_elapsed(&sink, &config, &last_alive).await else {
625 break;
626 };
627 sleep.as_mut().reset(tokio::time::Instant::now() + next_sleep_duration);
628 }
629 _ = abort_receiver.recv() => break,
630 }
631 }
632}
633
634#[cfg(target_family = "wasm")]
635async fn socket_heartbeat(
636 sink: Sink,
637 config: SocketConfig,
638 abort_receiver: async_channel::Receiver<()>,
639 last_alive: Arc<Mutex<Instant>>,
640) {
641 let mut sleep_duration = config.heartbeat;
642
643 loop {
644 let sleep = wasmtimer::tokio::sleep(sleep_duration).fuse();
647 futures::pin_mut!(sleep);
648 futures::select! {
649 _ = sleep => {
650 let Some(next_sleep_duration) = handle_heartbeat_sleep_elapsed(&sink, &config, &last_alive).await else {
651 break;
652 };
653 sleep_duration = next_sleep_duration;
654 }
655 _ = &mut abort_receiver.recv().fuse() => break,
656 }
657 }
658}
659
660async fn handle_heartbeat_sleep_elapsed(
661 sink: &Sink,
662 config: &SocketConfig,
663 last_alive: &Arc<Mutex<Instant>>,
664) -> Option<Duration> {
665 let elapsed_since_last_alive = last_alive.lock().await.elapsed();
667 if elapsed_since_last_alive > config.timeout {
668 tracing::info!("closing connection due to timeout");
669 let _ = sink
670 .send_raw(InRawMessage::new(RawMessage::Close(Some(CloseFrame {
671 code: CloseCode::Abnormal,
672 reason: "remote partner is inactive".into(),
673 }))))
674 .await;
675 return None;
676 } else if elapsed_since_last_alive < config.heartbeat {
677 return Some(config.heartbeat.saturating_sub(elapsed_since_last_alive));
680 }
681
682 let timestamp = SystemTime::now()
684 .duration_since(UNIX_EPOCH)
685 .unwrap_or_default();
686 if sink
687 .send_raw(InRawMessage::new((config.heartbeat_ping_msg_fn)(timestamp)))
688 .await
689 .is_err()
690 {
691 return None;
692 }
693
694 Some(config.heartbeat)
695}