1use thiserror::Error;
8
9pub mod sizes {
11 pub const AEAD_TAG_SIZE: usize = 16;
13 pub const SESSION_ID_SIZE: usize = 6;
15 pub const NONCE_COUNTER_SIZE: usize = 8;
17 pub const DATA_FRAME_HEADER_SIZE: usize = 1 + 1 + SESSION_ID_SIZE + NONCE_COUNTER_SIZE;
19 pub const MIN_FRAME_SIZE: usize = DATA_FRAME_HEADER_SIZE + AEAD_TAG_SIZE;
21 pub const PAYLOAD_HEADER_SIZE: usize = 4 + 4 + 2;
23 pub const DEFAULT_MAX_PAYLOAD: usize = 1200;
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29#[repr(u8)]
30pub enum FrameType {
31 HandshakeInit = 0x01,
33 HandshakeResp = 0x02,
35 Data = 0x03,
37 Rekey = 0x04,
39 Close = 0x05,
41}
42
43impl FrameType {
44 pub fn from_byte(byte: u8) -> Option<Self> {
46 match byte {
47 0x01 => Some(Self::HandshakeInit),
48 0x02 => Some(Self::HandshakeResp),
49 0x03 => Some(Self::Data),
50 0x04 => Some(Self::Rekey),
51 0x05 => Some(Self::Close),
52 _ => None,
53 }
54 }
55
56 pub fn as_byte(self) -> u8 {
58 self as u8
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
64pub struct FrameFlags(u8);
65
66impl FrameFlags {
67 pub const NONE: Self = Self(0);
69 pub const ACK_ONLY: Self = Self(0x01);
71 pub const HAS_EXTENSION: Self = Self(0x02);
73
74 pub fn from_byte(byte: u8) -> Self {
76 Self(byte)
77 }
78
79 pub fn as_byte(self) -> u8 {
81 self.0
82 }
83
84 pub fn is_ack_only(self) -> bool {
86 self.0 & 0x01 != 0
87 }
88
89 pub fn has_extension(self) -> bool {
91 self.0 & 0x02 != 0
92 }
93
94 pub fn with_ack_only(self) -> Self {
96 Self(self.0 | 0x01)
97 }
98
99 pub fn with_extension(self) -> Self {
101 Self(self.0 | 0x02)
102 }
103
104 pub fn is_valid(self) -> bool {
106 self.0 & 0xFC == 0
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub struct SessionId([u8; sizes::SESSION_ID_SIZE]);
113
114impl SessionId {
115 pub fn from_bytes(bytes: [u8; sizes::SESSION_ID_SIZE]) -> Self {
117 Self(bytes)
118 }
119
120 pub fn as_bytes(&self) -> &[u8; sizes::SESSION_ID_SIZE] {
122 &self.0
123 }
124
125 pub fn zero() -> Self {
127 Self([0u8; sizes::SESSION_ID_SIZE])
128 }
129}
130
131impl AsRef<[u8]> for SessionId {
132 fn as_ref(&self) -> &[u8] {
133 &self.0
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub struct DataFrameHeader {
148 pub frame_type: FrameType,
150 pub flags: FrameFlags,
152 pub session_id: SessionId,
154 pub nonce_counter: u64,
156}
157
158impl DataFrameHeader {
159 pub fn new(session_id: SessionId, nonce_counter: u64) -> Self {
161 Self {
162 frame_type: FrameType::Data,
163 flags: FrameFlags::NONE,
164 session_id,
165 nonce_counter,
166 }
167 }
168
169 pub fn close(session_id: SessionId, nonce_counter: u64) -> Self {
171 Self {
172 frame_type: FrameType::Close,
173 flags: FrameFlags::NONE,
174 session_id,
175 nonce_counter,
176 }
177 }
178
179 pub fn to_bytes(&self) -> [u8; sizes::DATA_FRAME_HEADER_SIZE] {
181 let mut buf = [0u8; sizes::DATA_FRAME_HEADER_SIZE];
182 buf[0] = self.frame_type.as_byte();
183 buf[1] = self.flags.as_byte();
184 buf[2..8].copy_from_slice(self.session_id.as_bytes());
185 buf[8..16].copy_from_slice(&self.nonce_counter.to_le_bytes());
186 buf
187 }
188
189 pub fn from_bytes(bytes: &[u8]) -> Result<Self, FrameError> {
191 if bytes.len() < sizes::DATA_FRAME_HEADER_SIZE {
192 return Err(FrameError::TooShort {
193 expected: sizes::DATA_FRAME_HEADER_SIZE,
194 actual: bytes.len(),
195 });
196 }
197
198 let frame_type = FrameType::from_byte(bytes[0]).ok_or(FrameError::InvalidType(bytes[0]))?;
199
200 let flags = FrameFlags::from_byte(bytes[1]);
201 if !flags.is_valid() {
202 return Err(FrameError::InvalidFlags(bytes[1]));
203 }
204
205 let mut session_id_bytes = [0u8; sizes::SESSION_ID_SIZE];
206 session_id_bytes.copy_from_slice(&bytes[2..8]);
207 let session_id = SessionId::from_bytes(session_id_bytes);
208
209 let mut nonce_bytes = [0u8; 8];
210 nonce_bytes.copy_from_slice(&bytes[8..16]);
211 let nonce_counter = u64::from_le_bytes(nonce_bytes);
212
213 Ok(Self {
214 frame_type,
215 flags,
216 session_id,
217 nonce_counter,
218 })
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
232pub struct PayloadHeader {
233 pub timestamp: u32,
235 pub timestamp_echo: u32,
237 pub payload_length: u16,
239}
240
241impl PayloadHeader {
242 pub fn new(timestamp: u32, timestamp_echo: u32, payload_length: u16) -> Self {
244 Self {
245 timestamp,
246 timestamp_echo,
247 payload_length,
248 }
249 }
250
251 pub fn to_bytes(&self) -> [u8; sizes::PAYLOAD_HEADER_SIZE] {
253 let mut buf = [0u8; sizes::PAYLOAD_HEADER_SIZE];
254 buf[0..4].copy_from_slice(&self.timestamp.to_le_bytes());
255 buf[4..8].copy_from_slice(&self.timestamp_echo.to_le_bytes());
256 buf[8..10].copy_from_slice(&self.payload_length.to_le_bytes());
257 buf
258 }
259
260 pub fn from_bytes(bytes: &[u8]) -> Result<Self, FrameError> {
262 if bytes.len() < sizes::PAYLOAD_HEADER_SIZE {
263 return Err(FrameError::TooShort {
264 expected: sizes::PAYLOAD_HEADER_SIZE,
265 actual: bytes.len(),
266 });
267 }
268
269 let timestamp = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
270 let timestamp_echo = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
271 let payload_length = u16::from_le_bytes([bytes[8], bytes[9]]);
272
273 Ok(Self {
274 timestamp,
275 timestamp_echo,
276 payload_length,
277 })
278 }
279}
280
281#[derive(Debug, Clone)]
283pub struct DataFrame {
284 pub header: DataFrameHeader,
286 pub payload_header: PayloadHeader,
288 pub sync_message: Vec<u8>,
290}
291
292impl DataFrame {
293 pub fn new(
295 session_id: SessionId,
296 nonce_counter: u64,
297 timestamp: u32,
298 timestamp_echo: u32,
299 sync_message: Vec<u8>,
300 ) -> Self {
301 let payload_length = sync_message.len() as u16;
302 Self {
303 header: DataFrameHeader::new(session_id, nonce_counter),
304 payload_header: PayloadHeader::new(timestamp, timestamp_echo, payload_length),
305 sync_message,
306 }
307 }
308
309 pub fn ack_only(
311 session_id: SessionId,
312 nonce_counter: u64,
313 timestamp: u32,
314 timestamp_echo: u32,
315 ) -> Self {
316 let mut frame = Self::new(session_id, nonce_counter, timestamp, timestamp_echo, vec![]);
317 frame.header.flags = FrameFlags::ACK_ONLY;
318 frame
319 }
320
321 pub fn plaintext(&self) -> Vec<u8> {
323 let mut plaintext =
324 Vec::with_capacity(sizes::PAYLOAD_HEADER_SIZE + self.sync_message.len());
325 plaintext.extend_from_slice(&self.payload_header.to_bytes());
326 plaintext.extend_from_slice(&self.sync_message);
327 plaintext
328 }
329
330 pub fn aad(&self) -> [u8; sizes::DATA_FRAME_HEADER_SIZE] {
332 self.header.to_bytes()
333 }
334}
335
336#[derive(Debug, Clone, Copy)]
338pub struct CloseFrame {
339 pub header: DataFrameHeader,
341 pub final_ack: u64,
343}
344
345pub mod rekey_sizes {
347 pub const EPHEMERAL_KEY_SIZE: usize = 32;
349 pub const TIMESTAMP_SIZE: usize = 4;
351 pub const REKEY_PAYLOAD_SIZE: usize = EPHEMERAL_KEY_SIZE + TIMESTAMP_SIZE;
353}
354
355#[derive(Debug, Clone, Copy)]
367pub struct RekeyFrame {
368 pub header: DataFrameHeader,
370 pub ephemeral_public: [u8; rekey_sizes::EPHEMERAL_KEY_SIZE],
372 pub timestamp: u32,
374}
375
376impl RekeyFrame {
377 pub fn new(
379 session_id: SessionId,
380 nonce_counter: u64,
381 ephemeral_public: [u8; rekey_sizes::EPHEMERAL_KEY_SIZE],
382 timestamp: u32,
383 ) -> Self {
384 Self {
385 header: DataFrameHeader {
386 frame_type: FrameType::Rekey,
387 flags: FrameFlags::NONE,
388 session_id,
389 nonce_counter,
390 },
391 ephemeral_public,
392 timestamp,
393 }
394 }
395
396 pub fn plaintext(&self) -> [u8; rekey_sizes::REKEY_PAYLOAD_SIZE] {
398 let mut buf = [0u8; rekey_sizes::REKEY_PAYLOAD_SIZE];
399 buf[..32].copy_from_slice(&self.ephemeral_public);
400 buf[32..36].copy_from_slice(&self.timestamp.to_le_bytes());
401 buf
402 }
403
404 pub fn aad(&self) -> [u8; sizes::DATA_FRAME_HEADER_SIZE] {
406 self.header.to_bytes()
407 }
408
409 pub fn from_decrypted(
411 header: DataFrameHeader,
412 payload: &[u8],
413 ) -> Result<Self, FrameError> {
414 if payload.len() < rekey_sizes::REKEY_PAYLOAD_SIZE {
415 return Err(FrameError::TooShort {
416 expected: rekey_sizes::REKEY_PAYLOAD_SIZE,
417 actual: payload.len(),
418 });
419 }
420
421 let mut ephemeral_public = [0u8; rekey_sizes::EPHEMERAL_KEY_SIZE];
422 ephemeral_public.copy_from_slice(&payload[..32]);
423 let timestamp = u32::from_le_bytes([payload[32], payload[33], payload[34], payload[35]]);
424
425 Ok(Self {
426 header,
427 ephemeral_public,
428 timestamp,
429 })
430 }
431}
432
433impl CloseFrame {
434 pub fn new(session_id: SessionId, nonce_counter: u64, final_ack: u64) -> Self {
436 Self {
437 header: DataFrameHeader::close(session_id, nonce_counter),
438 final_ack,
439 }
440 }
441
442 pub fn plaintext(&self) -> [u8; 8] {
444 self.final_ack.to_le_bytes()
445 }
446
447 pub fn aad(&self) -> [u8; sizes::DATA_FRAME_HEADER_SIZE] {
449 self.header.to_bytes()
450 }
451}
452
453#[derive(Debug, Error)]
455pub enum FrameError {
456 #[error("frame too short: expected at least {expected} bytes, got {actual}")]
458 TooShort {
459 expected: usize,
461 actual: usize,
463 },
464
465 #[error("invalid frame type: 0x{0:02x}")]
467 InvalidType(u8),
468
469 #[error("invalid flags: 0x{0:02x} (reserved bits must be 0)")]
471 InvalidFlags(u8),
472
473 #[error("payload length mismatch: header says {expected}, but {actual} bytes available")]
475 PayloadLengthMismatch {
476 expected: usize,
478 actual: usize,
480 },
481}
482
483pub fn parse_frame_header(data: &[u8]) -> Result<DataFrameHeader, FrameError> {
488 if data.len() < sizes::MIN_FRAME_SIZE {
489 return Err(FrameError::TooShort {
490 expected: sizes::MIN_FRAME_SIZE,
491 actual: data.len(),
492 });
493 }
494 DataFrameHeader::from_bytes(data)
495}
496
497pub fn parse_payload(data: &[u8]) -> Result<(PayloadHeader, &[u8]), FrameError> {
499 let header = PayloadHeader::from_bytes(data)?;
500 let sync_start = sizes::PAYLOAD_HEADER_SIZE;
501 let sync_end = sync_start + header.payload_length as usize;
502
503 if data.len() < sync_end {
504 return Err(FrameError::PayloadLengthMismatch {
505 expected: header.payload_length as usize,
506 actual: data.len() - sizes::PAYLOAD_HEADER_SIZE,
507 });
508 }
509
510 Ok((header, &data[sync_start..sync_end]))
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_frame_type_roundtrip() {
519 for t in [
520 FrameType::HandshakeInit,
521 FrameType::HandshakeResp,
522 FrameType::Data,
523 FrameType::Rekey,
524 FrameType::Close,
525 ] {
526 assert_eq!(FrameType::from_byte(t.as_byte()), Some(t));
527 }
528 assert_eq!(FrameType::from_byte(0x00), None);
529 assert_eq!(FrameType::from_byte(0xFF), None);
530 }
531
532 #[test]
533 fn test_frame_flags() {
534 let flags = FrameFlags::NONE;
535 assert!(!flags.is_ack_only());
536 assert!(!flags.has_extension());
537 assert!(flags.is_valid());
538
539 let flags = FrameFlags::ACK_ONLY;
540 assert!(flags.is_ack_only());
541 assert!(!flags.has_extension());
542 assert!(flags.is_valid());
543
544 let flags = FrameFlags::NONE.with_ack_only().with_extension();
545 assert!(flags.is_ack_only());
546 assert!(flags.has_extension());
547 assert!(flags.is_valid());
548
549 let invalid = FrameFlags::from_byte(0x04);
551 assert!(!invalid.is_valid());
552 }
553
554 #[test]
555 fn test_session_id() {
556 let id = SessionId::from_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]);
557 assert_eq!(id.as_bytes(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06]);
558 }
559
560 #[test]
561 fn test_data_frame_header_roundtrip() {
562 let header = DataFrameHeader {
563 frame_type: FrameType::Data,
564 flags: FrameFlags::ACK_ONLY,
565 session_id: SessionId::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66]),
566 nonce_counter: 0x123456789ABCDEF0,
567 };
568
569 let bytes = header.to_bytes();
570 assert_eq!(bytes.len(), sizes::DATA_FRAME_HEADER_SIZE);
571
572 let parsed = DataFrameHeader::from_bytes(&bytes).unwrap();
573 assert_eq!(parsed.frame_type, header.frame_type);
574 assert_eq!(parsed.flags, header.flags);
575 assert_eq!(parsed.session_id, header.session_id);
576 assert_eq!(parsed.nonce_counter, header.nonce_counter);
577 }
578
579 #[test]
580 fn test_payload_header_roundtrip() {
581 let header = PayloadHeader {
582 timestamp: 0x12345678,
583 timestamp_echo: 0xABCDEF01,
584 payload_length: 256,
585 };
586
587 let bytes = header.to_bytes();
588 assert_eq!(bytes.len(), sizes::PAYLOAD_HEADER_SIZE);
589
590 let parsed = PayloadHeader::from_bytes(&bytes).unwrap();
591 assert_eq!(parsed.timestamp, header.timestamp);
592 assert_eq!(parsed.timestamp_echo, header.timestamp_echo);
593 assert_eq!(parsed.payload_length, header.payload_length);
594 }
595
596 #[test]
597 fn test_data_frame_plaintext() {
598 let frame = DataFrame::new(
599 SessionId::zero(),
600 1,
601 1000,
602 500,
603 vec![0x01, 0x02, 0x03, 0x04],
604 );
605
606 let plaintext = frame.plaintext();
607 assert_eq!(plaintext.len(), sizes::PAYLOAD_HEADER_SIZE + 4);
609 }
610
611 #[test]
612 fn test_ack_only_frame() {
613 let frame = DataFrame::ack_only(SessionId::zero(), 1, 1000, 500);
614
615 assert!(frame.header.flags.is_ack_only());
616 assert!(frame.sync_message.is_empty());
617 assert_eq!(frame.payload_header.payload_length, 0);
618 }
619
620 #[test]
621 fn test_close_frame() {
622 let frame = CloseFrame::new(SessionId::zero(), 100, 12345);
623
624 assert_eq!(frame.header.frame_type, FrameType::Close);
625 assert_eq!(frame.final_ack, 12345);
626
627 let plaintext = frame.plaintext();
628 assert_eq!(plaintext, 12345u64.to_le_bytes());
629 }
630
631 #[test]
632 fn test_parse_too_short() {
633 let data = [0u8; 10]; assert!(matches!(
635 parse_frame_header(&data),
636 Err(FrameError::TooShort { .. })
637 ));
638 }
639
640 #[test]
641 fn test_parse_invalid_type() {
642 let mut data = [0u8; sizes::MIN_FRAME_SIZE];
643 data[0] = 0xFF; assert!(matches!(
645 parse_frame_header(&data),
646 Err(FrameError::InvalidType(0xFF))
647 ));
648 }
649}