use crate::common::error::Result;
use crate::common::protocol::Frame;
use crate::server::{ServerConfig, ConnectionHandler, HybridServer, Server};
use crate::server::handle::ServerHandle;
use std::sync::Arc;
use async_trait::async_trait;
use tokio::sync::Mutex;
pub struct MessageContext {
pub connection_id: String,
handle: Arc<dyn ServerHandle>,
}
impl MessageContext {
fn new(connection_id: String, handle: Arc<dyn ServerHandle>) -> Self {
Self {
connection_id,
handle,
}
}
pub async fn send_to(&self, connection_id: &str, frame: &Frame) -> Result<()> {
self.handle.send_to(connection_id, frame).await
}
pub async fn send_to_user(&self, user_id: &str, frame: &Frame) -> Result<()> {
self.handle.send_to_user(user_id, frame).await
}
pub async fn broadcast(&self, frame: &Frame) -> Result<()> {
self.handle.broadcast(frame).await
}
pub async fn broadcast_except(&self, frame: &Frame, exclude_connection_id: &str) -> Result<()> {
self.handle.broadcast_except(frame, exclude_connection_id).await
}
pub async fn disconnect(&self, connection_id: &str) -> Result<()> {
self.handle.disconnect(connection_id).await
}
pub fn connection_count(&self) -> usize {
self.handle.connection_count()
}
pub fn user_count(&self) -> usize {
self.handle.user_count()
}
}
pub type MessageHandlerFn = Box<dyn for<'a> Fn(&'a Frame, &'a MessageContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Option<Frame>>> + Send + 'a>> + Send + Sync>;
pub type OnConnectFn = Box<dyn for<'a> Fn(&'a str, &'a MessageContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> + Send + Sync>;
pub type OnDisconnectFn = Box<dyn for<'a> Fn(&'a str, &'a MessageContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> + Send + Sync>;
struct SimpleConnectionHandler {
message_handler: Option<MessageHandlerFn>,
on_connect: Option<OnConnectFn>,
on_disconnect: Option<OnDisconnectFn>,
handle: Arc<Mutex<Option<Arc<dyn ServerHandle>>>>,
}
impl SimpleConnectionHandler {
async fn set_handle(&self, handle: Arc<dyn ServerHandle>) {
*self.handle.lock().await = Some(handle);
}
}
#[async_trait]
impl ConnectionHandler for SimpleConnectionHandler {
async fn handle_frame(&self, frame: &Frame, connection_id: &str) -> Result<Option<Frame>> {
let handle = {
let handle_guard = self.handle.lock().await;
handle_guard.clone()
};
let context = if let Some(ref handle) = handle {
MessageContext::new(connection_id.to_string(), Arc::clone(handle))
} else {
return Err(crate::common::error::FlareError::general_error("Server handle is not available"));
};
if let Some(ref handler) = self.message_handler {
handler(frame, &context).await
} else {
Ok(None)
}
}
async fn on_connect(&self, connection_id: &str) -> Result<()> {
let handle = {
let handle_guard = self.handle.lock().await;
handle_guard.clone()
};
let context = if let Some(ref handle) = handle {
MessageContext::new(connection_id.to_string(), Arc::clone(handle))
} else {
return Err(crate::common::error::FlareError::general_error("Server handle is not available"));
};
if let Some(ref handler) = self.on_connect {
handler(connection_id, &context).await
} else {
Ok(())
}
}
async fn on_disconnect(&self, connection_id: &str) -> Result<()> {
let handle = {
let handle_guard = self.handle.lock().await;
handle_guard.clone()
};
let context = if let Some(ref handle) = handle {
MessageContext::new(connection_id.to_string(), Arc::clone(handle))
} else {
return Err(crate::common::error::FlareError::general_error("Server handle is not available"));
};
if let Some(ref handler) = self.on_disconnect {
handler(connection_id, &context).await
} else {
Ok(())
}
}
}
struct ServerWrapper {
server: Arc<Mutex<HybridServer>>,
}
struct ServerWrapperHandle {
server: Arc<Mutex<HybridServer>>,
}
#[async_trait]
impl crate::server::handle::ServerHandle for ServerWrapperHandle {
async fn send_to(&self, connection_id: &str, frame: &Frame) -> Result<()> {
let s = self.server.lock().await;
crate::server::handle::ServerHandle::send_to(&*s, connection_id, frame).await
}
async fn send_to_user(&self, user_id: &str, frame: &Frame) -> Result<()> {
let s = self.server.lock().await;
crate::server::handle::ServerHandle::send_to_user(&*s, user_id, frame).await
}
async fn broadcast(&self, frame: &Frame) -> Result<()> {
let s = self.server.lock().await;
crate::server::handle::ServerHandle::broadcast(&*s, frame).await
}
async fn broadcast_except(&self, frame: &Frame, exclude_connection_id: &str) -> Result<()> {
let s = self.server.lock().await;
crate::server::handle::ServerHandle::broadcast_except(&*s, frame, exclude_connection_id).await
}
async fn disconnect(&self, connection_id: &str) -> Result<()> {
let s = self.server.lock().await;
crate::server::handle::ServerHandle::disconnect(&*s, connection_id).await
}
fn connection_count(&self) -> usize {
tokio::task::block_in_place(|| {
let s = self.server.blocking_lock();
crate::server::handle::ServerHandle::connection_count(&*s)
})
}
fn user_count(&self) -> usize {
tokio::task::block_in_place(|| {
let s = self.server.blocking_lock();
crate::server::handle::ServerHandle::user_count(&*s)
})
}
}
#[async_trait]
impl Server for ServerWrapper {
async fn start(&mut self) -> Result<()> {
let mut s = self.server.lock().await;
s.start().await
}
async fn stop(&mut self) -> Result<()> {
let mut s = self.server.lock().await;
s.stop().await
}
fn is_running(&self) -> bool {
tokio::task::block_in_place(|| {
let s = self.server.blocking_lock();
s.is_running()
})
}
}
pub struct SimpleServer {
server: Arc<Mutex<HybridServer>>,
handler: Arc<SimpleConnectionHandler>,
server_wrapper: Arc<ServerWrapper>,
handle: Arc<dyn ServerHandle>,
}
impl SimpleServer {
pub async fn start(&mut self) -> Result<()> {
self.handler.set_handle(Arc::clone(&self.handle)).await;
let mut s = self.server.lock().await;
s.start().await
}
pub async fn stop(&mut self) -> Result<()> {
let mut s = self.server.lock().await;
s.stop().await
}
pub fn is_running(&self) -> bool {
self.server_wrapper.is_running()
}
pub fn connection_count(&self) -> usize {
self.handle.connection_count()
}
pub fn user_count(&self) -> usize {
self.handle.user_count()
}
pub fn handle(&self) -> Arc<dyn ServerHandle> {
Arc::clone(&self.handle)
}
pub async fn send_to(&self, connection_id: &str, frame: &Frame) -> Result<()> {
ServerHandle::send_to(&*self.handle, connection_id, frame).await
}
pub async fn send_to_user(&self, user_id: &str, frame: &Frame) -> Result<()> {
ServerHandle::send_to_user(&*self.handle, user_id, frame).await
}
pub async fn broadcast(&self, frame: &Frame) -> Result<()> {
ServerHandle::broadcast(&*self.handle, frame).await
}
pub async fn broadcast_except(&self, frame: &Frame, exclude_connection_id: &str) -> Result<()> {
ServerHandle::broadcast_except(&*self.handle, frame, exclude_connection_id).await
}
pub async fn disconnect(&self, connection_id: &str) -> Result<()> {
ServerHandle::disconnect(&*self.handle, connection_id).await
}
}
pub struct ServerBuilder {
config: ServerConfig,
message_handler: Option<MessageHandlerFn>,
on_connect: Option<OnConnectFn>,
on_disconnect: Option<OnDisconnectFn>,
authenticator: Option<Arc<dyn crate::server::auth::Authenticator>>,
}
impl ServerBuilder {
pub fn new(bind_address: impl Into<String>) -> Self {
Self {
config: ServerConfig::new(bind_address.into()),
message_handler: None,
on_connect: None,
on_disconnect: None,
authenticator: None,
}
}
pub fn with_authenticator(mut self, authenticator: Arc<dyn crate::server::auth::Authenticator>) -> Self {
self.authenticator = Some(authenticator);
self
}
pub fn enable_auth(mut self) -> Self {
self.config = self.config.enable_auth();
self
}
pub fn with_auth_timeout(mut self, timeout: std::time::Duration) -> Self {
self.config = self.config.with_auth_timeout(timeout);
self
}
pub fn on_message<F>(mut self, handler: F) -> Self
where
F: for<'a> Fn(&'a Frame, &'a MessageContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Option<Frame>>> + Send + 'a>> + Send + Sync + 'static,
{
self.message_handler = Some(Box::new(move |frame, ctx| {
handler(frame, ctx)
}));
self
}
pub fn on_connect<F>(mut self, handler: F) -> Self
where
F: for<'a> Fn(&'a str, &'a MessageContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> + Send + Sync + 'static,
{
self.on_connect = Some(Box::new(move |conn_id, ctx| {
handler(conn_id, ctx)
}));
self
}
pub fn on_disconnect<F>(mut self, handler: F) -> Self
where
F: for<'a> Fn(&'a str, &'a MessageContext) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<()>> + Send + 'a>> + Send + Sync + 'static,
{
self.on_disconnect = Some(Box::new(move |conn_id, ctx| {
handler(conn_id, ctx)
}));
self
}
pub fn with_protocol(mut self, protocol: crate::common::config_types::TransportProtocol) -> Self {
self.config.transport = protocol;
self
}
pub fn with_protocols(mut self, protocols: Vec<crate::common::config_types::TransportProtocol>) -> Self {
self.config = self.config.with_protocols(protocols);
self
}
pub fn with_max_connections(mut self, max: usize) -> Self {
self.config = self.config.with_max_connections(max);
self
}
pub fn with_heartbeat(mut self, heartbeat: crate::common::config_types::HeartbeatConfig) -> Self {
self.config = self.config.with_heartbeat(heartbeat);
self
}
pub fn with_tls(mut self, tls: crate::common::config_types::TlsConfig) -> Self {
self.config = self.config.with_tls(tls);
self
}
pub fn with_default_format(mut self, format: crate::common::protocol::SerializationFormat) -> Self {
self.config = self.config.with_format(format);
self
}
pub fn with_default_compression(mut self, compression: crate::common::compression::CompressionAlgorithm) -> Self {
self.config = self.config.with_compression(compression);
self
}
pub fn build(self) -> Result<SimpleServer> {
let handler = Arc::new(SimpleConnectionHandler {
message_handler: self.message_handler,
on_connect: self.on_connect,
on_disconnect: self.on_disconnect,
handle: Arc::new(Mutex::new(None)),
});
let server = HybridServer::with_connection_manager(
self.config,
handler.clone() as Arc<dyn ConnectionHandler>,
None,
None,
None,
self.authenticator,
)?;
let server_arc = Arc::new(Mutex::new(server));
let server_wrapper = Arc::new(ServerWrapper {
server: Arc::clone(&server_arc),
});
let handle: Arc<dyn ServerHandle> = Arc::new(ServerWrapperHandle {
server: server_arc.clone(),
}) as Arc<dyn ServerHandle>;
Ok(SimpleServer {
server: server_arc,
handler,
server_wrapper,
handle,
})
}
}