use bytes::Bytes;
use futures::lock::Mutex;
use futures::{FutureExt, SinkExt, StreamExt, TryStreamExt};
use std::marker::PhantomData;
use std::sync::atomic::{AtomicU8, Ordering};
use std::sync::Arc;
use std::time::Duration;
use tokio_tungstenite_wasm::Error as WSError;
use tungstenite::Utf8Bytes;
#[cfg(not(target_family = "wasm"))]
use std::time::{Instant, SystemTime, UNIX_EPOCH};
#[cfg(target_family = "wasm")]
use wasmtimer::std::{Instant, SystemTime, UNIX_EPOCH};
pub trait SocketHeartbeatPingFn: Fn(Duration) -> RawMessage + Sync + Send {}
impl<F> SocketHeartbeatPingFn for F where F: Fn(Duration) -> RawMessage + Sync + Send {}
pub type SocketHeartbeatPingFnT = dyn SocketHeartbeatPingFn<Output = RawMessage>;
impl std::fmt::Debug for SocketHeartbeatPingFnT {
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SocketConfig {
pub heartbeat: Duration,
pub timeout: Duration,
pub heartbeat_ping_msg_fn: Arc<dyn SocketHeartbeatPingFn>,
}
impl Default for SocketConfig {
fn default() -> Self {
Self {
heartbeat: Duration::from_secs(5),
timeout: Duration::from_secs(10),
heartbeat_ping_msg_fn: Arc::new(|timestamp: Duration| {
let timestamp = timestamp.as_millis();
let bytes = timestamp.to_be_bytes();
RawMessage::Ping(bytes.to_vec().into())
}),
}
}
}
#[derive(Debug, Clone)]
pub enum CloseCode {
Normal,
Away,
Protocol,
Unsupported,
Status,
Abnormal,
Invalid,
Policy,
Size,
Extension,
Error,
Restart,
Again,
#[doc(hidden)]
Tls,
#[doc(hidden)]
Reserved(u16),
#[doc(hidden)]
Iana(u16),
#[doc(hidden)]
Library(u16),
#[doc(hidden)]
Bad(u16),
}
impl From<CloseCode> for u16 {
fn from(code: CloseCode) -> u16 {
use self::CloseCode::*;
match code {
Normal => 1000,
Away => 1001,
Protocol => 1002,
Unsupported => 1003,
Status => 1005,
Abnormal => 1006,
Invalid => 1007,
Policy => 1008,
Size => 1009,
Extension => 1010,
Error => 1011,
Restart => 1012,
Again => 1013,
Tls => 1015,
Reserved(code) => code,
Iana(code) => code,
Library(code) => code,
Bad(code) => code,
}
}
}
impl From<u16> for CloseCode {
fn from(code: u16) -> Self {
use self::CloseCode::*;
match code {
1000 => Normal,
1001 => Away,
1002 => Protocol,
1003 => Unsupported,
1005 => Status,
1006 => Abnormal,
1007 => Invalid,
1008 => Policy,
1009 => Size,
1010 => Extension,
1011 => Error,
1012 => Restart,
1013 => Again,
1015 => Tls,
1..=999 => Bad(code),
1016..=2999 => Reserved(code),
3000..=3999 => Iana(code),
4000..=4999 => Library(code),
_ => Bad(code),
}
}
}
#[derive(Debug, Clone)]
pub struct CloseFrame {
pub code: CloseCode,
pub reason: Utf8Bytes,
}
#[derive(Debug, Clone)]
pub enum Message {
Text(Utf8Bytes),
Binary(Bytes),
Close(Option<CloseFrame>),
}
#[derive(Debug, Clone)]
pub enum RawMessage {
Text(Utf8Bytes),
Binary(Bytes),
Ping(Bytes),
Pong(Bytes),
Close(Option<CloseFrame>),
}
impl From<Message> for RawMessage {
fn from(message: Message) -> Self {
match message {
Message::Text(text) => Self::Text(text),
Message::Binary(bytes) => Self::Binary(bytes),
Message::Close(frame) => Self::Close(frame),
}
}
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum MessageStatus {
Sending,
Sent,
Failed,
}
#[derive(Debug, Clone)]
pub struct MessageSignal {
signal: Arc<AtomicU8>,
}
impl MessageSignal {
pub fn new(status: MessageStatus) -> Self {
let signal = Self::default();
signal.set(status);
signal
}
pub fn status(&self) -> MessageStatus {
match self.signal.load(Ordering::Acquire) {
0u8 => MessageStatus::Sending,
1u8 => MessageStatus::Sent,
_ => MessageStatus::Failed,
}
}
pub(crate) fn set(&self, status: MessageStatus) {
match status {
MessageStatus::Sending => self.signal.store(0u8, Ordering::Release),
MessageStatus::Sent => self.signal.store(1u8, Ordering::Release),
MessageStatus::Failed => self.signal.store(2u8, Ordering::Release),
}
}
}
impl Default for MessageSignal {
fn default() -> Self {
Self {
signal: Arc::new(AtomicU8::new(0u8)),
}
}
}
#[derive(Debug, Clone)]
pub struct InRawMessage {
message: Option<RawMessage>,
signal: Option<MessageSignal>,
}
impl InRawMessage {
pub fn new(message: RawMessage) -> Self {
Self {
message: Some(message),
signal: Some(MessageSignal::default()),
}
}
pub(crate) fn take_message(&mut self) -> Option<RawMessage> {
self.message.take()
}
pub(crate) fn set_signal(&mut self, state: MessageStatus) {
let Some(signal) = &self.signal else {
return;
};
signal.set(state);
self.signal = None;
}
}
impl Drop for InRawMessage {
fn drop(&mut self) {
self.set_signal(MessageStatus::Failed);
}
}
#[derive(Debug, Clone)]
pub struct InMessage {
pub(crate) message: Option<Message>,
signal: Option<MessageSignal>,
}
impl InMessage {
pub fn new(message: Message) -> Self {
Self {
message: Some(message),
signal: Some(MessageSignal::default()),
}
}
pub fn clone_signal(&self) -> Option<MessageSignal> {
self.signal.clone()
}
}
impl From<InMessage> for InRawMessage {
fn from(mut inmessage: InMessage) -> Self {
Self {
message: inmessage.message.take().map(|msg| msg.into()),
signal: inmessage.signal.take(),
}
}
}
impl Drop for InMessage {
fn drop(&mut self) {
let Some(signal) = self.signal.take() else {
return;
};
signal.set(MessageStatus::Failed);
}
}
#[derive(Debug)]
struct SinkActor<M, S>
where
M: From<RawMessage>,
S: SinkExt<M, Error = WSError> + Unpin,
{
receiver: async_channel::Receiver<InRawMessage>,
abort_receiver: async_channel::Receiver<()>,
sink: S,
phantom: PhantomData<M>,
}
impl<M, S> SinkActor<M, S>
where
M: From<RawMessage>,
S: SinkExt<M, Error = WSError> + Unpin,
{
async fn run(&mut self) -> Result<(), WSError> {
loop {
futures::select! {
res = self.receiver.recv().fuse() => {
let Ok(mut inmessage) = res else {
break;
};
let Some(message) = inmessage.take_message() else {
continue;
};
tracing::trace!("sending message: {:?}", message);
match self.sink.send(M::from(message)).await {
Ok(()) => inmessage.set_signal(MessageStatus::Sent),
Err(err) => {
inmessage.set_signal(MessageStatus::Failed);
tracing::warn!(?err, "sink send failed");
return Err(err);
}
}
},
_ = &mut self.abort_receiver.recv().fuse() => {
break;
},
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct Sink {
sender: async_channel::Sender<InRawMessage>,
}
impl Sink {
fn new<M, S>(
sink: S,
abort_receiver: async_channel::Receiver<()>,
handle: impl enfync::Handle,
) -> (enfync::PendingResult<Result<(), WSError>>, Self)
where
M: From<RawMessage> + Send + 'static,
S: SinkExt<M, Error = WSError> + Unpin + Send + 'static,
{
let (sender, receiver) = async_channel::unbounded();
let mut actor = SinkActor {
receiver,
abort_receiver,
sink,
phantom: Default::default(),
};
let future = handle.spawn(async move { actor.run().await });
(future, Self { sender })
}
pub fn is_closed(&self) -> bool {
self.sender.is_closed()
}
pub async fn send(
&self,
inmessage: InMessage,
) -> Result<(), async_channel::SendError<InRawMessage>> {
self.sender.send(inmessage.into()).await
}
pub(crate) async fn send_raw(
&self,
inmessage: InRawMessage,
) -> Result<(), async_channel::SendError<InRawMessage>> {
self.sender.send(inmessage).await
}
}
#[derive(Debug)]
struct StreamActor<M, S>
where
M: Into<RawMessage>,
S: StreamExt<Item = Result<M, WSError>> + Unpin,
{
sender: async_channel::Sender<Result<Message, WSError>>,
stream: S,
last_alive: Arc<Mutex<Instant>>,
}
impl<M, S> StreamActor<M, S>
where
M: Into<RawMessage>,
S: StreamExt<Item = Result<M, WSError>> + Unpin,
{
async fn run(mut self) {
while let Some(result) = self.stream.next().await {
let result = result.map(M::into);
tracing::trace!("received message: {:?}", result);
*self.last_alive.lock().await = Instant::now();
let mut closing = false;
let message = match result {
Ok(message) => Ok(match message {
RawMessage::Text(text) => Message::Text(text),
RawMessage::Binary(bytes) => Message::Binary(bytes),
RawMessage::Ping(_bytes) => continue,
RawMessage::Pong(bytes) => {
if let Ok(bytes) = (*bytes).try_into() {
let bytes: [u8; 16] = bytes;
let timestamp = u128::from_be_bytes(bytes);
let timestamp = Duration::from_millis(timestamp as u64); let latency = SystemTime::now()
.duration_since(UNIX_EPOCH + timestamp)
.unwrap_or_default();
tracing::trace!("latency: {}ms", latency.as_millis());
}
continue;
}
RawMessage::Close(frame) => {
closing = true;
Message::Close(frame)
}
}),
Err(err) => Err(err), };
if self.sender.send(message).await.is_err() {
if closing {
tracing::trace!("stream is closed");
} else {
tracing::warn!("failed to forward message, stream is disconnected");
}
break;
};
}
}
}
#[derive(Debug)]
pub struct Stream {
receiver: async_channel::Receiver<Result<Message, WSError>>,
}
impl Stream {
fn new<M, S>(
stream: S,
last_alive: Arc<Mutex<Instant>>,
handle: impl enfync::Handle,
) -> (enfync::PendingResult<()>, Self)
where
M: Into<RawMessage> + std::fmt::Debug + Send + 'static,
S: StreamExt<Item = Result<M, WSError>> + Unpin + Send + 'static,
{
let (sender, receiver) = async_channel::unbounded();
let actor = StreamActor {
sender,
stream,
last_alive,
};
let future = handle.spawn(actor.run());
(future, Self { receiver })
}
pub async fn recv(&mut self) -> Option<Result<Message, WSError>> {
self.receiver.recv().await.ok()
}
}
#[derive(Debug)]
pub struct Socket {
pub sink: Sink,
pub stream: Stream,
sink_result_receiver: Option<async_channel::Receiver<Result<(), WSError>>>,
}
impl Socket {
pub fn new<M, E, S>(socket: S, config: SocketConfig, handle: impl enfync::Handle) -> Self
where
M: Into<RawMessage> + From<RawMessage> + std::fmt::Debug + Send + 'static,
E: Into<WSError> + std::error::Error,
S: SinkExt<M, Error = E> + Unpin + StreamExt<Item = Result<M, E>> + Unpin + Send + 'static,
{
let last_alive = Instant::now();
let last_alive = Arc::new(Mutex::new(last_alive));
let (sink, stream) = socket.sink_err_into().err_into().split();
let (sink_abort_sender, sink_abort_receiver) = async_channel::bounded(1usize);
let ((mut sink_future, sink), (mut stream_future, stream)) = (
Sink::new(sink, sink_abort_receiver, handle.clone()),
Stream::new(stream, last_alive.clone(), handle.clone()),
);
let (hearbeat_abort_sender, hearbeat_abort_receiver) = async_channel::bounded(1usize);
let sink_clone = sink.clone();
handle.spawn(async move {
socket_heartbeat(sink_clone, config, hearbeat_abort_receiver, last_alive).await
});
let (sink_result_sender, sink_result_receiver) = async_channel::bounded(1usize);
handle.spawn(async move {
let _ = stream_future.extract().await;
let _ = sink_abort_sender.send_blocking(());
let _ = hearbeat_abort_sender.send_blocking(());
let _ = sink_result_sender.send_blocking(
sink_future
.extract()
.await
.unwrap_or(Err(WSError::AlreadyClosed)),
);
});
Self {
sink,
stream,
sink_result_receiver: Some(sink_result_receiver),
}
}
pub async fn send(
&self,
message: InMessage,
) -> Result<(), async_channel::SendError<InRawMessage>> {
self.sink.send(message).await
}
pub async fn send_raw(
&self,
message: InRawMessage,
) -> Result<(), async_channel::SendError<InRawMessage>> {
self.sink.send_raw(message).await
}
pub async fn recv(&mut self) -> Option<Result<Message, WSError>> {
self.stream.recv().await
}
pub(crate) async fn await_sink_close(&mut self) -> Result<(), WSError> {
let Some(sink_result_receiver) = self.sink_result_receiver.take() else {
return Err(WSError::AlreadyClosed);
};
sink_result_receiver
.recv()
.await
.unwrap_or(Err(WSError::AlreadyClosed))
}
}
#[cfg(not(target_family = "wasm"))]
async fn socket_heartbeat(
sink: Sink,
config: SocketConfig,
abort_receiver: async_channel::Receiver<()>,
last_alive: Arc<Mutex<Instant>>,
) {
let sleep = tokio::time::sleep(config.heartbeat);
tokio::pin!(sleep);
loop {
tokio::select! {
_ = &mut sleep => {
let Some(next_sleep_duration) = handle_heartbeat_sleep_elapsed(&sink, &config, &last_alive).await else {
break;
};
sleep.as_mut().reset(tokio::time::Instant::now() + next_sleep_duration);
}
_ = abort_receiver.recv() => break,
}
}
}
#[cfg(target_family = "wasm")]
async fn socket_heartbeat(
sink: Sink,
config: SocketConfig,
abort_receiver: async_channel::Receiver<()>,
last_alive: Arc<Mutex<Instant>>,
) {
let mut sleep_duration = config.heartbeat;
loop {
let sleep = wasmtimer::tokio::sleep(sleep_duration).fuse();
futures::pin_mut!(sleep);
futures::select! {
_ = sleep => {
let Some(next_sleep_duration) = handle_heartbeat_sleep_elapsed(&sink, &config, &last_alive).await else {
break;
};
sleep_duration = next_sleep_duration;
}
_ = &mut abort_receiver.recv().fuse() => break,
}
}
}
async fn handle_heartbeat_sleep_elapsed(
sink: &Sink,
config: &SocketConfig,
last_alive: &Arc<Mutex<Instant>>,
) -> Option<Duration> {
let elapsed_since_last_alive = last_alive.lock().await.elapsed();
if elapsed_since_last_alive > config.timeout {
tracing::info!("closing connection due to timeout");
let _ = sink
.send_raw(InRawMessage::new(RawMessage::Close(Some(CloseFrame {
code: CloseCode::Abnormal,
reason: "remote partner is inactive".into(),
}))))
.await;
return None;
} else if elapsed_since_last_alive < config.heartbeat {
return Some(config.heartbeat.saturating_sub(elapsed_since_last_alive));
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
if sink
.send_raw(InRawMessage::new((config.heartbeat_ping_msg_fn)(timestamp)))
.await
.is_err()
{
return None;
}
Some(config.heartbeat)
}