use bytes::{BufMut, Bytes, BytesMut};
use thiserror::Error;
use zerocopy::byteorder::network_endian::U16;
use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout, Unaligned};
use crate::headers::{decode_headers, Header};
#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
#[repr(C)]
struct FrameHeader {
opcode: u8,
length: U16,
}
#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
#[repr(C)]
struct ConnectWire {
version: u8,
flags: u8,
max_packet: U16,
}
#[derive(FromBytes, IntoBytes, KnownLayout, Immutable, Unaligned)]
#[repr(C)]
struct SetPathWire {
flags: u8,
constants: u8,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum OpCode {
Connect,
Disconnect,
Put,
PutFinal,
Get,
GetFinal,
SetPath,
Abort,
Continue,
Ok,
Created,
BadRequest,
Unauthorized,
Forbidden,
NotFound,
InternalServerError,
NotImplemented,
Other(u8),
}
impl OpCode {
#[must_use]
pub const fn from_byte(b: u8) -> Self {
match b {
0x80 => Self::Connect,
0x81 => Self::Disconnect,
0x02 => Self::Put,
0x82 => Self::PutFinal,
0x03 => Self::Get,
0x83 => Self::GetFinal,
0x85 => Self::SetPath,
0xFF => Self::Abort,
0x90 => Self::Continue,
0xA0 => Self::Ok,
0xA1 => Self::Created,
0xC0 => Self::BadRequest,
0xC1 => Self::Unauthorized,
0xC3 => Self::Forbidden,
0xC4 => Self::NotFound,
0xD0 => Self::InternalServerError,
0xD4 => Self::NotImplemented,
other => Self::Other(other),
}
}
#[must_use]
pub const fn to_byte(self) -> u8 {
match self {
Self::Connect => 0x80,
Self::Disconnect => 0x81,
Self::Put => 0x02,
Self::PutFinal => 0x82,
Self::Get => 0x03,
Self::GetFinal => 0x83,
Self::SetPath => 0x85,
Self::Abort => 0xFF,
Self::Continue => 0x90,
Self::Ok => 0xA0,
Self::Created => 0xA1,
Self::BadRequest => 0xC0,
Self::Unauthorized => 0xC1,
Self::Forbidden => 0xC3,
Self::NotFound => 0xC4,
Self::InternalServerError => 0xD0,
Self::NotImplemented => 0xD4,
Self::Other(b) => b,
}
}
#[must_use]
pub const fn is_ok(self) -> bool {
matches!(self, Self::Ok | Self::Created)
}
#[must_use]
pub const fn is_continue(self) -> bool {
matches!(self, Self::Continue)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum PacketExtra {
None,
Connect {
version: u8,
flags: u8,
max_packet: u16,
},
SetPath {
flags: u8,
constants: u8,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Packet {
pub opcode: OpCode,
pub extra: PacketExtra,
pub headers: Vec<Header>,
}
#[derive(Debug, Error)]
pub enum PacketError {
#[error("packet too short")]
TooShort,
#[error("invalid header encoding")]
InvalidHeader,
#[error("invalid UTF-16 in name header")]
InvalidName,
#[error("packet too large to encode")]
PacketTooLarge,
#[error("header payload too large to encode")]
HeaderTooLarge,
}
impl Packet {
pub fn decode(data: &[u8]) -> Result<Self, PacketError> {
parse_packet(data)
}
pub fn decode_connect_response(data: &[u8]) -> Result<Self, PacketError> {
parse_connect_response(data)
}
#[must_use = "encoded bytes must be sent"]
pub fn encode(&self) -> Result<Bytes, PacketError> {
let mut body = BytesMut::with_capacity(256);
match &self.extra {
PacketExtra::None => {}
PacketExtra::Connect { version, flags, max_packet } => {
body.put_slice(
ConnectWire {
version: *version,
flags: *flags,
max_packet: U16::new(*max_packet),
}
.as_bytes(),
);
}
PacketExtra::SetPath { flags, constants } => {
body.put_slice(SetPathWire { flags: *flags, constants: *constants }.as_bytes());
}
}
for h in &self.headers {
h.encode_into(&mut body)?;
}
let total = body
.len()
.checked_add(3)
.and_then(|n| u16::try_from(n).ok())
.ok_or(PacketError::PacketTooLarge)?;
let mut out = BytesMut::with_capacity(total.into());
out.put_slice(
FrameHeader { opcode: self.opcode.to_byte(), length: U16::new(total) }.as_bytes(),
);
out.put(body);
Ok(out.freeze())
}
#[must_use]
pub fn header_connection_id(&self) -> Option<u32> {
self.headers.iter().find_map(Header::connection_id)
}
#[must_use]
pub fn header_target(&self) -> Option<&[u8]> {
self.headers.iter().find_map(|h| {
if let Header::Target(b) = h {
Some(b.as_ref())
} else {
None
}
})
}
#[must_use]
pub fn header_name(&self) -> Option<String> {
self.headers.iter().find_map(|h| match h {
Header::Name(s) if !s.is_empty() => Some(s.clone()),
_ => None,
})
}
#[must_use]
pub fn body_payload(&self) -> Option<&[u8]> {
self.headers.iter().find_map(|h| match h {
Header::EndOfBody(b) | Header::Body(b) => Some(b.as_ref()),
_ => None,
})
}
}
fn framed(data: &[u8]) -> Result<(OpCode, &[u8]), PacketError> {
let (hdr, rest) = FrameHeader::ref_from_prefix(data).map_err(|_| PacketError::TooShort)?;
let total: usize = hdr.length.get().into();
let body_len = total.checked_sub(3).ok_or(PacketError::TooShort)?;
let body = rest.get(..body_len).ok_or(PacketError::TooShort)?;
Ok((OpCode::from_byte(hdr.opcode), body))
}
fn parse_packet(data: &[u8]) -> Result<Packet, PacketError> {
let (opcode, body) = framed(data)?;
let (extra, headers_bytes) = match opcode {
OpCode::Connect => {
let (w, rest) =
ConnectWire::ref_from_prefix(body).map_err(|_| PacketError::TooShort)?;
(
PacketExtra::Connect {
version: w.version,
flags: w.flags,
max_packet: w.max_packet.get(),
},
rest,
)
}
OpCode::SetPath => {
let (w, rest) =
SetPathWire::ref_from_prefix(body).map_err(|_| PacketError::TooShort)?;
(PacketExtra::SetPath { flags: w.flags, constants: w.constants }, rest)
}
_ => (PacketExtra::None, body),
};
let mut input = headers_bytes;
let headers = decode_headers(&mut input)?;
Ok(Packet { opcode, extra, headers })
}
fn parse_connect_response(data: &[u8]) -> Result<Packet, PacketError> {
let (opcode, body) = framed(data)?;
let (w, rest) = ConnectWire::ref_from_prefix(body).map_err(|_| PacketError::TooShort)?;
let extra =
PacketExtra::Connect { version: w.version, flags: w.flags, max_packet: w.max_packet.get() };
let mut input = rest;
let headers = decode_headers(&mut input)?;
Ok(Packet { opcode, extra, headers })
}