use crate::common::error::Result;
use crate::common::protocol::Frame;
use crate::server::connection::ConnectionManagerTrait;
use async_trait::async_trait;
use std::sync::Arc;
#[async_trait]
pub trait ServerHandle: Send + Sync {
async fn send_to(&self, connection_id: &str, frame: &Frame) -> Result<()>;
async fn send_to_user(&self, user_id: &str, frame: &Frame) -> Result<()>;
async fn broadcast(&self, frame: &Frame) -> Result<()>;
async fn broadcast_except(&self, frame: &Frame, exclude_connection_id: &str) -> Result<()>;
async fn disconnect(&self, connection_id: &str) -> Result<()>;
fn connection_count(&self) -> usize;
fn user_count(&self) -> usize;
}
pub struct DefaultServerHandle {
connection_manager: Arc<dyn ConnectionManagerTrait>,
}
impl DefaultServerHandle {
pub fn new(
connection_manager: Arc<dyn ConnectionManagerTrait>,
) -> Self {
Self {
connection_manager,
}
}
fn get_connection_count(&self) -> usize {
if let Some(manager) = self.connection_manager.as_any().downcast_ref::<crate::server::connection::ConnectionManager>() {
manager.connection_count()
} else {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::try_current()
.map(|handle| {
handle.block_on(async {
self.connection_manager.connection_count().await
})
})
.unwrap_or_else(|_| {
tokio::runtime::Runtime::new()
.unwrap()
.block_on(async {
self.connection_manager.connection_count().await
})
})
})
}
}
fn get_user_count(&self) -> usize {
if let Some(manager) = self.connection_manager.as_any().downcast_ref::<crate::server::connection::ConnectionManager>() {
manager.stats().total_users
} else {
tokio::task::block_in_place(|| {
tokio::runtime::Handle::try_current()
.map(|handle| {
handle.block_on(async {
let connection_ids = self.connection_manager.list_connections().await;
let mut user_set = std::collections::HashSet::new();
for conn_id in connection_ids {
if let Some((_, info)) = self.connection_manager.get_connection(&conn_id).await {
if let Some(user_id) = info.user_id {
user_set.insert(user_id);
}
}
}
user_set.len()
})
})
.unwrap_or_else(|_| {
tokio::runtime::Runtime::new()
.unwrap()
.block_on(async {
let connection_ids = self.connection_manager.list_connections().await;
let mut user_set = std::collections::HashSet::new();
for conn_id in connection_ids {
if let Some((_, info)) = self.connection_manager.get_connection(&conn_id).await {
if let Some(user_id) = info.user_id {
user_set.insert(user_id);
}
}
}
user_set.len()
})
})
})
}
}
}
#[async_trait]
impl ServerHandle for DefaultServerHandle {
async fn send_to(&self, connection_id: &str, frame: &Frame) -> Result<()> {
self.connection_manager
.send_frame_to(connection_id, frame, None)
.await
}
async fn send_to_user(&self, user_id: &str, frame: &Frame) -> Result<()> {
self.connection_manager
.send_frame_to_user(user_id, frame, None)
.await
}
async fn broadcast(&self, frame: &Frame) -> Result<()> {
self.connection_manager
.broadcast_frame(frame, None)
.await
}
async fn broadcast_except(&self, frame: &Frame, exclude_connection_id: &str) -> Result<()> {
self.connection_manager
.broadcast_frame_except(frame, exclude_connection_id, None)
.await
}
async fn disconnect(&self, connection_id: &str) -> Result<()> {
self.connection_manager.remove_connection(connection_id).await
}
fn connection_count(&self) -> usize {
self.get_connection_count()
}
fn user_count(&self) -> usize {
self.get_user_count()
}
}