use std::{net::SocketAddr, pin::Pin};
use crate::{
async_channel::{Receiver, Sender},
async_trait,
error::NetworkError,
managers::NetworkProvider,
NetworkPacket,
};
use async_net::{TcpListener, TcpStream};
use bevy::{
log::{debug, error, info, trace},
prelude::Resource,
};
use futures_lite::{AsyncReadExt, AsyncWriteExt, FutureExt, Stream};
use std::future::Future;
#[derive(Default, Debug)]
pub struct TcpProvider;
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl NetworkProvider for TcpProvider {
type NetworkSettings = NetworkSettings;
type Socket = TcpStream;
type ReadHalf = TcpStream;
type WriteHalf = TcpStream;
type ConnectInfo = SocketAddr;
type AcceptInfo = SocketAddr;
type AcceptStream = OwnedIncoming;
async fn accept_loop(
accept_info: Self::AcceptInfo,
_: Self::NetworkSettings,
) -> Result<Self::AcceptStream, NetworkError> {
let listener = TcpListener::bind(accept_info)
.await
.map_err(NetworkError::Listen)?;
Ok(OwnedIncoming::new(listener))
}
async fn connect_task(
connect_info: Self::ConnectInfo,
_: Self::NetworkSettings,
) -> Result<Self::Socket, NetworkError> {
info!("Beginning connection");
let stream = TcpStream::connect(connect_info)
.await
.map_err(NetworkError::Connection)?;
info!("Connected!");
let addr = stream
.peer_addr()
.expect("Could not fetch peer_addr of existing stream");
debug!("Connected to: {:?}", addr);
return Ok(stream);
}
async fn recv_loop(
mut read_half: Self::ReadHalf,
messages: Sender<NetworkPacket>,
settings: Self::NetworkSettings,
) {
let mut buffer = vec![0; settings.max_packet_length];
loop {
info!("Reading message length");
let length = match read_half.read(&mut buffer[..8]).await {
Ok(0) => {
info!("Client disconnected");
break;
}
Ok(8) => {
let bytes = &buffer[..8];
u64::from_le_bytes(
bytes
.try_into()
.expect("Couldn't read bytes from connection!"),
) as usize
}
Ok(n) => {
error!(
"Could not read enough bytes for header. Expected 8, got {}",
n
);
break;
}
Err(err) => {
error!("Encountered error while fetching length: {}", err);
break;
}
};
info!("Message length: {}", length);
if length > settings.max_packet_length {
error!(
"Received too large packet: {} > {}",
length, settings.max_packet_length
);
break;
}
info!("Reading message into buffer");
match read_half.read_exact(&mut buffer[..length]).await {
Ok(()) => (),
Err(err) => {
error!(
"Encountered error while fetching stream of length {}: {}",
length, err
);
break;
}
}
info!("Message read");
let packet: NetworkPacket = match bincode::deserialize(&buffer[..length]) {
Ok(packet) => packet,
Err(err) => {
error!("Failed to decode network packet from: {}", err);
break;
}
};
if messages.send(packet).await.is_err() {
error!("Failed to send decoded message to eventwork");
break;
}
info!("Message deserialized and sent to eventwork");
}
}
async fn send_loop(
mut write_half: Self::WriteHalf,
messages: Receiver<NetworkPacket>,
_settings: Self::NetworkSettings,
) {
while let Ok(message) = messages.recv().await {
let encoded = match bincode::serialize(&message) {
Ok(encoded) => encoded,
Err(err) => {
error!("Could not encode packet {:?}: {}", message, err);
continue;
}
};
let len = encoded.len() as u64;
debug!("Sending a new message of size: {}", len);
match write_half.write(&len.to_le_bytes()).await {
Ok(_) => (),
Err(err) => {
error!("Could not send packet length: {:?}: {}", len, err);
break;
}
}
trace!("Sending the content of the message!");
match write_half.write_all(&encoded).await {
Ok(_) => (),
Err(err) => {
error!("Could not send packet: {:?}: {}", message, err);
break;
}
}
trace!("Succesfully written all!");
}
}
fn split(combined: Self::Socket) -> (Self::ReadHalf, Self::WriteHalf) {
(combined.clone(), combined)
}
}
#[derive(Clone, Debug, Resource)]
#[allow(missing_copy_implementations)]
pub struct NetworkSettings {
pub max_packet_length: usize,
}
impl Default for NetworkSettings {
fn default() -> Self {
Self {
max_packet_length: 10 * 1024 * 1024,
}
}
}
pub struct OwnedIncoming {
inner: TcpListener,
stream: Option<Pin<Box<dyn Future<Output = Option<TcpStream>>>>>,
}
impl OwnedIncoming {
fn new(listener: TcpListener) -> Self {
Self {
inner: listener,
stream: None,
}
}
}
impl Stream for OwnedIncoming {
type Item = TcpStream;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
let incoming = self.get_mut();
if incoming.stream.is_none() {
let listener: *const TcpListener = &incoming.inner;
incoming.stream = Some(Box::pin(async move {
unsafe {
listener
.as_ref()
.expect("Segfault when trying to read listener in OwnedStream")
}
.accept()
.await
.map(|(s, _)| s)
.ok()
}));
}
if let Some(stream) = &mut incoming.stream {
if let std::task::Poll::Ready(res) = stream.poll(cx) {
incoming.stream = None;
return std::task::Poll::Ready(res);
}
}
std::task::Poll::Pending
}
}
unsafe impl Send for OwnedIncoming {}