use bevy::log;
use io::Write;
use std::error::Error;
use std::fmt::{Debug, Formatter};
use std::io;
use std::net::SocketAddr;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use async_trait::async_trait;
use crate::connection::MAX_PACKET_SIZE;
use crate::packet_length_serializer::PacketLengthDeserializationError;
use crate::serializer::Serializer;
use crate::PacketLengthSerializer;
#[async_trait]
pub trait Protocol: Send + Sync + 'static {
type Listener: Listener<Stream = Self::ServerStream>;
type ServerStream: ServerStream;
type ClientStream: ClientStream;
async fn bind(addr: SocketAddr) -> io::Result<Self::Listener>;
async fn connect_to_server(addr: SocketAddr) -> io::Result<Self::ClientStream> {
let stream = Self::ClientStream::connect(addr).await?;
log::debug!("Connected to a server at {:?}", stream.peer_addr());
Ok(stream)
}
}
#[async_trait]
pub trait Listener {
type Stream: ServerStream;
async fn accept(&self) -> io::Result<Self::Stream>;
fn address(&self) -> SocketAddr;
fn handle_disconnection(&self, #[allow(unused_variables)] peer_addr: SocketAddr) {}
}
#[async_trait]
pub trait ClientStream: NetworkStream {
async fn connect(addr: SocketAddr) -> io::Result<Self>
where
Self: Sized;
}
pub trait ServerStream: NetworkStream {}
#[async_trait]
pub trait NetworkStream: Send + Sync + 'static {
type ReadHalf: ReadStream;
type WriteHalf: WriteStream;
async fn into_split(self) -> io::Result<(Self::ReadHalf, Self::WriteHalf)>;
fn peer_addr(&self) -> SocketAddr;
fn local_addr(&self) -> SocketAddr;
}
#[async_trait]
pub trait ReadStream: Send + Sync + 'static {
async fn read_exact(&mut self, buffer: &mut [u8]) -> io::Result<()>;
async fn receive<ReceivingPacket, SendingPacket, S, LS>(
&mut self,
serializer: Arc<S>,
length_serializer: &LS,
) -> Result<ReceivingPacket, ReceiveError<S::DecodeError, LS>>
where
ReceivingPacket: Send + Sync + Debug + 'static,
SendingPacket: Send + Sync + Debug + 'static,
S: Serializer<ReceivingPacket, SendingPacket> + ?Sized,
LS: PacketLengthSerializer,
{
let mut buf = Vec::new();
let mut length = Err(PacketLengthDeserializationError::NeedMoreBytes(LS::SIZE));
while let Err(PacketLengthDeserializationError::NeedMoreBytes(amt)) = length {
let mut tmp = vec![0; amt];
self.read_exact(&mut tmp).await.map_err(ReceiveError::Io)?;
buf.extend(tmp);
length = length_serializer.deserialize_packet_length(&buf);
}
match length {
Ok(length) => {
if length > MAX_PACKET_SIZE.load(Ordering::Relaxed) {
Err(ReceiveError::PacketTooBig)
} else {
let mut buf = vec![0; length];
self.read_exact(&mut buf).await.map_err(ReceiveError::Io)?;
Ok(serializer
.deserialize(&buf)
.map_err(ReceiveError::Deserialization)?)
}
}
Err(PacketLengthDeserializationError::Err(err)) => {
Err(ReceiveError::LengthDeserialization(err))
}
Err(PacketLengthDeserializationError::NeedMoreBytes(_)) => unreachable!(),
}
}
}
pub enum ReceiveError<SerializationError, LS>
where
SerializationError: Error + Send + Sync,
LS: PacketLengthSerializer,
{
Io(io::Error),
Deserialization(SerializationError),
LengthDeserialization(LS::Error),
PacketTooBig,
NoConnection(io::Error),
IntentionalDisconnection,
}
impl<SerializationError, LS> Debug for ReceiveError<SerializationError, LS>
where
SerializationError: Error + Send + Sync,
LS: PacketLengthSerializer,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
ReceiveError::Io(error) => write!(f, "ReceiveError::Io({error:?})"),
ReceiveError::Deserialization(error) => {
write!(f, "ReceiveError::Deserialization({error:?})")
}
ReceiveError::LengthDeserialization(error) => {
write!(f, "ReceiveError::LengthDeserialization({error:?})")
}
ReceiveError::PacketTooBig => write!(f, "ReceiveError::PacketTooBig"),
ReceiveError::NoConnection(error) => write!(f, "ReceiveError::NoConnection({error:?})"),
ReceiveError::IntentionalDisconnection => write!(f, "IntentionalDisconnection"),
}
}
}
#[async_trait]
pub trait WriteStream: Send + Sync + 'static {
async fn write_all(&mut self, buffer: &[u8]) -> io::Result<()>;
async fn send<ReceivingPacket, SendingPacket, S, LS>(
&mut self,
packet: SendingPacket,
serializer: Arc<S>,
length_serializer: &LS,
) -> io::Result<()>
where
ReceivingPacket: Send + Sync + Debug + 'static,
SendingPacket: Send + Sync + Debug + 'static,
S: Serializer<ReceivingPacket, SendingPacket> + ?Sized,
LS: PacketLengthSerializer,
{
let serialized = serializer
.serialize(packet)
.expect("Error serializing packet");
let mut buf = length_serializer
.serialize_packet_length(serialized.len())
.expect("Error serializing packet length");
buf.write_all(&serialized)?;
self.write_all(&buf).await?;
Ok(())
}
}