#![allow(clippy::similar_names)]
use bytes::{Buf, BufMut, Bytes, BytesMut};
use std::{fmt::Debug, num::TryFromIntError};
use thiserror::Error;
use tracing::warn;
use tungstenite::Message;
#[derive(Debug, Error)]
pub enum Error {
#[error("invalid frame type: {0}")]
InvalidFrameType(u8),
#[error("invalid stream flag: {0}")]
InvalidStreamFlag(u8),
#[error("host longer than 255 octets")]
HostLength(#[from] TryFromIntError),
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
pub enum StreamFlag {
Syn = 0,
Ack = 2,
Rst = 3,
Fin = 4,
Psh = 5,
}
#[derive(Clone, PartialEq, Eq)]
#[repr(C)]
#[allow(clippy::module_name_repetitions)]
pub struct StreamFrame {
pub sport: u16,
pub dport: u16,
pub flag: StreamFlag,
pub data: Bytes,
}
impl Debug for StreamFrame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("StreamFrame")
.field("sport", &self.sport)
.field("dport", &self.dport)
.field("flag", &self.flag)
.field("data.len", &self.data.len())
.finish()
}
}
impl StreamFrame {
pub fn new_syn(dest_host: &[u8], dest_port: u16, sport: u16) -> Result<Self, Error> {
let host_len = dest_host.len();
let mut syn_payload = BytesMut::with_capacity(
std::mem::size_of::<u8>() + std::mem::size_of::<u16>() + host_len,
);
syn_payload.put_u8(u8::try_from(host_len)?);
syn_payload.extend_from_slice(dest_host);
syn_payload.put_u16(dest_port);
Ok(Self {
sport,
dport: 0,
flag: StreamFlag::Syn,
data: syn_payload.freeze(),
})
}
#[must_use]
pub const fn new_ack(sport: u16, dport: u16) -> Self {
Self {
sport,
dport,
flag: StreamFlag::Ack,
data: Bytes::new(),
}
}
#[must_use]
pub const fn new_rst(sport: u16, dport: u16) -> Self {
Self {
sport,
dport,
flag: StreamFlag::Rst,
data: Bytes::new(),
}
}
#[must_use]
pub const fn new_fin(sport: u16, dport: u16) -> Self {
Self {
sport,
dport,
flag: StreamFlag::Fin,
data: Bytes::new(),
}
}
#[must_use]
pub const fn new_psh(sport: u16, dport: u16, data: Bytes) -> Self {
Self {
sport,
dport,
flag: StreamFlag::Psh,
data,
}
}
}
#[derive(Clone, PartialEq, Eq)]
#[repr(C)]
#[allow(clippy::module_name_repetitions)]
pub struct DatagramFrame {
pub host: Bytes,
pub port: u16,
pub sid: u32,
pub data: Bytes,
}
impl Debug for DatagramFrame {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DatagramFrame")
.field("host", &self.host)
.field("port", &self.port)
.field("sid", &self.sid)
.field("data.len", &self.data.len())
.finish()
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[repr(C)]
pub enum Frame {
Stream(StreamFrame),
Datagram(DatagramFrame),
}
impl TryFrom<Frame> for Vec<u8> {
type Error = Error;
#[tracing::instrument(level = "trace")]
fn try_from(frame: Frame) -> Result<Self, Self::Error> {
match frame {
Frame::Stream(frame) => {
let size = 1
+ std::mem::size_of::<u16>()
+ std::mem::size_of::<u16>()
+ std::mem::size_of::<StreamFlag>()
+ std::mem::size_of::<u32>()
+ frame.data.len();
let mut encoded = Self::with_capacity(size);
encoded.put_u8(1);
encoded.put_u16(frame.sport);
encoded.put_u16(frame.dport);
encoded.put_u8(frame.flag as u8);
encoded.extend_from_slice(&frame.data);
Ok(encoded)
}
Frame::Datagram(frame) => {
let size = 1
+ frame.host.len()
+ std::mem::size_of::<u16>()
+ std::mem::size_of::<u32>()
+ frame.data.len();
let mut encoded = Self::with_capacity(size);
encoded.put_u8(3);
encoded.put_u8(u8::try_from(frame.host.len())?);
encoded.extend_from_slice(&frame.host);
encoded.put_u16(frame.port);
encoded.put_u32(frame.sid);
encoded.extend_from_slice(&frame.data);
Ok(encoded)
}
}
}
}
impl TryFrom<Frame> for Message {
type Error = <Vec<u8> as TryFrom<Frame>>::Error;
fn try_from(frame: Frame) -> Result<Self, Self::Error> {
let bytes = Vec::try_from(frame)?;
Ok(Self::Binary(bytes))
}
}
impl TryFrom<Bytes> for StreamFrame {
type Error = Error;
fn try_from(mut data: Bytes) -> Result<Self, Self::Error> {
let sport = data.get_u16();
let dport = data.get_u16();
let flag = match data.get_u8() {
0 => StreamFlag::Syn,
2 => StreamFlag::Ack,
3 => StreamFlag::Rst,
4 => StreamFlag::Fin,
5 => StreamFlag::Psh,
other => return Err(Error::InvalidStreamFlag(other)),
};
Ok(Self {
sport,
dport,
flag,
data,
})
}
}
impl From<Bytes> for DatagramFrame {
fn from(mut data: Bytes) -> Self {
let host_len = data.get_u8();
let host = data.split_to(host_len as usize);
let port = data.get_u16();
let sid = data.get_u32();
Self {
host,
port,
sid,
data,
}
}
}
impl TryFrom<Vec<u8>> for Frame {
type Error = Error;
#[tracing::instrument(skip_all, level = "trace")]
fn try_from(data: Vec<u8>) -> Result<Self, Self::Error> {
let mut data = Bytes::from(data);
let frame_type = data.get_u8();
match frame_type {
1 => Ok(Self::Stream(StreamFrame::try_from(data)?)),
3 => Ok(Self::Datagram(DatagramFrame::from(data))),
other => Err(Error::InvalidFrameType(other)),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_stream_frame() {
let frame = Frame::Stream(StreamFrame::new_syn(&[], 5678, 1234).unwrap());
assert_eq!(
frame,
Frame::Stream(StreamFrame {
sport: 1234,
dport: 0,
flag: StreamFlag::Syn,
data: Bytes::from_static(&[0x00, 0x16, 0x2e]),
})
);
let bytes = Vec::try_from(frame.clone()).unwrap();
let decoded = Frame::try_from(bytes).unwrap();
assert_eq!(frame, decoded);
}
#[test]
fn test_datagram_frame() {
let frame = Frame::Datagram(DatagramFrame {
host: Bytes::from_static(&[1, 2, 3, 4]),
port: 1234,
sid: 5678,
data: Bytes::from_static(&[1, 2, 3, 4]),
});
let bytes = Vec::try_from(frame.clone()).unwrap();
let decoded = Frame::try_from(bytes).unwrap();
assert_eq!(frame, decoded);
}
#[test]
fn test_frame_repr() {
let frame = Frame::Stream(StreamFrame::new_rst(1234, 5678));
let bytes = Vec::try_from(frame).unwrap();
assert_eq!(
bytes,
vec![
0x01, 0x04, 0xd2, 0x16, 0x2e, 0x03, ]
);
let frame = Frame::Stream(StreamFrame::new_psh(
1234,
5678,
Bytes::from_static(&[1, 2, 3, 4]),
));
let bytes = Vec::try_from(frame).unwrap();
assert_eq!(
bytes,
vec![
0x01, 0x04, 0xd2, 0x16, 0x2e, 0x05, 0x01, 0x02, 0x03, 0x04 ]
);
let frame = Frame::Stream(StreamFrame::new_fin(5678, 1234));
let bytes = Vec::try_from(frame).unwrap();
assert_eq!(
bytes,
vec![
0x01, 0x16, 0x2e, 0x04, 0xd2, 0x04 ]
);
let frame = Frame::Datagram(DatagramFrame {
host: Bytes::from_static(&[1, 2, 3, 4]),
port: 1234,
sid: 5678,
data: Bytes::from_static(&[1, 2, 3, 4]),
});
let bytes = Vec::try_from(frame).unwrap();
assert_eq!(
bytes,
vec![
0x03, 0x04, 0x01, 0x02, 0x03, 0x04, 0x04, 0xd2, 0x00, 0x00, 0x16, 0x2e, 0x01, 0x02, 0x03, 0x04 ]
);
}
}