use crate::adapter::Adapter;
use crate::socket::{subscribe_socket_to_transport_events, Socket, SocketEvent};
use crate::transport::{Transport, TransportImpl, TransportKind};
use crate::util::{HttpMethod, RequestContext, ServerError, SetCookie};
use engine_io_parser::packet::{Packet, PacketData};
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex as AsyncMutex;
use tokio::sync::RwLock as AsyncRwLock;
use tokio::sync::{broadcast, mpsc};
use uuid::Uuid;
pub const BUFFER_CONST: usize = 32;
#[derive(Debug, Clone, PartialEq)]
pub struct ServerOptions {
pub ping_timeout: u32,
pub ping_interval: u32,
pub upgrade_timeout: u32,
pub transports: Vec<TransportKind>,
pub allow_upgrades: bool,
pub initial_packet: Option<Packet>,
pub cookie: Option<CookieOptions>,
pub buffer_factor: usize,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CookieOptions {
pub name: String,
pub path: String,
pub http_only: bool,
}
#[derive(Debug, Clone)]
pub struct EventSenders {
server: broadcast::Sender<ServerEvent>,
client: mpsc::Sender<SocketEvent>,
}
pub struct ServerState<A: 'static + Adapter> {
socket_receiver_temp: Option<mpsc::Receiver<SocketEvent>>,
clients: HashMap<String, Arc<AsyncMutex<Socket<A>>>>,
}
pub struct Server<A: 'static + Adapter> {
state: Arc<AsyncRwLock<ServerState<A>>>,
event_senders: EventSenders,
options: ServerOptions,
}
impl Default for ServerOptions {
fn default() -> Self {
ServerOptions {
ping_timeout: 5000,
ping_interval: 25000,
upgrade_timeout: 10000,
transports: vec![TransportKind::WebSocket, TransportKind::Polling],
allow_upgrades: true,
initial_packet: None,
cookie: Some(CookieOptions::default()),
buffer_factor: 2,
}
}
}
impl Default for CookieOptions {
fn default() -> Self {
CookieOptions {
name: "io".to_owned(),
path: "/".to_owned(),
http_only: true,
}
}
}
#[derive(Display, Debug, Clone, PartialEq)]
pub enum ServerEvent {
Connection {
connection_id: String,
},
Flush {
connection_id: String,
},
Drain {
connection_id: String,
},
Message {
connection_id: String,
data: PacketData,
},
}
pub struct HandleRequestResult {
pub connection_id: String,
pub set_cookie: Option<SetCookie>,
}
impl<A: 'static + Adapter> Server<A> {
pub fn new(options: ServerOptions) -> Self {
let (client_event_sender, _) = mpsc::channel(options.buffer_factor * BUFFER_CONST);
let (server_event_sender, _) = broadcast::channel(options.buffer_factor * BUFFER_CONST);
Server {
state: Arc::new(AsyncRwLock::new(ServerState {
socket_receiver_temp: None,
clients: HashMap::new(),
})),
event_senders: EventSenders {
server: server_event_sender,
client: client_event_sender,
},
options,
}
}
pub async fn subscribe(&self) -> broadcast::Receiver<ServerEvent> {
if let Some(socket_receiver_temp) = self.state.write().await.socket_receiver_temp.take() {
self.subscribe_to_socket_events(socket_receiver_temp);
}
self.event_senders.server.subscribe()
}
pub async fn close(&self) {
let mut state = self.state.write().await;
let state = &mut *state;
let clients = &mut state.clients;
for (_id, socket) in clients.iter_mut() {
socket.lock().await.close(true);
}
}
pub async fn handle_request(
&self,
context: RequestContext,
) -> Result<HandleRequestResult, ServerError> {
let sid_ref = context.query.get("sid");
let sid = sid_ref.map(|s| s.to_owned());
self.verify_request(sid_ref, false, context.transport_kind, context.http_method)
.await?;
if let Some(sid) = sid {
let client = self.get_client_or_error(&sid).await?;
let mut client = client.lock().await;
client.handle_polling_request(context).await?;
let connection_id = sid.clone();
Ok(HandleRequestResult {
connection_id,
set_cookie: None,
})
} else {
let sid = self.handshake(context).await?;
Ok(HandleRequestResult {
connection_id: sid.clone(),
set_cookie: SetCookie::from_cookie_options(&self.options.cookie, sid),
})
}
}
pub fn handle_upgrade(&self) {
unimplemented!()
}
pub async fn verify_request(
&self,
sid: Option<&String>,
upgrade: bool,
transport_kind: TransportKind,
http_method: HttpMethod,
) -> Result<(), ServerError> {
if let Some(sid) = sid {
let state = self.state.read().await;
let client = state.clients.get(sid);
if let Some(client) = client {
let client_transport_kind = client.lock().await.get_transport_kind();
if !upgrade && transport_kind != client_transport_kind {
return Err(ServerError::BadRequest);
}
} else {
return Err(ServerError::UnknownSid);
}
} else {
if http_method != HttpMethod::Get {
return Err(ServerError::BadHandshakeMethod);
}
}
Ok(())
}
pub fn generate_id(&self) -> String {
Uuid::new_v4().to_hyphenated().to_string()
}
pub async fn handshake(&self, context: RequestContext) -> Result<String, ServerError> {
let id = self.generate_id();
let transport_kind = context.transport_kind;
let supports_binary = !context.query.contains_key("b64");
let remote_address = context.remote_address.clone();
let socket: Arc<AsyncMutex<Socket<A>>> = Arc::new(AsyncMutex::new(Socket::new(
id.clone(),
remote_address,
self.event_senders.client.clone(),
transport_kind,
supports_binary,
)));
{
{
let mut state = self.state.write().await;
state.clients.insert(id.clone(), socket.clone());
}
let mut socket = socket.lock().await;
socket.open(&self.options).await;
if let Some(initial_message_packet) = self.options.initial_packet.clone() {
socket.send_packet(initial_message_packet, None).await;
}
}
subscribe_socket_to_transport_events(socket).await;
{
let client = self.get_client_or_error(&id).await?;
let mut client = client.lock().await;
let transport = client.get_transport_mut();
if let Transport::Polling(transport) = transport {
transport.handle_request(&context).await;
}
}
let _ = self
.event_senders
.server
.clone()
.send(ServerEvent::Connection {
connection_id: id.clone(),
});
Ok(id)
}
pub async fn clients_count(&self) -> usize {
self.state.read().await.clients.len()
}
pub async fn get_client_or_error(
&self,
id: &str,
) -> Result<Arc<AsyncMutex<Socket<A>>>, ServerError> {
let state = self.state.read().await;
if let Some(client) = state.clients.get(id) {
Ok(client.clone())
} else {
Err(ServerError::UnknownSid)
}
}
fn subscribe_to_socket_events(&self, client_event_receiver: mpsc::Receiver<SocketEvent>) {
let server_event_sender = self.event_senders.server.clone();
let state = self.state.clone();
tokio::spawn(async move {
let mut receiver = client_event_receiver;
while let Some(message) = receiver.recv().await {
match message {
SocketEvent::Close { socket_id } => {
let mut state = state.write().await;
state.clients.remove(&socket_id);
}
SocketEvent::Flush { socket_id } => {
let _ = server_event_sender.send(ServerEvent::Flush {
connection_id: socket_id,
});
}
SocketEvent::Drain { socket_id } => {
let _ = server_event_sender.send(ServerEvent::Drain {
connection_id: socket_id,
});
}
SocketEvent::Message { socket_id, data } => {
let _ = server_event_sender.send(ServerEvent::Message {
connection_id: socket_id,
data,
});
}
_ => {}
}
}
});
}
}