1use borsh::{BorshDeserialize, BorshSerialize};
10use std::fmt;
11
12#[derive(Clone, Copy, PartialEq, Eq, Hash)]
17pub struct SessionId(pub [u8; 32]);
18
19impl SessionId {
20 pub fn random() -> Self {
22 let mut bytes = [0u8; 32];
23 if getrandom::getrandom(&mut bytes).is_err() {
24 rand::RngCore::fill_bytes(&mut rand::thread_rng(), &mut bytes);
26 }
27 Self(bytes)
28 }
29
30 pub fn from_bytes(bytes: [u8; 32]) -> Self {
32 Self(bytes)
33 }
34
35 pub fn as_bytes(&self) -> &[u8; 32] {
37 &self.0
38 }
39}
40
41impl fmt::Debug for SessionId {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 write!(f, "SessionId({}...)", hex::encode(&self.0[..8]))
44 }
45}
46
47impl fmt::Display for SessionId {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 write!(f, "{}...", hex::encode(&self.0[..8]))
50 }
51}
52
53pub type StreamId = u16;
58
59pub type SequenceNumber = u32;
61
62pub const WIRE_VERSION: u8 = 2;
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
75pub enum WireError {
76 Truncated,
78}
79
80impl fmt::Display for WireError {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 match self {
83 WireError::Truncated => write!(f, "truncated packet"),
84 }
85 }
86}
87
88impl std::error::Error for WireError {}
89
90#[derive(Clone, Copy, PartialEq, Eq, Default)]
92pub struct PacketFlags(pub u16);
93
94impl PacketFlags {
95 pub const RELIABLE: u16 = 0x0001;
97 pub const ACK: u16 = 0x0002;
99 pub const FIN: u16 = 0x0004;
101 pub const UNRELIABLE: u16 = 0x0008;
103 pub const PRIORITY: u16 = 0x0010;
105 pub const ENCRYPTED: u16 = 0x0020;
107 pub const COMPRESSED: u16 = 0x0040;
109 pub const CONTROL: u16 = 0x0080;
111 pub const REKEY: u16 = 0x0100;
114 pub const PATH_VALIDATION: u16 = 0x0200;
117 pub const COALESCED: u16 = 0x0400;
120 pub const WINDOW_UPDATE: u16 = 0x0800;
125 pub const fn empty() -> Self {
129 Self(0)
130 }
131
132 pub const fn new(bits: u16) -> Self {
134 Self(bits)
135 }
136
137 #[inline]
139 pub const fn contains(&self, flag: u16) -> bool {
140 (self.0 & flag) == flag
141 }
142
143 #[inline]
145 pub fn set(&mut self, flag: u16) {
146 self.0 |= flag;
147 }
148
149 #[inline]
151 pub fn clear(&mut self, flag: u16) {
152 self.0 &= !flag;
153 }
154
155 #[inline]
157 pub const fn is_reliable(&self) -> bool {
158 self.contains(Self::RELIABLE)
159 }
160
161 #[inline]
163 pub const fn is_ack(&self) -> bool {
164 self.contains(Self::ACK)
165 }
166
167 #[inline]
169 pub const fn is_fin(&self) -> bool {
170 self.contains(Self::FIN)
171 }
172
173 #[inline]
175 pub const fn is_control(&self) -> bool {
176 self.contains(Self::CONTROL)
177 }
178
179 #[inline]
181 pub const fn is_rekey(&self) -> bool {
182 self.contains(Self::REKEY)
183 }
184}
185
186impl fmt::Debug for PacketFlags {
187 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
188 let mut flags = Vec::new();
189 if self.contains(Self::RELIABLE) {
190 flags.push("RELIABLE");
191 }
192 if self.contains(Self::ACK) {
193 flags.push("ACK");
194 }
195 if self.contains(Self::FIN) {
196 flags.push("FIN");
197 }
198 if self.contains(Self::UNRELIABLE) {
199 flags.push("UNRELIABLE");
200 }
201 if self.contains(Self::PRIORITY) {
202 flags.push("PRIORITY");
203 }
204 if self.contains(Self::ENCRYPTED) {
205 flags.push("ENCRYPTED");
206 }
207 if self.contains(Self::COMPRESSED) {
208 flags.push("COMPRESSED");
209 }
210 if self.contains(Self::CONTROL) {
211 flags.push("CONTROL");
212 }
213 if self.contains(Self::REKEY) {
214 flags.push("REKEY");
215 }
216 if self.contains(Self::PATH_VALIDATION) {
217 flags.push("PATH_VALIDATION");
218 }
219 if self.contains(Self::COALESCED) {
220 flags.push("COALESCED");
221 }
222 if self.contains(Self::WINDOW_UPDATE) {
223 flags.push("WINDOW_UPDATE");
224 }
225 write!(f, "PacketFlags({})", flags.join("|"))
226 }
227}
228
229#[derive(Clone, Copy, PartialEq, Eq)]
251#[repr(C)]
252pub struct PacketHeader {
253 pub version: u8,
256 pub session_id: SessionId,
258 pub stream_id: StreamId,
260 pub sequence: SequenceNumber,
262 pub flags: PacketFlags,
264 pub ack_delay: u16,
266 pub epoch: u8,
269 pub path_id: u8,
271}
272
273impl PacketHeader {
274 pub const SIZE: usize = 45;
276
277 pub fn new(
280 session_id: SessionId,
281 stream_id: StreamId,
282 sequence: SequenceNumber,
283 flags: PacketFlags,
284 ) -> Self {
285 Self {
286 version: WIRE_VERSION,
287 session_id,
288 stream_id,
289 sequence,
290 flags,
291 ack_delay: 0,
292 epoch: 0,
293 path_id: 0,
294 }
295 }
296
297 pub fn with_epoch(mut self, epoch: u8) -> Self {
299 self.epoch = epoch;
300 self
301 }
302
303 pub fn with_path_id(mut self, path_id: u8) -> Self {
305 self.path_id = path_id;
306 self
307 }
308
309 pub fn to_wire(&self) -> [u8; Self::SIZE] {
312 let mut b = [0u8; Self::SIZE];
313 b[0] = self.version;
314 b[1..33].copy_from_slice(&self.session_id.0);
315 b[33..35].copy_from_slice(&self.stream_id.to_be_bytes());
316 b[35..39].copy_from_slice(&self.sequence.to_be_bytes());
317 b[39..41].copy_from_slice(&self.flags.0.to_be_bytes());
318 b[41..43].copy_from_slice(&self.ack_delay.to_be_bytes());
319 b[43] = self.epoch;
320 b[44] = self.path_id;
321 b
322 }
323
324 pub fn from_wire(bytes: &[u8]) -> Result<Self, WireError> {
328 if bytes.len() < Self::SIZE {
329 return Err(WireError::Truncated);
330 }
331 let mut session_id = [0u8; 32];
332 session_id.copy_from_slice(&bytes[1..33]);
333 Ok(Self {
334 version: bytes[0],
335 session_id: SessionId(session_id),
336 stream_id: u16::from_be_bytes([bytes[33], bytes[34]]),
337 sequence: u32::from_be_bytes([bytes[35], bytes[36], bytes[37], bytes[38]]),
338 flags: PacketFlags(u16::from_be_bytes([bytes[39], bytes[40]])),
339 ack_delay: u16::from_be_bytes([bytes[41], bytes[42]]),
340 epoch: bytes[43],
341 path_id: bytes[44],
342 })
343 }
344}
345
346impl fmt::Debug for PacketHeader {
347 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
348 f.debug_struct("PacketHeader")
349 .field("version", &self.version)
350 .field("session", &self.session_id)
351 .field("stream", &self.stream_id)
352 .field("seq", &self.sequence)
353 .field("flags", &self.flags)
354 .field("epoch", &self.epoch)
355 .field("path_id", &self.path_id)
356 .finish()
357 }
358}
359
360fn read_length_prefixed(bytes: &[u8], pos: &mut usize) -> Result<Vec<u8>, WireError> {
365 let start = *pos;
366 let len_end = start.checked_add(4).ok_or(WireError::Truncated)?;
367 if len_end > bytes.len() {
368 return Err(WireError::Truncated);
369 }
370 let len = u32::from_be_bytes([
371 bytes[start],
372 bytes[start + 1],
373 bytes[start + 2],
374 bytes[start + 3],
375 ]) as usize;
376 let data_end = len_end.checked_add(len).ok_or(WireError::Truncated)?;
377 if data_end > bytes.len() {
378 return Err(WireError::Truncated);
379 }
380 *pos = data_end;
381 Ok(bytes[len_end..data_end].to_vec())
382}
383
384#[derive(Clone, PartialEq, Eq)]
386pub struct PhantomPacket {
387 pub header: PacketHeader,
389 pub payload: Vec<u8>,
391 pub extensions: Vec<u8>,
395}
396
397impl PhantomPacket {
398 pub fn new(header: PacketHeader, payload: Vec<u8>) -> Self {
400 Self {
401 header,
402 payload,
403 extensions: Vec::new(),
404 }
405 }
406
407 pub fn ack(session_id: SessionId, stream_id: StreamId, ack_sequence: SequenceNumber) -> Self {
409 Self {
410 header: PacketHeader::new(
411 session_id,
412 stream_id,
413 ack_sequence,
414 PacketFlags::new(PacketFlags::ACK),
415 ),
416 payload: Vec::new(),
417 extensions: Vec::new(),
418 }
419 }
420
421 pub fn wire_size(&self) -> usize {
423 PacketHeader::SIZE + 8 + self.payload.len() + self.extensions.len()
424 }
425
426 pub fn to_wire(&self) -> Vec<u8> {
429 let mut b = Vec::with_capacity(self.wire_size());
430 b.extend_from_slice(&self.header.to_wire());
431 b.extend_from_slice(&(self.payload.len() as u32).to_be_bytes());
432 b.extend_from_slice(&self.payload);
433 b.extend_from_slice(&(self.extensions.len() as u32).to_be_bytes());
434 b.extend_from_slice(&self.extensions);
435 b
436 }
437
438 pub fn from_wire(bytes: &[u8]) -> Result<Self, WireError> {
441 let header = PacketHeader::from_wire(bytes)?;
442 let mut pos = PacketHeader::SIZE;
443 let payload = read_length_prefixed(bytes, &mut pos)?;
444 let extensions = read_length_prefixed(bytes, &mut pos)?;
445 Ok(Self {
446 header,
447 payload,
448 extensions,
449 })
450 }
451}
452
453impl fmt::Debug for PhantomPacket {
454 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
455 f.debug_struct("PhantomPacket")
456 .field("header", &self.header)
457 .field("payload_len", &self.payload.len())
458 .field("extensions_len", &self.extensions.len())
459 .finish()
460 }
461}
462
463#[derive(Clone, Copy, Debug, PartialEq, Eq, BorshSerialize, BorshDeserialize)]
465#[borsh(use_discriminant = true)]
466#[repr(u8)]
467pub enum ControlMessage {
468 Hello = 0,
470 HelloAck = 1,
472 Resume = 2,
474 ResumeAck = 3,
476 Migrate = 4,
478 MigrateAck = 5,
480 Close = 6,
482 CloseAck = 7,
484 Ping = 8,
486 Pong = 9,
488}
489
490#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, BorshSerialize, BorshDeserialize)]
492pub enum LegType {
493 Kcp,
495 Tcp,
497 FakeTls,
499}
500
501impl LegType {
502 pub fn is_reliable(&self) -> bool {
504 matches!(self, LegType::Kcp | LegType::Tcp | LegType::FakeTls)
505 }
506
507 pub fn is_obfuscated(&self) -> bool {
509 matches!(self, LegType::FakeTls)
510 }
511}
512
513#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, BorshSerialize, BorshDeserialize)]
515pub enum SchedulerMode {
516 LowLatency,
518 HighThroughput,
520 Reliability,
522 Stealth,
524}
525
526#[cfg(test)]
527mod tests {
528 use super::*;
529
530 #[test]
531 fn test_session_id_random() {
532 let id1 = SessionId::random();
533 let id2 = SessionId::random();
534 assert_ne!(id1, id2);
535 }
536
537 #[test]
538 fn test_packet_flags() {
539 let mut flags = PacketFlags::empty();
540 assert!(!flags.is_reliable());
541
542 flags.set(PacketFlags::RELIABLE);
543 assert!(flags.is_reliable());
544
545 flags.set(PacketFlags::ENCRYPTED);
546 assert!(flags.contains(PacketFlags::RELIABLE));
547 assert!(flags.contains(PacketFlags::ENCRYPTED));
548
549 flags.clear(PacketFlags::RELIABLE);
550 assert!(!flags.is_reliable());
551 assert!(flags.contains(PacketFlags::ENCRYPTED));
552 }
553
554 #[test]
555 fn flags_bit_assignments() {
556 assert_eq!(PacketFlags::RELIABLE, 0x0001);
557 assert_eq!(PacketFlags::ENCRYPTED, 0x0020);
558 assert_eq!(PacketFlags::CONTROL, 0x0080);
559 assert_eq!(PacketFlags::REKEY, 0x0100);
560 assert_eq!(PacketFlags::PATH_VALIDATION, 0x0200);
561 assert_eq!(PacketFlags::COALESCED, 0x0400);
562 assert_eq!(PacketFlags::WINDOW_UPDATE, 0x0800);
563 }
564
565 #[test]
566 fn flags_contains_set_clear() {
567 let mut f = PacketFlags::empty();
568 assert!(!f.is_reliable());
569 assert!(!f.is_rekey());
570 f.set(PacketFlags::RELIABLE | PacketFlags::REKEY);
571 assert!(f.is_reliable());
572 assert!(f.is_rekey());
573 f.clear(PacketFlags::REKEY);
574 assert!(f.is_reliable());
575 assert!(!f.is_rekey());
576 }
577
578 #[test]
579 fn packet_header_serializes_to_45_bytes() {
580 assert_eq!(PacketHeader::SIZE, 45);
581 let header = PacketHeader::new(
582 SessionId::from_bytes([0u8; 32]),
583 1,
584 1,
585 PacketFlags::new(PacketFlags::ENCRYPTED),
586 );
587 let bytes = header.to_wire();
588 assert_eq!(
589 bytes.len(),
590 PacketHeader::SIZE,
591 "the serialised header (= AEAD AAD) must be exactly 45 bytes"
592 );
593 assert_eq!(bytes[0], WIRE_VERSION);
595 assert_eq!(PacketHeader::from_wire(&bytes).expect("roundtrip"), header);
596 }
597
598 #[test]
599 fn test_phantom_packet_ack() {
600 let session_id = SessionId::random();
601 let ack = PhantomPacket::ack(session_id, 5, 100);
602
603 assert!(ack.header.flags.is_ack());
604 assert_eq!(ack.header.stream_id, 5);
605 assert_eq!(ack.header.sequence, 100);
606 assert!(ack.payload.is_empty());
607 assert!(ack.extensions.is_empty());
608 }
609
610 #[test]
611 fn packet_roundtrip_preserves_fields() {
612 let session_id = SessionId::random();
613 let header = PacketHeader::new(
614 session_id,
615 7,
616 42,
617 PacketFlags::new(PacketFlags::ENCRYPTED | PacketFlags::RELIABLE),
618 )
619 .with_epoch(3)
620 .with_path_id(1);
621 let packet = PhantomPacket::new(header, vec![0xCA, 0xFE, 0xBA, 0xBE]);
622
623 let bytes = packet.to_wire();
624 let decoded = PhantomPacket::from_wire(&bytes).expect("roundtrip");
625 assert_eq!(decoded, packet);
626 assert_eq!(decoded.header.version, WIRE_VERSION);
627 assert_eq!(decoded.header.stream_id, 7);
628 assert_eq!(decoded.header.sequence, 42);
629 assert_eq!(decoded.header.epoch, 3);
630 assert_eq!(decoded.header.path_id, 1);
631 assert!(decoded.header.flags.is_reliable());
632 assert!(decoded.header.flags.contains(PacketFlags::ENCRYPTED));
633 assert_eq!(decoded.payload, vec![0xCA, 0xFE, 0xBA, 0xBE]);
634 }
635
636 #[test]
637 fn extensions_preserved_on_roundtrip() {
638 let session_id = SessionId::random();
639 let mut packet = PhantomPacket::new(
640 PacketHeader::new(
641 session_id,
642 1,
643 1,
644 PacketFlags::new(PacketFlags::CONTROL | PacketFlags::RELIABLE),
645 ),
646 vec![1, 2, 3],
647 );
648 packet.extensions = vec![0xFF, 0x01, 0x00, 0x04, b't', b'e', b's', b't'];
649
650 let bytes = packet.to_wire();
651 let deser = PhantomPacket::from_wire(&bytes).expect("deserialize failed");
652 assert_eq!(
653 deser.extensions,
654 vec![0xFF, 0x01, 0x00, 0x04, b't', b'e', b's', b't']
655 );
656 }
657}