use futures_util::{FutureExt, SinkExt, StreamExt};
pub use rumqttc_core::AsyncReadWrite;
use tokio_util::codec::Framed;
use crate::mqttbytes::{
self,
v4::{Codec, Packet},
};
use crate::{Incoming, MqttState, StateError};
pub struct Network {
framed: Framed<Box<dyn AsyncReadWrite>, Codec>,
}
impl Network {
pub fn new(
socket: impl AsyncReadWrite + 'static,
max_incoming_size: usize,
max_outgoing_size: usize,
) -> Self {
let socket = Box::new(socket) as Box<dyn AsyncReadWrite>;
let codec = Codec {
max_incoming_size,
max_outgoing_size,
};
let framed = Framed::new(socket, codec);
Self { framed }
}
pub async fn read(&mut self) -> Result<Incoming, StateError> {
match self.framed.next().await {
Some(Ok(packet)) => Ok(packet),
Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(),
Some(Err(e)) => Err(StateError::Deserialization(e)),
None => Err(StateError::ConnectionAborted),
}
}
pub async fn readb(
&mut self,
state: &mut MqttState,
read_batch_limit: usize,
) -> Result<(), StateError> {
let mut res = self.framed.next().await;
let read_batch_limit = read_batch_limit.max(1);
let mut count = 0;
loop {
match res {
Some(Ok(packet)) => {
if let Some(outgoing) = state.handle_incoming_packet(packet)? {
self.write(outgoing).await?;
}
count += 1;
if count >= read_batch_limit {
break;
}
}
Some(Err(mqttbytes::Error::InsufficientBytes(_))) => unreachable!(),
Some(Err(e)) => return Err(StateError::Deserialization(e)),
None => return Err(StateError::ConnectionAborted),
}
match self.framed.next().now_or_never() {
Some(r) => res = r,
_ => break,
}
}
Ok(())
}
pub async fn write(&mut self, packet: Packet) -> Result<(), StateError> {
self.framed
.feed(packet)
.await
.map_err(StateError::Deserialization)
}
pub async fn flush(&mut self) -> Result<(), crate::state::StateError> {
self.framed
.flush()
.await
.map_err(StateError::Deserialization)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tokio::io::{AsyncWriteExt, duplex};
#[tokio::test]
async fn readb_processes_exactly_two_packets_when_limit_is_two() {
let (client, mut peer) = duplex(64);
let mut network = Network::new(client, 1024, 1024);
let mut state = MqttState::builder(10).build();
peer.write_all(&[0xD0, 0x00, 0xD0, 0x00]).await.unwrap();
network.readb(&mut state, 2).await.unwrap();
assert_eq!(state.events.len(), 2);
}
#[tokio::test]
async fn readb_processes_one_packet_when_limit_is_one() {
let (client, mut peer) = duplex(64);
let mut network = Network::new(client, 1024, 1024);
let mut state = MqttState::builder(10).build();
peer.write_all(&[0xD0, 0x00, 0xD0, 0x00]).await.unwrap();
network.readb(&mut state, 1).await.unwrap();
assert_eq!(state.events.len(), 1);
}
}