use bytes::BytesMut;
use mqttbytes::v4::*;
use mqttbytes::*;
use futures::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use crate::{Incoming, MqttState, StateError};
use std::io;
pub struct Network {
socket: Box<dyn N>,
read: BytesMut,
max_incoming_size: usize,
max_readb_count: usize,
}
impl Network {
pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network {
let socket = Box::new(socket) as Box<dyn N>;
Network {
socket,
read: BytesMut::with_capacity(10 * 1024),
max_incoming_size,
max_readb_count: 10,
}
}
async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
let mut total_read = 0;
loop {
let read = self.socket.read(&mut self.read).await?;
if 0 == read {
return if self.read.is_empty() {
Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"connection closed by peer",
))
} else {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection reset by peer",
))
};
}
total_read += read;
if total_read >= required {
return Ok(total_read);
}
}
}
pub async fn read(&mut self) -> Result<Incoming, io::Error> {
loop {
let required = match read(&mut self.read, self.max_incoming_size) {
Ok(packet) => return Ok(packet),
Err(mqttbytes::Error::InsufficientBytes(required)) => required,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};
self.read_bytes(required).await?;
}
}
pub async fn readb(&mut self, state: &mut MqttState) -> Result<(), StateError> {
let mut count = 0;
loop {
match read(&mut self.read, self.max_incoming_size) {
Ok(packet) => {
state.handle_incoming_packet(packet)?;
count += 1;
if count >= self.max_readb_count {
return Ok(());
}
}
Err(Error::InsufficientBytes(_)) if count > 0 => return Ok(()),
Err(Error::InsufficientBytes(required)) => {
self.read_bytes(required).await?;
}
Err(e) => return Err(StateError::Serialization(e)),
};
}
}
pub async fn connect(&mut self, connect: Connect) -> Result<usize, io::Error> {
let mut write = BytesMut::new();
let len = match connect.write(&mut write) {
Ok(size) => size,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};
self.socket.write_all(&write[..]).await?;
Ok(len)
}
pub async fn flush(&mut self, write: &mut BytesMut) -> Result<(), io::Error> {
if write.is_empty() {
return Ok(());
}
self.socket.write_all(&write[..]).await?;
write.clear();
Ok(())
}
}
pub trait N: AsyncRead + AsyncWrite + Send + Unpin {}
impl<T> N for T where T: AsyncRead + AsyncWrite + Send + Unpin {}