use std::fmt;
use std::net::SocketAddr;
use thiserror::Error;
use crate::transport::TransportAddr;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ConnectionId(pub u16);
impl ConnectionId {
pub const fn new(value: u16) -> Self {
Self(value)
}
pub const fn value(self) -> u16 {
self.0
}
pub const fn to_bytes(self) -> [u8; 2] {
self.0.to_be_bytes()
}
pub const fn from_bytes(bytes: [u8; 2]) -> Self {
Self(u16::from_be_bytes(bytes))
}
pub fn random() -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
let seed = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u16;
Self(seed ^ 0x5A5A) }
}
impl fmt::Display for ConnectionId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "CID:{:04X}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct SequenceNumber(pub u8);
impl SequenceNumber {
pub const fn new(value: u8) -> Self {
Self(value)
}
pub const fn value(self) -> u8 {
self.0
}
pub const fn next(self) -> Self {
Self(self.0.wrapping_add(1))
}
pub fn distance_to(self, other: Self) -> i16 {
let diff = other.0.wrapping_sub(self.0) as i8;
diff as i16
}
pub fn is_in_window(self, other: Self, window_size: u8) -> bool {
let dist = self.distance_to(other);
dist >= 0 && dist <= window_size as i16
}
}
impl fmt::Display for SequenceNumber {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SEQ:{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum PacketType {
Syn = 0x01,
Ack = 0x02,
Fin = 0x04,
Reset = 0x08,
Data = 0x10,
Ping = 0x20,
Pong = 0x40,
}
impl PacketType {
pub const fn flag(self) -> u8 {
self as u8
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct PacketFlags(pub u8);
impl PacketFlags {
pub const NONE: Self = Self(0);
pub const SYN: Self = Self(0x01);
pub const ACK: Self = Self(0x02);
pub const FIN: Self = Self(0x04);
pub const RST: Self = Self(0x08);
pub const DATA: Self = Self(0x10);
pub const PING: Self = Self(0x20);
pub const PONG: Self = Self(0x40);
pub const SYN_ACK: Self = Self(0x03);
pub const fn new(value: u8) -> Self {
Self(value)
}
pub const fn value(self) -> u8 {
self.0
}
pub const fn has(self, flag: PacketType) -> bool {
self.0 & (flag as u8) != 0
}
pub const fn is_syn(self) -> bool {
self.0 & 0x01 != 0
}
pub const fn is_ack(self) -> bool {
self.0 & 0x02 != 0
}
pub const fn is_fin(self) -> bool {
self.0 & 0x04 != 0
}
pub const fn is_rst(self) -> bool {
self.0 & 0x08 != 0
}
pub const fn is_data(self) -> bool {
self.0 & 0x10 != 0
}
pub const fn is_ping(self) -> bool {
self.0 & 0x20 != 0
}
pub const fn is_pong(self) -> bool {
self.0 & 0x40 != 0
}
pub const fn with(self, flag: PacketType) -> Self {
Self(self.0 | flag as u8)
}
pub const fn union(self, other: Self) -> Self {
Self(self.0 | other.0)
}
}
impl fmt::Display for PacketFlags {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut flags = Vec::new();
if self.is_syn() {
flags.push("SYN");
}
if self.is_ack() {
flags.push("ACK");
}
if self.is_fin() {
flags.push("FIN");
}
if self.is_rst() {
flags.push("RST");
}
if self.is_data() {
flags.push("DATA");
}
if self.is_ping() {
flags.push("PING");
}
if self.is_pong() {
flags.push("PONG");
}
if flags.is_empty() {
write!(f, "NONE")
} else {
write!(f, "{}", flags.join("|"))
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ConstrainedAddr(TransportAddr);
impl ConstrainedAddr {
pub fn new(addr: TransportAddr) -> Self {
Self(addr)
}
pub fn transport_addr(&self) -> &TransportAddr {
&self.0
}
pub fn into_transport_addr(self) -> TransportAddr {
self.0
}
pub fn is_constrained_transport(&self) -> bool {
matches!(
self.0,
TransportAddr::Ble { .. }
| TransportAddr::LoRa { .. }
| TransportAddr::Serial { .. }
| TransportAddr::Ax25 { .. }
)
}
}
impl From<TransportAddr> for ConstrainedAddr {
fn from(addr: TransportAddr) -> Self {
Self(addr)
}
}
impl From<ConstrainedAddr> for TransportAddr {
fn from(addr: ConstrainedAddr) -> Self {
addr.0
}
}
impl From<SocketAddr> for ConstrainedAddr {
fn from(addr: SocketAddr) -> Self {
Self(TransportAddr::Udp(addr))
}
}
impl fmt::Display for ConstrainedAddr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Error)]
pub enum ConstrainedError {
#[error("packet too small: expected at least {expected} bytes, got {actual}")]
PacketTooSmall {
expected: usize,
actual: usize,
},
#[error("invalid header: {0}")]
InvalidHeader(String),
#[error("connection not found: {0}")]
ConnectionNotFound(ConnectionId),
#[error("connection already exists: {0}")]
ConnectionExists(ConnectionId),
#[error("invalid state transition from {from} to {to}")]
InvalidStateTransition {
from: String,
to: String,
},
#[error("connection reset by peer")]
ConnectionReset,
#[error("connection timed out")]
Timeout,
#[error("maximum retransmissions exceeded ({count})")]
MaxRetransmissions {
count: u32,
},
#[error("send buffer full")]
SendBufferFull,
#[error("receive buffer full")]
ReceiveBufferFull,
#[error("transport error: {0}")]
Transport(String),
#[error("sequence number {seq} out of window (expected {expected_min}-{expected_max})")]
SequenceOutOfWindow {
seq: u8,
expected_min: u8,
expected_max: u8,
},
#[error("connection closed")]
ConnectionClosed,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constrained_addr_from_transport() {
let ble_addr = TransportAddr::Ble {
device_id: [0xAA, 0xBB, 0xCC, 0xDD, 0xEE, 0xFF],
service_uuid: None,
};
let constrained = ConstrainedAddr::from(ble_addr.clone());
assert!(constrained.is_constrained_transport());
assert_eq!(*constrained.transport_addr(), ble_addr);
}
#[test]
fn test_constrained_addr_from_socket() {
let socket: SocketAddr = "127.0.0.1:8080".parse().unwrap();
let constrained = ConstrainedAddr::from(socket);
assert!(!constrained.is_constrained_transport());
assert_eq!(
*constrained.transport_addr(),
TransportAddr::Udp("127.0.0.1:8080".parse().unwrap())
);
}
#[test]
fn test_constrained_addr_into_transport() {
let ble_addr = TransportAddr::Ble {
device_id: [0x11, 0x22, 0x33, 0x44, 0x55, 0x66],
service_uuid: None,
};
let constrained = ConstrainedAddr::new(ble_addr.clone());
let back: TransportAddr = constrained.into();
assert_eq!(back, ble_addr);
}
#[test]
fn test_constrained_addr_transport_detection() {
let ble = ConstrainedAddr::new(TransportAddr::Ble {
device_id: [0; 6],
service_uuid: None,
});
assert!(ble.is_constrained_transport());
let lora = ConstrainedAddr::new(TransportAddr::LoRa {
device_addr: [0; 4],
params: crate::transport::LoRaParams::default(),
});
assert!(lora.is_constrained_transport());
let udp = ConstrainedAddr::new(TransportAddr::Udp("0.0.0.0:0".parse().unwrap()));
assert!(!udp.is_constrained_transport());
}
#[test]
fn test_connection_id() {
let cid = ConnectionId::new(0x1234);
assert_eq!(cid.value(), 0x1234);
assert_eq!(cid.to_bytes(), [0x12, 0x34]);
assert_eq!(ConnectionId::from_bytes([0x12, 0x34]), cid);
}
#[test]
fn test_connection_id_display() {
let cid = ConnectionId::new(0xABCD);
assert_eq!(format!("{}", cid), "CID:ABCD");
}
#[test]
fn test_connection_id_random() {
let cid1 = ConnectionId::random();
let cid2 = ConnectionId::random();
assert!(cid1.value() != 0 || cid2.value() != 0);
}
#[test]
fn test_sequence_number_next() {
assert_eq!(SequenceNumber::new(0).next(), SequenceNumber::new(1));
assert_eq!(SequenceNumber::new(254).next(), SequenceNumber::new(255));
assert_eq!(SequenceNumber::new(255).next(), SequenceNumber::new(0));
}
#[test]
fn test_sequence_number_distance() {
let a = SequenceNumber::new(10);
let b = SequenceNumber::new(15);
assert_eq!(a.distance_to(b), 5);
assert_eq!(b.distance_to(a), -5);
let x = SequenceNumber::new(250);
let y = SequenceNumber::new(5);
assert_eq!(x.distance_to(y), 11); }
#[test]
fn test_sequence_number_in_window() {
let base = SequenceNumber::new(100);
assert!(base.is_in_window(SequenceNumber::new(100), 16));
assert!(base.is_in_window(SequenceNumber::new(110), 16));
assert!(base.is_in_window(SequenceNumber::new(116), 16));
assert!(!base.is_in_window(SequenceNumber::new(117), 16));
assert!(!base.is_in_window(SequenceNumber::new(99), 16));
}
#[test]
fn test_packet_flags() {
let flags = PacketFlags::SYN;
assert!(flags.is_syn());
assert!(!flags.is_ack());
let syn_ack = flags.with(PacketType::Ack);
assert!(syn_ack.is_syn());
assert!(syn_ack.is_ack());
assert_eq!(syn_ack, PacketFlags::SYN_ACK);
}
#[test]
fn test_packet_flags_display() {
assert_eq!(format!("{}", PacketFlags::NONE), "NONE");
assert_eq!(format!("{}", PacketFlags::SYN), "SYN");
assert_eq!(format!("{}", PacketFlags::SYN_ACK), "SYN|ACK");
assert_eq!(
format!("{}", PacketFlags::DATA.with(PacketType::Ack)),
"ACK|DATA"
);
}
#[test]
fn test_packet_flags_union() {
let a = PacketFlags::SYN;
let b = PacketFlags::DATA;
let combined = a.union(b);
assert!(combined.is_syn());
assert!(combined.is_data());
assert!(!combined.is_ack());
}
#[test]
fn test_constrained_error_display() {
let err = ConstrainedError::PacketTooSmall {
expected: 5,
actual: 3,
};
assert!(format!("{}", err).contains("expected at least 5 bytes"));
let err = ConstrainedError::ConnectionNotFound(ConnectionId::new(0x1234));
assert!(format!("{}", err).contains("CID:1234"));
}
}