use std::time::{Duration, SystemTime, UNIX_EPOCH};
use byteorder::{BigEndian, ByteOrder};
use serde::{Deserialize, Serialize};
use crate::error::{ProtocolError, Result};
use crate::types::{SequenceNumber, SessionId};
use crate::PROTOCOL_VERSION;
use super::{HEADER_SIZE, MAX_PAYLOAD_SIZE};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[repr(u8)]
pub enum PacketType {
Data = 0,
Control = 1,
Ack = 2,
Nack = 3,
Ping = 4,
Pong = 5,
Handshake = 6,
Close = 7,
Error = 8,
}
impl PacketType {
pub fn from_u8(v: u8) -> Option<Self> {
match v {
0 => Some(Self::Data),
1 => Some(Self::Control),
2 => Some(Self::Ack),
3 => Some(Self::Nack),
4 => Some(Self::Ping),
5 => Some(Self::Pong),
6 => Some(Self::Handshake),
7 => Some(Self::Close),
8 => Some(Self::Error),
_ => None,
}
}
pub fn is_reliable(self) -> bool {
matches!(self, Self::Data | Self::Control | Self::Handshake | Self::Close)
}
pub fn is_control(self) -> bool {
!matches!(self, Self::Data)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct PacketFlags(u16);
impl PacketFlags {
pub const NONE: Self = Self(0);
pub const NEED_ACK: u16 = 1 << 0;
pub const RETRANSMIT: u16 = 1 << 1;
pub const FRAGMENT: u16 = 1 << 2;
pub const LAST_FRAGMENT: u16 = 1 << 3;
pub const ENCRYPTED: u16 = 1 << 4;
pub const COMPRESSED: u16 = 1 << 5;
pub const PRIORITY: u16 = 1 << 6;
pub const PROBE: u16 = 1 << 7;
pub fn new(bits: u16) -> Self {
Self(bits)
}
pub fn has(self, flag: u16) -> bool {
self.0 & flag != 0
}
pub fn set(&mut self, flag: u16) {
self.0 |= flag;
}
pub fn clear(&mut self, flag: u16) {
self.0 &= !flag;
}
pub fn bits(self) -> u16 {
self.0
}
}
impl Serialize for PacketFlags {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u16(self.0)
}
}
impl<'de> Deserialize<'de> for PacketFlags {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(Self(u16::deserialize(deserializer)?))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PacketHeader {
pub version: u8,
pub packet_type: PacketType,
pub flags: PacketFlags,
pub sequence: SequenceNumber,
pub timestamp: u64,
pub session_id: SessionId,
pub uplink_id: u16,
pub payload_len: u16,
pub checksum: u32,
}
impl PacketHeader {
pub fn new(
packet_type: PacketType,
sequence: SequenceNumber,
session_id: SessionId,
uplink_id: u16,
payload_len: usize,
) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or(Duration::ZERO)
.as_micros() as u64;
let mut header = Self {
version: PROTOCOL_VERSION,
packet_type,
flags: PacketFlags::NONE,
sequence,
timestamp,
session_id,
uplink_id,
payload_len: payload_len as u16,
checksum: 0,
};
header.checksum = header.compute_checksum();
header
}
fn compute_checksum(&self) -> u32 {
let mut buf = [0u8; HEADER_SIZE - 4]; self.encode_without_checksum(&mut buf);
super::checksum(&buf)
}
fn encode_without_checksum(&self, buf: &mut [u8]) {
buf[0] = self.version;
buf[1] = self.packet_type as u8;
BigEndian::write_u16(&mut buf[2..4], self.flags.bits());
BigEndian::write_u64(&mut buf[4..12], self.sequence.0);
BigEndian::write_u64(&mut buf[12..20], self.timestamp);
buf[20..52].copy_from_slice(self.session_id.as_bytes());
BigEndian::write_u16(&mut buf[52..54], self.uplink_id);
BigEndian::write_u16(&mut buf[54..56], self.payload_len);
}
pub fn encode(&self, buf: &mut [u8]) -> Result<()> {
if buf.len() < HEADER_SIZE {
return Err(ProtocolError::MalformedHeader.into());
}
self.encode_without_checksum(&mut buf[..HEADER_SIZE - 4]);
BigEndian::write_u32(&mut buf[56..60], self.checksum);
Ok(())
}
pub fn decode(buf: &[u8]) -> Result<Self> {
if buf.len() < HEADER_SIZE {
return Err(ProtocolError::MalformedHeader.into());
}
let version = buf[0];
if version != PROTOCOL_VERSION {
return Err(ProtocolError::InvalidVersion {
expected: PROTOCOL_VERSION,
got: version,
}
.into());
}
let packet_type = PacketType::from_u8(buf[1])
.ok_or(ProtocolError::InvalidMessageType(buf[1]))?;
let flags = PacketFlags::new(BigEndian::read_u16(&buf[2..4]));
let sequence = SequenceNumber(BigEndian::read_u64(&buf[4..12]));
let timestamp = BigEndian::read_u64(&buf[12..20]);
let mut session_bytes = [0u8; 32];
session_bytes.copy_from_slice(&buf[20..52]);
let session_id = SessionId::new(session_bytes);
let uplink_id = BigEndian::read_u16(&buf[52..54]);
let payload_len = BigEndian::read_u16(&buf[54..56]);
let checksum = BigEndian::read_u32(&buf[56..60]);
let header = Self {
version,
packet_type,
flags,
sequence,
timestamp,
session_id,
uplink_id,
payload_len,
checksum,
};
let computed = header.compute_checksum();
if computed != checksum {
return Err(ProtocolError::ChecksumMismatch.into());
}
Ok(header)
}
pub fn timestamp_duration(&self) -> Duration {
Duration::from_micros(self.timestamp)
}
pub fn one_way_delay(&self) -> Option<Duration> {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.ok()?
.as_micros() as u64;
if now >= self.timestamp {
Some(Duration::from_micros(now - self.timestamp))
} else {
None }
}
}
#[derive(Debug, Clone)]
pub struct Packet {
pub header: PacketHeader,
pub payload: Vec<u8>,
}
impl Packet {
pub fn new(
packet_type: PacketType,
sequence: SequenceNumber,
session_id: SessionId,
uplink_id: u16,
payload: Vec<u8>,
) -> Result<Self> {
if payload.len() > MAX_PAYLOAD_SIZE {
return Err(ProtocolError::PayloadTooLarge {
size: payload.len(),
max: MAX_PAYLOAD_SIZE,
}
.into());
}
Ok(Self {
header: PacketHeader::new(packet_type, sequence, session_id, uplink_id, payload.len()),
payload,
})
}
pub fn data(
sequence: SequenceNumber,
session_id: SessionId,
uplink_id: u16,
payload: Vec<u8>,
) -> Result<Self> {
Self::new(PacketType::Data, sequence, session_id, uplink_id, payload)
}
pub fn ack(
sequence: SequenceNumber,
session_id: SessionId,
uplink_id: u16,
acked_sequences: &[u64],
) -> Result<Self> {
let payload = bincode::serialize(acked_sequences)
.map_err(|e| ProtocolError::Serialization(e.to_string()))?;
Self::new(PacketType::Ack, sequence, session_id, uplink_id, payload)
}
pub fn ping(sequence: SequenceNumber, session_id: SessionId, uplink_id: u16) -> Result<Self> {
Self::new(PacketType::Ping, sequence, session_id, uplink_id, vec![])
}
pub fn pong(
sequence: SequenceNumber,
session_id: SessionId,
uplink_id: u16,
ping_timestamp: u64,
) -> Result<Self> {
let payload = ping_timestamp.to_be_bytes().to_vec();
Self::new(PacketType::Pong, sequence, session_id, uplink_id, payload)
}
pub fn encode(&self) -> Result<Vec<u8>> {
let mut buf = vec![0u8; HEADER_SIZE + self.payload.len()];
self.header.encode(&mut buf)?;
buf[HEADER_SIZE..].copy_from_slice(&self.payload);
Ok(buf)
}
pub fn decode(buf: &[u8]) -> Result<Self> {
if buf.len() < HEADER_SIZE {
return Err(ProtocolError::MalformedHeader.into());
}
let header = PacketHeader::decode(buf)?;
let expected_len = HEADER_SIZE + header.payload_len as usize;
if buf.len() < expected_len {
return Err(ProtocolError::MalformedHeader.into());
}
let payload = buf[HEADER_SIZE..expected_len].to_vec();
Ok(Self { header, payload })
}
pub fn size(&self) -> usize {
HEADER_SIZE + self.payload.len()
}
pub fn set_flag(&mut self, flag: u16) {
self.header.flags.set(flag);
self.header.checksum = self.header.compute_checksum();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_header_encode_decode() {
let session_id = SessionId::generate();
let header = PacketHeader::new(
PacketType::Data,
SequenceNumber(42),
session_id,
1,
100,
);
let mut buf = [0u8; HEADER_SIZE];
header.encode(&mut buf).unwrap();
let decoded = PacketHeader::decode(&buf).unwrap();
assert_eq!(decoded.version, header.version);
assert_eq!(decoded.packet_type, header.packet_type);
assert_eq!(decoded.sequence.0, header.sequence.0);
assert_eq!(decoded.uplink_id, header.uplink_id);
assert_eq!(decoded.payload_len, header.payload_len);
}
#[test]
fn test_packet_encode_decode() {
let session_id = SessionId::generate();
let payload = b"hello world".to_vec();
let packet = Packet::data(
SequenceNumber(1),
session_id,
0,
payload.clone(),
).unwrap();
let encoded = packet.encode().unwrap();
let decoded = Packet::decode(&encoded).unwrap();
assert_eq!(decoded.payload, payload);
assert_eq!(decoded.header.sequence.0, 1);
}
#[test]
fn test_checksum_validation() {
let session_id = SessionId::generate();
let packet = Packet::data(
SequenceNumber(1),
session_id,
0,
b"test".to_vec(),
).unwrap();
let mut encoded = packet.encode().unwrap();
encoded[10] ^= 0xff;
assert!(Packet::decode(&encoded).is_err());
}
}