1use thiserror::Error;
8
9pub mod sizes {
11 pub const AEAD_TAG_SIZE: usize = 16;
13 pub const SESSION_ID_SIZE: usize = 6;
15 pub const NONCE_COUNTER_SIZE: usize = 8;
17 pub const DATA_FRAME_HEADER_SIZE: usize = 1 + 1 + SESSION_ID_SIZE + NONCE_COUNTER_SIZE;
19 pub const MIN_FRAME_SIZE: usize = DATA_FRAME_HEADER_SIZE + AEAD_TAG_SIZE;
21 pub const PAYLOAD_HEADER_SIZE: usize = 4 + 4 + 2;
23 pub const DEFAULT_MAX_PAYLOAD: usize = 1200;
25}
26
27#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
29#[repr(u8)]
30pub enum FrameType {
31 HandshakeInit = 0x01,
33 HandshakeResp = 0x02,
35 Data = 0x03,
37 Rekey = 0x04,
39 Close = 0x05,
41}
42
43impl FrameType {
44 pub fn from_byte(byte: u8) -> Option<Self> {
46 match byte {
47 0x01 => Some(Self::HandshakeInit),
48 0x02 => Some(Self::HandshakeResp),
49 0x03 => Some(Self::Data),
50 0x04 => Some(Self::Rekey),
51 0x05 => Some(Self::Close),
52 _ => None,
53 }
54 }
55
56 pub fn as_byte(self) -> u8 {
58 self as u8
59 }
60}
61
62#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
64pub struct FrameFlags(u8);
65
66impl FrameFlags {
67 pub const NONE: Self = Self(0);
69 pub const ACK_ONLY: Self = Self(0x01);
71 pub const HAS_EXTENSION: Self = Self(0x02);
73
74 pub fn from_byte(byte: u8) -> Self {
76 Self(byte)
77 }
78
79 pub fn as_byte(self) -> u8 {
81 self.0
82 }
83
84 pub fn is_ack_only(self) -> bool {
86 self.0 & 0x01 != 0
87 }
88
89 pub fn has_extension(self) -> bool {
91 self.0 & 0x02 != 0
92 }
93
94 pub fn with_ack_only(self) -> Self {
96 Self(self.0 | 0x01)
97 }
98
99 pub fn with_extension(self) -> Self {
101 Self(self.0 | 0x02)
102 }
103
104 pub fn is_valid(self) -> bool {
106 self.0 & 0xFC == 0
107 }
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
112pub struct SessionId([u8; sizes::SESSION_ID_SIZE]);
113
114impl SessionId {
115 pub fn from_bytes(bytes: [u8; sizes::SESSION_ID_SIZE]) -> Self {
117 Self(bytes)
118 }
119
120 pub fn as_bytes(&self) -> &[u8; sizes::SESSION_ID_SIZE] {
122 &self.0
123 }
124
125 pub fn zero() -> Self {
127 Self([0u8; sizes::SESSION_ID_SIZE])
128 }
129}
130
131impl AsRef<[u8]> for SessionId {
132 fn as_ref(&self) -> &[u8] {
133 &self.0
134 }
135}
136
137#[derive(Debug, Clone, Copy, PartialEq, Eq)]
147pub struct DataFrameHeader {
148 pub frame_type: FrameType,
150 pub flags: FrameFlags,
152 pub session_id: SessionId,
154 pub nonce_counter: u64,
156}
157
158impl DataFrameHeader {
159 pub fn new(session_id: SessionId, nonce_counter: u64) -> Self {
161 Self {
162 frame_type: FrameType::Data,
163 flags: FrameFlags::NONE,
164 session_id,
165 nonce_counter,
166 }
167 }
168
169 pub fn close(session_id: SessionId, nonce_counter: u64) -> Self {
171 Self {
172 frame_type: FrameType::Close,
173 flags: FrameFlags::NONE,
174 session_id,
175 nonce_counter,
176 }
177 }
178
179 pub fn to_bytes(&self) -> [u8; sizes::DATA_FRAME_HEADER_SIZE] {
181 let mut buf = [0u8; sizes::DATA_FRAME_HEADER_SIZE];
182 buf[0] = self.frame_type.as_byte();
183 buf[1] = self.flags.as_byte();
184 buf[2..8].copy_from_slice(self.session_id.as_bytes());
185 buf[8..16].copy_from_slice(&self.nonce_counter.to_le_bytes());
186 buf
187 }
188
189 pub fn from_bytes(bytes: &[u8]) -> Result<Self, FrameError> {
191 if bytes.len() < sizes::DATA_FRAME_HEADER_SIZE {
192 return Err(FrameError::TooShort {
193 expected: sizes::DATA_FRAME_HEADER_SIZE,
194 actual: bytes.len(),
195 });
196 }
197
198 let frame_type = FrameType::from_byte(bytes[0]).ok_or(FrameError::InvalidType(bytes[0]))?;
199
200 let flags = FrameFlags::from_byte(bytes[1]);
201 if !flags.is_valid() {
202 return Err(FrameError::InvalidFlags(bytes[1]));
203 }
204
205 let mut session_id_bytes = [0u8; sizes::SESSION_ID_SIZE];
206 session_id_bytes.copy_from_slice(&bytes[2..8]);
207 let session_id = SessionId::from_bytes(session_id_bytes);
208
209 let mut nonce_bytes = [0u8; 8];
210 nonce_bytes.copy_from_slice(&bytes[8..16]);
211 let nonce_counter = u64::from_le_bytes(nonce_bytes);
212
213 Ok(Self {
214 frame_type,
215 flags,
216 session_id,
217 nonce_counter,
218 })
219 }
220}
221
222#[derive(Debug, Clone, Copy, PartialEq, Eq)]
232pub struct PayloadHeader {
233 pub timestamp: u32,
235 pub timestamp_echo: u32,
237 pub payload_length: u16,
239}
240
241impl PayloadHeader {
242 pub fn new(timestamp: u32, timestamp_echo: u32, payload_length: u16) -> Self {
244 Self {
245 timestamp,
246 timestamp_echo,
247 payload_length,
248 }
249 }
250
251 pub fn to_bytes(&self) -> [u8; sizes::PAYLOAD_HEADER_SIZE] {
253 let mut buf = [0u8; sizes::PAYLOAD_HEADER_SIZE];
254 buf[0..4].copy_from_slice(&self.timestamp.to_le_bytes());
255 buf[4..8].copy_from_slice(&self.timestamp_echo.to_le_bytes());
256 buf[8..10].copy_from_slice(&self.payload_length.to_le_bytes());
257 buf
258 }
259
260 pub fn from_bytes(bytes: &[u8]) -> Result<Self, FrameError> {
262 if bytes.len() < sizes::PAYLOAD_HEADER_SIZE {
263 return Err(FrameError::TooShort {
264 expected: sizes::PAYLOAD_HEADER_SIZE,
265 actual: bytes.len(),
266 });
267 }
268
269 let timestamp = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
270 let timestamp_echo = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
271 let payload_length = u16::from_le_bytes([bytes[8], bytes[9]]);
272
273 Ok(Self {
274 timestamp,
275 timestamp_echo,
276 payload_length,
277 })
278 }
279}
280
281#[derive(Debug, Clone)]
283pub struct DataFrame {
284 pub header: DataFrameHeader,
286 pub payload_header: PayloadHeader,
288 pub sync_message: Vec<u8>,
290}
291
292impl DataFrame {
293 pub fn new(
295 session_id: SessionId,
296 nonce_counter: u64,
297 timestamp: u32,
298 timestamp_echo: u32,
299 sync_message: Vec<u8>,
300 ) -> Self {
301 let payload_length = sync_message.len() as u16;
302 Self {
303 header: DataFrameHeader::new(session_id, nonce_counter),
304 payload_header: PayloadHeader::new(timestamp, timestamp_echo, payload_length),
305 sync_message,
306 }
307 }
308
309 pub fn ack_only(
311 session_id: SessionId,
312 nonce_counter: u64,
313 timestamp: u32,
314 timestamp_echo: u32,
315 ) -> Self {
316 let mut frame = Self::new(session_id, nonce_counter, timestamp, timestamp_echo, vec![]);
317 frame.header.flags = FrameFlags::ACK_ONLY;
318 frame
319 }
320
321 pub fn plaintext(&self) -> Vec<u8> {
323 let mut plaintext =
324 Vec::with_capacity(sizes::PAYLOAD_HEADER_SIZE + self.sync_message.len());
325 plaintext.extend_from_slice(&self.payload_header.to_bytes());
326 plaintext.extend_from_slice(&self.sync_message);
327 plaintext
328 }
329
330 pub fn aad(&self) -> [u8; sizes::DATA_FRAME_HEADER_SIZE] {
332 self.header.to_bytes()
333 }
334}
335
336#[derive(Debug, Clone, Copy)]
338pub struct CloseFrame {
339 pub header: DataFrameHeader,
341 pub final_ack: u64,
343}
344
345impl CloseFrame {
346 pub fn new(session_id: SessionId, nonce_counter: u64, final_ack: u64) -> Self {
348 Self {
349 header: DataFrameHeader::close(session_id, nonce_counter),
350 final_ack,
351 }
352 }
353
354 pub fn plaintext(&self) -> [u8; 8] {
356 self.final_ack.to_le_bytes()
357 }
358
359 pub fn aad(&self) -> [u8; sizes::DATA_FRAME_HEADER_SIZE] {
361 self.header.to_bytes()
362 }
363}
364
365#[derive(Debug, Error)]
367pub enum FrameError {
368 #[error("frame too short: expected at least {expected} bytes, got {actual}")]
370 TooShort {
371 expected: usize,
373 actual: usize,
375 },
376
377 #[error("invalid frame type: 0x{0:02x}")]
379 InvalidType(u8),
380
381 #[error("invalid flags: 0x{0:02x} (reserved bits must be 0)")]
383 InvalidFlags(u8),
384
385 #[error("payload length mismatch: header says {expected}, but {actual} bytes available")]
387 PayloadLengthMismatch {
388 expected: usize,
390 actual: usize,
392 },
393}
394
395pub fn parse_frame_header(data: &[u8]) -> Result<DataFrameHeader, FrameError> {
400 if data.len() < sizes::MIN_FRAME_SIZE {
401 return Err(FrameError::TooShort {
402 expected: sizes::MIN_FRAME_SIZE,
403 actual: data.len(),
404 });
405 }
406 DataFrameHeader::from_bytes(data)
407}
408
409pub fn parse_payload(data: &[u8]) -> Result<(PayloadHeader, &[u8]), FrameError> {
411 let header = PayloadHeader::from_bytes(data)?;
412 let sync_start = sizes::PAYLOAD_HEADER_SIZE;
413 let sync_end = sync_start + header.payload_length as usize;
414
415 if data.len() < sync_end {
416 return Err(FrameError::PayloadLengthMismatch {
417 expected: header.payload_length as usize,
418 actual: data.len() - sizes::PAYLOAD_HEADER_SIZE,
419 });
420 }
421
422 Ok((header, &data[sync_start..sync_end]))
423}
424
425#[cfg(test)]
426mod tests {
427 use super::*;
428
429 #[test]
430 fn test_frame_type_roundtrip() {
431 for t in [
432 FrameType::HandshakeInit,
433 FrameType::HandshakeResp,
434 FrameType::Data,
435 FrameType::Rekey,
436 FrameType::Close,
437 ] {
438 assert_eq!(FrameType::from_byte(t.as_byte()), Some(t));
439 }
440 assert_eq!(FrameType::from_byte(0x00), None);
441 assert_eq!(FrameType::from_byte(0xFF), None);
442 }
443
444 #[test]
445 fn test_frame_flags() {
446 let flags = FrameFlags::NONE;
447 assert!(!flags.is_ack_only());
448 assert!(!flags.has_extension());
449 assert!(flags.is_valid());
450
451 let flags = FrameFlags::ACK_ONLY;
452 assert!(flags.is_ack_only());
453 assert!(!flags.has_extension());
454 assert!(flags.is_valid());
455
456 let flags = FrameFlags::NONE.with_ack_only().with_extension();
457 assert!(flags.is_ack_only());
458 assert!(flags.has_extension());
459 assert!(flags.is_valid());
460
461 let invalid = FrameFlags::from_byte(0x04);
463 assert!(!invalid.is_valid());
464 }
465
466 #[test]
467 fn test_session_id() {
468 let id = SessionId::from_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]);
469 assert_eq!(id.as_bytes(), &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06]);
470 }
471
472 #[test]
473 fn test_data_frame_header_roundtrip() {
474 let header = DataFrameHeader {
475 frame_type: FrameType::Data,
476 flags: FrameFlags::ACK_ONLY,
477 session_id: SessionId::from_bytes([0x11, 0x22, 0x33, 0x44, 0x55, 0x66]),
478 nonce_counter: 0x123456789ABCDEF0,
479 };
480
481 let bytes = header.to_bytes();
482 assert_eq!(bytes.len(), sizes::DATA_FRAME_HEADER_SIZE);
483
484 let parsed = DataFrameHeader::from_bytes(&bytes).unwrap();
485 assert_eq!(parsed.frame_type, header.frame_type);
486 assert_eq!(parsed.flags, header.flags);
487 assert_eq!(parsed.session_id, header.session_id);
488 assert_eq!(parsed.nonce_counter, header.nonce_counter);
489 }
490
491 #[test]
492 fn test_payload_header_roundtrip() {
493 let header = PayloadHeader {
494 timestamp: 0x12345678,
495 timestamp_echo: 0xABCDEF01,
496 payload_length: 256,
497 };
498
499 let bytes = header.to_bytes();
500 assert_eq!(bytes.len(), sizes::PAYLOAD_HEADER_SIZE);
501
502 let parsed = PayloadHeader::from_bytes(&bytes).unwrap();
503 assert_eq!(parsed.timestamp, header.timestamp);
504 assert_eq!(parsed.timestamp_echo, header.timestamp_echo);
505 assert_eq!(parsed.payload_length, header.payload_length);
506 }
507
508 #[test]
509 fn test_data_frame_plaintext() {
510 let frame = DataFrame::new(
511 SessionId::zero(),
512 1,
513 1000,
514 500,
515 vec![0x01, 0x02, 0x03, 0x04],
516 );
517
518 let plaintext = frame.plaintext();
519 assert_eq!(plaintext.len(), sizes::PAYLOAD_HEADER_SIZE + 4);
521 }
522
523 #[test]
524 fn test_ack_only_frame() {
525 let frame = DataFrame::ack_only(SessionId::zero(), 1, 1000, 500);
526
527 assert!(frame.header.flags.is_ack_only());
528 assert!(frame.sync_message.is_empty());
529 assert_eq!(frame.payload_header.payload_length, 0);
530 }
531
532 #[test]
533 fn test_close_frame() {
534 let frame = CloseFrame::new(SessionId::zero(), 100, 12345);
535
536 assert_eq!(frame.header.frame_type, FrameType::Close);
537 assert_eq!(frame.final_ack, 12345);
538
539 let plaintext = frame.plaintext();
540 assert_eq!(plaintext, 12345u64.to_le_bytes());
541 }
542
543 #[test]
544 fn test_parse_too_short() {
545 let data = [0u8; 10]; assert!(matches!(
547 parse_frame_header(&data),
548 Err(FrameError::TooShort { .. })
549 ));
550 }
551
552 #[test]
553 fn test_parse_invalid_type() {
554 let mut data = [0u8; sizes::MIN_FRAME_SIZE];
555 data[0] = 0xFF; assert!(matches!(
557 parse_frame_header(&data),
558 Err(FrameError::InvalidType(0xFF))
559 ));
560 }
561}