use alloc::string::String;
use bytes::{Buf, BufMut};
use core::{convert::TryFrom, fmt, num::NonZeroU16};
#[cfg(feature = "derive")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "std")]
use alloc::format;
#[cfg(feature = "std")]
use std::{
error::Error as ErrorTrait,
io::{Error as IoError, ErrorKind},
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Error {
WriteZero,
InvalidPid,
InvalidQos(u8),
InvalidConnectReturnCode(u8),
InvalidProtocol(String, u8),
InvalidHeader,
InvalidLength,
InvalidString(core::str::Utf8Error),
#[cfg(feature = "std")]
IoError(ErrorKind, String),
}
#[cfg(feature = "std")]
impl ErrorTrait for Error {}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{:?}", self)
}
}
#[cfg(feature = "std")]
impl From<Error> for IoError {
fn from(err: Error) -> IoError {
match err {
Error::WriteZero => IoError::new(ErrorKind::WriteZero, err),
_ => IoError::new(ErrorKind::InvalidData, err),
}
}
}
#[cfg(feature = "std")]
impl From<IoError> for Error {
fn from(err: IoError) -> Error {
match err.kind() {
ErrorKind::WriteZero => Error::WriteZero,
k => Error::IoError(k, format!("{}", err)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[cfg_attr(feature = "derive", derive(Serialize, Deserialize))]
pub struct Pid(NonZeroU16);
impl Pid {
pub fn new() -> Self {
Pid(NonZeroU16::new(1).unwrap())
}
pub fn get(self) -> u16 {
self.0.get()
}
pub(crate) fn from_buffer(buf: &mut impl Buf) -> Result<Self, Error> {
Self::try_from(buf.get_u16())
}
pub(crate) fn to_buffer(self, buf: &mut impl BufMut) -> Result<(), Error> {
Ok(buf.put_u16(self.get()))
}
}
impl Default for Pid {
fn default() -> Pid {
Pid::new()
}
}
impl core::ops::Add<u16> for Pid {
type Output = Pid;
fn add(self, u: u16) -> Pid {
let n = match self.get().overflowing_add(u) {
(n, false) => n,
(n, true) => n + 1,
};
Pid(NonZeroU16::new(n).unwrap())
}
}
impl core::ops::Sub<u16> for Pid {
type Output = Pid;
fn sub(self, u: u16) -> Pid {
let n = match self.get().overflowing_sub(u) {
(0, _) => core::u16::MAX,
(n, false) => n,
(n, true) => n - 1,
};
Pid(NonZeroU16::new(n).unwrap())
}
}
impl From<Pid> for u16 {
fn from(p: Pid) -> Self {
p.0.get()
}
}
impl TryFrom<u16> for Pid {
type Error = Error;
fn try_from(u: u16) -> Result<Self, Error> {
match NonZeroU16::new(u) {
Some(nz) => Ok(Pid(nz)),
None => Err(Error::InvalidPid),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "derive", derive(Serialize, Deserialize))]
pub enum QoS {
AtMostOnce,
AtLeastOnce,
ExactlyOnce,
}
impl QoS {
pub(crate) fn to_u8(&self) -> u8 {
match *self {
QoS::AtMostOnce => 0,
QoS::AtLeastOnce => 1,
QoS::ExactlyOnce => 2,
}
}
pub(crate) fn from_u8(byte: u8) -> Result<QoS, Error> {
match byte {
0 => Ok(QoS::AtMostOnce),
1 => Ok(QoS::AtLeastOnce),
2 => Ok(QoS::ExactlyOnce),
n => Err(Error::InvalidQos(n)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "derive", derive(Serialize, Deserialize))]
pub enum QosPid {
AtMostOnce,
AtLeastOnce(Pid),
ExactlyOnce(Pid),
}
impl QosPid {
#[cfg(test)]
pub(crate) fn from_u8u16(qos: u8, pid: u16) -> Self {
match qos {
0 => QosPid::AtMostOnce,
1 => QosPid::AtLeastOnce(Pid::try_from(pid).expect("pid == 0")),
2 => QosPid::ExactlyOnce(Pid::try_from(pid).expect("pid == 0")),
_ => panic!("Qos > 2"),
}
}
pub fn pid(self) -> Option<Pid> {
match self {
QosPid::AtMostOnce => None,
QosPid::AtLeastOnce(p) => Some(p),
QosPid::ExactlyOnce(p) => Some(p),
}
}
pub fn qos(self) -> QoS {
match self {
QosPid::AtMostOnce => QoS::AtMostOnce,
QosPid::AtLeastOnce(_) => QoS::AtLeastOnce,
QosPid::ExactlyOnce(_) => QoS::ExactlyOnce,
}
}
}
#[cfg(test)]
mod test {
use crate::Pid;
use alloc::vec;
use alloc::vec::Vec;
use core::convert::TryFrom;
#[test]
fn pid_add_sub() {
let t: Vec<(u16, u16, u16, u16)> = vec![
(2, 1, 1, 3),
(100, 1, 99, 101),
(1, 1, core::u16::MAX, 2),
(1, 2, core::u16::MAX - 1, 3),
(1, 3, core::u16::MAX - 2, 4),
(core::u16::MAX, 1, core::u16::MAX - 1, 1),
(core::u16::MAX, 2, core::u16::MAX - 2, 2),
(10, core::u16::MAX, 10, 10),
(10, 0, 10, 10),
(1, 0, 1, 1),
(core::u16::MAX, 0, core::u16::MAX, core::u16::MAX),
];
for (cur, d, prev, next) in t {
let sub = Pid::try_from(cur).unwrap() - d;
let add = Pid::try_from(cur).unwrap() + d;
assert_eq!(prev, sub.get(), "{} - {} should be {}", cur, d, prev);
assert_eq!(next, add.get(), "{} + {} should be {}", cur, d, next);
}
}
}