use bytes::BytesMut;
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use std::{
collections::VecDeque,
fmt::Debug,
io::{self, ErrorKind},
};
use tokio::time::{error::Elapsed, Duration};
use crate::protocol::{self, Packet, Protocol};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("I/O = {0}")]
Io(#[from] io::Error),
#[error("Invalid data = {0}")]
Protocol(#[from] protocol::Error),
#[error["Keep alive timeout"]]
KeepAlive(#[from] Elapsed),
}
pub struct Network<P> {
socket: Box<dyn N>,
read: BytesMut,
write: BytesMut,
max_incoming_size: usize,
max_connection_buffer_len: usize,
keepalive: Duration,
pub(crate) protocol: P,
}
impl<P: Protocol> Network<P> {
pub fn new(
socket: Box<dyn N>,
max_incoming_size: usize,
max_connection_buffer_len: usize,
protocol: P,
) -> Network<P> {
Network {
socket,
read: BytesMut::with_capacity(10 * 1024),
write: BytesMut::with_capacity(10 * 1024),
max_incoming_size,
max_connection_buffer_len,
keepalive: Duration::from_secs(0),
protocol,
}
}
pub fn set_keepalive(&mut self, keepalive: u16) {
let keepalive = Duration::from_secs(keepalive as u64);
self.keepalive = keepalive + keepalive.mul_f32(0.5);
}
async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
let mut total_read = 0;
loop {
let read = self.socket.read_buf(&mut self.read).await?;
if 0 == read {
let error = if self.read.is_empty() {
io::Error::new(ErrorKind::ConnectionAborted, "connection closed by peer")
} else {
io::Error::new(ErrorKind::ConnectionReset, "connection reset by peer")
};
return Err(error);
}
total_read += read;
if total_read >= required {
return Ok(total_read);
}
}
}
pub async fn read(&mut self) -> Result<Packet, io::Error> {
loop {
let required = match Protocol::read_mut(
&mut self.protocol,
&mut self.read,
self.max_incoming_size,
) {
Ok(packet) => return Ok(packet),
Err(protocol::Error::InsufficientBytes(required)) => required,
Err(e) => return Err(io::Error::new(ErrorKind::InvalidData, e.to_string())),
};
self.read_bytes(required).await?;
}
}
pub fn readv(&mut self, packets: &mut VecDeque<Packet>) -> Result<usize, Error> {
loop {
match self
.protocol
.read_mut(&mut self.read, self.max_incoming_size)
{
Ok(packet) => {
packets.push_back(packet);
let connection_buffer_length = packets.len();
if connection_buffer_length >= self.max_connection_buffer_len {
return Ok(connection_buffer_length);
}
}
Err(protocol::Error::InsufficientBytes(_)) => return Ok(packets.len()),
Err(e) => return Err(io::Error::new(ErrorKind::InvalidData, e.to_string()).into()),
}
}
}
pub async fn write<T>(&mut self, notification: T) -> Result<bool, Error>
where
T: Into<Option<Packet>>,
{
let mut unscheduled = false;
let packet_or_unscheduled = notification.into();
if let Some(packet) = packet_or_unscheduled {
Protocol::write(&self.protocol, packet, &mut self.write)?;
} else {
unscheduled = true;
}
self.socket.write_all(&self.write).await?;
self.write.clear();
Ok(unscheduled)
}
pub async fn writev<T>(&mut self, notifications: &mut VecDeque<T>) -> Result<bool, Error>
where
T: Into<Option<Packet>>,
{
let mut o = false;
for notification in notifications.drain(..) {
let packet_or_unscheduled = notification.into();
if let Some(packet) = packet_or_unscheduled {
Protocol::write(&self.protocol, packet, &mut self.write)?;
} else {
o = true
}
}
self.socket.write_all(&self.write).await?;
self.write.clear();
Ok(o)
}
}
pub trait N: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
impl<T> N for T where T: AsyncRead + AsyncWrite + Unpin + Send + Sync {}