use crate::{Dupe, proto_version};
use bytes::{Buf, BufMut, Bytes};
use std::{fmt::Debug, mem::size_of};
use thiserror::Error;
#[derive(Clone, Copy, Debug, Error, PartialEq, Eq)]
pub enum Error {
#[error("frame is invalid or incomplete")]
FrameTooShort,
#[error("unsupported frame version `{0}`")]
FrameVersion(u8),
#[error("invalid opcode `{0}`")]
InvalidOpCode(u8),
#[error("invalid `Bind` type `{0}`")]
InvalidBindType(u8),
}
#[derive(Clone, Debug)]
pub(crate) enum CowBytes<'data> {
Borrowed(&'data [u8]),
Owned(Bytes),
}
impl PartialEq for CowBytes<'_> {
fn eq(&self, other: &Self) -> bool {
self.as_ref() == other.as_ref()
}
}
impl Eq for CowBytes<'_> {}
impl Dupe for CowBytes<'_> {
#[inline]
fn dupe(&self) -> Self {
match self {
Self::Borrowed(data) => Self::Borrowed(data),
Self::Owned(bytes) => Self::Owned(bytes.dupe()),
}
}
}
impl AsRef<[u8]> for CowBytes<'_> {
#[inline]
fn as_ref(&self) -> &[u8] {
match self {
Self::Borrowed(data) => data,
Self::Owned(bytes) => bytes.as_ref(),
}
}
}
impl Default for CowBytes<'_> {
#[inline]
fn default() -> Self {
Self::Borrowed(&[])
}
}
impl CowBytes<'_> {
#[inline]
pub fn into_owned(self) -> Bytes {
match self {
Self::Borrowed(data) => Bytes::from(data.to_vec()),
Self::Owned(bytes) => bytes,
}
}
#[inline]
pub const fn len(&self) -> usize {
match self {
Self::Borrowed(data) => data.len(),
Self::Owned(bytes) => bytes.len(),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum BindType {
Stream = 1,
Datagram = 3,
}
impl TryFrom<u8> for BindType {
type Error = Error;
#[inline]
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
1 => Ok(Self::Stream),
3 => Ok(Self::Datagram),
other => Err(Error::InvalidBindType(other)),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum OpCode {
Connect = 0,
Acknowledge = 1,
Reset = 2,
Finish = 3,
Push = 4,
Bind = 5,
Datagram = 6,
}
impl TryFrom<u8> for OpCode {
type Error = Error;
#[inline]
fn try_from(value: u8) -> Result<Self, Error> {
match value {
0 => Ok(Self::Connect),
1 => Ok(Self::Acknowledge),
2 => Ok(Self::Reset),
3 => Ok(Self::Finish),
4 => Ok(Self::Push),
5 => Ok(Self::Bind),
6 => Ok(Self::Datagram),
other => Err(Error::InvalidOpCode(other)),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct ConnectPayload<'data> {
pub rwnd: u32,
pub target_port: u16,
pub target_host: CowBytes<'data>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct BindPayload<'data> {
pub bind_type: BindType,
pub target_port: u16,
pub target_host: CowBytes<'data>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct DatagramPayload<'data> {
pub target_port: u16,
pub target_host: CowBytes<'data>,
pub data: CowBytes<'data>,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) enum Payload<'data> {
Connect(ConnectPayload<'data>),
Acknowledge(u32),
Reset,
Finish,
Push(CowBytes<'data>),
Bind(BindPayload<'data>),
Datagram(DatagramPayload<'data>),
}
impl Payload<'_> {
#[inline]
const fn len(&self) -> usize {
match self {
Self::Connect(ConnectPayload { target_host, .. }) => {
size_of::<u32>() + size_of::<u16>() + target_host.len()
}
Self::Acknowledge(_) => size_of::<u32>(),
Self::Reset | Self::Finish => 0,
Self::Push(data) => data.len(),
Self::Bind(BindPayload { target_host, .. }) => {
size_of::<u8>() + size_of::<u16>() + target_host.len()
}
Self::Datagram(DatagramPayload {
target_host, data, ..
}) => size_of::<u8>() + size_of::<u16>() + target_host.len() + data.len(),
}
}
}
impl<'data> From<&Payload<'data>> for OpCode {
#[inline]
fn from(payload: &Payload<'data>) -> Self {
match payload {
Payload::Connect { .. } => Self::Connect,
Payload::Acknowledge(_) => Self::Acknowledge,
Payload::Reset => Self::Reset,
Payload::Finish => Self::Finish,
Payload::Push(_) => Self::Push,
Payload::Bind { .. } => Self::Bind,
Payload::Datagram { .. } => Self::Datagram,
}
}
}
#[derive(Clone, PartialEq, Eq)]
pub struct Frame<'data> {
pub id: u32,
pub(crate) payload: Payload<'data>,
}
impl Debug for Frame<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Frame")
.field("opcode", &OpCode::from(&self.payload))
.field("id", &format_args!("{:08x}", self.id))
.field("payload.len", &self.payload.len())
.finish()
}
}
impl<'data> Frame<'data> {
#[must_use]
#[inline]
pub const fn new_connect(
target_host: &'data [u8],
target_port: u16,
id: u32,
rwnd: u32,
) -> Self {
let payload = Payload::Connect(ConnectPayload {
rwnd,
target_port,
target_host: CowBytes::Borrowed(target_host),
});
Self { id, payload }
}
#[must_use]
#[inline]
pub const fn new_acknowledge(id: u32, psh_recvd_since: u32) -> Self {
let payload = Payload::Acknowledge(psh_recvd_since);
Self { id, payload }
}
#[must_use]
#[inline]
pub const fn new_reset(id: u32) -> Self {
Self {
id,
payload: Payload::Reset,
}
}
#[must_use]
#[inline]
pub const fn new_finish(id: u32) -> Self {
Self {
id,
payload: Payload::Finish,
}
}
#[must_use]
#[inline]
pub const fn new_push(id: u32, data: &'data [u8]) -> Self {
Self {
id,
payload: Payload::Push(CowBytes::Borrowed(data)),
}
}
#[must_use]
#[inline]
pub const fn new_push_owned(id: u32, data: Bytes) -> Self {
Self {
id,
payload: Payload::Push(CowBytes::Owned(data)),
}
}
#[must_use]
#[inline]
pub const fn new_bind(
id: u32,
bind_type: BindType,
target_host: &'data [u8],
target_port: u16,
) -> Self {
let payload = Payload::Bind(BindPayload {
bind_type,
target_port,
target_host: CowBytes::Borrowed(target_host),
});
Self { id, payload }
}
#[must_use]
#[inline]
pub const fn new_datagram(
id: u32,
target_host: &'data [u8],
target_port: u16,
data: &'data [u8],
) -> Self {
let payload = Payload::Datagram(DatagramPayload {
target_host: CowBytes::Borrowed(target_host),
target_port,
data: CowBytes::Borrowed(data),
});
Self { id, payload }
}
#[must_use]
#[inline]
pub const fn new_datagram_owned(
id: u32,
target_host: Bytes,
target_port: u16,
data: Bytes,
) -> Self {
let payload = Payload::Datagram(DatagramPayload {
target_host: CowBytes::Owned(target_host),
target_port,
data: CowBytes::Owned(data),
});
Self { id, payload }
}
#[must_use]
#[inline]
pub(crate) fn finalize(&self) -> FinalizedFrame {
FinalizedFrame(Bytes::from(self))
}
}
macro_rules! check_remaining {
($data:expr, $len:expr) => {
let remaining = $data.remaining();
if remaining < $len {
#[cfg(not(fuzzing))]
debug_assert!(
false,
"`FrameTooShort` at {}:{}, have {}/{}",
file!(),
line!(),
remaining,
$len
);
return Err(Error::FrameTooShort);
}
};
}
impl TryFrom<Bytes> for Frame<'static> {
type Error = Error;
#[inline]
fn try_from(mut data: Bytes) -> Result<Self, Self::Error> {
check_remaining!(data, size_of::<u8>() + size_of::<u32>());
let firstbyte = data.get_u8();
let ver = firstbyte >> 4;
if ver != proto_version::PROTOCOL_VERSION_NUMBER {
return Err(Error::FrameVersion(ver));
}
let opcode = OpCode::try_from(firstbyte & 0x0F)?;
let id = data.get_u32();
let payload = match opcode {
OpCode::Connect => {
check_remaining!(data, size_of::<u32>() + size_of::<u16>());
let rwnd = data.get_u32();
let target_port = data.get_u16();
let target_host = data;
Payload::Connect(ConnectPayload {
rwnd,
target_port,
target_host: CowBytes::Owned(target_host),
})
}
OpCode::Acknowledge => {
check_remaining!(data, size_of::<u32>());
let psh_recvd_since = data.get_u32();
Payload::Acknowledge(psh_recvd_since)
}
OpCode::Reset => Payload::Reset,
OpCode::Finish => Payload::Finish,
OpCode::Push => Payload::Push(CowBytes::Owned(data)),
OpCode::Bind => {
check_remaining!(data, size_of::<u8>() + size_of::<u16>());
let bind_type = BindType::try_from(data.get_u8())?;
let target_port = data.get_u16();
Payload::Bind(BindPayload {
bind_type,
target_port,
target_host: CowBytes::Owned(data),
})
}
OpCode::Datagram => {
check_remaining!(data, size_of::<u8>() + 6);
let host_len = usize::from(data.get_u8());
check_remaining!(data, host_len + 6);
let target_port = data.get_u16();
let target_host = data.split_to(host_len);
Payload::Datagram(DatagramPayload {
target_port,
target_host: CowBytes::Owned(target_host),
data: CowBytes::Owned(data),
})
}
};
Ok(Self { id, payload })
}
}
impl From<&Frame<'_>> for Vec<u8> {
#[tracing::instrument(level = "trace")]
#[inline]
fn from(frame: &Frame<'_>) -> Self {
let size = size_of::<u8>() + size_of::<u32>() + frame.payload.len();
let opcode = OpCode::from(&frame.payload) as u8;
let firstbyte = opcode | (proto_version::PROTOCOL_VERSION_NUMBER << 4);
let mut encoded = Self::with_capacity(size);
encoded.put_u8(firstbyte);
encoded.put_u32(frame.id);
match &frame.payload {
Payload::Connect(ConnectPayload {
rwnd,
target_port,
target_host,
}) => {
encoded.put_u32(*rwnd);
encoded.put_u16(*target_port);
encoded.extend(target_host.as_ref());
}
Payload::Acknowledge(psh_recvd_since) => {
encoded.put_u32(*psh_recvd_since);
}
Payload::Reset | Payload::Finish => {}
Payload::Push(data) => {
encoded.extend(data.as_ref());
}
Payload::Bind(BindPayload {
bind_type,
target_port,
target_host,
}) => {
encoded.put_u8(*bind_type as u8);
encoded.put_u16(*target_port);
encoded.extend(target_host.as_ref());
}
Payload::Datagram(DatagramPayload {
target_port,
target_host,
data,
}) => {
let len_u8 =
u8::try_from(target_host.len()).expect("Datagram target host too long");
encoded.put_u8(len_u8);
encoded.put_u16(*target_port);
encoded.extend(target_host.as_ref());
encoded.extend(data.as_ref());
}
}
debug_assert_eq!(size, encoded.len());
encoded
}
}
impl From<&Frame<'_>> for Bytes {
#[inline]
fn from(frame: &Frame<'_>) -> Self {
Self::from(Vec::from(frame))
}
}
#[derive(Clone, PartialEq, Eq)]
pub(crate) struct FinalizedFrame(Bytes);
impl FinalizedFrame {
#[inline]
pub fn opcode(&self) -> Result<OpCode, Error> {
let firstbyte = self.0.first().ok_or({
Error::FrameTooShort
})?;
let raw_opcode = firstbyte & 0x0F;
OpCode::try_from(raw_opcode)
}
}
impl Debug for FinalizedFrame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FinalizedFrame")
.field("opcode", &self.opcode())
.field("encoded_len", &self.0.len())
.finish()
}
}
impl<'data> From<&Frame<'data>> for FinalizedFrame {
#[inline]
fn from(frame: &Frame<'data>) -> Self {
Self(Bytes::from(frame))
}
}
impl TryFrom<FinalizedFrame> for Frame<'_> {
type Error = Error;
#[inline]
fn try_from(frame: FinalizedFrame) -> Result<Self, Self::Error> {
Frame::try_from(frame.0)
}
}
impl From<Bytes> for FinalizedFrame {
#[inline]
fn from(bytes: Bytes) -> Self {
Self(bytes)
}
}
impl From<FinalizedFrame> for Bytes {
#[inline]
fn from(frame: FinalizedFrame) -> Self {
frame.0
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cow_bytes_eq() {
crate::tests::setup_logging();
let cow1 = CowBytes::Borrowed(b"1234");
let cow2 = CowBytes::Owned(Bytes::from_static(b"1234"));
assert_eq!(cow1, cow2);
let cow3 = CowBytes::Borrowed(b"12345");
assert_ne!(cow1, cow3);
}
#[test]
fn test_cow_bytes() {
crate::tests::setup_logging();
let cow1 = CowBytes::Borrowed(&[1, 2, 3]);
let cow2 = cow1.dupe();
assert_eq!(cow1, cow2);
assert_eq!(cow1.len(), 3);
assert_eq!(cow2.len(), 3);
let cow3 = cow1.into_owned();
assert_eq!(cow3.as_ref(), cow2.as_ref());
let bytes = Bytes::from(vec![4, 5, 6]);
let cow4 = CowBytes::Owned(bytes.clone());
assert_eq!(cow4.as_ref(), bytes.as_ref());
assert_eq!(cow4.len(), 3);
}
#[test]
fn test_frames() {
crate::tests::setup_logging();
let frame = Frame::new_connect(&[], 5678, 1234, 128);
assert_eq!(
frame,
Frame {
id: 1234,
payload: Payload::Connect(ConnectPayload {
rwnd: 128,
target_port: 5678,
target_host: CowBytes::default(),
})
}
);
let bytes = Bytes::from(&frame);
let decoded = Frame::try_from(bytes).unwrap();
assert_eq!(frame, decoded);
let frame = Frame {
id: 5678,
payload: Payload::Datagram(DatagramPayload {
target_host: CowBytes::Borrowed(&[1, 2, 3, 4]),
target_port: 1234,
data: CowBytes::Borrowed(&[1, 2, 3, 4]),
}),
};
let bytes = Bytes::from(&frame);
let decoded = Frame::try_from(bytes).unwrap();
assert_eq!(frame, decoded);
}
#[test]
fn test_frame_repr_connect() {
crate::tests::setup_logging();
let frame = Frame::new_connect(&[0x01, 0x02, 0x03], 5678, 1234, 512);
let bytes = Vec::from(&frame);
assert_eq!(
bytes,
vec![
0x70, 0x00, 0x00, 0x04, 0xd2, 0x00, 0x00, 0x02, 0x00, 0x16, 0x2e, 0x01, 0x02, 0x03, ]
);
let frame_back = Frame::try_from(Bytes::from(bytes)).unwrap();
assert_eq!(frame, frame_back);
}
#[test]
fn test_frame_repr_acknowledge() {
crate::tests::setup_logging();
let frame = Frame::new_acknowledge(5678, 128);
let bytes = Vec::from(&frame);
assert_eq!(
bytes,
vec![
0x71, 0x00, 0x00, 0x16, 0x2e, 0x00, 0x00, 0x00, 0x80, ]
);
let frame_back = Frame::try_from(Bytes::from(bytes)).unwrap();
assert_eq!(frame, frame_back);
}
#[test]
fn test_frame_repr_reset() {
crate::tests::setup_logging();
let frame = Frame::new_reset(1291);
let bytes = Vec::from(&frame);
assert_eq!(
bytes,
vec![
0x72, 0x00, 0x00, 0x05, 0x0b, ]
);
let frame_back = Frame::try_from(Bytes::from(bytes)).unwrap();
assert_eq!(frame, frame_back);
}
#[test]
fn test_frame_repr_finish() {
crate::tests::setup_logging();
let frame = Frame::new_finish(0x534c);
let bytes = Vec::from(&frame);
assert_eq!(
bytes,
vec![
0x73, 0x00, 0x00, 0x53, 0x4c, ]
);
let frame_back = Frame::try_from(Bytes::from(bytes)).unwrap();
assert_eq!(frame, frame_back);
}
#[test]
fn test_frame_repr_push() {
crate::tests::setup_logging();
let frame = Frame::new_push(0x75b_97bb, &[1, 2, 3, 4]);
let bytes = Vec::from(&frame);
assert_eq!(
bytes,
vec![
0x74, 0x07, 0x5b, 0x97, 0xbb, 0x01, 0x02, 0x03, 0x04, ]
);
let frame_back = Frame::try_from(Bytes::from(bytes)).unwrap();
assert_eq!(frame, frame_back);
}
#[test]
fn test_frame_repr_bind() {
crate::tests::setup_logging();
let frame = Frame::new_bind(42132, BindType::Datagram, &[1, 2, 3, 4], 1234);
let bytes = Vec::from(&frame);
assert_eq!(
bytes,
vec![
0x75, 0x00, 0x00, 0xa4, 0x94, 0x03, 0x04, 0xd2, 0x01, 0x02, 0x03, 0x04 ]
);
let frame_back = Frame::try_from(Bytes::from(bytes)).unwrap();
assert_eq!(frame, frame_back);
let frame = Frame::new_bind(0x282_ea5f, BindType::Stream, &[4, 2, 3, 4], 1234);
let bytes = Vec::from(&frame);
assert_eq!(
bytes,
vec![
0x75, 0x02, 0x82, 0xea, 0x5f, 0x01, 0x04, 0xd2, 0x04, 0x02, 0x03, 0x04 ]
);
let frame_back = Frame::try_from(Bytes::from(bytes)).unwrap();
assert_eq!(frame, frame_back);
}
#[test]
fn test_frame_repr_datagram() {
crate::tests::setup_logging();
let frame = Frame::new_datagram(2134, &[1, 2, 3, 4], 1234, &[1, 2, 3, 4]);
let bytes = Vec::from(&frame);
assert_eq!(
bytes,
vec![
0x76, 0x00, 0x00, 0x08, 0x56, 0x04, 0x04, 0xd2, 0x01, 0x02, 0x03, 0x04, 0x01, 0x02, 0x03, 0x04 ]
);
let frame_back = Frame::try_from(Bytes::from(bytes)).unwrap();
assert_eq!(frame, frame_back);
}
#[test]
#[should_panic(expected = "Datagram target host too long")]
fn test_finalized_frame_too_long() {
crate::tests::setup_logging();
let long_hostname = vec![0; 256];
let frame = Frame::new_datagram(2134, &long_hostname, 1234, &[1, 2, 3, 4]);
let _ = frame.finalize();
}
#[test]
fn test_finalized_frame() {
const COMMON_OVERHEAD_SIZE: usize = size_of::<u8>() + size_of::<u32>();
crate::tests::setup_logging();
let frame = Frame::new_connect(&[0x01, 0x02, 0x03], 5678, 1234, 128);
let payload_len = frame.payload.len();
let finalized = frame.finalize();
assert_eq!(finalized.0.len(), COMMON_OVERHEAD_SIZE + payload_len);
assert_eq!(finalized.opcode().unwrap(), OpCode::Connect);
let decoded = Frame::try_from(finalized).unwrap();
assert_eq!(frame, decoded);
let frame = Frame::new_datagram(2134, &[1, 2, 3, 4], 1234, &[1, 2, 3, 4]);
let payload_len = frame.payload.len();
let finalized = frame.finalize();
assert_eq!(finalized.0.len(), COMMON_OVERHEAD_SIZE + payload_len);
assert_eq!(finalized.opcode().unwrap(), OpCode::Datagram);
let decoded = Frame::try_from(finalized).unwrap();
assert_eq!(frame, decoded);
}
}