use super::ProtocolError;
use crate::NodeAddr;
use std::fmt;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum HandshakeMessageType {
NoiseIKMsg1 = 0x01,
NoiseIKMsg2 = 0x02,
}
impl HandshakeMessageType {
pub fn from_byte(b: u8) -> Option<Self> {
match b {
0x01 => Some(HandshakeMessageType::NoiseIKMsg1),
0x02 => Some(HandshakeMessageType::NoiseIKMsg2),
_ => None,
}
}
pub fn to_byte(self) -> u8 {
self as u8
}
pub fn is_handshake(b: u8) -> bool {
matches!(b, 0x01 | 0x02)
}
}
impl fmt::Display for HandshakeMessageType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
HandshakeMessageType::NoiseIKMsg1 => "NoiseIKMsg1",
HandshakeMessageType::NoiseIKMsg2 => "NoiseIKMsg2",
};
write!(f, "{}", name)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum LinkMessageType {
SessionDatagram = 0x00,
SenderReport = 0x01,
ReceiverReport = 0x02,
TreeAnnounce = 0x10,
FilterAnnounce = 0x20,
LookupRequest = 0x30,
LookupResponse = 0x31,
Disconnect = 0x50,
Heartbeat = 0x51,
}
impl LinkMessageType {
pub fn from_byte(b: u8) -> Option<Self> {
match b {
0x00 => Some(LinkMessageType::SessionDatagram),
0x01 => Some(LinkMessageType::SenderReport),
0x02 => Some(LinkMessageType::ReceiverReport),
0x10 => Some(LinkMessageType::TreeAnnounce),
0x20 => Some(LinkMessageType::FilterAnnounce),
0x30 => Some(LinkMessageType::LookupRequest),
0x31 => Some(LinkMessageType::LookupResponse),
0x50 => Some(LinkMessageType::Disconnect),
0x51 => Some(LinkMessageType::Heartbeat),
_ => None,
}
}
pub fn to_byte(self) -> u8 {
self as u8
}
}
impl fmt::Display for LinkMessageType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
LinkMessageType::SessionDatagram => "SessionDatagram",
LinkMessageType::SenderReport => "SenderReport",
LinkMessageType::ReceiverReport => "ReceiverReport",
LinkMessageType::TreeAnnounce => "TreeAnnounce",
LinkMessageType::FilterAnnounce => "FilterAnnounce",
LinkMessageType::LookupRequest => "LookupRequest",
LinkMessageType::LookupResponse => "LookupResponse",
LinkMessageType::Disconnect => "Disconnect",
LinkMessageType::Heartbeat => "Heartbeat",
};
write!(f, "{}", name)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[repr(u8)]
pub enum DisconnectReason {
Shutdown = 0x00,
Restart = 0x01,
ProtocolError = 0x02,
TransportFailure = 0x03,
ResourceExhaustion = 0x04,
SecurityViolation = 0x05,
ConfigurationChange = 0x06,
Timeout = 0x07,
Other = 0xFF,
}
impl DisconnectReason {
pub fn from_byte(b: u8) -> Option<Self> {
match b {
0x00 => Some(DisconnectReason::Shutdown),
0x01 => Some(DisconnectReason::Restart),
0x02 => Some(DisconnectReason::ProtocolError),
0x03 => Some(DisconnectReason::TransportFailure),
0x04 => Some(DisconnectReason::ResourceExhaustion),
0x05 => Some(DisconnectReason::SecurityViolation),
0x06 => Some(DisconnectReason::ConfigurationChange),
0x07 => Some(DisconnectReason::Timeout),
0xFF => Some(DisconnectReason::Other),
_ => None,
}
}
pub fn to_byte(self) -> u8 {
self as u8
}
}
impl fmt::Display for DisconnectReason {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name = match self {
DisconnectReason::Shutdown => "Shutdown",
DisconnectReason::Restart => "Restart",
DisconnectReason::ProtocolError => "ProtocolError",
DisconnectReason::TransportFailure => "TransportFailure",
DisconnectReason::ResourceExhaustion => "ResourceExhaustion",
DisconnectReason::SecurityViolation => "SecurityViolation",
DisconnectReason::ConfigurationChange => "ConfigurationChange",
DisconnectReason::Timeout => "Timeout",
DisconnectReason::Other => "Other",
};
write!(f, "{}", name)
}
}
#[derive(Clone, Debug)]
pub struct Disconnect {
pub reason: DisconnectReason,
}
impl Disconnect {
pub fn new(reason: DisconnectReason) -> Self {
Self { reason }
}
pub fn encode(&self) -> [u8; 2] {
[LinkMessageType::Disconnect.to_byte(), self.reason.to_byte()]
}
pub fn decode(payload: &[u8]) -> Result<Self, ProtocolError> {
if payload.is_empty() {
return Err(ProtocolError::MessageTooShort {
expected: 1,
got: 0,
});
}
let reason = DisconnectReason::from_byte(payload[0]).unwrap_or(DisconnectReason::Other);
Ok(Self { reason })
}
}
#[derive(Clone, Debug)]
pub struct SessionDatagram {
pub src_addr: NodeAddr,
pub dest_addr: NodeAddr,
pub ttl: u8,
pub path_mtu: u16,
pub payload: Vec<u8>,
}
pub const SESSION_DATAGRAM_HEADER_SIZE: usize = 36;
impl SessionDatagram {
pub fn new(src_addr: NodeAddr, dest_addr: NodeAddr, payload: Vec<u8>) -> Self {
Self {
src_addr,
dest_addr,
ttl: 64,
path_mtu: u16::MAX,
payload,
}
}
pub fn with_ttl(mut self, ttl: u8) -> Self {
self.ttl = ttl;
self
}
pub fn with_path_mtu(mut self, path_mtu: u16) -> Self {
self.path_mtu = path_mtu;
self
}
pub fn decrement_ttl(&mut self) -> bool {
if self.ttl > 0 {
self.ttl -= 1;
true
} else {
false
}
}
pub fn can_forward(&self) -> bool {
self.ttl > 0
}
pub fn encode(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(SESSION_DATAGRAM_HEADER_SIZE + self.payload.len());
buf.push(LinkMessageType::SessionDatagram.to_byte());
buf.push(self.ttl);
buf.extend_from_slice(&self.path_mtu.to_le_bytes());
buf.extend_from_slice(self.src_addr.as_bytes());
buf.extend_from_slice(self.dest_addr.as_bytes());
buf.extend_from_slice(&self.payload);
buf
}
pub fn decode(payload: &[u8]) -> Result<Self, ProtocolError> {
let r = SessionDatagramRef::decode(payload)?;
Ok(Self {
src_addr: r.src_addr,
dest_addr: r.dest_addr,
ttl: r.ttl,
path_mtu: r.path_mtu,
payload: r.payload.to_vec(),
})
}
}
#[derive(Debug, Clone, Copy)]
pub struct SessionDatagramRef<'a> {
pub src_addr: NodeAddr,
pub dest_addr: NodeAddr,
pub ttl: u8,
pub path_mtu: u16,
pub payload: &'a [u8],
}
impl<'a> SessionDatagramRef<'a> {
pub fn decode(buf: &'a [u8]) -> Result<Self, ProtocolError> {
if buf.len() < 35 {
return Err(ProtocolError::MessageTooShort {
expected: 35,
got: buf.len(),
});
}
let ttl = buf[0];
let path_mtu = u16::from_le_bytes([buf[1], buf[2]]);
let mut src_bytes = [0u8; 16];
src_bytes.copy_from_slice(&buf[3..19]);
let mut dest_bytes = [0u8; 16];
dest_bytes.copy_from_slice(&buf[19..35]);
Ok(Self {
src_addr: NodeAddr::from_bytes(src_bytes),
dest_addr: NodeAddr::from_bytes(dest_bytes),
ttl,
path_mtu,
payload: &buf[35..],
})
}
pub const HEADER_LEN: usize = 35;
}
#[deprecated(note = "Use LinkMessageType or SessionMessageType instead")]
pub type MessageType = LinkMessageType;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_handshake_message_type_roundtrip() {
let types = [
HandshakeMessageType::NoiseIKMsg1,
HandshakeMessageType::NoiseIKMsg2,
];
for ty in types {
let byte = ty.to_byte();
let restored = HandshakeMessageType::from_byte(byte);
assert_eq!(restored, Some(ty));
}
}
#[test]
fn test_handshake_message_type_invalid() {
assert!(HandshakeMessageType::from_byte(0x00).is_none());
assert!(HandshakeMessageType::from_byte(0x03).is_none());
assert!(HandshakeMessageType::from_byte(0x10).is_none());
}
#[test]
fn test_handshake_message_type_is_handshake() {
assert!(HandshakeMessageType::is_handshake(0x01));
assert!(HandshakeMessageType::is_handshake(0x02));
assert!(!HandshakeMessageType::is_handshake(0x00));
assert!(!HandshakeMessageType::is_handshake(0x10));
}
#[test]
fn test_link_message_type_roundtrip() {
let types = [
LinkMessageType::TreeAnnounce,
LinkMessageType::FilterAnnounce,
LinkMessageType::LookupRequest,
LinkMessageType::LookupResponse,
LinkMessageType::SessionDatagram,
LinkMessageType::Disconnect,
LinkMessageType::Heartbeat,
];
for ty in types {
let byte = ty.to_byte();
let restored = LinkMessageType::from_byte(byte);
assert_eq!(restored, Some(ty));
}
}
#[test]
fn test_link_message_type_invalid() {
assert!(LinkMessageType::from_byte(0xFF).is_none());
assert!(LinkMessageType::from_byte(0x03).is_none());
assert!(LinkMessageType::from_byte(0x40).is_none());
}
#[test]
fn test_disconnect_reason_roundtrip() {
let reasons = [
DisconnectReason::Shutdown,
DisconnectReason::Restart,
DisconnectReason::ProtocolError,
DisconnectReason::TransportFailure,
DisconnectReason::ResourceExhaustion,
DisconnectReason::SecurityViolation,
DisconnectReason::ConfigurationChange,
DisconnectReason::Timeout,
DisconnectReason::Other,
];
for reason in reasons {
let byte = reason.to_byte();
let restored = DisconnectReason::from_byte(byte);
assert_eq!(restored, Some(reason));
}
}
#[test]
fn test_disconnect_reason_unknown_byte() {
assert!(DisconnectReason::from_byte(0x08).is_none());
assert!(DisconnectReason::from_byte(0x80).is_none());
assert!(DisconnectReason::from_byte(0xFE).is_none());
}
#[test]
fn test_disconnect_encode_decode() {
let msg = Disconnect::new(DisconnectReason::Shutdown);
let encoded = msg.encode();
assert_eq!(encoded.len(), 2);
assert_eq!(encoded[0], 0x50); assert_eq!(encoded[1], 0x00);
let decoded = Disconnect::decode(&encoded[1..]).unwrap();
assert_eq!(decoded.reason, DisconnectReason::Shutdown);
}
#[test]
fn test_disconnect_all_reasons() {
let reasons = [
DisconnectReason::Shutdown,
DisconnectReason::Restart,
DisconnectReason::ProtocolError,
DisconnectReason::Other,
];
for reason in reasons {
let msg = Disconnect::new(reason);
let encoded = msg.encode();
let decoded = Disconnect::decode(&encoded[1..]).unwrap();
assert_eq!(decoded.reason, reason);
}
}
#[test]
fn test_disconnect_decode_empty_payload() {
let result = Disconnect::decode(&[]);
assert!(result.is_err());
}
#[test]
fn test_disconnect_decode_unknown_reason() {
let decoded = Disconnect::decode(&[0x80]).unwrap();
assert_eq!(decoded.reason, DisconnectReason::Other);
}
fn make_node_addr(val: u8) -> NodeAddr {
let mut bytes = [0u8; 16];
bytes[0] = val;
NodeAddr::from_bytes(bytes)
}
#[test]
fn test_session_datagram_encode_decode() {
let src = make_node_addr(0xAA);
let dest = make_node_addr(0xBB);
let payload = vec![0x10, 0x00, 0x05, 0x00, 1, 2, 3, 4, 5]; let dg = SessionDatagram::new(src, dest, payload.clone()).with_ttl(32);
let encoded = dg.encode();
assert_eq!(encoded[0], 0x00); assert_eq!(encoded.len(), SESSION_DATAGRAM_HEADER_SIZE + payload.len());
let decoded = SessionDatagram::decode(&encoded[1..]).unwrap();
assert_eq!(decoded.src_addr, src);
assert_eq!(decoded.dest_addr, dest);
assert_eq!(decoded.ttl, 32);
assert_eq!(decoded.payload, payload);
}
#[test]
fn test_session_datagram_empty_payload() {
let dg = SessionDatagram::new(make_node_addr(1), make_node_addr(2), Vec::new());
let encoded = dg.encode();
assert_eq!(encoded.len(), SESSION_DATAGRAM_HEADER_SIZE);
let decoded = SessionDatagram::decode(&encoded[1..]).unwrap();
assert!(decoded.payload.is_empty());
}
#[test]
fn test_session_datagram_decode_too_short() {
assert!(SessionDatagram::decode(&[]).is_err());
assert!(SessionDatagram::decode(&[0x00; 20]).is_err());
}
#[test]
fn test_session_datagram_ttl_roundtrip() {
for hop in [0u8, 1, 64, 128, 255] {
let dg = SessionDatagram::new(make_node_addr(1), make_node_addr(2), vec![0x42])
.with_ttl(hop);
let encoded = dg.encode();
let decoded = SessionDatagram::decode(&encoded[1..]).unwrap();
assert_eq!(decoded.ttl, hop);
}
}
}