mod err;
mod header;
mod net;
pub(crate) use err::Error as RawError;
pub(crate) use net::Error as NetStateError;
#[cfg(debug_assertions)]
use crate::fmt::unreachable;
use crate::{
buffer::BufferProvider,
client::raw::{header::HeaderState, net::NetState},
fmt::{debug, debug_assert, error, warn},
header::FixedHeader,
io::{Transport, err::WriteError, read::BodyReader},
packet::{RxError, RxPacket, TxError, TxPacket},
types::ReasonCode,
v5::packet::DisconnectPacket,
};
#[derive(Debug)]
pub(crate) struct Raw<'b, N: Transport, B: BufferProvider<'b>> {
n: NetState<N>,
buf: &'b mut B,
header: HeaderState,
}
impl<'b, N: Transport, B: BufferProvider<'b>> Raw<'b, N, B> {
pub fn new_disconnected(buf: &'b mut B) -> Self {
Self {
n: NetState::Terminated,
buf,
header: HeaderState::new(),
}
}
pub fn set_net(&mut self, net: N) {
debug_assert!(
!self.n.is_ok(),
"Network must not be in Ok() state to replace it."
);
self.n.replace(net);
}
pub fn buffer(&self) -> &B {
self.buf
}
pub fn buffer_mut(&mut self) -> &mut B {
self.buf
}
pub fn close_with(&mut self, reason_code: Option<ReasonCode>) {
match reason_code {
Some(r) => self.n.fail(r),
None => {
drop(self.n.terminate());
debug!("closed network connection");
}
}
}
pub async fn abort(&mut self) -> Result<(), RawError<B::ProvisionError>> {
debug_assert!(
!self.n.is_ok(),
"Network must not be in Ok() state to disconnect due to an error."
);
match &mut self.n {
NetState::Faulted(n, r) => {
let packet = DisconnectPacket::new(*r);
debug!("sending DISCONNECT packet with reason code: {:?}", r);
let r = packet
.send(n)
.await
.map_err(Into::into)
.inspect_err(|e| error!("I/O error during send: {:?}", e));
self.close_with(None);
r
}
NetState::Ok(_) => {
self.close_with(None);
Ok(())
}
NetState::Terminated => Ok(()),
}
}
fn handle_rx<E: Into<(RawError<B::ProvisionError>, Option<ReasonCode>)>>(
&mut self,
e: E,
) -> RawError<B::ProvisionError> {
let (e, r) = e.into();
match e {
RawError::Network(ref e) => error!("I/O error during receive: {:?}", e),
RawError::Disconnected => {
#[cfg(debug_assertions)]
unreachable!(
"only instantiated from `NetStateError` which is not handled with `handle_rx` and logged separately"
);
#[cfg(not(debug_assertions))]
error!(
"unreachable: only instantiated from `NetStateError` which is not handled with `handle_rx` and logged separately"
);
}
RawError::Alloc(ref e) => error!("buffer provision failed: {:?}", e),
RawError::Server => error!("server protocol violation"),
}
self.close_with(r);
e
}
fn handle_tx<E: Into<RawError<B::ProvisionError>>>(
&mut self,
e: E,
) -> RawError<B::ProvisionError> {
let e = e.into();
match e {
RawError::Network(ref e) => error!("I/O error during send: {:?}", e),
RawError::Disconnected => {
#[cfg(debug_assertions)]
unreachable!(
"only instantiated from `NetStateError` which is not handled with `handle_tx` and logged separately"
);
#[cfg(not(debug_assertions))]
error!(
"unreachable: only instantiated from `NetStateError` which is not handled with `handle_tx` and logged separately"
);
}
RawError::Alloc(_) => {
#[cfg(debug_assertions)]
unreachable!("writing cannot trigger allocation");
#[cfg(not(debug_assertions))]
error!("unreachable: writing cannot trigger allocation");
}
RawError::Server => {
#[cfg(debug_assertions)]
unreachable!("server error cannot be caused by sending");
#[cfg(not(debug_assertions))]
error!("unreachable: server error cannot be caused by sending");
}
}
drop(self.n.terminate());
debug!("closed network connection");
e
}
pub async fn recv_header(&mut self) -> Result<FixedHeader, RawError<B::ProvisionError>> {
let net = self.n.get().inspect_err(|e| match e {
NetStateError::Faulted => {
warn!("attempted to receive from a faulted mqtt connection")
}
NetStateError::Terminated => {
warn!("attempted to receive from a closed network connection")
}
})?;
loop {
match self.header.update(net).await {
Ok(None) => {}
Ok(Some(h)) => return Ok(h),
Err(e) => {
let e: RxError<_, _> = e.into();
return Err(self.handle_rx(e));
}
}
}
}
pub async fn recv_body<P: RxPacket<'b>>(
&mut self,
header: &FixedHeader,
) -> Result<P, RawError<B::ProvisionError>> {
let net = self.n.get().inspect_err(|e| match e {
NetStateError::Faulted => warn!("attempted to receive from a faulted mqtt connection"),
NetStateError::Terminated => {
warn!("attempted to receive from a closed network connection")
}
})?;
let reader = BodyReader::new(net, self.buf, header.remaining_len.size());
P::receive(header, reader)
.await
.map_err(|e| self.handle_rx(e))
}
pub async fn send<P: TxPacket>(
&mut self,
packet: &P,
) -> Result<(), RawError<B::ProvisionError>> {
let net = self.n.get().inspect_err(|e| match e {
NetStateError::Faulted => warn!("attempted to send on a faulted mqtt connection"),
NetStateError::Terminated => warn!("attempted to send on a closed network connection"),
})?;
packet.send(net).await.map_err(|e| self.handle_tx(e))
}
pub async fn flush(&mut self) -> Result<(), RawError<B::ProvisionError>> {
let net = self.n.get().inspect_err(|e| match e {
NetStateError::Faulted => warn!("attempted to flush a faulted mqtt connection"),
NetStateError::Terminated => warn!("attempted to flush a closed network connection"),
})?;
net.flush().await.map_err(|e| {
let e: WriteError<_> = e.into();
let e: TxError<_> = e.into();
self.handle_tx(e)
})
}
}
#[cfg(test)]
mod unit {
use core::time::Duration;
use embedded_io_adapters::tokio_1::FromTokio;
use tokio::{
io::{AsyncWriteExt, duplex},
join,
sync::oneshot::channel,
time::{sleep, timeout},
};
use tokio_test::{assert_err, assert_ok};
#[cfg(feature = "alloc")]
use crate::buffer::AllocBuffer;
#[cfg(feature = "bump")]
use crate::buffer::BumpBuffer;
use crate::{
client::raw::Raw,
header::{FixedHeader, PacketType},
types::VarByteInt,
};
#[tokio::test]
#[test_log::test]
async fn recv_header_simple() {
#[cfg(feature = "alloc")]
let mut b = AllocBuffer;
#[cfg(feature = "bump")]
let mut b = [0; 64];
#[cfg(feature = "bump")]
let mut b = BumpBuffer::new(&mut b);
let (c, mut s) = duplex(64);
let r = FromTokio::new(c);
let mut c = Raw::new_disconnected(&mut b);
c.set_net(r);
let tx = async {
assert_ok!(s.write_all(&[0x10, 0x00, 0x24]).await);
};
let rx = async {
let h = assert_ok!(c.recv_header().await);
assert_eq!(
h,
FixedHeader::new(PacketType::Connect, 0x00, VarByteInt::from(0u8))
);
};
join!(rx, tx);
}
#[tokio::test]
#[test_log::test]
async fn recv_header_with_pause() {
#[cfg(feature = "alloc")]
let mut b = AllocBuffer;
#[cfg(feature = "bump")]
let mut b = [0; 64];
#[cfg(feature = "bump")]
let mut b = BumpBuffer::new(&mut b);
let (c, mut s) = duplex(64);
let r = FromTokio::new(c);
let mut c = Raw::new_disconnected(&mut b);
c.set_net(r);
let tx = async {
assert_ok!(s.write_u8(0xE0).await);
sleep(Duration::from_millis(100)).await;
assert_ok!(s.write_u8(0x80).await);
sleep(Duration::from_millis(100)).await;
assert_ok!(s.write_u8(0x80).await);
sleep(Duration::from_millis(100)).await;
assert_ok!(s.write_u8(0x01).await);
};
let rx = async {
let h = assert_ok!(c.recv_header().await);
assert_eq!(
h,
FixedHeader::new(PacketType::Disconnect, 0x00, VarByteInt::from(16_384u16))
);
};
join!(rx, tx);
}
#[tokio::test]
#[test_log::test]
async fn recv_header_cancel_no_progres() {
#[cfg(feature = "alloc")]
let mut b = AllocBuffer;
#[cfg(feature = "bump")]
let mut b = [0; 64];
#[cfg(feature = "bump")]
let mut b = BumpBuffer::new(&mut b);
let (c, mut s) = duplex(64);
let r = FromTokio::new(c);
let (rx_ready, tx_ready) = channel();
let mut c = Raw::new_disconnected(&mut b);
c.set_net(r);
let tx = async {
assert_ok!(tx_ready.await);
assert_ok!(s.write_all(&[0xE0, 0x00]).await);
};
let rx = async {
assert_err!(timeout(Duration::from_millis(100), c.recv_header()).await);
assert_ok!(rx_ready.send(()));
let h = assert_ok!(c.recv_header().await);
assert_eq!(
h,
FixedHeader::new(PacketType::Disconnect, 0x00, VarByteInt::from(0u8))
);
};
join!(rx, tx);
}
#[tokio::test]
#[test_log::test]
async fn recv_header_cancel_type_and_flags_byte() {
#[cfg(feature = "alloc")]
let mut b = AllocBuffer;
#[cfg(feature = "bump")]
let mut b = [0; 64];
#[cfg(feature = "bump")]
let mut b = BumpBuffer::new(&mut b);
let (c, mut s) = duplex(64);
let r = FromTokio::new(c);
let (rx_ready, tx_ready) = channel();
let mut c = Raw::new_disconnected(&mut b);
c.set_net(r);
let tx = async {
assert_ok!(s.write_u8(0xA0).await);
assert_ok!(tx_ready.await);
assert_ok!(s.write_all(&[0x80, 0x80, 0x80, 0x01]).await);
};
let rx = async {
assert_err!(timeout(Duration::from_millis(100), c.recv_header()).await);
assert_ok!(rx_ready.send(()));
let h = assert_ok!(c.recv_header().await);
assert_eq!(
h,
FixedHeader::new(
PacketType::Unsubscribe,
0x00,
VarByteInt::new(2_097_152u32).unwrap()
)
);
};
join!(rx, tx);
}
#[tokio::test]
#[test_log::test]
async fn recv_header_cancel_single_length_byte() {
#[cfg(feature = "alloc")]
let mut b = AllocBuffer;
#[cfg(feature = "bump")]
let mut b = [0; 64];
#[cfg(feature = "bump")]
let mut b = BumpBuffer::new(&mut b);
let (c, mut s) = duplex(64);
let r = FromTokio::new(c);
let (rx_ready, tx_ready) = channel();
let mut c = Raw::new_disconnected(&mut b);
c.set_net(r);
let tx = async {
assert_ok!(s.write_all(&[0xD7, 0xFF]).await);
assert_ok!(tx_ready.await);
assert_ok!(s.write_all(&[0xFF, 0xFF, 0x7F]).await);
};
let rx = async {
assert_err!(timeout(Duration::from_millis(100), c.recv_header()).await);
assert_ok!(rx_ready.send(()));
let h = assert_ok!(c.recv_header().await);
assert_eq!(
h,
FixedHeader::new(
PacketType::Pingresp,
0x07,
VarByteInt::new(VarByteInt::MAX_ENCODABLE).unwrap()
)
);
};
join!(rx, tx);
}
#[tokio::test]
#[test_log::test]
async fn recv_header_cancel_multi() {
#[cfg(feature = "alloc")]
let mut b = AllocBuffer;
#[cfg(feature = "bump")]
let mut b = [0; 64];
#[cfg(feature = "bump")]
let mut b = BumpBuffer::new(&mut b);
let (c, mut s) = duplex(64);
let r = FromTokio::new(c);
let (rx_ready1, tx_ready1) = channel();
let (rx_ready2, tx_ready2) = channel();
let (rx_ready3, tx_ready3) = channel();
let mut c = Raw::new_disconnected(&mut b);
c.set_net(r);
let tx = async {
assert_ok!(s.write_u8(0x68).await);
assert_ok!(tx_ready1.await);
assert_ok!(s.write_u8(0xFF).await);
assert_ok!(tx_ready2.await);
assert_ok!(s.write_u8(0xFF).await);
assert_ok!(tx_ready3.await);
assert_ok!(s.write_u8(0x7F).await);
};
let rx = async {
assert_err!(timeout(Duration::from_millis(50), c.recv_header()).await);
assert_ok!(rx_ready1.send(()));
assert_err!(timeout(Duration::from_millis(50), c.recv_header()).await);
assert_ok!(rx_ready2.send(()));
assert_err!(timeout(Duration::from_millis(50), c.recv_header()).await);
assert_ok!(rx_ready3.send(()));
let h = assert_ok!(c.recv_header().await);
assert_eq!(
h,
FixedHeader::new(
PacketType::Pubrel,
0x08,
VarByteInt::new(2_097_151u32).unwrap()
)
);
};
join!(rx, tx);
}
}