use std::error::Error;
use std::fmt::{Debug, Formatter};
use std::future::Future;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use bevy::log;
use bevy::prelude::{DetectChanges, Res, Resource};
use futures::task::AtomicWaker;
use tokio::sync::mpsc::error::SendError;
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use crate::packet_length_serializer::PacketLengthSerializer;
use crate::protocol::NetworkStream;
use crate::serializer::Serializer;
#[derive(Resource)]
pub struct EcsConnection<SendingPacket>
where
SendingPacket: Send + Sync + Debug + 'static,
{
pub(crate) disconnect_task: DisconnectTask,
pub(crate) id: ConnectionId,
pub(crate) packet_tx: UnboundedSender<SendingPacket>,
pub(crate) local_addr: SocketAddr,
pub(crate) peer_addr: SocketAddr,
}
impl<SendingPacket> Clone for EcsConnection<SendingPacket>
where
SendingPacket: Send + Sync + Debug + 'static,
{
fn clone(&self) -> Self {
EcsConnection {
disconnect_task: self.disconnect_task.clone(),
id: self.id,
packet_tx: self.packet_tx.clone(),
local_addr: self.local_addr,
peer_addr: self.peer_addr,
}
}
}
impl<SendingPacket> Debug for EcsConnection<SendingPacket>
where
SendingPacket: Send + Sync + Debug + 'static,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Connection #{}", self.id().0)
}
}
impl<SendingPacket> EcsConnection<SendingPacket>
where
SendingPacket: Send + Sync + Debug + 'static,
{
pub fn id(&self) -> ConnectionId {
self.id
}
pub fn peer_addr(&self) -> SocketAddr {
self.peer_addr
}
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
pub fn send(&self, packet: SendingPacket) -> Result<(), SendError<SendingPacket>> {
self.packet_tx.send(packet)
}
pub fn disconnect(&self) {
self.disconnect_task.disconnect();
}
}
pub struct RawConnection<ReceivingPacket, SendingPacket, NS, EncErr, DecErr, LS>
where
ReceivingPacket: Send + Sync + Debug + 'static,
SendingPacket: Send + Sync + Debug + 'static,
NS: NetworkStream,
EncErr: Error + Send + Sync,
DecErr: Error + Send + Sync,
LS: PacketLengthSerializer,
{
pub disconnect_task: DisconnectTask,
pub stream: NS,
pub serializer: Arc<
dyn Serializer<ReceivingPacket, SendingPacket, EncodeError = EncErr, DecodeError = DecErr>,
>,
pub packet_length_serializer: Arc<LS>,
pub packets_rx: UnboundedReceiver<SendingPacket>,
pub id: ConnectionId,
}
impl<ReceivingPacket, SendingPacket, NS, EncErr, DecErr, LS> Debug
for RawConnection<ReceivingPacket, SendingPacket, NS, EncErr, DecErr, LS>
where
ReceivingPacket: Send + Sync + Debug + 'static,
SendingPacket: Send + Sync + Debug + 'static,
NS: NetworkStream,
EncErr: Error + Send + Sync,
DecErr: Error + Send + Sync,
LS: PacketLengthSerializer,
{
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "RawConnection #{}", self.id().0)
}
}
#[derive(Clone, bevy::ecs::component::Component, Copy, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct ConnectionId(usize);
impl Debug for ConnectionId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "#{}", self.0)
}
}
impl ConnectionId {
pub fn next() -> ConnectionId {
static CONNECTION_ID: AtomicUsize = AtomicUsize::new(0);
ConnectionId(CONNECTION_ID.fetch_add(1, Ordering::Relaxed))
}
pub fn read(&self) -> usize {
self.0
}
}
pub(crate) static MAX_PACKET_SIZE: AtomicUsize = AtomicUsize::new(usize::MAX);
#[derive(Clone, Copy, Resource)]
pub struct MaxPacketSize(pub usize);
impl<ReceivingPacket, SendingPacket, NS, EncErr, DecErr, LS>
RawConnection<ReceivingPacket, SendingPacket, NS, EncErr, DecErr, LS>
where
ReceivingPacket: Send + Sync + Debug + 'static,
SendingPacket: Send + Sync + Debug + 'static,
NS: NetworkStream,
EncErr: Error + Send + Sync,
DecErr: Error + Send + Sync,
LS: PacketLengthSerializer,
{
#[cfg(feature = "client")]
pub fn new(
stream: NS,
serializer: Arc<
dyn Serializer<
ReceivingPacket,
SendingPacket,
EncodeError = EncErr,
DecodeError = DecErr,
>,
>,
packet_length_serializer: LS,
packets_rx: UnboundedReceiver<SendingPacket>,
) -> Self {
Self {
disconnect_task: DisconnectTask::default(),
stream,
serializer,
packet_length_serializer: Arc::new(packet_length_serializer),
packets_rx,
id: ConnectionId::next(),
}
}
pub fn id(&self) -> ConnectionId {
self.id
}
pub fn local_addr(&self) -> SocketAddr {
self.stream.local_addr()
}
pub fn peer_addr(&self) -> SocketAddr {
self.stream.peer_addr()
}
}
#[derive(Clone, Default)]
pub struct DisconnectTask(Arc<DisconnectTaskInner>);
#[derive(Default)]
struct DisconnectTaskInner {
disconnect: AtomicBool,
waker: AtomicWaker,
}
impl DisconnectTask {
fn disconnect(&self) {
self.0.disconnect.store(true, Ordering::Relaxed);
self.0.waker.wake();
}
}
impl Future for DisconnectTask {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.0.disconnect.load(Ordering::Relaxed) {
return Poll::Ready(());
}
self.0.waker.register(cx.waker());
if self.0.disconnect.load(Ordering::Relaxed) {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
pub(crate) fn set_max_packet_size_system(max_packet_size: Option<Res<MaxPacketSize>>) {
match max_packet_size {
Some(res) if res.is_changed() => {
MAX_PACKET_SIZE.store(res.0, Ordering::Relaxed);
}
_ => (),
}
}
pub(crate) fn max_packet_size_warning_system(max_packet_size: Option<Res<MaxPacketSize>>) {
if max_packet_size.is_none() {
log::warn!("You haven't set \"MaxPacketSize\" resource! This is a security risk, please insert it before using this in production.")
}
}