use mqttbytes::*;
use mqttbytes::v4::*;
use std::collections::VecDeque;
use std::io;
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::select;
use tokio::{task, time};
use async_channel::{bounded, Receiver, Sender};
use bytes::BytesMut;
use rumqttc::{Event, Incoming, Outgoing, Packet};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
pub struct Broker {
pub(crate) framed: Network,
pub(crate) incoming: VecDeque<Packet>,
outgoing_tx: Sender<Packet>,
outgoing_rx: Receiver<Packet>,
}
impl Broker {
pub async fn new(port: u16, connack: u8) -> Broker {
let addr = format!("127.0.0.1:{}", port);
let listener = TcpListener::bind(&addr).await.unwrap();
let (stream, _) = listener.accept().await.unwrap();
let mut framed = Network::new(stream, 10 * 1024);
let mut incoming = VecDeque::new();
let (outgoing_tx, outgoing_rx) = bounded(10);
framed.readb(&mut incoming).await.unwrap();
match incoming.pop_front().unwrap() {
Packet::Connect(_) => {
let connack = match connack {
0 => ConnAck::new(ConnectReturnCode::Success, false),
1 => ConnAck::new(ConnectReturnCode::BadUserNamePassword, false),
_ => {
return Broker {
framed,
incoming,
outgoing_tx,
outgoing_rx,
}
}
};
framed.connack(connack).await.unwrap();
}
_ => {
panic!("Expecting connect packet");
}
}
Broker {
framed,
incoming: VecDeque::new(),
outgoing_tx,
outgoing_rx,
}
}
pub async fn read_publish(&mut self) -> Option<Publish> {
loop {
let packet = if self.incoming.len() > 0 {
self.incoming.pop_front().unwrap()
} else {
let packet = time::timeout(Duration::from_secs(2), async {
self.framed.readb(&mut self.incoming).await.unwrap();
self.incoming.pop_front().unwrap()
})
.await;
match packet {
Ok(packet) => packet,
Err(_e) => return None,
}
};
match packet {
Packet::Publish(publish) => return Some(publish),
Packet::PingReq => {
self.framed.write(Packet::PingResp).await.unwrap();
continue;
}
packet => panic!("Expecting a publish. Received = {:?}", packet),
}
}
}
pub async fn read_packet(&mut self) -> Packet {
time::timeout(Duration::from_secs(30), async {
let p = self.framed.readb(&mut self.incoming).await;
p.unwrap()
})
.await
.unwrap();
let packet = self.incoming.pop_front().unwrap();
packet
}
pub async fn blackhole(&mut self) -> Packet {
loop {
let _packet = self.framed.readb(&mut self.incoming).await.unwrap();
}
}
pub async fn ack(&mut self, pkid: u16) {
let packet = Packet::PubAck(PubAck::new(pkid));
self.framed.write(packet).await.unwrap();
}
pub async fn pingresp(&mut self) {
let packet = Packet::PingResp;
self.framed.write(packet).await.unwrap();
}
pub async fn spawn_publishes(&mut self, count: u8, qos: QoS, delay: u64) {
let tx = self.outgoing_tx.clone();
task::spawn(async move {
for i in 1..=count {
let topic = "hello/world".to_owned();
let payload = vec![1, 2, 3, i];
let mut publish = Publish::new(topic, qos, payload);
if qos as u8 > 0 {
publish.pkid = i as u16;
}
let packet = Packet::Publish(publish);
tx.send(packet).await.unwrap();
time::sleep(Duration::from_secs(delay)).await;
}
});
}
pub async fn tick(&mut self) -> Event {
select! {
request = self.outgoing_rx.recv() => {
let request = request.unwrap();
let outgoing = self.framed.write(request).await.unwrap();
Event::Outgoing(outgoing)
}
packet = self.framed.readb(&mut self.incoming) => {
packet.unwrap();
let incoming = self.incoming.pop_front().unwrap();
Event::Incoming(incoming)
}
}
}
}
pub struct Network {
socket: Box<dyn N>,
read: BytesMut,
write: BytesMut,
max_incoming_size: usize,
max_readb_count: usize,
}
impl Network {
pub fn new(socket: impl N + 'static, max_incoming_size: usize) -> Network {
let socket = Box::new(socket) as Box<dyn N>;
Network {
socket,
read: BytesMut::with_capacity(10 * 1024),
write: BytesMut::with_capacity(10 * 1024),
max_incoming_size,
max_readb_count: 10,
}
}
async fn read_bytes(&mut self, required: usize) -> io::Result<usize> {
let mut total_read = 0;
loop {
let read = self.socket.read_buf(&mut self.read).await?;
if 0 == read {
return if self.read.is_empty() {
Err(io::Error::new(
io::ErrorKind::ConnectionAborted,
"connection closed by peer",
))
} else {
Err(io::Error::new(
io::ErrorKind::ConnectionReset,
"connection reset by peer",
))
};
}
total_read += read;
if total_read >= required {
return Ok(total_read);
}
}
}
pub async fn connack(&mut self, connack: ConnAck) -> Result<usize, io::Error> {
let mut write = BytesMut::new();
let len = match connack.write(&mut write) {
Ok(size) => size,
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};
self.socket.write_all(&write[..]).await?;
Ok(len)
}
pub async fn readb(&mut self, incoming: &mut VecDeque<Incoming>) -> Result<(), io::Error> {
let mut count = 0;
loop {
match read(&mut self.read, self.max_incoming_size) {
Ok(packet) => {
incoming.push_back(packet);
count += 1;
if count >= self.max_readb_count {
return Ok(());
}
}
Err(Error::InsufficientBytes(_)) if count > 0 => return Ok(()),
Err(Error::InsufficientBytes(required)) => {
self.read_bytes(required).await?;
}
Err(e) => return Err(io::Error::new(io::ErrorKind::InvalidData, e.to_string())),
};
}
}
#[inline]
async fn write(&mut self, packet: Packet) -> Result<Outgoing, Error> {
let outgoing = outgoing(&packet);
match packet {
Packet::Publish(packet) => packet.write(&mut self.write)?,
Packet::PubRel(packet) => packet.write(&mut self.write)?,
Packet::PingReq => {
let packet = PingReq;
packet.write(&mut self.write)?
}
Packet::PingResp => {
let packet = PingResp;
packet.write(&mut self.write)?
}
Packet::Subscribe(packet) => packet.write(&mut self.write)?,
Packet::SubAck(packet) => packet.write(&mut self.write)?,
Packet::Unsubscribe(packet) => packet.write(&mut self.write)?,
Packet::UnsubAck(packet) => packet.write(&mut self.write)?,
Packet::Disconnect => {
let packet = Disconnect;
packet.write(&mut self.write)?
}
Packet::PubAck(packet) => packet.write(&mut self.write)?,
Packet::PubRec(packet) => packet.write(&mut self.write)?,
Packet::PubComp(packet) => packet.write(&mut self.write)?,
_ => unimplemented!(),
};
self.socket.write_all(&self.write[..]).await.unwrap();
self.write.clear();
Ok(outgoing)
}
}
fn outgoing(packet: &Packet) -> Outgoing {
match packet {
Packet::Publish(publish) => Outgoing::Publish(publish.pkid),
Packet::PubAck(puback) => Outgoing::PubAck(puback.pkid),
Packet::PubRec(pubrec) => Outgoing::PubRec(pubrec.pkid),
Packet::PubRel(pubrel) => Outgoing::PubRel(pubrel.pkid),
Packet::PubComp(pubcomp) => Outgoing::PubComp(pubcomp.pkid),
Packet::Subscribe(subscribe) => Outgoing::Subscribe(subscribe.pkid),
Packet::Unsubscribe(unsubscribe) => Outgoing::Unsubscribe(unsubscribe.pkid),
Packet::PingReq => Outgoing::PingReq,
Packet::PingResp => Outgoing::PingResp,
Packet::Disconnect => Outgoing::Disconnect,
packet => panic!("Invalid outgoing packet = {:?}", packet),
}
}
pub trait N: AsyncRead + AsyncWrite + Send + Sync + Unpin {}
impl<T> N for T where T: AsyncRead + AsyncWrite + Unpin + Send + Sync {}