1use std::num::NonZeroU16;
11
12use bytes::{Buf, BufMut, Bytes, BytesMut};
13use iroh_base::{EndpointId, KeyParsingError};
14use n0_error::{e, ensure, stack_error};
15use n0_future::time::Duration;
16
17use super::common::{FrameType, FrameTypeError};
18use crate::KeyCache;
19
20pub const MAX_PACKET_SIZE: usize = 64 * 1024;
24
25#[cfg(not(wasm_browser))]
29pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024;
30
31#[cfg(feature = "server")]
36pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15);
37
38#[cfg(feature = "server")]
40pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512;
41
42#[stack_error(derive, add_meta, from_sources)]
44#[allow(missing_docs)]
45#[non_exhaustive]
46pub enum Error {
47 #[error("unexpected frame: got {got:?}, expected {expected:?}")]
48 UnexpectedFrame { got: FrameType, expected: FrameType },
49 #[error("Frame is too large, has {frame_len} bytes")]
50 FrameTooLarge { frame_len: usize },
51 #[error(transparent)]
52 SerDe {
53 #[error(std_err)]
54 source: postcard::Error,
55 },
56 #[error(transparent)]
57 FrameTypeError { source: FrameTypeError },
58 #[error("Invalid public key")]
59 InvalidPublicKey { source: KeyParsingError },
60 #[error("Invalid frame encoding")]
61 InvalidFrame {},
62 #[error("Invalid frame type: {frame_type:?}")]
63 InvalidFrameType { frame_type: FrameType },
64 #[error("Invalid protocol message encoding")]
65 InvalidProtocolMessageEncoding {
66 #[error(std_err)]
67 source: std::str::Utf8Error,
68 },
69 #[error("Too few bytes")]
70 TooSmall {},
71}
72
73#[derive(Debug, Clone, PartialEq, Eq)]
75pub enum RelayToClientMsg {
76 Datagrams {
78 remote_endpoint_id: EndpointId,
80 datagrams: Datagrams,
82 },
83 EndpointGone(EndpointId),
86 Health {
88 problem: String,
95 },
96 Restarting {
98 reconnect_in: Duration,
101 try_for: Duration,
106 },
107 Ping([u8; 8]),
110 Pong([u8; 8]),
113}
114
115#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum ClientToRelayMsg {
118 Ping([u8; 8]),
121 Pong([u8; 8]),
124 Datagrams {
126 dst_endpoint_id: EndpointId,
128 datagrams: Datagrams,
130 },
131}
132
133#[derive(derive_more::Debug, Clone, PartialEq, Eq)]
138pub struct Datagrams {
139 pub ecn: Option<quinn_proto::EcnCodepoint>,
141 pub segment_size: Option<NonZeroU16>,
144 #[debug(skip)]
146 pub contents: Bytes,
147}
148
149impl<T: AsRef<[u8]>> From<T> for Datagrams {
150 fn from(bytes: T) -> Self {
151 Self {
152 ecn: None,
153 segment_size: None,
154 contents: Bytes::copy_from_slice(bytes.as_ref()),
155 }
156 }
157}
158
159impl Datagrams {
160 pub fn take_segments(&mut self, num_segments: usize) -> Datagrams {
173 let Some(segment_size) = self.segment_size else {
174 let contents = std::mem::take(&mut self.contents);
175 return Datagrams {
176 ecn: self.ecn,
177 segment_size: None,
178 contents,
179 };
180 };
181
182 let usize_segment_size = usize::from(u16::from(segment_size));
183 let max_content_len = num_segments * usize_segment_size;
184 let contents = self
185 .contents
186 .split_to(std::cmp::min(max_content_len, self.contents.len()));
187
188 let is_datagram_batch = num_segments > 1 && usize_segment_size < contents.len();
189
190 if self.contents.len() <= usize_segment_size {
193 self.segment_size = None;
194 }
195
196 Datagrams {
197 ecn: self.ecn,
198 segment_size: is_datagram_batch.then_some(segment_size),
199 contents,
200 }
201 }
202
203 fn write_to<O: BufMut>(&self, mut dst: O) -> O {
204 let ecn = self.ecn.map_or(0, |ecn| ecn as u8);
205 dst.put_u8(ecn);
206 if let Some(segment_size) = self.segment_size {
207 dst.put_u16(segment_size.into());
208 }
209 dst.put(self.contents.as_ref());
210 dst
211 }
212
213 fn encoded_len(&self) -> usize {
214 1 + self.segment_size.map_or(0, |_| 2) + self.contents.len()
217 }
218
219 #[allow(clippy::len_zero, clippy::result_large_err)]
220 fn from_bytes(mut bytes: Bytes, is_batch: bool) -> Result<Self, Error> {
221 if is_batch {
222 ensure!(bytes.len() >= 3, Error::InvalidFrame);
224 } else {
225 ensure!(bytes.len() >= 1, Error::InvalidFrame);
226 }
227
228 let ecn_byte = bytes.get_u8();
229 let ecn = quinn_proto::EcnCodepoint::from_bits(ecn_byte);
230
231 let segment_size = if is_batch {
232 let segment_size = bytes.get_u16(); NonZeroU16::new(segment_size)
234 } else {
235 None
236 };
237
238 Ok(Self {
239 ecn,
240 segment_size,
241 contents: bytes,
242 })
243 }
244}
245
246impl RelayToClientMsg {
247 pub fn typ(&self) -> FrameType {
249 match self {
250 Self::Datagrams { datagrams, .. } => {
251 if datagrams.segment_size.is_some() {
252 FrameType::RelayToClientDatagramBatch
253 } else {
254 FrameType::RelayToClientDatagram
255 }
256 }
257 Self::EndpointGone { .. } => FrameType::EndpointGone,
258 Self::Ping { .. } => FrameType::Ping,
259 Self::Pong { .. } => FrameType::Pong,
260 Self::Health { .. } => FrameType::Health,
261 Self::Restarting { .. } => FrameType::Restarting,
262 }
263 }
264
265 #[cfg(feature = "server")]
266 pub(crate) fn to_bytes(&self) -> BytesMut {
267 self.write_to(BytesMut::with_capacity(self.encoded_len()))
268 }
269
270 #[cfg(feature = "server")]
274 pub(crate) fn write_to<O: BufMut>(&self, mut dst: O) -> O {
275 dst = self.typ().write_to(dst);
276 match self {
277 Self::Datagrams {
278 remote_endpoint_id,
279 datagrams,
280 } => {
281 dst.put(remote_endpoint_id.as_ref());
282 dst = datagrams.write_to(dst);
283 }
284 Self::EndpointGone(endpoint_id) => {
285 dst.put(endpoint_id.as_ref());
286 }
287 Self::Ping(data) => {
288 dst.put(&data[..]);
289 }
290 Self::Pong(data) => {
291 dst.put(&data[..]);
292 }
293 Self::Health { problem } => {
294 dst.put(problem.as_ref());
295 }
296 Self::Restarting {
297 reconnect_in,
298 try_for,
299 } => {
300 dst.put_u32(reconnect_in.as_millis() as u32);
301 dst.put_u32(try_for.as_millis() as u32);
302 }
303 }
304 dst
305 }
306
307 #[cfg(feature = "server")]
308 pub(crate) fn encoded_len(&self) -> usize {
309 let payload_len = match self {
310 Self::Datagrams { datagrams, .. } => {
311 32 + datagrams.encoded_len()
313 }
314 Self::EndpointGone(_) => 32,
315 Self::Ping(_) | Self::Pong(_) => 8,
316 Self::Health { problem } => problem.len(),
317 Self::Restarting { .. } => {
318 4 + 4 }
321 };
322 self.typ().encoded_len() + payload_len
323 }
324
325 #[allow(clippy::result_large_err)]
329 pub(crate) fn from_bytes(mut content: Bytes, cache: &KeyCache) -> Result<Self, Error> {
330 let frame_type = FrameType::from_bytes(&mut content)?;
331 let frame_len = content.len();
332 ensure!(
333 frame_len <= MAX_PACKET_SIZE,
334 Error::FrameTooLarge { frame_len }
335 );
336
337 let res = match frame_type {
338 FrameType::RelayToClientDatagram | FrameType::RelayToClientDatagramBatch => {
339 ensure!(content.len() >= EndpointId::LENGTH, Error::InvalidFrame);
340
341 let remote_endpoint_id = cache.key_from_slice(&content[..EndpointId::LENGTH])?;
342 let datagrams = Datagrams::from_bytes(
343 content.slice(EndpointId::LENGTH..),
344 frame_type == FrameType::RelayToClientDatagramBatch,
345 )?;
346 Self::Datagrams {
347 remote_endpoint_id,
348 datagrams,
349 }
350 }
351 FrameType::EndpointGone => {
352 ensure!(content.len() == EndpointId::LENGTH, Error::InvalidFrame);
353 let endpoint_id = cache.key_from_slice(content.as_ref())?;
354 Self::EndpointGone(endpoint_id)
355 }
356 FrameType::Ping => {
357 ensure!(content.len() == 8, Error::InvalidFrame);
358 let mut data = [0u8; 8];
359 data.copy_from_slice(&content[..8]);
360 Self::Ping(data)
361 }
362 FrameType::Pong => {
363 ensure!(content.len() == 8, Error::InvalidFrame);
364 let mut data = [0u8; 8];
365 data.copy_from_slice(&content[..8]);
366 Self::Pong(data)
367 }
368 FrameType::Health => {
369 let problem = std::str::from_utf8(&content)?.to_owned();
370 Self::Health { problem }
371 }
372 FrameType::Restarting => {
373 ensure!(content.len() == 4 + 4, Error::InvalidFrame);
374 let reconnect_in = u32::from_be_bytes(
375 content[..4]
376 .try_into()
377 .map_err(|_| e!(Error::InvalidFrame))?,
378 );
379 let try_for = u32::from_be_bytes(
380 content[4..]
381 .try_into()
382 .map_err(|_| e!(Error::InvalidFrame))?,
383 );
384 let reconnect_in = Duration::from_millis(reconnect_in as u64);
385 let try_for = Duration::from_millis(try_for as u64);
386 Self::Restarting {
387 reconnect_in,
388 try_for,
389 }
390 }
391 _ => {
392 return Err(e!(Error::InvalidFrameType { frame_type }));
393 }
394 };
395 Ok(res)
396 }
397}
398
399impl ClientToRelayMsg {
400 pub(crate) fn typ(&self) -> FrameType {
401 match self {
402 Self::Datagrams { datagrams, .. } => {
403 if datagrams.segment_size.is_some() {
404 FrameType::ClientToRelayDatagramBatch
405 } else {
406 FrameType::ClientToRelayDatagram
407 }
408 }
409 Self::Ping { .. } => FrameType::Ping,
410 Self::Pong { .. } => FrameType::Pong,
411 }
412 }
413
414 pub(crate) fn to_bytes(&self) -> BytesMut {
415 self.write_to(BytesMut::with_capacity(self.encoded_len()))
416 }
417
418 pub(crate) fn write_to<O: BufMut>(&self, mut dst: O) -> O {
422 dst = self.typ().write_to(dst);
423 match self {
424 Self::Datagrams {
425 dst_endpoint_id,
426 datagrams,
427 } => {
428 dst.put(dst_endpoint_id.as_ref());
429 dst = datagrams.write_to(dst);
430 }
431 Self::Ping(data) => {
432 dst.put(&data[..]);
433 }
434 Self::Pong(data) => {
435 dst.put(&data[..]);
436 }
437 }
438 dst
439 }
440
441 pub(crate) fn encoded_len(&self) -> usize {
442 let payload_len = match self {
443 Self::Ping(_) | Self::Pong(_) => 8,
444 Self::Datagrams { datagrams, .. } => {
445 32 + datagrams.encoded_len()
447 }
448 };
449 self.typ().encoded_len() + payload_len
450 }
451
452 #[allow(clippy::result_large_err)]
456 #[cfg(feature = "server")]
457 pub(crate) fn from_bytes(mut content: Bytes, cache: &KeyCache) -> Result<Self, Error> {
458 let frame_type = FrameType::from_bytes(&mut content)?;
459 let frame_len = content.len();
460 ensure!(
461 frame_len <= MAX_PACKET_SIZE,
462 Error::FrameTooLarge { frame_len }
463 );
464
465 let res = match frame_type {
466 FrameType::ClientToRelayDatagram | FrameType::ClientToRelayDatagramBatch => {
467 let dst_endpoint_id = cache.key_from_slice(&content[..EndpointId::LENGTH])?;
468 let datagrams = Datagrams::from_bytes(
469 content.slice(EndpointId::LENGTH..),
470 frame_type == FrameType::ClientToRelayDatagramBatch,
471 )?;
472 Self::Datagrams {
473 dst_endpoint_id,
474 datagrams,
475 }
476 }
477 FrameType::Ping => {
478 ensure!(content.len() == 8, Error::InvalidFrame);
479 let mut data = [0u8; 8];
480 data.copy_from_slice(&content[..8]);
481 Self::Ping(data)
482 }
483 FrameType::Pong => {
484 ensure!(content.len() == 8, Error::InvalidFrame);
485 let mut data = [0u8; 8];
486 data.copy_from_slice(&content[..8]);
487 Self::Pong(data)
488 }
489 _ => {
490 return Err(e!(Error::InvalidFrameType { frame_type }));
491 }
492 };
493 Ok(res)
494 }
495}
496
497#[cfg(test)]
498#[cfg(feature = "server")]
499mod tests {
500 use data_encoding::HEXLOWER;
501 use iroh_base::SecretKey;
502 use n0_error::Result;
503
504 use super::*;
505
506 fn check_expected_bytes(frames: Vec<(Vec<u8>, &str)>) {
507 for (bytes, expected_hex) in frames {
508 let stripped: Vec<u8> = expected_hex
509 .chars()
510 .filter_map(|s| {
511 if s.is_ascii_whitespace() {
512 None
513 } else {
514 Some(s as u8)
515 }
516 })
517 .collect();
518 let expected_bytes = HEXLOWER.decode(&stripped).unwrap();
519 assert_eq!(HEXLOWER.encode(&bytes), HEXLOWER.encode(&expected_bytes));
520 }
521 }
522
523 #[test]
524 fn test_server_client_frames_snapshot() -> Result {
525 let client_key = SecretKey::from_bytes(&[42u8; 32]);
526
527 check_expected_bytes(vec![
528 (
529 RelayToClientMsg::Health {
530 problem: "Hello? Yes this is dog.".into(),
531 }
532 .write_to(Vec::new()),
533 "0b 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73
534 20 69 73 20 64 6f 67 2e",
535 ),
536 (
537 RelayToClientMsg::EndpointGone(client_key.public()).write_to(Vec::new()),
538 "08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e
539 a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d
540 61",
541 ),
542 (
543 RelayToClientMsg::Ping([42u8; 8]).write_to(Vec::new()),
544 "09 2a 2a 2a 2a 2a 2a 2a 2a",
545 ),
546 (
547 RelayToClientMsg::Pong([42u8; 8]).write_to(Vec::new()),
548 "0a 2a 2a 2a 2a 2a 2a 2a 2a",
549 ),
550 (
551 RelayToClientMsg::Datagrams {
552 remote_endpoint_id: client_key.public(),
553 datagrams: Datagrams {
554 ecn: Some(quinn::EcnCodepoint::Ce),
555 segment_size: NonZeroU16::new(6),
556 contents: "Hello World!".into(),
557 },
558 }
559 .write_to(Vec::new()),
560 "07
567 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
568 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
569 03
570 00 06
571 48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
572 ),
573 (
574 RelayToClientMsg::Datagrams {
575 remote_endpoint_id: client_key.public(),
576 datagrams: Datagrams {
577 ecn: Some(quinn::EcnCodepoint::Ce),
578 segment_size: None,
579 contents: "Hello World!".into(),
580 },
581 }
582 .write_to(Vec::new()),
583 "06
589 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
590 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
591 03
592 48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
593 ),
594 (
595 RelayToClientMsg::Restarting {
596 reconnect_in: Duration::from_millis(10),
597 try_for: Duration::from_millis(20),
598 }
599 .write_to(Vec::new()),
600 "0c 00 00 00 0a 00 00 00 14",
601 ),
602 ]);
603
604 Ok(())
605 }
606
607 #[test]
608 fn test_client_server_frames_snapshot() -> Result {
609 let client_key = SecretKey::from_bytes(&[42u8; 32]);
610
611 check_expected_bytes(vec![
612 (
613 ClientToRelayMsg::Ping([42u8; 8]).write_to(Vec::new()),
614 "09 2a 2a 2a 2a 2a 2a 2a 2a",
615 ),
616 (
617 ClientToRelayMsg::Pong([42u8; 8]).write_to(Vec::new()),
618 "0a 2a 2a 2a 2a 2a 2a 2a 2a",
619 ),
620 (
621 ClientToRelayMsg::Datagrams {
622 dst_endpoint_id: client_key.public(),
623 datagrams: Datagrams {
624 ecn: Some(quinn::EcnCodepoint::Ce),
625 segment_size: NonZeroU16::new(6),
626 contents: "Hello World!".into(),
627 },
628 }
629 .write_to(Vec::new()),
630 "05
637 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
638 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
639 03
640 00 06
641 48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
642 ),
643 (
644 ClientToRelayMsg::Datagrams {
645 dst_endpoint_id: client_key.public(),
646 datagrams: Datagrams {
647 ecn: Some(quinn::EcnCodepoint::Ce),
648 segment_size: None,
649 contents: "Hello World!".into(),
650 },
651 }
652 .write_to(Vec::new()),
653 "04
659 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
660 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
661 03
662 48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
663 ),
664 ]);
665
666 Ok(())
667 }
668}
669
670#[cfg(all(test, feature = "server"))]
671mod proptests {
672 use iroh_base::SecretKey;
673 use proptest::prelude::*;
674
675 use super::*;
676
677 fn secret_key() -> impl Strategy<Value = SecretKey> {
678 prop::array::uniform32(any::<u8>()).prop_map(SecretKey::from)
679 }
680
681 fn key() -> impl Strategy<Value = EndpointId> {
682 secret_key().prop_map(|key| key.public())
683 }
684
685 fn ecn() -> impl Strategy<Value = Option<quinn_proto::EcnCodepoint>> {
686 (0..=3).prop_map(|n| match n {
687 1 => Some(quinn_proto::EcnCodepoint::Ce),
688 2 => Some(quinn_proto::EcnCodepoint::Ect0),
689 3 => Some(quinn_proto::EcnCodepoint::Ect1),
690 _ => None,
691 })
692 }
693
694 fn datagrams() -> impl Strategy<Value = Datagrams> {
695 const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - EndpointId::LENGTH - 1 - 2 ;
697 (
698 ecn(),
699 prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE),
700 prop::collection::vec(any::<u8>(), 0..MAX_PAYLOAD_SIZE),
701 )
702 .prop_map(|(ecn, segment_size, data)| Datagrams {
703 ecn,
704 segment_size: segment_size
705 .map(|ss| std::cmp::min(data.len(), ss) as u16)
706 .and_then(NonZeroU16::new),
707 contents: Bytes::from(data),
708 })
709 }
710
711 fn server_client_frame() -> impl Strategy<Value = RelayToClientMsg> {
713 let recv_packet = (key(), datagrams()).prop_map(|(remote_endpoint_id, datagrams)| {
714 RelayToClientMsg::Datagrams {
715 remote_endpoint_id,
716 datagrams,
717 }
718 });
719 let endpoint_gone = key().prop_map(RelayToClientMsg::EndpointGone);
720 let ping = prop::array::uniform8(any::<u8>()).prop_map(RelayToClientMsg::Ping);
721 let pong = prop::array::uniform8(any::<u8>()).prop_map(RelayToClientMsg::Pong);
722 let health = ".{0,65536}"
723 .prop_filter("exceeds MAX_PACKET_SIZE", |s| {
724 s.len() < MAX_PACKET_SIZE })
726 .prop_map(|problem| RelayToClientMsg::Health { problem });
727 let restarting = (any::<u32>(), any::<u32>()).prop_map(|(reconnect_in, try_for)| {
728 RelayToClientMsg::Restarting {
729 reconnect_in: Duration::from_millis(reconnect_in.into()),
730 try_for: Duration::from_millis(try_for.into()),
731 }
732 });
733 prop_oneof![recv_packet, endpoint_gone, ping, pong, health, restarting]
734 }
735
736 fn client_server_frame() -> impl Strategy<Value = ClientToRelayMsg> {
737 let send_packet = (key(), datagrams()).prop_map(|(dst_endpoint_id, datagrams)| {
738 ClientToRelayMsg::Datagrams {
739 dst_endpoint_id,
740 datagrams,
741 }
742 });
743 let ping = prop::array::uniform8(any::<u8>()).prop_map(ClientToRelayMsg::Ping);
744 let pong = prop::array::uniform8(any::<u8>()).prop_map(ClientToRelayMsg::Pong);
745 prop_oneof![send_packet, ping, pong]
746 }
747
748 proptest! {
749 #[test]
750 fn server_client_frame_roundtrip(frame in server_client_frame()) {
751 let encoded = frame.to_bytes().freeze();
752 let decoded = RelayToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap();
753 prop_assert_eq!(frame, decoded);
754 }
755
756 #[test]
757 fn client_server_frame_roundtrip(frame in client_server_frame()) {
758 let encoded = frame.to_bytes().freeze();
759 let decoded = ClientToRelayMsg::from_bytes(encoded, &KeyCache::test()).unwrap();
760 prop_assert_eq!(frame, decoded);
761 }
762
763 #[test]
764 fn server_client_frame_encoded_len(frame in server_client_frame()) {
765 let claimed_encoded_len = frame.encoded_len();
766 let actual_encoded_len = frame.to_bytes().len();
767 prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
768 }
769
770 #[test]
771 fn client_server_frame_encoded_len(frame in client_server_frame()) {
772 let claimed_encoded_len = frame.encoded_len();
773 let actual_encoded_len = frame.to_bytes().len();
774 prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
775 }
776
777 #[test]
778 fn datagrams_encoded_len(datagrams in datagrams()) {
779 let claimed_encoded_len = datagrams.encoded_len();
780 let actual_encoded_len = datagrams.write_to(Vec::new()).len();
781 prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
782 }
783 }
784}