use std::io::{self, Read, Write};
use byteorder::{ReadBytesExt, WriteBytesExt};
#[cfg(feature = "tokio")]
use tokio::io::{AsyncRead, AsyncReadExt};
use crate::control::packet_type::{PacketType, PacketTypeError};
use crate::{Decodable, Encodable};
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct FixedHeader {
pub packet_type: PacketType,
pub remaining_length: u32,
}
impl FixedHeader {
pub fn new(packet_type: PacketType, remaining_length: u32) -> FixedHeader {
debug_assert!(remaining_length <= 0x0FFF_FFFF);
FixedHeader {
packet_type,
remaining_length,
}
}
#[cfg(feature = "tokio")]
pub async fn parse<A: AsyncRead + Unpin>(rdr: &mut A) -> Result<Self, FixedHeaderError> {
let type_val = rdr.read_u8().await?;
let mut remaining_len = 0;
let mut i = 0;
loop {
let byte = rdr.read_u8().await?;
remaining_len |= (u32::from(byte) & 0x7F) << (7 * i);
if i >= 4 {
return Err(FixedHeaderError::MalformedRemainingLength);
}
if byte & 0x80 == 0 {
break;
} else {
i += 1;
}
}
match PacketType::from_u8(type_val) {
Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)),
Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)),
Err(err) => Err(From::from(err)),
}
}
}
impl Encodable for FixedHeader {
fn encode<W: Write>(&self, wr: &mut W) -> Result<(), io::Error> {
wr.write_u8(self.packet_type.to_u8())?;
let mut cur_len = self.remaining_length;
loop {
let mut byte = (cur_len & 0x7F) as u8;
cur_len >>= 7;
if cur_len > 0 {
byte |= 0x80;
}
wr.write_u8(byte)?;
if cur_len == 0 {
break;
}
}
Ok(())
}
fn encoded_length(&self) -> u32 {
let rem_size = if self.remaining_length >= 2_097_152 {
4
} else if self.remaining_length >= 16_384 {
3
} else if self.remaining_length >= 128 {
2
} else {
1
};
1 + rem_size
}
}
impl Decodable for FixedHeader {
type Error = FixedHeaderError;
type Cond = ();
fn decode_with<R: Read>(rdr: &mut R, _rest: ()) -> Result<FixedHeader, FixedHeaderError> {
let type_val = rdr.read_u8()?;
let remaining_len = {
let mut cur = 0u32;
for i in 0.. {
let byte = rdr.read_u8()?;
cur |= ((byte as u32) & 0x7F) << (7 * i);
if i >= 4 {
return Err(FixedHeaderError::MalformedRemainingLength);
}
if byte & 0x80 == 0 {
break;
}
}
cur
};
match PacketType::from_u8(type_val) {
Ok(packet_type) => Ok(FixedHeader::new(packet_type, remaining_len)),
Err(PacketTypeError::ReservedType(ty, _)) => Err(FixedHeaderError::ReservedType(ty, remaining_len)),
Err(err) => Err(From::from(err)),
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum FixedHeaderError {
#[error("malformed remaining length")]
MalformedRemainingLength,
#[error("reserved header ({0}, {1})")]
ReservedType(u8, u32),
#[error(transparent)]
PacketTypeError(#[from] PacketTypeError),
#[error(transparent)]
IoError(#[from] io::Error),
}
#[cfg(test)]
mod test {
use super::*;
use crate::control::packet_type::{ControlType, PacketType};
use crate::{Decodable, Encodable};
use std::io::Cursor;
#[test]
fn test_encode_fixed_header() {
let header = FixedHeader::new(PacketType::with_default(ControlType::Connect), 321);
let mut buf = Vec::new();
header.encode(&mut buf).unwrap();
let expected = b"\x10\xc1\x02";
assert_eq!(&expected[..], &buf[..]);
}
#[test]
fn test_decode_fixed_header() {
let stream = b"\x10\xc1\x02";
let mut cursor = Cursor::new(&stream[..]);
let header = FixedHeader::decode(&mut cursor).unwrap();
assert_eq!(header.packet_type, PacketType::with_default(ControlType::Connect));
assert_eq!(header.remaining_length, 321);
}
#[test]
#[should_panic]
fn test_decode_too_long_fixed_header() {
let stream = b"\x10\x80\x80\x80\x80\x02";
let mut cursor = Cursor::new(&stream[..]);
FixedHeader::decode(&mut cursor).unwrap();
}
}