use crate::Result;
use crate::log::{debug, error};
use crate::ws::message::Message;
use crate::ws::upgrade::WebSocketParts;
use crate::ws::websocket_handler::WebSocketHandler;
use anyhow::anyhow;
use async_channel::{Sender as UnboundedSender, unbounded as unbounded_channel};
use async_lock::RwLock;
use async_trait::async_trait;
use async_tungstenite::tungstenite::protocol;
use async_tungstenite::{WebSocketReceiver, WebSocketSender, WebSocketStream};
use futures::io::{AsyncRead, AsyncWrite};
use futures_util::ready;
use futures_util::stream::{Stream, StreamExt};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
pub struct WebSocket<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
parts: Arc<RwLock<WebSocketParts>>,
upgrade: WebSocketStream<S>,
}
impl<S> WebSocket<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
#[inline]
pub(crate) async fn from_raw_socket(
upgraded: crate::ws::upgrade::Upgraded<S>,
role: protocol::Role,
config: Option<protocol::WebSocketConfig>,
) -> Self {
let (parts, upgraded) = upgraded.into_parts();
Self {
parts: Arc::new(RwLock::new(parts)),
upgrade: WebSocketStream::from_raw_socket(upgraded, role, config).await,
}
}
#[inline]
pub fn into_parts(self) -> (Arc<RwLock<WebSocketParts>>, Self) {
(self.parts.clone(), self)
}
pub async fn recv(&mut self) -> Option<Result<Message>> {
self.next().await
}
pub async fn send(&mut self, msg: Message) -> Result<()> {
self.upgrade
.send(msg.inner)
.await
.map_err(|e| anyhow!("send error: {}", e).into())
}
#[inline]
pub async fn close(mut self) -> Result<()> {
self.upgrade
.close(None)
.await
.map_err(|e| anyhow!("close error: {}", e).into())
}
}
impl<S> WebSocket<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
#[inline]
pub fn split(self) -> (WebSocketSender<S>, WebSocketReceiver<S>) {
let Self { parts: _, upgrade } = self;
upgrade.split()
}
}
#[async_trait]
pub trait WebSocketHandlerTrait<
FnOnConnect,
FnOnConnectFut,
FnOnSend,
FnOnSendFut,
FnOnReceive,
FnOnReceiveFut,
FnOnClose,
FnOnCloseFut,
> where
FnOnConnect: Fn(Arc<RwLock<WebSocketParts>>, UnboundedSender<Message>) -> FnOnConnectFut
+ Send
+ Sync
+ 'static,
FnOnConnectFut: Future<Output = Result<()>> + Send + 'static,
FnOnSend: Fn(Message, Arc<RwLock<WebSocketParts>>) -> FnOnSendFut + Send + Sync + 'static,
FnOnSendFut: Future<Output = Result<Message>> + Send + 'static,
FnOnReceive: Fn(Message, Arc<RwLock<WebSocketParts>>) -> FnOnReceiveFut + Send + Sync + 'static,
FnOnReceiveFut: Future<Output = Result<()>> + Send + 'static,
FnOnClose: Fn(Arc<RwLock<WebSocketParts>>) -> FnOnCloseFut + Send + Sync + 'static,
FnOnCloseFut: Future<Output = ()> + Send + 'static,
{
async fn handle(
self,
handler: Arc<
WebSocketHandler<
FnOnConnect,
FnOnConnectFut,
FnOnSend,
FnOnSendFut,
FnOnReceive,
FnOnReceiveFut,
FnOnClose,
FnOnCloseFut,
>,
>,
) -> Result<()>;
}
#[async_trait]
impl<
FnOnConnect,
FnOnConnectFut,
FnOnSend,
FnOnSendFut,
FnOnReceive,
FnOnReceiveFut,
FnOnClose,
FnOnCloseFut,
S,
>
WebSocketHandlerTrait<
FnOnConnect,
FnOnConnectFut,
FnOnSend,
FnOnSendFut,
FnOnReceive,
FnOnReceiveFut,
FnOnClose,
FnOnCloseFut,
> for WebSocket<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
FnOnConnect: Fn(Arc<RwLock<WebSocketParts>>, UnboundedSender<Message>) -> FnOnConnectFut
+ Send
+ Sync
+ 'static,
FnOnConnectFut: Future<Output = Result<()>> + Send + 'static,
FnOnSend: Fn(Message, Arc<RwLock<WebSocketParts>>) -> FnOnSendFut + Send + Sync + 'static,
FnOnSendFut: Future<Output = Result<Message>> + Send + 'static,
FnOnReceive: Fn(Message, Arc<RwLock<WebSocketParts>>) -> FnOnReceiveFut + Send + Sync + 'static,
FnOnReceiveFut: Future<Output = Result<()>> + Send + 'static,
FnOnClose: Fn(Arc<RwLock<WebSocketParts>>) -> FnOnCloseFut + Send + Sync + 'static,
FnOnCloseFut: Future<Output = ()> + Send + 'static,
{
async fn handle(
self,
handler: Arc<
WebSocketHandler<
FnOnConnect,
FnOnConnectFut,
FnOnSend,
FnOnSendFut,
FnOnReceive,
FnOnReceiveFut,
FnOnClose,
FnOnCloseFut,
>,
>,
) -> Result<()> {
let on_connect = handler.on_connect.clone();
let on_send = handler.on_send.clone();
let on_receive = handler.on_receive.clone();
let on_close = handler.on_close.clone();
let (parts, ws) = self.into_parts();
let (mut ws_tx, mut ws_rx) = ws.split();
let (tx, rx) = unbounded_channel();
debug!("on_connect: {:?}", parts);
if let Some(on_connect) = on_connect {
on_connect(parts.clone(), tx.clone()).await?;
}
let sender_parts = parts.clone();
let receiver_parts = parts;
let fut = async move {
while let Ok(message) = rx.recv().await {
let message = if let Some(on_send) = on_send.clone() {
match on_send(message.clone(), sender_parts.clone()).await {
Ok(message) => message,
Err(e) => {
error!("websocket on_send error: {}", e);
continue;
}
}
} else {
message
};
debug!("send message: {:?}", message);
if let Err(e) = ws_tx.send(message.inner).await {
error!("websocket send error: {}", e);
break;
}
}
};
async_global_executor::spawn(fut).detach();
let fut = async move {
while let Some(message) = ws_rx.next().await {
if let Ok(message) = message {
if message.is_close() {
break;
}
debug!("receive message: {:?}", message);
if let Some(on_receive) = on_receive.clone()
&& on_receive(Message { inner: message }, receiver_parts.clone())
.await
.is_err()
{
break;
}
}
}
if let Some(on_close) = on_close {
on_close(receiver_parts).await;
}
};
async_global_executor::spawn(fut).detach();
Ok(())
}
}
impl<S> Stream for WebSocket<S>
where
S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
type Item = Result<Message>;
#[inline]
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match ready!(Pin::new(&mut self.upgrade).poll_next(cx)) {
Some(Ok(item)) => Poll::Ready(Some(Ok(Message { inner: item }))),
Some(Err(e)) => {
debug!("websocket poll error: {}", e);
Poll::Ready(Some(Err(anyhow!("websocket poll error: {}", e).into())))
}
None => {
debug!("websocket closed");
Poll::Ready(None)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use async_channel::unbounded as unbounded_channel;
use async_lock::RwLock;
use async_tungstenite::tungstenite::protocol;
use futures::FutureExt;
use std::sync::Arc;
#[test]
fn test_message_creation() {
let text_msg = Message::text("hello");
let binary_msg = Message::binary(vec![1, 2, 3]);
let close_msg = Message::close();
assert!(text_msg.is_text());
assert!(binary_msg.is_binary());
assert!(close_msg.is_close());
}
#[test]
fn test_message_cloning() {
let msg = Message::text("test");
let msg2 = msg.clone();
assert_eq!(msg.to_str().unwrap(), msg2.to_str().unwrap());
}
#[test]
fn test_channel_creation_and_clone() {
let (tx, _rx) = unbounded_channel::<Message>();
let _tx2 = tx.clone();
}
#[test]
fn test_channel_send() {
let (tx, _rx) = unbounded_channel::<Message>();
let msg = Message::text("test message");
let _ = tx.send(msg).now_or_never();
}
#[test]
fn test_channel_close() {
let (tx, _rx) = unbounded_channel::<Message>();
drop(tx);
}
#[test]
fn test_empty_message() {
let msg = Message::text("");
assert_eq!(msg.to_str().unwrap(), "");
}
#[test]
fn test_large_binary_message() {
let large_data = vec![0u8; 1024 * 1024]; let msg = Message::binary(large_data);
assert!(msg.is_binary());
}
#[test]
fn test_unicode_message() {
let unicode_str = "你好世界 🌍";
let msg = Message::text(unicode_str);
assert_eq!(msg.to_str().unwrap(), unicode_str);
}
#[test]
fn test_message_inner_field() {
let msg = Message::text("test");
let _inner = msg.inner;
}
#[test]
fn test_websocket_send_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<Message>();
assert_sync::<Message>();
assert_send::<UnboundedSender<Message>>();
assert_sync::<UnboundedSender<Message>>();
}
#[test]
fn test_websocket_arc_rwlock() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<Arc<RwLock<Message>>>();
assert_sync::<Arc<RwLock<Message>>>();
}
#[test]
fn test_message_type_conversions() {
let text = Message::text("hello");
let binary = Message::binary(vec![1, 2, 3]);
let ping = Message::ping(vec![1, 2, 3]);
let pong = Message::pong(vec![1, 2, 3]);
let close = Message::close();
assert!(text.is_text());
assert!(binary.is_binary());
assert!(ping.is_ping());
assert!(pong.is_pong());
assert!(close.is_close());
}
#[test]
fn test_message_serialization() {
let msg = Message::text("test");
let text = msg.to_str();
assert!(text.is_ok());
assert_eq!(text.unwrap(), "test");
let binary_msg = Message::binary(vec![0xFF, 0xFE]);
assert!(binary_msg.to_str().is_err());
}
#[test]
fn test_protocol_role() {
let _server_role = protocol::Role::Server;
let _client_role = protocol::Role::Client;
assert!(matches!(_server_role, protocol::Role::Server));
assert!(matches!(_client_role, protocol::Role::Client));
}
#[test]
fn test_websocket_config() {
let config = protocol::WebSocketConfig::default();
assert!(config.max_message_size.is_some());
let mut custom_config = protocol::WebSocketConfig::default();
custom_config.max_frame_size = Some(1024);
custom_config.max_message_size = Some(1024 * 1024);
custom_config.accept_unmasked_frames = false;
assert_eq!(custom_config.max_frame_size, Some(1024));
assert_eq!(custom_config.max_message_size, Some(1024 * 1024));
}
#[test]
fn test_message_type_validation() {
let text_msg = Message::text("hello");
let binary_msg = Message::binary(vec![1, 2, 3]);
assert!(text_msg.is_text() && !text_msg.is_binary());
assert!(binary_msg.is_binary() && !binary_msg.is_text());
assert!(!text_msg.is_close());
assert!(!binary_msg.is_close());
}
#[test]
fn test_message_size_operations() {
let small_data = vec![1u8; 10];
let msg = Message::binary(small_data);
assert!(msg.is_binary());
let binary_data = msg.into_bytes();
assert_eq!(binary_data.len(), 10);
}
#[tokio::test]
async fn test_async_channel_with_websocket() {
let (tx, rx) = unbounded_channel::<Message>();
let msg = Message::text("test message");
tx.send(msg).await.unwrap();
let received = rx.recv().await.unwrap();
assert_eq!(received.to_str().unwrap(), "test message");
}
#[tokio::test]
async fn test_multiple_senders() {
let (tx, rx) = unbounded_channel::<Message>();
let tx2 = tx.clone();
tx.send(Message::text("from sender 1")).await.unwrap();
tx2.send(Message::text("from sender 2")).await.unwrap();
let msg1 = rx.recv().await.unwrap();
let msg2 = rx.recv().await.unwrap();
assert!(msg1.to_str().unwrap().contains("sender 1"));
assert!(msg2.to_str().unwrap().contains("sender 2"));
}
#[test]
fn test_message_from_bytes() {
let data = b"hello world".to_vec();
let msg = Message::binary(data);
assert!(msg.is_binary());
let bytes = msg.into_bytes();
assert_eq!(bytes, b"hello world".to_vec());
}
#[test]
fn test_message_ping_pong() {
let ping_data = vec![1, 2, 3, 4];
let ping_msg = Message::ping(ping_data.clone());
assert!(ping_msg.is_ping());
assert_eq!(ping_msg.into_bytes(), ping_data);
let pong_data = vec![5, 6, 7, 8];
let pong_msg = Message::pong(pong_data.clone());
assert!(pong_msg.is_pong());
assert_eq!(pong_msg.into_bytes(), pong_data);
}
#[test]
fn test_message_close_with_code() {
let close_msg = Message::close();
assert!(close_msg.is_close());
}
#[tokio::test]
async fn test_message_rwlock() {
let msg = Arc::new(RwLock::new(Message::text("test message")));
let handles: Vec<_> = (0..5)
.map(|_| {
let msg = msg.clone();
tokio::spawn(async move {
let reader = msg.read().await;
let _ = reader.is_text();
})
})
.collect();
for handle in handles {
handle.await.unwrap();
}
let writer = msg.write().await;
assert!(writer.is_text());
}
#[tokio::test]
async fn test_websocket_parts_type_validation() {
use crate::ws::upgrade::WebSocketParts;
use std::sync::Arc;
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<Arc<async_lock::RwLock<WebSocketParts>>>();
assert_sync::<Arc<async_lock::RwLock<WebSocketParts>>>();
}
#[test]
fn test_websocket_struct_size() {
use std::mem::size_of;
let _ = size_of::<WebSocket<futures::io::Cursor<Vec<u8>>>>();
}
#[tokio::test]
async fn test_websocket_handler_components() {
use crate::ws::WebSocketHandler;
use std::future::Ready;
type Handler = WebSocketHandler<
fn(
Arc<async_lock::RwLock<crate::ws::upgrade::WebSocketParts>>,
UnboundedSender<Message>,
) -> Ready<Result<()>>,
Ready<Result<()>>,
fn(
Message,
Arc<async_lock::RwLock<crate::ws::upgrade::WebSocketParts>>,
) -> Ready<Result<Message>>,
Ready<Result<Message>>,
fn(
Message,
Arc<async_lock::RwLock<crate::ws::upgrade::WebSocketParts>>,
) -> Ready<Result<()>>,
Ready<Result<()>>,
fn(Arc<async_lock::RwLock<crate::ws::upgrade::WebSocketParts>>) -> Ready<()>,
Ready<()>,
>;
let handler: Handler = WebSocketHandler::new();
assert!(handler.on_connect.is_none());
assert!(handler.on_send.is_none());
assert!(handler.on_receive.is_none());
assert!(handler.on_close.is_none());
}
#[tokio::test]
async fn test_websocket_handler_arc_cloning() {
use crate::ws::WebSocketHandler;
use std::future::Ready;
use std::sync::Arc;
type Handler = WebSocketHandler<
fn(
Arc<async_lock::RwLock<crate::ws::upgrade::WebSocketParts>>,
UnboundedSender<Message>,
) -> Ready<Result<()>>,
Ready<Result<()>>,
fn(
Message,
Arc<async_lock::RwLock<crate::ws::upgrade::WebSocketParts>>,
) -> Ready<Result<Message>>,
Ready<Result<Message>>,
fn(
Message,
Arc<async_lock::RwLock<crate::ws::upgrade::WebSocketParts>>,
) -> Ready<Result<()>>,
Ready<Result<()>>,
fn(Arc<async_lock::RwLock<crate::ws::upgrade::WebSocketParts>>) -> Ready<()>,
Ready<()>,
>;
let handler: Arc<Handler> = Arc::new(WebSocketHandler::new());
let _handler2 = handler.clone();
assert_eq!(Arc::strong_count(&handler), 2);
}
#[tokio::test]
async fn test_websocket_stream_trait() {
let _ = || {
let _: Option<()> = None;
};
fn assert_stream<Item>() {}
assert_stream::<Result<Message>>();
}
#[test]
fn test_websocket_message_close_detection() {
let close_msg = Message::close();
assert!(close_msg.is_close());
assert!(!close_msg.is_text());
assert!(!close_msg.is_binary());
assert!(!close_msg.is_ping());
assert!(!close_msg.is_pong());
}
#[test]
fn test_websocket_message_ping_pong_detection() {
let ping_msg = Message::ping(vec![1, 2, 3]);
let pong_msg = Message::pong(vec![4, 5, 6]);
assert!(ping_msg.is_ping());
assert!(!ping_msg.is_pong());
assert!(!ping_msg.is_text());
assert!(!ping_msg.is_binary());
assert!(pong_msg.is_pong());
assert!(!pong_msg.is_ping());
assert!(!pong_msg.is_text());
assert!(!pong_msg.is_binary());
}
#[test]
fn test_websocket_message_text_binary_distinction() {
let text_msg = Message::text("hello");
let binary_msg = Message::binary(vec![1, 2, 3]);
assert!(text_msg.is_text());
assert!(!text_msg.is_binary());
assert_eq!(text_msg.to_str().unwrap(), "hello");
assert!(binary_msg.is_binary());
assert!(!binary_msg.is_text());
assert_eq!(binary_msg.into_bytes(), vec![1, 2, 3]);
}
#[tokio::test]
async fn test_websocket_channel_error_handling() {
let (tx, rx) = unbounded_channel::<Message>();
drop(rx);
let send_result = tx.send(Message::text("test")).await;
assert!(send_result.is_err());
}
#[tokio::test]
async fn test_websocket_multiple_channel_receivers() {
let (tx, rx) = unbounded_channel::<Message>();
for i in 0..10 {
tx.send(Message::text(format!("message {}", i)))
.await
.unwrap();
}
let mut count = 0;
for _ in 0..10 {
if rx.recv().await.is_ok() {
count += 1;
}
}
assert_eq!(count, 10);
}
#[test]
fn test_websocket_empty_close_message() {
let close_msg = Message::close();
assert!(close_msg.is_close());
let _bytes = close_msg.into_bytes();
}
#[test]
fn test_websocket_large_ping_pong() {
let large_data = vec![0u8; 1024];
let ping_msg = Message::ping(large_data.clone());
let pong_msg = Message::pong(large_data.clone());
assert!(ping_msg.is_ping());
assert_eq!(ping_msg.into_bytes(), large_data);
assert!(pong_msg.is_pong());
assert_eq!(pong_msg.into_bytes(), large_data);
}
#[tokio::test]
async fn test_websocket_parts_with_rwlock() {
use crate::ws::upgrade::WebSocketParts;
use std::sync::Arc;
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<WebSocketParts>();
assert_sync::<WebSocketParts>();
assert_send::<Arc<async_lock::RwLock<WebSocketParts>>>();
assert_sync::<Arc<async_lock::RwLock<WebSocketParts>>>();
}
#[test]
fn test_websocket_message_content_validation() {
let text_msg = Message::text("valid utf-8 你好");
let binary_msg = Message::binary(vec![0x00, 0xFF, 0x7F]);
assert!(text_msg.is_text());
assert_eq!(text_msg.to_str().unwrap(), "valid utf-8 你好");
assert!(binary_msg.is_binary());
assert_eq!(binary_msg.into_bytes(), vec![0x00, 0xFF, 0x7F]);
}
#[tokio::test]
async fn test_websocket_concurrent_message_access() {
use std::sync::Arc;
let msg = Arc::new(Message::text("concurrent test"));
let mut handles = vec![];
for _ in 0..10 {
let msg = msg.clone();
let handle = tokio::spawn(async move {
let _ = msg.is_text();
let _ = msg.to_str();
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
}
#[test]
fn test_websocket_config_combinations() {
let mut config1 = protocol::WebSocketConfig::default();
config1.max_frame_size = Some(512);
config1.max_message_size = Some(512 * 1024);
config1.accept_unmasked_frames = true;
assert_eq!(config1.max_frame_size, Some(512));
assert_eq!(config1.max_message_size, Some(524288));
assert!(config1.accept_unmasked_frames);
let mut config2 = protocol::WebSocketConfig::default();
config2.max_frame_size = None;
config2.max_message_size = None;
config2.accept_unmasked_frames = false;
assert_eq!(config2.max_frame_size, None);
assert_eq!(config2.max_message_size, None);
assert!(!config2.accept_unmasked_frames);
}
}