1use super::ProtocolError;
4use crate::NodeAddr;
5use std::fmt;
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq)]
17#[repr(u8)]
18pub enum HandshakeMessageType {
19 NoiseIKMsg1 = 0x01,
22
23 NoiseIKMsg2 = 0x02,
26}
27
28impl HandshakeMessageType {
29 pub fn from_byte(b: u8) -> Option<Self> {
31 match b {
32 0x01 => Some(HandshakeMessageType::NoiseIKMsg1),
33 0x02 => Some(HandshakeMessageType::NoiseIKMsg2),
34 _ => None,
35 }
36 }
37
38 pub fn to_byte(self) -> u8 {
40 self as u8
41 }
42
43 pub fn is_handshake(b: u8) -> bool {
45 matches!(b, 0x01 | 0x02)
46 }
47}
48
49impl fmt::Display for HandshakeMessageType {
50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51 let name = match self {
52 HandshakeMessageType::NoiseIKMsg1 => "NoiseIKMsg1",
53 HandshakeMessageType::NoiseIKMsg2 => "NoiseIKMsg2",
54 };
55 write!(f, "{}", name)
56 }
57}
58
59#[derive(Clone, Copy, Debug, PartialEq, Eq)]
69#[repr(u8)]
70pub enum LinkMessageType {
71 SessionDatagram = 0x00,
75
76 SenderReport = 0x01,
79 ReceiverReport = 0x02,
81
82 TreeAnnounce = 0x10,
85
86 FilterAnnounce = 0x20,
89
90 LookupRequest = 0x30,
93 LookupResponse = 0x31,
95
96 Disconnect = 0x50,
99 Heartbeat = 0x51,
102}
103
104impl LinkMessageType {
105 pub fn from_byte(b: u8) -> Option<Self> {
107 match b {
108 0x00 => Some(LinkMessageType::SessionDatagram),
109 0x01 => Some(LinkMessageType::SenderReport),
110 0x02 => Some(LinkMessageType::ReceiverReport),
111 0x10 => Some(LinkMessageType::TreeAnnounce),
112 0x20 => Some(LinkMessageType::FilterAnnounce),
113 0x30 => Some(LinkMessageType::LookupRequest),
114 0x31 => Some(LinkMessageType::LookupResponse),
115 0x50 => Some(LinkMessageType::Disconnect),
116 0x51 => Some(LinkMessageType::Heartbeat),
117 _ => None,
118 }
119 }
120
121 pub fn to_byte(self) -> u8 {
123 self as u8
124 }
125}
126
127impl fmt::Display for LinkMessageType {
128 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
129 let name = match self {
130 LinkMessageType::SessionDatagram => "SessionDatagram",
131 LinkMessageType::SenderReport => "SenderReport",
132 LinkMessageType::ReceiverReport => "ReceiverReport",
133 LinkMessageType::TreeAnnounce => "TreeAnnounce",
134 LinkMessageType::FilterAnnounce => "FilterAnnounce",
135 LinkMessageType::LookupRequest => "LookupRequest",
136 LinkMessageType::LookupResponse => "LookupResponse",
137 LinkMessageType::Disconnect => "Disconnect",
138 LinkMessageType::Heartbeat => "Heartbeat",
139 };
140 write!(f, "{}", name)
141 }
142}
143
144#[derive(Clone, Copy, Debug, PartialEq, Eq)]
150#[repr(u8)]
151pub enum DisconnectReason {
152 Shutdown = 0x00,
154 Restart = 0x01,
156 ProtocolError = 0x02,
158 TransportFailure = 0x03,
160 ResourceExhaustion = 0x04,
162 SecurityViolation = 0x05,
164 ConfigurationChange = 0x06,
166 Timeout = 0x07,
168 Other = 0xFF,
170}
171
172impl DisconnectReason {
173 pub fn from_byte(b: u8) -> Option<Self> {
175 match b {
176 0x00 => Some(DisconnectReason::Shutdown),
177 0x01 => Some(DisconnectReason::Restart),
178 0x02 => Some(DisconnectReason::ProtocolError),
179 0x03 => Some(DisconnectReason::TransportFailure),
180 0x04 => Some(DisconnectReason::ResourceExhaustion),
181 0x05 => Some(DisconnectReason::SecurityViolation),
182 0x06 => Some(DisconnectReason::ConfigurationChange),
183 0x07 => Some(DisconnectReason::Timeout),
184 0xFF => Some(DisconnectReason::Other),
185 _ => None,
186 }
187 }
188
189 pub fn to_byte(self) -> u8 {
191 self as u8
192 }
193}
194
195impl fmt::Display for DisconnectReason {
196 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
197 let name = match self {
198 DisconnectReason::Shutdown => "Shutdown",
199 DisconnectReason::Restart => "Restart",
200 DisconnectReason::ProtocolError => "ProtocolError",
201 DisconnectReason::TransportFailure => "TransportFailure",
202 DisconnectReason::ResourceExhaustion => "ResourceExhaustion",
203 DisconnectReason::SecurityViolation => "SecurityViolation",
204 DisconnectReason::ConfigurationChange => "ConfigurationChange",
205 DisconnectReason::Timeout => "Timeout",
206 DisconnectReason::Other => "Other",
207 };
208 write!(f, "{}", name)
209 }
210}
211
212#[derive(Clone, Debug)]
229pub struct Disconnect {
230 pub reason: DisconnectReason,
232}
233
234impl Disconnect {
235 pub fn new(reason: DisconnectReason) -> Self {
237 Self { reason }
238 }
239
240 pub fn encode(&self) -> [u8; 2] {
242 [LinkMessageType::Disconnect.to_byte(), self.reason.to_byte()]
243 }
244
245 pub fn decode(payload: &[u8]) -> Result<Self, ProtocolError> {
247 if payload.is_empty() {
248 return Err(ProtocolError::MessageTooShort {
249 expected: 1,
250 got: 0,
251 });
252 }
253 let reason = DisconnectReason::from_byte(payload[0]).unwrap_or(DisconnectReason::Other);
254 Ok(Self { reason })
255 }
256}
257
258#[derive(Clone, Debug)]
284pub struct SessionDatagram {
285 pub src_addr: NodeAddr,
289 pub dest_addr: NodeAddr,
291 pub ttl: u8,
293 pub path_mtu: u16,
296 pub payload: Vec<u8>,
298}
299
300pub const SESSION_DATAGRAM_HEADER_SIZE: usize = 36;
302
303impl SessionDatagram {
304 pub fn new(src_addr: NodeAddr, dest_addr: NodeAddr, payload: Vec<u8>) -> Self {
306 Self {
307 src_addr,
308 dest_addr,
309 ttl: 64,
310 path_mtu: u16::MAX,
311 payload,
312 }
313 }
314
315 pub fn with_ttl(mut self, ttl: u8) -> Self {
317 self.ttl = ttl;
318 self
319 }
320
321 pub fn with_path_mtu(mut self, path_mtu: u16) -> Self {
323 self.path_mtu = path_mtu;
324 self
325 }
326
327 pub fn decrement_ttl(&mut self) -> bool {
329 if self.ttl > 0 {
330 self.ttl -= 1;
331 true
332 } else {
333 false
334 }
335 }
336
337 pub fn can_forward(&self) -> bool {
339 self.ttl > 0
340 }
341
342 pub fn encode(&self) -> Vec<u8> {
344 let mut buf = Vec::with_capacity(SESSION_DATAGRAM_HEADER_SIZE + self.payload.len());
345 buf.push(LinkMessageType::SessionDatagram.to_byte());
346 buf.push(self.ttl);
347 buf.extend_from_slice(&self.path_mtu.to_le_bytes());
348 buf.extend_from_slice(self.src_addr.as_bytes());
349 buf.extend_from_slice(self.dest_addr.as_bytes());
350 buf.extend_from_slice(&self.payload);
351 buf
352 }
353
354 pub fn decode(payload: &[u8]) -> Result<Self, ProtocolError> {
360 let r = SessionDatagramRef::decode(payload)?;
361 Ok(Self {
362 src_addr: r.src_addr,
363 dest_addr: r.dest_addr,
364 ttl: r.ttl,
365 path_mtu: r.path_mtu,
366 payload: r.payload.to_vec(),
367 })
368 }
369}
370
371#[derive(Debug, Clone, Copy)]
377pub struct SessionDatagramRef<'a> {
378 pub src_addr: NodeAddr,
379 pub dest_addr: NodeAddr,
380 pub ttl: u8,
381 pub path_mtu: u16,
382 pub payload: &'a [u8],
383}
384
385impl<'a> SessionDatagramRef<'a> {
386 pub fn decode(buf: &'a [u8]) -> Result<Self, ProtocolError> {
389 if buf.len() < 35 {
391 return Err(ProtocolError::MessageTooShort {
392 expected: 35,
393 got: buf.len(),
394 });
395 }
396 let ttl = buf[0];
397 let path_mtu = u16::from_le_bytes([buf[1], buf[2]]);
398 let mut src_bytes = [0u8; 16];
399 src_bytes.copy_from_slice(&buf[3..19]);
400 let mut dest_bytes = [0u8; 16];
401 dest_bytes.copy_from_slice(&buf[19..35]);
402 Ok(Self {
403 src_addr: NodeAddr::from_bytes(src_bytes),
404 dest_addr: NodeAddr::from_bytes(dest_bytes),
405 ttl,
406 path_mtu,
407 payload: &buf[35..],
408 })
409 }
410
411 pub const HEADER_LEN: usize = 35;
414}
415
416#[deprecated(note = "Use LinkMessageType or SessionMessageType instead")]
418pub type MessageType = LinkMessageType;
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
427 fn test_handshake_message_type_roundtrip() {
428 let types = [
429 HandshakeMessageType::NoiseIKMsg1,
430 HandshakeMessageType::NoiseIKMsg2,
431 ];
432
433 for ty in types {
434 let byte = ty.to_byte();
435 let restored = HandshakeMessageType::from_byte(byte);
436 assert_eq!(restored, Some(ty));
437 }
438 }
439
440 #[test]
441 fn test_handshake_message_type_invalid() {
442 assert!(HandshakeMessageType::from_byte(0x00).is_none());
443 assert!(HandshakeMessageType::from_byte(0x03).is_none());
444 assert!(HandshakeMessageType::from_byte(0x10).is_none());
445 }
446
447 #[test]
448 fn test_handshake_message_type_is_handshake() {
449 assert!(HandshakeMessageType::is_handshake(0x01));
450 assert!(HandshakeMessageType::is_handshake(0x02));
451 assert!(!HandshakeMessageType::is_handshake(0x00));
452 assert!(!HandshakeMessageType::is_handshake(0x10));
453 }
454
455 #[test]
458 fn test_link_message_type_roundtrip() {
459 let types = [
460 LinkMessageType::TreeAnnounce,
461 LinkMessageType::FilterAnnounce,
462 LinkMessageType::LookupRequest,
463 LinkMessageType::LookupResponse,
464 LinkMessageType::SessionDatagram,
465 LinkMessageType::Disconnect,
466 LinkMessageType::Heartbeat,
467 ];
468
469 for ty in types {
470 let byte = ty.to_byte();
471 let restored = LinkMessageType::from_byte(byte);
472 assert_eq!(restored, Some(ty));
473 }
474 }
475
476 #[test]
477 fn test_link_message_type_invalid() {
478 assert!(LinkMessageType::from_byte(0xFF).is_none());
479 assert!(LinkMessageType::from_byte(0x03).is_none());
480 assert!(LinkMessageType::from_byte(0x40).is_none());
481 }
482
483 #[test]
486 fn test_disconnect_reason_roundtrip() {
487 let reasons = [
488 DisconnectReason::Shutdown,
489 DisconnectReason::Restart,
490 DisconnectReason::ProtocolError,
491 DisconnectReason::TransportFailure,
492 DisconnectReason::ResourceExhaustion,
493 DisconnectReason::SecurityViolation,
494 DisconnectReason::ConfigurationChange,
495 DisconnectReason::Timeout,
496 DisconnectReason::Other,
497 ];
498
499 for reason in reasons {
500 let byte = reason.to_byte();
501 let restored = DisconnectReason::from_byte(byte);
502 assert_eq!(restored, Some(reason));
503 }
504 }
505
506 #[test]
507 fn test_disconnect_reason_unknown_byte() {
508 assert!(DisconnectReason::from_byte(0x08).is_none());
509 assert!(DisconnectReason::from_byte(0x80).is_none());
510 assert!(DisconnectReason::from_byte(0xFE).is_none());
511 }
512
513 #[test]
516 fn test_disconnect_encode_decode() {
517 let msg = Disconnect::new(DisconnectReason::Shutdown);
518 let encoded = msg.encode();
519
520 assert_eq!(encoded.len(), 2);
521 assert_eq!(encoded[0], 0x50); assert_eq!(encoded[1], 0x00); let decoded = Disconnect::decode(&encoded[1..]).unwrap();
526 assert_eq!(decoded.reason, DisconnectReason::Shutdown);
527 }
528
529 #[test]
530 fn test_disconnect_all_reasons() {
531 let reasons = [
532 DisconnectReason::Shutdown,
533 DisconnectReason::Restart,
534 DisconnectReason::ProtocolError,
535 DisconnectReason::Other,
536 ];
537
538 for reason in reasons {
539 let msg = Disconnect::new(reason);
540 let encoded = msg.encode();
541 let decoded = Disconnect::decode(&encoded[1..]).unwrap();
542 assert_eq!(decoded.reason, reason);
543 }
544 }
545
546 #[test]
547 fn test_disconnect_decode_empty_payload() {
548 let result = Disconnect::decode(&[]);
549 assert!(result.is_err());
550 }
551
552 #[test]
553 fn test_disconnect_decode_unknown_reason() {
554 let decoded = Disconnect::decode(&[0x80]).unwrap();
555 assert_eq!(decoded.reason, DisconnectReason::Other);
556 }
557
558 fn make_node_addr(val: u8) -> NodeAddr {
561 let mut bytes = [0u8; 16];
562 bytes[0] = val;
563 NodeAddr::from_bytes(bytes)
564 }
565
566 #[test]
567 fn test_session_datagram_encode_decode() {
568 let src = make_node_addr(0xAA);
569 let dest = make_node_addr(0xBB);
570 let payload = vec![0x10, 0x00, 0x05, 0x00, 1, 2, 3, 4, 5]; let dg = SessionDatagram::new(src, dest, payload.clone()).with_ttl(32);
572
573 let encoded = dg.encode();
574 assert_eq!(encoded[0], 0x00); assert_eq!(encoded.len(), SESSION_DATAGRAM_HEADER_SIZE + payload.len());
576
577 let decoded = SessionDatagram::decode(&encoded[1..]).unwrap();
579 assert_eq!(decoded.src_addr, src);
580 assert_eq!(decoded.dest_addr, dest);
581 assert_eq!(decoded.ttl, 32);
582 assert_eq!(decoded.payload, payload);
583 }
584
585 #[test]
586 fn test_session_datagram_empty_payload() {
587 let dg = SessionDatagram::new(make_node_addr(1), make_node_addr(2), Vec::new());
588
589 let encoded = dg.encode();
590 assert_eq!(encoded.len(), SESSION_DATAGRAM_HEADER_SIZE);
591
592 let decoded = SessionDatagram::decode(&encoded[1..]).unwrap();
593 assert!(decoded.payload.is_empty());
594 }
595
596 #[test]
597 fn test_session_datagram_decode_too_short() {
598 assert!(SessionDatagram::decode(&[]).is_err());
599 assert!(SessionDatagram::decode(&[0x00; 20]).is_err());
600 }
601
602 #[test]
603 fn test_session_datagram_ttl_roundtrip() {
604 for hop in [0u8, 1, 64, 128, 255] {
605 let dg = SessionDatagram::new(make_node_addr(1), make_node_addr(2), vec![0x42])
606 .with_ttl(hop);
607
608 let encoded = dg.encode();
609 let decoded = SessionDatagram::decode(&encoded[1..]).unwrap();
610 assert_eq!(decoded.ttl, hop);
611 }
612 }
613}