use futures_util::StreamExt;
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use tokio_util::sync::CancellationToken;
#[derive(Debug, thiserror::Error)]
pub enum WebSocketError {
#[error("连接失败: {0}")]
ConnectionFailed(String),
#[error("其他错误: {0}")]
Other(String),
}
#[derive(Debug, Clone)]
pub enum WsBaseEvent {
Open,
Close(Option<String>),
Error(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum WsEventType {
Open,
Close,
Error,
All,
}
pub type EventListener = Arc<dyn Fn(WsBaseEvent) + Send + Sync + 'static>;
pub trait MessageHandler: Send + Sync {
fn handle_message(&self, msg: String);
}
pub struct WebSocketClient {
listeners: Arc<Mutex<HashMap<WsEventType, Vec<EventListener>>>>,
cancel_token: CancellationToken,
_handle: tokio::task::JoinHandle<()>,
}
impl WebSocketClient {
pub async fn connect<H>(url: &str, message_handler: H) -> Result<Self, WebSocketError>
where
H: MessageHandler + 'static,
{
let listeners = Arc::new(Mutex::new(HashMap::<WsEventType, Vec<EventListener>>::new()));
let cancel_token = CancellationToken::new();
let (ws_stream, _) = connect_async(url)
.await
.map_err(|e| WebSocketError::ConnectionFailed(e.to_string()))?;
let (_write, mut read) = ws_stream.split();
let listeners_clone = listeners.clone();
let cancel = cancel_token.clone();
let handle = tokio::spawn(async move {
tokio::select! {
_ = cancel.cancelled() => {}
_ = async {
Self::emit_event(&listeners_clone, &WsEventType::Open, WsBaseEvent::Open).await;
while let Some(msg) = read.next().await {
match msg {
Ok(Message::Text(text)) => {
message_handler.handle_message(text.to_string());
}
Ok(Message::Close(frame)) => {
let reason = frame.map(|f| f.reason.to_string());
Self::emit_event(&listeners_clone, &WsEventType::Close, WsBaseEvent::Close(reason)).await;
break;
}
Err(e) => {
Self::emit_event(&listeners_clone, &WsEventType::Error, WsBaseEvent::Error(e.to_string())).await;
break;
}
_ => {}
}
}
} => {}
}
});
Ok(Self {
listeners,
cancel_token,
_handle: handle,
})
}
pub async fn add_listener<F>(&self, event: WsEventType, listener: F)
where
F: Fn(WsBaseEvent) + Send + Sync + 'static,
{
let mut listeners = self.listeners.lock().await;
listeners
.entry(event)
.or_insert_with(Vec::new)
.push(Arc::new(listener));
}
pub async fn on_open<F>(&self, listener: F)
where
F: Fn() + Send + Sync + 'static,
{
self.add_listener(WsEventType::Open, move |_| listener())
.await;
}
pub async fn on_close<F>(&self, listener: F)
where
F: Fn(Option<String>) + Send + Sync + 'static,
{
self.add_listener(WsEventType::Close, move |event| {
if let WsBaseEvent::Close(reason) = event {
listener(reason);
}
})
.await;
}
pub async fn on_error<F>(&self, listener: F)
where
F: Fn(String) + Send + Sync + 'static,
{
self.add_listener(WsEventType::Error, move |event| {
if let WsBaseEvent::Error(error) = event {
listener(error);
}
})
.await;
}
pub async fn remove_listener(&self, event: Option<WsEventType>) {
let mut listeners = self.listeners.lock().await;
match event {
Some(e) => {
listeners.remove(&e);
}
None => {
listeners.clear();
}
}
}
pub fn disconnect(&self) {
self.cancel_token.cancel();
}
async fn emit_event(
listeners: &Arc<Mutex<HashMap<WsEventType, Vec<EventListener>>>>,
event: &WsEventType,
data: WsBaseEvent,
) {
let event_listeners: Vec<EventListener> = {
let listeners_guard = listeners.lock().await;
listeners_guard.get(event).cloned().unwrap_or_default()
};
for listener in event_listeners {
let data = data.clone();
tokio::spawn(async move { listener(data) });
}
if *event != WsEventType::All {
let all_listeners: Vec<EventListener> = {
let listeners_guard = listeners.lock().await;
listeners_guard
.get(&WsEventType::All)
.cloned()
.unwrap_or_default()
};
for listener in all_listeners {
let data = data.clone();
tokio::spawn(async move { listener(data) });
}
}
}
}
impl Drop for WebSocketClient {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}