use futures_util::{FutureExt, StreamExt};
use salvo::Error;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use once_cell::sync::Lazy;
use tokio::sync::{mpsc, RwLock};
use tokio_stream::wrappers::UnboundedReceiverStream;
use salvo::extra::ws::{Message, WebSocket};
use salvo::prelude::*;
use tokio::sync::mpsc::UnboundedSender;
pub struct WebSocketController {
pub caller_book: HashMap<String, Vec<UnboundedSender<Result<Message, Error>>>>,
}
impl WebSocketController {
pub fn new() -> Self {
WebSocketController {
caller_book: HashMap::new(),
}
}
pub fn send_group(&mut self, group: String, message: Message) -> Result<(), anyhow::Error> {
let senders = self.caller_book.get_mut(group.as_str());
match senders {
None => Err(anyhow::anyhow!("群组不存在")),
Some(senders) => {
let mut pre_delete_list = vec![];
for (index, sender) in senders.iter().enumerate() {
match sender.send(Ok(message.clone())) {
Ok(_) => {}
Err(_) => {
pre_delete_list.push(index);
}
};
}
for delete_index in pre_delete_list {
senders.remove(delete_index);
}
Ok(())
}
}
}
pub fn join_group(
&mut self,
group: String,
sender: UnboundedSender<Result<Message, Error>>,
) -> Result<(), Error> {
match self.caller_book.get_mut(group.as_str()) {
None => {
self.caller_book.insert(group, vec![sender]);
}
Some(senders) => {
senders.insert(0, sender);
}
};
Ok(())
}
}
impl Default for WebSocketController {
fn default() -> Self {
Self::new()
}
}
type Controller = RwLock<WebSocketController>;
static NEXT_WS_ID: AtomicUsize = AtomicUsize::new(1);
pub static WS_CONTROLLER: Lazy<Controller> = Lazy::new(Controller::default);
pub async fn handle_socket<T: WebSocketHandler + Send + Sync + 'static>(ws: WebSocket, _self: T) {
let ws_id = NEXT_WS_ID.fetch_add(1, Ordering::Relaxed);
tracing::info!("new ws connected: {}", ws_id);
let (ws_sender, mut ws_reader) = ws.split();
let (sender, reader) = mpsc::unbounded_channel();
let reader = UnboundedReceiverStream::new(reader);
let fut = reader.forward(ws_sender).map(|result| {
eprintln!("{:?}", result);
if let Err(e) = result {
tracing::error!(error = ?e, "websocket send error");
}
});
tokio::task::spawn(fut);
let fut = async move {
_self.on_connected(ws_id, sender).await;
while let Some(result) = ws_reader.next().await {
let msg = match result {
Ok(msg) => msg,
Err(e) => {
eprintln!("websocket error(uid={}): {}", ws_id, e);
break;
}
};
eprintln!("on_receive_message message(uid={}): {:?}", ws_id, msg);
_self.on_receive_message(msg).await;
}
_self.on_disconnected(ws_id).await;
};
tokio::task::spawn(fut);
}
#[async_trait]
pub trait WebSocketHandler {
async fn on_connected(&self, ws_id: usize, sender: UnboundedSender<Result<Message, Error>>);
async fn on_disconnected(&self, ws_id: usize);
async fn on_receive_message(&self, msg: Message);
async fn on_send_message(&self, msg: Message) -> Result<Message, Error>;
}
#[cfg(test)]
mod test {
#[tokio::test]
async fn websocket_test() {
}
}