use std::sync::{Arc, Mutex};
use std::io::{Read, Write};
use polariton::packet::{Packet, Message, StandardMessage, Data};
const MAX_ERRORS: usize = 3;
pub enum ToSend<C = ()> {
Packet(Packet<C>),
Data {
data: Data<C>,
encrypt: bool,
channel: u8,
reliable: bool,
},
}
impl <C> std::convert::From<Packet<C>> for ToSend<C> {
fn from(value: Packet<C>) -> Self {
Self::Packet(value)
}
}
pub struct Server<U: Send + Sync + 'static, C: Send + Sync + Clone + core::fmt::Debug + 'static = ()> {
op_handler: Arc<super::operations::OperationsHandler<U, C>>,
event_handler: Arc<super::events::EventsHandler<U, C>>,
allow_unencrypted: bool,
join_handles: Mutex<Vec<std::thread::JoinHandle<()>>>,
#[cfg(feature = "tokio-async")]
async_join_handles: tokio::sync::Mutex<Vec<tokio::task::JoinHandle<()>>>,
}
impl <U: Send + Sync + 'static, C: Send + Sync + Clone + core::fmt::Debug + 'static> Server<U, C> {
pub fn new(ops: super::operations::OperationsHandler<U, C>, events: super::events::EventsHandler<U, C>) -> Self {
Self {
op_handler: Arc::new(ops),
event_handler: Arc::new(events),
allow_unencrypted: true,
join_handles: Mutex::new(Vec::default()),
#[cfg(feature = "tokio-async")]
async_join_handles: tokio::sync::Mutex::new(Vec::default()),
}
}
pub fn force_encryption(mut self) -> Self {
self.allow_unencrypted = false;
self
}
pub fn handle<R: Read + Send + 'static, W: Write + Send + 'static, CSI: polariton::serdes::CustomSerdes<C> + Send + Sync + 'static>(
&self,
socket_r: R,
socket_w: W,
user_state: U,
ctx: polariton::packet::SerdesContext<C, CSI>,
) -> std::sync::mpsc::Sender<ToSend<C>> {
let (tx, rx) = std::sync::mpsc::channel();
self.handle_with_channel(socket_r, socket_w, user_state, ctx, tx.clone(), rx);
tx
}
pub fn handle_with_channel<R: Read + Send + 'static, W: Write + Send + 'static, CSI: polariton::serdes::CustomSerdes<C> + Send + Sync + 'static>(
&self,
socket_r: R,
socket_w: W,
user_state: U,
ctx: polariton::packet::SerdesContext<C, CSI>,
tx: std::sync::mpsc::Sender<ToSend<C>>,
rx: std::sync::mpsc::Receiver<ToSend<C>>,
) {
let op_handler = self.op_handler.clone();
let event_handler = self.event_handler.clone();
let ctx = Arc::new(ctx);
let ctx2 = ctx.clone();
let allow_unencrypt = self.allow_unencrypted;
let join1 = std::thread::spawn(move || process_socket(op_handler, event_handler, socket_r, user_state, ctx2, tx.clone(), allow_unencrypt));
let join2 = std::thread::spawn(move || handle_send_packet(rx, ctx, socket_w, allow_unencrypt));
let mut lock = self.join_handles.lock().unwrap();
lock.push(join1);
lock.push(join2);
}
#[cfg(feature = "tokio-async")]
pub async fn handle_async<'a, R: tokio::io::AsyncRead + Send + Unpin + 'static, W: tokio::io::AsyncWrite + Send + Unpin + 'static, CSI: polariton::serdes::CustomSerdes<C> + Send + Sync + 'static>(
&self,
socket_r: R,
socket_w: W,
user_state: U,
ctx: polariton::packet::SerdesContext<C, CSI>,
) -> tokio::sync::mpsc::UnboundedSender<ToSend<C>> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
self.handle_async_with_channel(socket_r, socket_w, user_state, ctx, tx.clone(), rx).await;
tx
}
#[cfg(feature = "tokio-async")]
pub async fn handle_async_with_channel<'a, R: tokio::io::AsyncRead + Send + Unpin + 'static, W: tokio::io::AsyncWrite + Send + Unpin + 'static, CSI: polariton::serdes::CustomSerdes<C> + Send + Sync + 'static>(
&self,
socket_r: R,
socket_w: W,
user_state: U,
ctx: polariton::packet::SerdesContext<C, CSI>,
tx: tokio::sync::mpsc::UnboundedSender<ToSend<C>>,
rx: tokio::sync::mpsc::UnboundedReceiver<ToSend<C>>,
) {
let op_handler = self.op_handler.clone();
let event_handler = self.event_handler.clone();
let ctx = Arc::new(ctx);
let join1 = tokio::spawn(process_socket_async(op_handler, event_handler, socket_r, user_state, ctx.clone(), tx.clone(), self.allow_unencrypted));
let join2 = tokio::spawn(handle_send_packet_async(rx, ctx, socket_w, self.allow_unencrypted));
let mut lock = self.async_join_handles.lock().await;
lock.push(join1);
lock.push(join2);
}
#[cfg(feature = "tokio-async")]
pub async fn handle_async_with_channel_join<'a, R: tokio::io::AsyncRead + Send + Unpin + 'static, W: tokio::io::AsyncWrite + Send + Unpin + 'static, CSI: polariton::serdes::CustomSerdes<C> + Send + Sync + 'static>(
&self,
socket_r: R,
socket_w: W,
user_state: U,
ctx: polariton::packet::SerdesContext<C, CSI>,
tx: tokio::sync::mpsc::UnboundedSender<ToSend<C>>,
rx: tokio::sync::mpsc::UnboundedReceiver<ToSend<C>>,
) {
let op_handler = self.op_handler.clone();
let event_handler = self.event_handler.clone();
let ctx = Arc::new(ctx);
let join1 = tokio::spawn(process_socket_async(op_handler, event_handler, socket_r, user_state, ctx.clone(), tx.clone(), self.allow_unencrypted));
let join2 = tokio::spawn(handle_send_packet_async(rx, ctx, socket_w, self.allow_unencrypted));
join1.await.unwrap_or_default();
let mut lock = self.async_join_handles.lock().await;
lock.push(join2);
}
pub fn join(&self) {
let join_handles = {
let mut lock = self.join_handles.lock().unwrap();
let mut join_handles = Vec::with_capacity(lock.len());
join_handles.append(&mut lock);
join_handles
};
for jh in join_handles {
jh.join().unwrap();
}
}
#[cfg(feature = "tokio-async")]
pub async fn join_async(&self) {
let join_handles = {
let mut lock = self.async_join_handles.lock().await;
let mut join_handles = Vec::with_capacity(lock.len());
join_handles.append(&mut lock);
join_handles
};
for jh in join_handles {
jh.await.unwrap();
}
}
}
fn process_socket<R: Read + Send + 'static, C: Send + Sync + Clone + core::fmt::Debug + 'static, CSI: polariton::serdes::CustomSerdes<C> + Send + 'static, U: Send + Sync + 'static,>(
op_handler: Arc<super::operations::OperationsHandler<U, C>>,
event_handler: Arc<super::events::EventsHandler<U, C>>,
mut socket: R,
user_state: U,
ctx: Arc<polariton::packet::SerdesContext<C, CSI>>,
chann: std::sync::mpsc::Sender<ToSend<C>>,
allow_unencrypted: bool,
) {
let mut error_count = 0;
loop {
match super::utils::receive_packet(&mut socket, &ctx) {
Ok(packet) => if let Err(e) = handle_packet(packet, &chann, &ctx, &op_handler, &event_handler, &user_state, allow_unencrypted) {
log::error!("socket write error: {}", e);
error_count += 1;
} else {
error_count = 0;
},
Err(e) => {
match e.kind() {
std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::UnexpectedEof => {
log::debug!("socket closed: {}", e);
break;
},
_ => {
error_count += 1;
log::error!("socket read error: {}", e);
},
}
}
}
if error_count >= MAX_ERRORS {
break;
}
}
}
#[cfg(feature = "tokio-async")]
async fn process_socket_async<R: tokio::io::AsyncRead + Unpin + 'static, C: Send + Sync + Clone + core::fmt::Debug + 'static, CSI: polariton::serdes::CustomSerdes<C> + Send + Sync + 'static, U: Send + Sync + 'static,>(
op_handler: Arc<super::operations::OperationsHandler<U, C>>,
event_handler: Arc<super::events::EventsHandler<U, C>>,
mut socket: R,
user_state: U,
ctx: Arc<polariton::packet::SerdesContext<C, CSI>>,
chann: tokio::sync::mpsc::UnboundedSender<ToSend<C>>,
allow_unencrypted: bool,
) {
let mut error_count = 0;
loop {
match super::utils::receive_packet_async(&mut socket, &ctx).await {
Ok(packet) => if let Err(e) = handle_packet_async(packet, &chann, &ctx, &op_handler, &event_handler, &user_state, allow_unencrypted).await {
log::error!("socket write error: {}", e);
error_count += 1;
} else {
error_count = 0;
},
Err(e) => {
match e.kind() {
std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::UnexpectedEof => {
log::debug!("socket closed: {}", e);
break;
},
_ => {
error_count += 1;
log::error!("socket read error: {}", e);
},
}
}
}
if error_count >= MAX_ERRORS {
break;
}
}
}
fn handle_packet<C: Send + Sync + Clone + core::fmt::Debug + 'static, CSI: polariton::serdes::CustomSerdes<C>, U: Send + Sync>(
packet: Packet<C>,
chann: &std::sync::mpsc::Sender<ToSend<C>>,
ctx: &polariton::packet::SerdesContext<C, CSI>,
op_handler: &super::operations::OperationsHandler<U, C>,
event_handler: &super::events::EventsHandler<U, C>,
user_state: &U,
allow_unencrypted: bool,
) -> std::io::Result<()> {
match packet {
Packet::Ping(ping) => chann.send(super::utils::process_ping(ping).into()).map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
Packet::Packet(std_packet) => match std_packet.message {
Message::Ping(ping) => chann.send(super::utils::process_ping(ping).into()).map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
Message::Standard(msg) => {
let packet_info = PacketInfo {
channel: std_packet.header.channel,
is_reliable: std_packet.header.is_reliable(),
};
if !msg.is_encrypted() && !allow_unencrypted {
log::warn!("Ignoring unencrypted packet message {}", msg);
return Ok(());
}
handle_message(msg, chann, ctx, op_handler, event_handler, user_state, packet_info)
},
},
}
}
#[cfg(feature = "tokio-async")]
async fn handle_packet_async<C: Send + Sync + Clone + core::fmt::Debug + 'static, CSI: polariton::serdes::CustomSerdes<C>, U: Send + Sync>(
packet: Packet<C>,
chann: &tokio::sync::mpsc::UnboundedSender<ToSend<C>>,
ctx: &polariton::packet::SerdesContext<C, CSI>,
op_handler: &super::operations::OperationsHandler<U, C>,
event_handler: &super::events::EventsHandler<U, C>,
user_state: &U,
allow_unencrypted: bool,
) -> std::io::Result<()> {
match packet {
Packet::Ping(ping) => chann.send(super::utils::process_ping(ping).into()).map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
Packet::Packet(std_packet) => match std_packet.message {
Message::Ping(ping) => chann.send(super::utils::process_ping(ping).into()).map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e)),
Message::Standard(msg) => {
let packet_info = PacketInfo {
channel: std_packet.header.channel,
is_reliable: std_packet.header.is_reliable(),
};
if !msg.is_encrypted() && !allow_unencrypted {
log::warn!("Ignoring unencrypted packet message {}", msg);
return Ok(());
}
handle_message_async(msg, chann, ctx, op_handler, event_handler, user_state, packet_info).await
},
},
}
}
struct PacketInfo {
channel: u8,
is_reliable: bool,
}
fn handle_message<C: Send + Sync + Clone + core::fmt::Debug + 'static, CSI: polariton::serdes::CustomSerdes<C>, U: Send + Sync>(
msg: StandardMessage<C>,
chann: &std::sync::mpsc::Sender<ToSend<C>>,
ctx: &polariton::packet::SerdesContext<C, CSI>,
op_handler: &super::operations::OperationsHandler<U, C>,
event_handler: &super::events::EventsHandler<U, C>,
user_state: &U,
packet: PacketInfo,
) -> std::io::Result<()> {
let is_encrypted = msg.is_encrypted();
let mut flags = msg.flags;
match msg.data {
Data::OpReq(req) => {
let resp = op_handler.handle_op(user_state, req, &mut flags)?;
let resp_packet = Packet::<C>::from_message(
Message::Standard(StandardMessage {
flags: 0,
data: Data::OpResp(resp),
}.encrypt(is_encrypted || flags & 0x80 != 0)),
packet.channel,
packet.is_reliable,
ctx,
)?;
chann.send(resp_packet.into()).map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
Ok(())
},
Data::Event(event) => {
event_handler.handle_event(user_state, event);
Ok(())
}
other => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, format!("Unsupported message data variant {}", other.code())))
}
}
#[cfg(feature = "tokio-async")]
async fn handle_message_async<C: Send + Sync + Clone + core::fmt::Debug + 'static, CSI: polariton::serdes::CustomSerdes<C>, U: Send + Sync>(
msg: StandardMessage<C>,
chann: &tokio::sync::mpsc::UnboundedSender<ToSend<C>>,
ctx: &polariton::packet::SerdesContext<C, CSI>,
op_handler: &super::operations::OperationsHandler<U, C>,
event_handler: &super::events::EventsHandler<U, C>,
user_state: &U,
packet: PacketInfo,
) -> std::io::Result<()> {
let is_encrypted = msg.is_encrypted();
let mut flags = msg.flags;
match msg.data {
Data::OpReq(req) => {
let resp = op_handler.handle_op_async(user_state, req, &mut flags).await?;
let resp_packet = Packet::<C>::from_message(
Message::Standard(StandardMessage {
flags: 0,
data: Data::OpResp(resp),
}.encrypt(is_encrypted || flags & 0x80 != 0)),
packet.channel,
packet.is_reliable,
ctx,
)?;
chann.send(resp_packet.into()).map_err(|e| std::io::Error::new(std::io::ErrorKind::BrokenPipe, e))?;
Ok(())
},
Data::Event(event) => {
event_handler.handle_event_async(user_state, event).await;
Ok(())
}
other => Err(std::io::Error::new(std::io::ErrorKind::InvalidInput, format!("Unsupported message data variant {}", other.code())))
}
}
fn handle_send_packet<C: Send + Sync + Clone + core::fmt::Debug + 'static, CSI: polariton::serdes::CustomSerdes<C>, W: Write>(
chann: std::sync::mpsc::Receiver<ToSend<C>>,
ctx: Arc<polariton::packet::SerdesContext<C, CSI>>,
mut socket_w: W,
allow_unencrypted: bool,
) {
while let Ok(to_send) = chann.recv() {
let result = match to_send {
ToSend::Packet(packet) => super::utils::send_packet(&packet, &mut socket_w, &ctx),
ToSend::Data { data, encrypt, channel, reliable } => {
let actual_encrypt = if allow_unencrypted { encrypt } else { true };
match Packet::from_message(Message::Standard(StandardMessage { flags: 0, data }.encrypt(actual_encrypt)), channel, reliable, &ctx) {
Ok(packet) => super::utils::send_packet(&packet, &mut socket_w, &ctx),
Err(e) => Err(e),
}
}
};
if let Err(e) = result {
log::error!("Failed sending packet: {}", e);
}
}
}
#[cfg(feature = "tokio-async")]
async fn handle_send_packet_async<C: Send + Sync + Clone + core::fmt::Debug + 'static, CSI: polariton::serdes::CustomSerdes<C>, W: tokio::io::AsyncWrite + Unpin>(
mut chann: tokio::sync::mpsc::UnboundedReceiver<ToSend<C>>,
ctx: Arc<polariton::packet::SerdesContext<C, CSI>>,
mut socket_w: W,
allow_unencrypted: bool,
) {
while let Some(to_send) = chann.recv().await {
let result = match to_send {
ToSend::Packet(packet) => super::utils::send_packet_async(&packet, &mut socket_w, &ctx).await,
ToSend::Data { data, encrypt, channel, reliable } => {
let actual_encrypt = if allow_unencrypted { encrypt } else { true };
match Packet::from_message(Message::Standard(StandardMessage { flags: 0, data }.encrypt(actual_encrypt)), channel, reliable, &ctx) {
Ok(packet) => super::utils::send_packet_async(&packet, &mut socket_w, &ctx).await,
Err(e) => Err(e),
}
}
};
if let Err(e) = result {
log::error!("Failed sending packet: {}", e);
}
}
}