use crate::adapter::Adapter;
use crate::packet::{Packet, PacketData, PacketType};
use crate::server::ServerOptions;
use crate::transport::{
PollingResponder, PollingTransportOptions, RequestReply, Transport, TransportError,
TransportEvent, TransportImpl, TransportKind, WebsocketTransportOptions,
};
use crate::util::{RequestContext, ServerError};
use serde_json::json;
use std::mem;
use std::sync::Arc;
use tokio::sync::mpsc;
use tokio::sync::Mutex as AsyncMutex;
type Callback = Box<dyn Fn() + Send + 'static>;
#[derive(Display, Debug, Clone, PartialEq)]
pub enum SocketEvent {
Open { socket_id: String },
Close { socket_id: String },
Flush { socket_id: String },
Drain { socket_id: String },
Upgrade { socket_id: String },
Heartbeat { socket_id: String },
Message { socket_id: String, data: PacketData },
}
pub struct Socket<A: 'static + Adapter> {
pub id: String,
upgrade_state: UpgradeState,
ready_state: ReadyState,
remote_address: String,
write_buffer: Vec<Packet>,
event_sender: mpsc::Sender<SocketEvent>,
transport_holder: TransportHolder<A>,
pending_callbacks: Vec<Callback>,
flushed_callbacks: Vec<CallbackBatch>,
}
struct TransportHolder<A: 'static + Adapter> {
transport: Transport<A>,
transport_event_sender: mpsc::Sender<TransportEvent>,
transport_event_receiver: Arc<AsyncMutex<mpsc::Receiver<TransportEvent>>>,
}
impl<A: 'static + Adapter> TransportHolder<A> {
pub(crate) fn new(
transport: Transport<A>,
transport_event_sender: mpsc::Sender<TransportEvent>,
transport_event_receiver: mpsc::Receiver<TransportEvent>,
) -> Self {
let mut holder = TransportHolder {
transport,
transport_event_receiver: Arc::new(AsyncMutex::new(transport_event_receiver)),
transport_event_sender,
};
holder
}
}
enum CallbackBatch {
NonFramed { callbacks: Vec<Callback> },
Framed { callback: Callback },
}
pub enum SocketError {
TransportError,
ParseError,
}
impl<A: 'static + Adapter> Socket<A> {
pub fn new(
id: String,
remote_address: String,
event_sender: mpsc::Sender<SocketEvent>,
transport_kind: TransportKind,
supports_binary: bool,
) -> Self {
let (transport, transport_event_sender, transport_event_receiver) =
Self::create_transport(&id, transport_kind, supports_binary);
Socket {
id,
remote_address,
upgrade_state: UpgradeState::Initial,
ready_state: ReadyState::Opening,
transport_holder: TransportHolder::new(
transport,
transport_event_sender,
transport_event_receiver,
),
write_buffer: Vec::new(),
event_sender,
pending_callbacks: Vec::new(),
flushed_callbacks: Vec::new(),
}
}
fn create_transport(
sid: &str,
transport_kind: TransportKind,
supports_binary: bool,
) -> (
Transport<A>,
mpsc::Sender<TransportEvent>,
mpsc::Receiver<TransportEvent>,
) {
let (transport_event_sender, transport_event_receiver) = mpsc::channel(128);
let transport: Transport<A> = match transport_kind {
TransportKind::WebSocket => Transport::WebSocket(A::create_websocket_transport(
&sid,
WebsocketTransportOptions {
per_message_deflate: true,
},
transport_event_sender.clone(),
)),
TransportKind::Polling => {
Transport::Polling(A::create_polling_transport(
&sid,
PollingTransportOptions {
max_http_buffer_size: 1024,
http_compression: None,
supports_binary,
},
transport_event_sender.clone(),
))
}
};
(transport, transport_event_sender, transport_event_receiver)
}
fn set_transport(&mut self, transport_kind: TransportKind, supports_binary: bool) {
let (transport, transport_event_sender, transport_event_receiver) =
Self::create_transport(&self.id, transport_kind, supports_binary);
self.transport_holder =
TransportHolder::new(transport, transport_event_sender, transport_event_receiver);
}
pub fn get_transport(&self) -> &Transport<A> {
&self.transport_holder.transport
}
pub fn get_transport_mut(&mut self) -> &mut Transport<A> {
&mut self.transport_holder.transport
}
pub fn get_transport_mut_as_polling(&mut self) -> Result<&mut A::Polling, ServerError> {
let transport = self.get_transport_mut();
if let Transport::Polling(transport) = transport {
Ok(transport)
} else {
Err(ServerError::Unknown)
}
}
pub fn get_transport_kind(&self) -> TransportKind {
self.transport_holder.transport.get_transport_kind()
}
pub async fn open(&mut self, server_options: &ServerOptions) {
self.ready_state = ReadyState::Open;
self.send_open_packet(server_options).await;
self.event_sender
.send(SocketEvent::Open {
socket_id: self.id.clone(),
})
.await;
self.set_ping_timeout();
}
pub fn close(&mut self, discard: bool) {
if self.ready_state == ReadyState::Open {
self.ready_state = ReadyState::Closing {
with_discard: discard,
};
if self.write_buffer.is_empty() {
self.transport_holder.transport.close();
}
}
}
pub(crate) async fn send_packet(&mut self, packet: Packet, callback: Option<Callback>) {
if self.ready_state == ReadyState::Opening || self.ready_state == ReadyState::Open {
self.write_buffer.push(packet.clone());
if let Some(callback) = callback {
self.pending_callbacks.push(callback);
}
self.flush().await;
}
}
pub async fn send(&mut self, packet_data: PacketData, callback: Option<Callback>) {
self.send_packet(
Packet {
packet_type: PacketType::Message,
data: packet_data,
},
callback,
)
.await;
}
pub async fn write(&mut self, packet_data: PacketData, callback: Option<Callback>) {
self.send_packet(
Packet {
packet_type: PacketType::Message,
data: packet_data,
},
callback,
)
.await;
}
pub async fn handle_polling_request(
&mut self,
request_context: RequestContext,
) -> Result<A::Response, ServerError> {
let transport = self.get_transport_mut_as_polling()?;
match transport.handle_request(&request_context).await {
RequestReply::Action(event) => {
match event {
TransportEvent::Drain => {
if let Some(packets) = self.flush().await {
let transport = self.get_transport_mut_as_polling()?;
Ok(transport.respond_with_packets(&request_context, packets))
} else {
Err(ServerError::BadRequest)
}
}
_ => {
Err(ServerError::Unknown)
}
}
}
RequestReply::Response(response) => Ok(response),
}
}
pub fn maybe_upgrade(&mut self, transport_kind: TransportKind, supports_binary: bool) {
self.set_transport(transport_kind, supports_binary);
}
async fn close_transport(&mut self, discard: bool) {
if discard {
self.transport_holder.transport.discard();
}
self.transport_holder.transport.close().await;
}
async fn flush(&mut self) -> Option<Vec<Packet>> {
let transport = &mut self.transport_holder.transport;
if self.ready_state != ReadyState::Closed
&& transport.is_writable()
&& self.write_buffer.len() > 0
{
let id = &self.id;
self.event_sender
.send(SocketEvent::Flush {
socket_id: id.clone(),
})
.await;
let mut buffer = Some(mem::replace(&mut self.write_buffer, Vec::new()));
if let Transport::Polling(_) = transport {
transport.send(buffer.take().unwrap()).await;
}
let callbacks = mem::replace(&mut self.pending_callbacks, Vec::new());
let flushed_callbacks: Vec<CallbackBatch> = if transport.supports_framing() {
callbacks
.into_iter()
.map(move |callback| CallbackBatch::Framed { callback })
.collect()
} else {
vec![CallbackBatch::NonFramed { callbacks }]
};
self.flushed_callbacks.extend(flushed_callbacks);
let _ = self
.event_sender
.send(SocketEvent::Drain {
socket_id: id.clone(),
})
.await;
if let ReadyState::Closing { with_discard } = self.ready_state {
if with_discard {
transport.discard();
} else {
transport.close().await;
}
}
if let Some(buffer) = buffer {
return Some(buffer);
}
}
None
}
pub fn get_available_upgrades(&self) -> Vec<&str> {
unimplemented!();
}
fn set_ping_timeout(&self) {
unimplemented!();
}
async fn send_open_packet(&mut self, server_options: &ServerOptions) {
let open_packet_data = json!({
"sid": self.id,
"upgrades": self.get_available_upgrades(),
"ping_interval": server_options.ping_interval,
"ping_timeout": server_options.ping_timeout
});
let open_packet = Packet {
packet_type: PacketType::Open,
data: PacketData::from(open_packet_data.to_string()),
};
self.send_packet(open_packet, None).await;
}
async fn on_close(&mut self, reason: SocketError, description: &str) {
if self.ready_state != ReadyState::Closed {
self.ready_state = ReadyState::Closed;
self.write_buffer.clear();
self.pending_callbacks.clear();
self.close_transport(false).await;
let _ = self
.event_sender
.send(SocketEvent::Close {
socket_id: self.id.clone(),
})
.await;
}
}
async fn on_transport_error(&mut self, error: TransportError) {
if self.ready_state == ReadyState::Opening || self.ready_state == ReadyState::Open {
match error {
TransportError::PacketParseError => {
self.on_close(SocketError::ParseError, "FIXME").await
}
_ => self.on_close(SocketError::TransportError, "FIXME").await,
}
}
}
async fn on_packet(&mut self, packet: Packet) {
if self.ready_state == ReadyState::Open {
self.set_ping_timeout();
match packet.packet_type {
PacketType::Ping => {
self.send_packet(
Packet {
packet_type: PacketType::Pong,
data: PacketData::Empty,
},
None,
)
.await
}
PacketType::Upgrade => {
if self.ready_state != ReadyState::Closed
&& self.upgrade_state == UpgradeState::Upgrading
{
self.close_transport(false).await;
let _ = self
.event_sender
.send(SocketEvent::Upgrade {
socket_id: self.id.clone(),
})
.await;
self.set_ping_timeout();
self.flush().await;
}
}
PacketType::Message => {
let _ = self
.event_sender
.send(SocketEvent::Message {
socket_id: self.id.clone(),
data: packet.data,
})
.await;
}
_ => {}
}
}
}
async fn on_flush(&mut self) {
self.flush().await;
}
async fn on_drain(&mut self) {
if !self.flushed_callbacks.is_empty() {
let first_callback_batch = self.flushed_callbacks.remove(0);
match first_callback_batch {
CallbackBatch::Framed { callback } => {
callback();
}
CallbackBatch::NonFramed { callbacks } => {
callbacks.iter().for_each(|callback| {
callback();
})
}
}
}
}
}
pub async fn subscribe_socket_to_transport_events<A: 'static + Adapter>(
socket: Arc<AsyncMutex<Socket<A>>>,
) {
let receiver = {
let socket = socket.lock().await;
socket.transport_holder.transport_event_receiver.clone()
};
let subscriber_task = async move {
let mut receiver = receiver.lock().await;
while let Some(message) = receiver.recv().await {
let _ = match message {
TransportEvent::Error { error } => {
println!("transport error");
socket.lock().await.on_transport_error(error).await;
}
TransportEvent::Packet { packet } => {
println!("on packet!");
socket.lock().await.on_packet(packet).await;
}
TransportEvent::Drain => {
println!("on drain");
let mut socket = socket.lock().await;
socket.on_flush().await;
socket.on_drain().await;
}
TransportEvent::Close => {
println!("on close");
socket
.lock()
.await
.on_close(SocketError::TransportError, "FIXME")
.await
}
};
}
};
tokio::spawn(subscriber_task);
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum UpgradeState {
Initial,
Upgrading,
Upgraded,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ReadyState {
Opening,
Open,
Closing { with_discard: bool },
Closed,
}