use borsh::{BorshDeserialize, BorshSerialize};
use std::fmt;
#[derive(Clone, Copy, PartialEq, Eq, Hash)]
pub struct SessionId(pub [u8; 32]);
impl SessionId {
pub fn random() -> Self {
let mut bytes = [0u8; 32];
if getrandom::getrandom(&mut bytes).is_err() {
rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut bytes);
}
Self(bytes)
}
pub fn from_bytes(bytes: [u8; 32]) -> Self {
Self(bytes)
}
pub fn as_bytes(&self) -> &[u8; 32] {
&self.0
}
}
impl fmt::Debug for SessionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SessionId({}...)", hex::encode(&self.0[..8]))
}
}
impl fmt::Display for SessionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}...", hex::encode(&self.0[..8]))
}
}
pub type StreamId = u16;
pub type SequenceNumber = u32;
pub const WIRE_VERSION: u8 = 2;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WireError {
Truncated,
}
impl fmt::Display for WireError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
WireError::Truncated => write!(f, "truncated packet"),
}
}
}
impl std::error::Error for WireError {}
#[derive(Clone, Copy, PartialEq, Eq, Default)]
pub struct PacketFlags(pub u16);
impl PacketFlags {
pub const RELIABLE: u16 = 0x0001;
pub const ACK: u16 = 0x0002;
pub const FIN: u16 = 0x0004;
pub const UNRELIABLE: u16 = 0x0008;
pub const PRIORITY: u16 = 0x0010;
pub const ENCRYPTED: u16 = 0x0020;
pub const COMPRESSED: u16 = 0x0040;
pub const CONTROL: u16 = 0x0080;
pub const REKEY: u16 = 0x0100;
pub const PATH_VALIDATION: u16 = 0x0200;
pub const COALESCED: u16 = 0x0400;
pub const WINDOW_UPDATE: u16 = 0x0800;
pub const fn empty() -> Self {
Self(0)
}
pub const fn new(bits: u16) -> Self {
Self(bits)
}
#[inline]
pub const fn contains(&self, flag: u16) -> bool {
(self.0 & flag) == flag
}
#[inline]
pub fn set(&mut self, flag: u16) {
self.0 |= flag;
}
#[inline]
pub fn clear(&mut self, flag: u16) {
self.0 &= !flag;
}
#[inline]
pub const fn is_reliable(&self) -> bool {
self.contains(Self::RELIABLE)
}
#[inline]
pub const fn is_ack(&self) -> bool {
self.contains(Self::ACK)
}
#[inline]
pub const fn is_fin(&self) -> bool {
self.contains(Self::FIN)
}
#[inline]
pub const fn is_control(&self) -> bool {
self.contains(Self::CONTROL)
}
#[inline]
pub const fn is_rekey(&self) -> bool {
self.contains(Self::REKEY)
}
}
impl fmt::Debug for PacketFlags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut flags = Vec::new();
if self.contains(Self::RELIABLE) {
flags.push("RELIABLE");
}
if self.contains(Self::ACK) {
flags.push("ACK");
}
if self.contains(Self::FIN) {
flags.push("FIN");
}
if self.contains(Self::UNRELIABLE) {
flags.push("UNRELIABLE");
}
if self.contains(Self::PRIORITY) {
flags.push("PRIORITY");
}
if self.contains(Self::ENCRYPTED) {
flags.push("ENCRYPTED");
}
if self.contains(Self::COMPRESSED) {
flags.push("COMPRESSED");
}
if self.contains(Self::CONTROL) {
flags.push("CONTROL");
}
if self.contains(Self::REKEY) {
flags.push("REKEY");
}
if self.contains(Self::PATH_VALIDATION) {
flags.push("PATH_VALIDATION");
}
if self.contains(Self::COALESCED) {
flags.push("COALESCED");
}
if self.contains(Self::WINDOW_UPDATE) {
flags.push("WINDOW_UPDATE");
}
write!(f, "PacketFlags({})", flags.join("|"))
}
}
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(C)]
pub struct PacketHeader {
pub version: u8,
pub session_id: SessionId,
pub stream_id: StreamId,
pub sequence: SequenceNumber,
pub flags: PacketFlags,
pub ack_delay: u16,
pub epoch: u8,
pub path_id: u8,
}
impl PacketHeader {
pub const SIZE: usize = 45;
pub fn new(
session_id: SessionId,
stream_id: StreamId,
sequence: SequenceNumber,
flags: PacketFlags,
) -> Self {
Self {
version: WIRE_VERSION,
session_id,
stream_id,
sequence,
flags,
ack_delay: 0,
epoch: 0,
path_id: 0,
}
}
pub fn with_epoch(mut self, epoch: u8) -> Self {
self.epoch = epoch;
self
}
pub fn with_path_id(mut self, path_id: u8) -> Self {
self.path_id = path_id;
self
}
pub fn to_wire(&self) -> [u8; Self::SIZE] {
let mut b = [0u8; Self::SIZE];
b[0] = self.version;
b[1..33].copy_from_slice(&self.session_id.0);
b[33..35].copy_from_slice(&self.stream_id.to_be_bytes());
b[35..39].copy_from_slice(&self.sequence.to_be_bytes());
b[39..41].copy_from_slice(&self.flags.0.to_be_bytes());
b[41..43].copy_from_slice(&self.ack_delay.to_be_bytes());
b[43] = self.epoch;
b[44] = self.path_id;
b
}
pub fn from_wire(bytes: &[u8]) -> Result<Self, WireError> {
if bytes.len() < Self::SIZE {
return Err(WireError::Truncated);
}
let mut session_id = [0u8; 32];
session_id.copy_from_slice(&bytes[1..33]);
Ok(Self {
version: bytes[0],
session_id: SessionId(session_id),
stream_id: u16::from_be_bytes([bytes[33], bytes[34]]),
sequence: u32::from_be_bytes([bytes[35], bytes[36], bytes[37], bytes[38]]),
flags: PacketFlags(u16::from_be_bytes([bytes[39], bytes[40]])),
ack_delay: u16::from_be_bytes([bytes[41], bytes[42]]),
epoch: bytes[43],
path_id: bytes[44],
})
}
}
impl fmt::Debug for PacketHeader {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PacketHeader")
.field("version", &self.version)
.field("session", &self.session_id)
.field("stream", &self.stream_id)
.field("seq", &self.sequence)
.field("flags", &self.flags)
.field("epoch", &self.epoch)
.field("path_id", &self.path_id)
.finish()
}
}
fn read_length_prefixed(bytes: &[u8], pos: &mut usize) -> Result<Vec<u8>, WireError> {
let start = *pos;
let len_end = start.checked_add(4).ok_or(WireError::Truncated)?;
if len_end > bytes.len() {
return Err(WireError::Truncated);
}
let len = u32::from_be_bytes([
bytes[start],
bytes[start + 1],
bytes[start + 2],
bytes[start + 3],
]) as usize;
let data_end = len_end.checked_add(len).ok_or(WireError::Truncated)?;
if data_end > bytes.len() {
return Err(WireError::Truncated);
}
*pos = data_end;
Ok(bytes[len_end..data_end].to_vec())
}
#[derive(Clone, PartialEq, Eq)]
pub struct PhantomPacket {
pub header: PacketHeader,
pub payload: Vec<u8>,
pub extensions: Vec<u8>,
}
impl PhantomPacket {
pub fn new(header: PacketHeader, payload: Vec<u8>) -> Self {
Self {
header,
payload,
extensions: Vec::new(),
}
}
pub fn ack(session_id: SessionId, stream_id: StreamId, ack_sequence: SequenceNumber) -> Self {
Self {
header: PacketHeader::new(
session_id,
stream_id,
ack_sequence,
PacketFlags::new(PacketFlags::ACK),
),
payload: Vec::new(),
extensions: Vec::new(),
}
}
pub fn wire_size(&self) -> usize {
PacketHeader::SIZE + 8 + self.payload.len() + self.extensions.len()
}
pub fn to_wire(&self) -> Vec<u8> {
let mut b = Vec::with_capacity(self.wire_size());
b.extend_from_slice(&self.header.to_wire());
b.extend_from_slice(&(self.payload.len() as u32).to_be_bytes());
b.extend_from_slice(&self.payload);
b.extend_from_slice(&(self.extensions.len() as u32).to_be_bytes());
b.extend_from_slice(&self.extensions);
b
}
pub fn from_wire(bytes: &[u8]) -> Result<Self, WireError> {
let header = PacketHeader::from_wire(bytes)?;
let mut pos = PacketHeader::SIZE;
let payload = read_length_prefixed(bytes, &mut pos)?;
let extensions = read_length_prefixed(bytes, &mut pos)?;
Ok(Self {
header,
payload,
extensions,
})
}
}
impl fmt::Debug for PhantomPacket {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PhantomPacket")
.field("header", &self.header)
.field("payload_len", &self.payload.len())
.field("extensions_len", &self.extensions.len())
.finish()
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
#[borsh(use_discriminant = true)]
#[repr(u8)]
pub enum ControlMessage {
Hello = 0,
HelloAck = 1,
Resume = 2,
ResumeAck = 3,
Migrate = 4,
MigrateAck = 5,
Close = 6,
CloseAck = 7,
Ping = 8,
Pong = 9,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, BorshSerialize, BorshDeserialize)]
pub enum LegType {
Kcp,
Tcp,
FakeTls,
}
impl LegType {
pub fn is_reliable(&self) -> bool {
matches!(self, LegType::Kcp | LegType::Tcp | LegType::FakeTls)
}
pub fn is_obfuscated(&self) -> bool {
matches!(self, LegType::FakeTls)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, BorshSerialize, BorshDeserialize)]
pub enum SchedulerMode {
LowLatency,
HighThroughput,
Reliability,
Stealth,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_session_id_random() {
let id1 = SessionId::random();
let id2 = SessionId::random();
assert_ne!(id1, id2);
}
#[test]
fn test_packet_flags() {
let mut flags = PacketFlags::empty();
assert!(!flags.is_reliable());
flags.set(PacketFlags::RELIABLE);
assert!(flags.is_reliable());
flags.set(PacketFlags::ENCRYPTED);
assert!(flags.contains(PacketFlags::RELIABLE));
assert!(flags.contains(PacketFlags::ENCRYPTED));
flags.clear(PacketFlags::RELIABLE);
assert!(!flags.is_reliable());
assert!(flags.contains(PacketFlags::ENCRYPTED));
}
#[test]
fn flags_bit_assignments() {
assert_eq!(PacketFlags::RELIABLE, 0x0001);
assert_eq!(PacketFlags::ENCRYPTED, 0x0020);
assert_eq!(PacketFlags::CONTROL, 0x0080);
assert_eq!(PacketFlags::REKEY, 0x0100);
assert_eq!(PacketFlags::PATH_VALIDATION, 0x0200);
assert_eq!(PacketFlags::COALESCED, 0x0400);
assert_eq!(PacketFlags::WINDOW_UPDATE, 0x0800);
}
#[test]
fn flags_contains_set_clear() {
let mut f = PacketFlags::empty();
assert!(!f.is_reliable());
assert!(!f.is_rekey());
f.set(PacketFlags::RELIABLE | PacketFlags::REKEY);
assert!(f.is_reliable());
assert!(f.is_rekey());
f.clear(PacketFlags::REKEY);
assert!(f.is_reliable());
assert!(!f.is_rekey());
}
#[test]
fn packet_header_serializes_to_45_bytes() {
assert_eq!(PacketHeader::SIZE, 45);
let header = PacketHeader::new(
SessionId::from_bytes([0u8; 32]),
1,
1,
PacketFlags::new(PacketFlags::ENCRYPTED),
);
let bytes = header.to_wire();
assert_eq!(
bytes.len(),
PacketHeader::SIZE,
"the serialised header (= AEAD AAD) must be exactly 45 bytes"
);
assert_eq!(bytes[0], WIRE_VERSION);
assert_eq!(PacketHeader::from_wire(&bytes).expect("roundtrip"), header);
}
#[test]
fn test_phantom_packet_ack() {
let session_id = SessionId::random();
let ack = PhantomPacket::ack(session_id, 5, 100);
assert!(ack.header.flags.is_ack());
assert_eq!(ack.header.stream_id, 5);
assert_eq!(ack.header.sequence, 100);
assert!(ack.payload.is_empty());
assert!(ack.extensions.is_empty());
}
#[test]
fn packet_roundtrip_preserves_fields() {
let session_id = SessionId::random();
let header = PacketHeader::new(
session_id,
7,
42,
PacketFlags::new(PacketFlags::ENCRYPTED | PacketFlags::RELIABLE),
)
.with_epoch(3)
.with_path_id(1);
let packet = PhantomPacket::new(header, vec![0xCA, 0xFE, 0xBA, 0xBE]);
let bytes = packet.to_wire();
let decoded = PhantomPacket::from_wire(&bytes).expect("roundtrip");
assert_eq!(decoded, packet);
assert_eq!(decoded.header.version, WIRE_VERSION);
assert_eq!(decoded.header.stream_id, 7);
assert_eq!(decoded.header.sequence, 42);
assert_eq!(decoded.header.epoch, 3);
assert_eq!(decoded.header.path_id, 1);
assert!(decoded.header.flags.is_reliable());
assert!(decoded.header.flags.contains(PacketFlags::ENCRYPTED));
assert_eq!(decoded.payload, vec![0xCA, 0xFE, 0xBA, 0xBE]);
}
#[test]
fn extensions_preserved_on_roundtrip() {
let session_id = SessionId::random();
let mut packet = PhantomPacket::new(
PacketHeader::new(
session_id,
1,
1,
PacketFlags::new(PacketFlags::CONTROL | PacketFlags::RELIABLE),
),
vec![1, 2, 3],
);
packet.extensions = vec![0xFF, 0x01, 0x00, 0x04, b't', b'e', b's', b't'];
let bytes = packet.to_wire();
let deser = PhantomPacket::from_wire(&bytes).expect("deserialize failed");
assert_eq!(
deser.extensions,
vec![0xFF, 0x01, 0x00, 0x04, b't', b'e', b's', b't']
);
}
}