use std::convert::{From, TryFrom, TryInto};
use std::net::SocketAddr;
use std::ops::{Add, AddAssign, Sub, SubAssign};
use crate::connection::ConnectionID;
#[derive(Copy, Clone, Debug, Hash, PartialEq, PartialOrd)]
pub struct PacketCount(pub i64);
impl Sub for PacketCount {
type Output = Self;
fn sub(self, other: Self) -> Self {
PacketCount(self.0 - other.0)
}
}
impl Add for PacketCount {
type Output = Self;
fn add(self, other: Self) -> Self {
PacketCount(self.0 + other.0)
}
}
impl AddAssign for PacketCount {
fn add_assign(&mut self, other: Self) {
*self = PacketCount(self.0 + other.0)
}
}
impl SubAssign for PacketCount {
fn sub_assign(&mut self, other: Self) {
*self = PacketCount(self.0 - other.0)
}
}
#[cfg(test)]
mod packet_count_tests {
use crate::packet;
#[test]
fn packet_count_value() {
let packet_count_1 = packet::PacketCount(2);
assert_eq!(packet_count_1.0, 2);
}
#[test]
fn packet_count_equal() {
let packet_count_1 = packet::PacketCount(2);
let packet_count_2 = packet::PacketCount(2);
assert_eq!(packet_count_1, packet_count_2);
}
#[test]
fn sub_packet_count() {
let packet_count_1 = packet::PacketCount(15);
let packet_count_2 = packet::PacketCount(2);
let packet_count_3 = packet_count_1 - packet_count_2;
assert_eq!(packet_count_3.0, 13);
}
#[test]
fn sub_packet_count_negative_value() {
let packet_count_1 = packet::PacketCount(2);
let packet_count_2 = packet::PacketCount(15);
let packet_count_3 = packet_count_1 - packet_count_2;
assert_eq!(packet_count_3.0, -13);
}
#[test]
fn add_packet_count() {
let packet_count_1 = packet::PacketCount(15);
let packet_count_2 = packet::PacketCount(2);
let packet_count_3 = packet_count_1 + packet_count_2;
assert_eq!(packet_count_3.0, 17);
}
#[test]
fn add_packet_count_negative_value() {
let packet_count_1 = packet::PacketCount(-2);
let packet_count_2 = packet::PacketCount(15);
let packet_count_3 = packet_count_1 + packet_count_2;
assert_eq!(packet_count_3.0, 13);
}
#[test]
fn add_assign_packet_count() {
let mut packet_count_1 = packet::PacketCount(15);
let packet_count_2 = packet::PacketCount(2);
packet_count_1 += packet_count_2;
assert_eq!(packet_count_1.0, 17);
}
#[test]
fn sub_assign_packet_count() {
let mut packet_count_1 = packet::PacketCount(15);
let packet_count_2 = packet::PacketCount(2);
packet_count_1 -= packet_count_2;
assert_eq!(packet_count_1.0, 13);
}
#[test]
fn partial_ord_packet_count() {
let packet_count_1 = packet::PacketCount(15);
let packet_count_2 = packet::PacketCount(2);
assert_ne!(packet_count_1, packet_count_2);
assert!(packet_count_2 < packet_count_1);
assert!(packet_count_1 > packet_count_2);
}
}
#[derive(Copy, Clone, Debug, Hash, PartialEq, PartialOrd)]
pub struct PacketNumber(pub u16);
impl PacketNumber {
pub fn next(self) -> Self {
PacketNumber(self.0.wrapping_add(1))
}
pub fn prev(self) -> Self {
PacketNumber(self.0.wrapping_sub(1))
}
}
impl Add<PacketCount> for PacketNumber {
type Output = PacketNumber;
fn add(self, rhs: PacketCount) -> Self {
let remainder: u16 = u16::try_from(rhs.0.rem_euclid(i64::from(std::u16::MAX))).unwrap();
PacketNumber(self.0.wrapping_add(remainder))
}
}
impl AddAssign<PacketCount> for PacketNumber {
fn add_assign(&mut self, rhs: PacketCount) {
let remainder: u16 = u16::try_from(rhs.0.rem_euclid(i64::from(std::u16::MAX))).unwrap();
*self = PacketNumber(self.0.wrapping_add(remainder))
}
}
impl Sub for PacketNumber {
type Output = PacketCount;
fn sub(self, rhs: Self) -> PacketCount {
PacketCount(i64::from(self.0.wrapping_sub(rhs.0)))
}
}
impl Sub<PacketCount> for PacketNumber {
type Output = PacketNumber;
fn sub(self, rhs: PacketCount) -> Self {
let remainder: u16 = u16::try_from(rhs.0.rem_euclid(i64::from(std::u16::MAX))).unwrap();
PacketNumber(self.0.wrapping_sub(remainder))
}
}
#[cfg(test)]
mod packet_number_tests {
use crate::packet;
#[test]
fn packet_count_value() {
let packet_number_1 = packet::PacketNumber(2);
assert_eq!(packet_number_1.0, 2);
}
#[test]
fn packet_number_equal() {
let packet_number_1 = packet::PacketNumber(2);
let packet_number_2 = packet::PacketNumber(2);
assert_eq!(packet_number_1, packet_number_2);
}
#[test]
fn packet_number_next() {
let packet_number_1 = packet::PacketNumber(2);
let packet_number_2 = packet::PacketNumber(3);
assert_eq!(packet_number_1.next(), packet_number_2);
}
#[test]
fn packet_number_next_overflow() {
let packet_number_1 = packet::PacketNumber(std::u16::MAX);
let packet_number_2 = packet::PacketNumber(std::u16::MIN);
assert_eq!(packet_number_1.next(), packet_number_2);
}
#[test]
fn packet_number_prev() {
let packet_number_1 = packet::PacketNumber(2);
let packet_number_2 = packet::PacketNumber(1);
assert_eq!(packet_number_1.prev(), packet_number_2);
}
#[test]
fn packet_number_prev_overflow() {
let packet_number_1 = packet::PacketNumber(std::u16::MIN);
let packet_number_2 = packet::PacketNumber(std::u16::MAX);
assert_eq!(packet_number_1.prev(), packet_number_2);
}
#[test]
fn sub_packet_number_and_packet_count() {
let packet_number_1 = packet::PacketNumber(15);
let packet_count_1 = packet::PacketCount(4);
let packet_number_2: packet::PacketNumber = packet_number_1 - packet_count_1;
assert_eq!(packet_number_2.0, 11);
}
#[test]
fn sub_packet_number_and_packet_count_overflow() {
let packet_number_1 = packet::PacketNumber(15);
let packet_count_1 = packet::PacketCount(std::i64::MAX);
let packet_number_2 = packet_number_1 - packet_count_1;
assert_eq!(packet_number_2.0, 32784);
}
#[test]
fn add_packet_number_and_packet_count() {
let packet_number_1 = packet::PacketNumber(15);
let packet_count_1 = packet::PacketCount(4);
let packet_number_2 = packet_number_1 + packet_count_1;
assert_eq!(packet_number_2.0, 19);
}
#[test]
fn add_packet_number_and_packet_count_overflow() {
let packet_number_1 = packet::PacketNumber(15);
let packet_count_1 = packet::PacketCount(std::i64::MAX);
let packet_number_2 = packet_number_1 + packet_count_1;
assert_eq!(packet_number_2.0, 32782);
}
#[test]
fn add_assign_packet_number_and_packet_count() {
let mut packet_number_1 = packet::PacketNumber(15);
let packet_count_1 = packet::PacketCount(4);
packet_number_1 += packet_count_1;
assert_eq!(packet_number_1.0, 19);
}
#[test]
fn sub_packet_number_and_packet_number() {
let packet_number_1 = packet::PacketNumber(15);
let packet_number_2 = packet::PacketNumber(4);
let packet_count_1: packet::PacketCount = packet_number_1 - packet_number_2;
assert_eq!(packet_count_1.0, 11);
}
#[test]
fn sub_packet_number_and_packet_number_overflow() {
let packet_number_1 = packet::PacketNumber(15);
let packet_number_2 = packet::PacketNumber(17);
let packet_count_1 = packet_number_1 - packet_number_2;
assert_eq!(packet_count_1.0, 65534);
}
#[test]
fn partial_ord_packet_number() {
let packet_number_1 = packet::PacketNumber(15);
let packet_number_2 = packet::PacketNumber(2);
assert_ne!(packet_number_1, packet_number_2);
assert!(packet_number_2 < packet_number_1);
assert!(packet_number_1 > packet_number_2);
}
}
#[derive(Copy, Clone, Debug)]
pub enum PacketType {
Data,
Fin,
State,
Reset,
Syn,
Unknown(u8),
}
impl PacketType {
pub fn raw_value(self) -> u8 {
match self {
PacketType::Data => 0,
PacketType::Fin => 1,
PacketType::State => 2,
PacketType::Reset => 3,
PacketType::Syn => 4,
PacketType::Unknown(value) => value,
}
}
pub fn new(value: u8) -> PacketType {
match value {
0 => PacketType::Data,
1 => PacketType::Fin,
2 => PacketType::State,
3 => PacketType::Reset,
4 => PacketType::Syn,
_ => PacketType::Unknown(value),
}
}
}
#[derive(Copy, Clone, Debug)]
pub enum PacketExtensionType {
None,
SelectiveAck,
Unknown(u8),
}
impl PacketExtensionType {
pub fn raw_value(self) -> u8 {
match self {
PacketExtensionType::None => 0,
PacketExtensionType::SelectiveAck => 1,
PacketExtensionType::Unknown(value) => value,
}
}
pub fn new(value: u8) -> Self {
match value {
0 => PacketExtensionType::None,
1 => PacketExtensionType::SelectiveAck,
_ => PacketExtensionType::Unknown(value),
}
}
}
#[derive(PartialEq)]
pub struct PacketHeader {
pub version_type: u8,
pub extension_type: u8,
pub connection_id: ConnectionID,
pub tv_usec: u32,
pub reply_micro: u32,
pub window_size: u32,
pub seq_nr: PacketNumber,
pub ack_nr: PacketNumber,
}
impl PacketHeader {
pub fn version(&self) -> u8 {
self.version_type & 0x0f
}
pub fn set_version(&mut self, v: u8) {
self.version_type = (v & 0x0f) | (self.version_type & 0xf0)
}
pub fn packet_type(&self) -> PacketType {
PacketType::new(self.version_type >> 4)
}
pub fn set_packet_type(&mut self, v: PacketType) {
self.version_type = (self.version_type & 0x0f) | (v.raw_value() << 4)
}
pub fn extension_type(&self) -> PacketExtensionType {
PacketExtensionType::new(self.extension_type)
}
pub fn set_extension_type(&mut self, v: PacketExtensionType) {
self.extension_type = v.raw_value()
}
pub fn into_bytes(self) -> [u8; 20] {
let mut bytes = [0u8; 20];
bytes[0] = self.version_type;
bytes[1] = self.extension_type;
let connection_id_bytes = self.connection_id.0.to_be_bytes();
bytes[2] = connection_id_bytes[0];
bytes[3] = connection_id_bytes[1];
let tv_usec_bytes = self.tv_usec.to_be_bytes();
bytes[4] = tv_usec_bytes[0];
bytes[5] = tv_usec_bytes[1];
bytes[6] = tv_usec_bytes[2];
bytes[7] = tv_usec_bytes[3];
let reply_micro_bytes = self.reply_micro.to_be_bytes();
bytes[8] = reply_micro_bytes[0];
bytes[9] = reply_micro_bytes[1];
bytes[10] = reply_micro_bytes[2];
bytes[11] = reply_micro_bytes[3];
let window_size_bytes = self.window_size.to_be_bytes();
bytes[12] = window_size_bytes[0];
bytes[13] = window_size_bytes[1];
bytes[14] = window_size_bytes[2];
bytes[15] = window_size_bytes[3];
let seq_nr_bytes = self.seq_nr.0.to_be_bytes();
bytes[16] = seq_nr_bytes[0];
bytes[17] = seq_nr_bytes[1];
let ack_nr_bytes = self.ack_nr.0.to_be_bytes();
bytes[18] = ack_nr_bytes[0];
bytes[19] = ack_nr_bytes[1];
bytes
}
}
impl TryFrom<&[u8]> for PacketHeader {
type Error = &'static str;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() < 20 {
return Err("not enough bytes.");
}
let (version_type_byte, rest) = bytes.split_at(std::mem::size_of::<u8>());
let (extension_type_byte, rest) = rest.split_at(std::mem::size_of::<u8>());
let (connection_id_bytes, rest) = rest.split_at(std::mem::size_of::<u16>());
let connection_id = ConnectionID(match connection_id_bytes.try_into() {
Ok(bytes) => u16::from_be_bytes(bytes),
Err(_) => return Err("could not split bytes."),
});
let (tv_usec_bytes, rest) = rest.split_at(std::mem::size_of::<u32>());
let tv_usec = match tv_usec_bytes.try_into() {
Ok(bytes) => u32::from_be_bytes(bytes),
Err(_) => return Err("could not split bytes."),
};
let (reply_micro_bytes, rest) = rest.split_at(std::mem::size_of::<u32>());
let reply_micro = match reply_micro_bytes.try_into() {
Ok(bytes) => u32::from_be_bytes(bytes),
Err(_) => return Err("could not split bytes."),
};
let (window_size_bytes, rest) = rest.split_at(std::mem::size_of::<u32>());
let window_size = match window_size_bytes.try_into() {
Ok(bytes) => u32::from_be_bytes(bytes),
Err(_) => return Err("could not split bytes."),
};
let (seq_nr_bytes, rest) = rest.split_at(std::mem::size_of::<u16>());
let seq_nr = PacketNumber(match seq_nr_bytes.try_into() {
Ok(bytes) => u16::from_be_bytes(bytes),
Err(_) => return Err("could not split bytes."),
});
let (ack_nr_bytes, _) = rest.split_at(std::mem::size_of::<u16>());
let ack_nr = PacketNumber(match ack_nr_bytes.try_into() {
Ok(bytes) => u16::from_be_bytes(bytes),
Err(_) => return Err("could not split bytes."),
});
Ok(PacketHeader {
version_type: version_type_byte[0],
extension_type: extension_type_byte[0],
connection_id,
tv_usec,
reply_micro,
window_size,
seq_nr,
ack_nr,
})
}
}
impl Into<[u8; 20]> for PacketHeader {
fn into(self) -> [u8; 20] {
self.into_bytes()
}
}
pub struct Packet {
pub header: PacketHeader,
pub raw_content: Option<Vec<u8>>,
pub remote_address: SocketAddr,
}
const MAX_SELECTIVE_ACK_DATA_BYTE_LEN: usize = 512;
impl Packet {
pub fn len(&self) -> usize {
let raw_content_len = match self.raw_content {
Some(ref raw_content) => raw_content.len(),
None => 0,
};
std::mem::size_of::<PacketHeader>() + raw_content_len
}
pub fn is_empty(&self) -> bool {
match self.content() {
Some(ref content) => content.len() == 0,
None => true,
}
}
pub fn selective_ack_data(&self) -> Option<&[u8]> {
let mut extension_type = self.header.extension_type;
let raw_content = match self.raw_content {
Some(ref raw_content) => raw_content,
None => return None,
};
let mut data_offset = 0;
let raw_content_len = raw_content.len();
while extension_type != 0 {
data_offset += 2;
if data_offset > raw_content_len {
return None;
}
let offset_byte = usize::from(raw_content[data_offset - 1]);
if data_offset + offset_byte > raw_content_len {
return None;
}
if extension_type == 1 {
if offset_byte > MAX_SELECTIVE_ACK_DATA_BYTE_LEN {
return None;
}
return Some(&raw_content[data_offset..(data_offset + offset_byte)]);
}
let extension_byte = raw_content[data_offset - 2];
extension_type = extension_byte;
data_offset += offset_byte;
}
None
}
pub fn content(&self) -> Option<&[u8]> {
let raw_content = match self.raw_content {
Some(ref raw_content) => raw_content,
None => return None,
};
if raw_content.is_empty() {
return None;
}
let raw_content_len = raw_content.len();
let mut extension_type = self.header.extension_type;
let mut data_offset = 0;
while extension_type != 0 {
data_offset += 2;
if data_offset > raw_content_len {
return None;
}
let extension_byte = raw_content[data_offset - 2];
extension_type = extension_byte;
let offset_byte = usize::from(raw_content[data_offset - 1]);
data_offset += offset_byte;
if data_offset >= raw_content_len {
return None;
}
}
Some(&raw_content[data_offset..raw_content_len])
}
}