iroh_relay/protos/
relay.rs

1//! This module implements the send/recv relaying protocol.
2//!
3//! Protocol flow:
4//!  * server occasionally sends [`FrameType::Ping`]
5//!  * client responds to any [`FrameType::Ping`] with a [`FrameType::Pong`]
6//!  * clients sends [`FrameType::ClientToRelayDatagram`] or [`FrameType::ClientToRelayDatagramBatch`]
7//!  * server then sends [`FrameType::RelayToClientDatagram`] or [`FrameType::RelayToClientDatagramBatch`] to recipient
8//!  * server sends [`FrameType::EndpointGone`] when the other client disconnects
9
10use std::num::NonZeroU16;
11
12use bytes::{Buf, BufMut, Bytes, BytesMut};
13use iroh_base::{EndpointId, KeyParsingError};
14use n0_error::{e, ensure, stack_error};
15use n0_future::time::Duration;
16
17use super::common::{FrameType, FrameTypeError};
18use crate::KeyCache;
19
20/// The maximum size of a packet sent over relay.
21/// (This only includes the data bytes visible to magicsock, not
22/// including its on-wire framing overhead)
23pub const MAX_PACKET_SIZE: usize = 64 * 1024;
24
25/// The maximum frame size.
26///
27/// This is also the minimum burst size that a rate-limiter has to accept.
28#[cfg(not(wasm_browser))]
29pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024;
30
31/// Interval in which we ping the relay server to ensure the connection is alive.
32///
33/// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some
34/// chance of recovering.
35#[cfg(feature = "server")]
36pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15);
37
38/// The number of packets buffered for sending per client
39#[cfg(feature = "server")]
40pub(crate) const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512;
41
42/// Protocol send errors.
43#[stack_error(derive, add_meta, from_sources)]
44#[allow(missing_docs)]
45#[non_exhaustive]
46pub enum Error {
47    #[error("unexpected frame: got {got:?}, expected {expected:?}")]
48    UnexpectedFrame { got: FrameType, expected: FrameType },
49    #[error("Frame is too large, has {frame_len} bytes")]
50    FrameTooLarge { frame_len: usize },
51    #[error(transparent)]
52    SerDe {
53        #[error(std_err)]
54        source: postcard::Error,
55    },
56    #[error(transparent)]
57    FrameTypeError { source: FrameTypeError },
58    #[error("Invalid public key")]
59    InvalidPublicKey { source: KeyParsingError },
60    #[error("Invalid frame encoding")]
61    InvalidFrame {},
62    #[error("Invalid frame type: {frame_type:?}")]
63    InvalidFrameType { frame_type: FrameType },
64    #[error("Invalid protocol message encoding")]
65    InvalidProtocolMessageEncoding {
66        #[error(std_err)]
67        source: std::str::Utf8Error,
68    },
69    #[error("Too few bytes")]
70    TooSmall {},
71}
72
73/// The messages that a relay sends to clients or the clients receive from the relay.
74#[derive(Debug, Clone, PartialEq, Eq)]
75pub enum RelayToClientMsg {
76    /// Represents datagrams sent from relays (originally sent to them by another client).
77    Datagrams {
78        /// The [`EndpointId`] of the original sender.
79        remote_endpoint_id: EndpointId,
80        /// The datagrams and related metadata.
81        datagrams: Datagrams,
82    },
83    /// Indicates that the client identified by the underlying public key had previously sent you a
84    /// packet but has now disconnected from the relay.
85    EndpointGone(EndpointId),
86    /// A one-way message from relay to client, declaring the connection health state.
87    Health {
88        /// If set, is a description of why the connection is unhealthy.
89        ///
90        /// If `None` means the connection is healthy again.
91        ///
92        /// The default condition is healthy, so the relay doesn't broadcast a [`RelayToClientMsg::Health`]
93        /// until a problem exists.
94        problem: String,
95    },
96    /// A one-way message from relay to client, advertising that the relay is restarting.
97    Restarting {
98        /// An advisory duration that the client should wait before attempting to reconnect.
99        /// It might be zero. It exists for the relay to smear out the reconnects.
100        reconnect_in: Duration,
101        /// An advisory duration for how long the client should attempt to reconnect
102        /// before giving up and proceeding with its normal connection failure logic. The interval
103        /// between retries is undefined for now. A relay should not send a `try_for` duration more
104        /// than a few seconds.
105        try_for: Duration,
106    },
107    /// Request from the relay to reply to the
108    /// other side with a [`ClientToRelayMsg::Pong`] with the given payload.
109    Ping([u8; 8]),
110    /// Reply to a [`ClientToRelayMsg::Ping`] from a client
111    /// with the payload sent previously in the ping.
112    Pong([u8; 8]),
113}
114
115/// Messages that clients send to relays.
116#[derive(Debug, Clone, PartialEq, Eq)]
117pub enum ClientToRelayMsg {
118    /// Request from the client to the server to reply to the
119    /// other side with a [`RelayToClientMsg::Pong`] with the given payload.
120    Ping([u8; 8]),
121    /// Reply to a [`RelayToClientMsg::Ping`] from a server
122    /// with the payload sent previously in the ping.
123    Pong([u8; 8]),
124    /// Request from the client to relay datagrams to given remote endpoint.
125    Datagrams {
126        /// The remote endpoint to relay to.
127        dst_endpoint_id: EndpointId,
128        /// The datagrams and related metadata to relay.
129        datagrams: Datagrams,
130    },
131}
132
133/// One or multiple datagrams being transferred via the relay.
134///
135/// This type is modeled after [`quinn_proto::Transmit`]
136/// (or even more similarly `quinn_udp::Transmit`, but we don't depend on that library here).
137#[derive(derive_more::Debug, Clone, PartialEq, Eq)]
138pub struct Datagrams {
139    /// Explicit congestion notification bits
140    pub ecn: Option<quinn_proto::EcnCodepoint>,
141    /// The segment size if this transmission contains multiple datagrams.
142    /// This is `None` if the transmit only contains a single datagram
143    pub segment_size: Option<NonZeroU16>,
144    /// The contents of the datagram(s)
145    #[debug(skip)]
146    pub contents: Bytes,
147}
148
149impl<T: AsRef<[u8]>> From<T> for Datagrams {
150    fn from(bytes: T) -> Self {
151        Self {
152            ecn: None,
153            segment_size: None,
154            contents: Bytes::copy_from_slice(bytes.as_ref()),
155        }
156    }
157}
158
159impl Datagrams {
160    /// Splits the current datagram into at maximum `num_segments` segments, returning
161    /// the batch with at most `num_segments` and leaving only the rest in `self`.
162    ///
163    /// Calling this on a datagram batch that only contains a single datagram (`segment_size == None`)
164    /// will result in returning essentially a clone of `self`, while making `self` empty afterwards.
165    ///
166    /// Calling this on a datagram batch with e.g. 15 datagrams with `num_segments == 10` will
167    /// result in returning a datagram batch that contains the first 10 datagrams and leave `self`
168    /// containing the remaining 5 datagrams.
169    ///
170    /// Calling this on a datagram batch with less than `num_segments` datagrams will result in
171    /// making `self` empty and returning essentially a clone of `self`.
172    pub fn take_segments(&mut self, num_segments: usize) -> Datagrams {
173        let Some(segment_size) = self.segment_size else {
174            let contents = std::mem::take(&mut self.contents);
175            return Datagrams {
176                ecn: self.ecn,
177                segment_size: None,
178                contents,
179            };
180        };
181
182        let usize_segment_size = usize::from(u16::from(segment_size));
183        let max_content_len = num_segments * usize_segment_size;
184        let contents = self
185            .contents
186            .split_to(std::cmp::min(max_content_len, self.contents.len()));
187
188        let is_datagram_batch = num_segments > 1 && usize_segment_size < contents.len();
189
190        // If this left our batch with only one more datagram, then remove the segment size
191        // to uphold the invariant that single-datagram batches don't have a segment size set.
192        if self.contents.len() <= usize_segment_size {
193            self.segment_size = None;
194        }
195
196        Datagrams {
197            ecn: self.ecn,
198            segment_size: is_datagram_batch.then_some(segment_size),
199            contents,
200        }
201    }
202
203    fn write_to<O: BufMut>(&self, mut dst: O) -> O {
204        let ecn = self.ecn.map_or(0, |ecn| ecn as u8);
205        dst.put_u8(ecn);
206        if let Some(segment_size) = self.segment_size {
207            dst.put_u16(segment_size.into());
208        }
209        dst.put(self.contents.as_ref());
210        dst
211    }
212
213    fn encoded_len(&self) -> usize {
214        1 // ECN byte
215        + self.segment_size.map_or(0, |_| 2) // segment size, when None, then a packed representation is assumed
216        + self.contents.len()
217    }
218
219    #[allow(clippy::len_zero, clippy::result_large_err)]
220    fn from_bytes(mut bytes: Bytes, is_batch: bool) -> Result<Self, Error> {
221        if is_batch {
222            // 1 bytes ECN, 2 bytes segment size
223            ensure!(bytes.len() >= 3, Error::InvalidFrame);
224        } else {
225            ensure!(bytes.len() >= 1, Error::InvalidFrame);
226        }
227
228        let ecn_byte = bytes.get_u8();
229        let ecn = quinn_proto::EcnCodepoint::from_bits(ecn_byte);
230
231        let segment_size = if is_batch {
232            let segment_size = bytes.get_u16(); // length checked above
233            NonZeroU16::new(segment_size)
234        } else {
235            None
236        };
237
238        Ok(Self {
239            ecn,
240            segment_size,
241            contents: bytes,
242        })
243    }
244}
245
246impl RelayToClientMsg {
247    /// Returns this frame's corresponding frame type.
248    pub fn typ(&self) -> FrameType {
249        match self {
250            Self::Datagrams { datagrams, .. } => {
251                if datagrams.segment_size.is_some() {
252                    FrameType::RelayToClientDatagramBatch
253                } else {
254                    FrameType::RelayToClientDatagram
255                }
256            }
257            Self::EndpointGone { .. } => FrameType::EndpointGone,
258            Self::Ping { .. } => FrameType::Ping,
259            Self::Pong { .. } => FrameType::Pong,
260            Self::Health { .. } => FrameType::Health,
261            Self::Restarting { .. } => FrameType::Restarting,
262        }
263    }
264
265    #[cfg(feature = "server")]
266    pub(crate) fn to_bytes(&self) -> BytesMut {
267        self.write_to(BytesMut::with_capacity(self.encoded_len()))
268    }
269
270    /// Encodes this frame for sending over websockets.
271    ///
272    /// Specifically meant for being put into a binary websocket message frame.
273    #[cfg(feature = "server")]
274    pub(crate) fn write_to<O: BufMut>(&self, mut dst: O) -> O {
275        dst = self.typ().write_to(dst);
276        match self {
277            Self::Datagrams {
278                remote_endpoint_id,
279                datagrams,
280            } => {
281                dst.put(remote_endpoint_id.as_ref());
282                dst = datagrams.write_to(dst);
283            }
284            Self::EndpointGone(endpoint_id) => {
285                dst.put(endpoint_id.as_ref());
286            }
287            Self::Ping(data) => {
288                dst.put(&data[..]);
289            }
290            Self::Pong(data) => {
291                dst.put(&data[..]);
292            }
293            Self::Health { problem } => {
294                dst.put(problem.as_ref());
295            }
296            Self::Restarting {
297                reconnect_in,
298                try_for,
299            } => {
300                dst.put_u32(reconnect_in.as_millis() as u32);
301                dst.put_u32(try_for.as_millis() as u32);
302            }
303        }
304        dst
305    }
306
307    #[cfg(feature = "server")]
308    pub(crate) fn encoded_len(&self) -> usize {
309        let payload_len = match self {
310            Self::Datagrams { datagrams, .. } => {
311                32 // endpointid
312                + datagrams.encoded_len()
313            }
314            Self::EndpointGone(_) => 32,
315            Self::Ping(_) | Self::Pong(_) => 8,
316            Self::Health { problem } => problem.len(),
317            Self::Restarting { .. } => {
318                4 // u32
319                + 4 // u32
320            }
321        };
322        self.typ().encoded_len() + payload_len
323    }
324
325    /// Tries to decode a frame received over websockets.
326    ///
327    /// Specifically, bytes received from a binary websocket message frame.
328    #[allow(clippy::result_large_err)]
329    pub(crate) fn from_bytes(mut content: Bytes, cache: &KeyCache) -> Result<Self, Error> {
330        let frame_type = FrameType::from_bytes(&mut content)?;
331        let frame_len = content.len();
332        ensure!(
333            frame_len <= MAX_PACKET_SIZE,
334            Error::FrameTooLarge { frame_len }
335        );
336
337        let res = match frame_type {
338            FrameType::RelayToClientDatagram | FrameType::RelayToClientDatagramBatch => {
339                ensure!(content.len() >= EndpointId::LENGTH, Error::InvalidFrame);
340
341                let remote_endpoint_id = cache.key_from_slice(&content[..EndpointId::LENGTH])?;
342                let datagrams = Datagrams::from_bytes(
343                    content.slice(EndpointId::LENGTH..),
344                    frame_type == FrameType::RelayToClientDatagramBatch,
345                )?;
346                Self::Datagrams {
347                    remote_endpoint_id,
348                    datagrams,
349                }
350            }
351            FrameType::EndpointGone => {
352                ensure!(content.len() == EndpointId::LENGTH, Error::InvalidFrame);
353                let endpoint_id = cache.key_from_slice(content.as_ref())?;
354                Self::EndpointGone(endpoint_id)
355            }
356            FrameType::Ping => {
357                ensure!(content.len() == 8, Error::InvalidFrame);
358                let mut data = [0u8; 8];
359                data.copy_from_slice(&content[..8]);
360                Self::Ping(data)
361            }
362            FrameType::Pong => {
363                ensure!(content.len() == 8, Error::InvalidFrame);
364                let mut data = [0u8; 8];
365                data.copy_from_slice(&content[..8]);
366                Self::Pong(data)
367            }
368            FrameType::Health => {
369                let problem = std::str::from_utf8(&content)?.to_owned();
370                Self::Health { problem }
371            }
372            FrameType::Restarting => {
373                ensure!(content.len() == 4 + 4, Error::InvalidFrame);
374                let reconnect_in = u32::from_be_bytes(
375                    content[..4]
376                        .try_into()
377                        .map_err(|_| e!(Error::InvalidFrame))?,
378                );
379                let try_for = u32::from_be_bytes(
380                    content[4..]
381                        .try_into()
382                        .map_err(|_| e!(Error::InvalidFrame))?,
383                );
384                let reconnect_in = Duration::from_millis(reconnect_in as u64);
385                let try_for = Duration::from_millis(try_for as u64);
386                Self::Restarting {
387                    reconnect_in,
388                    try_for,
389                }
390            }
391            _ => {
392                return Err(e!(Error::InvalidFrameType { frame_type }));
393            }
394        };
395        Ok(res)
396    }
397}
398
399impl ClientToRelayMsg {
400    pub(crate) fn typ(&self) -> FrameType {
401        match self {
402            Self::Datagrams { datagrams, .. } => {
403                if datagrams.segment_size.is_some() {
404                    FrameType::ClientToRelayDatagramBatch
405                } else {
406                    FrameType::ClientToRelayDatagram
407                }
408            }
409            Self::Ping { .. } => FrameType::Ping,
410            Self::Pong { .. } => FrameType::Pong,
411        }
412    }
413
414    pub(crate) fn to_bytes(&self) -> BytesMut {
415        self.write_to(BytesMut::with_capacity(self.encoded_len()))
416    }
417
418    /// Encodes this frame for sending over websockets.
419    ///
420    /// Specifically meant for being put into a binary websocket message frame.
421    pub(crate) fn write_to<O: BufMut>(&self, mut dst: O) -> O {
422        dst = self.typ().write_to(dst);
423        match self {
424            Self::Datagrams {
425                dst_endpoint_id,
426                datagrams,
427            } => {
428                dst.put(dst_endpoint_id.as_ref());
429                dst = datagrams.write_to(dst);
430            }
431            Self::Ping(data) => {
432                dst.put(&data[..]);
433            }
434            Self::Pong(data) => {
435                dst.put(&data[..]);
436            }
437        }
438        dst
439    }
440
441    pub(crate) fn encoded_len(&self) -> usize {
442        let payload_len = match self {
443            Self::Ping(_) | Self::Pong(_) => 8,
444            Self::Datagrams { datagrams, .. } => {
445                32 // endpoint id
446                + datagrams.encoded_len()
447            }
448        };
449        self.typ().encoded_len() + payload_len
450    }
451
452    /// Tries to decode a frame received over websockets.
453    ///
454    /// Specifically, bytes received from a binary websocket message frame.
455    #[allow(clippy::result_large_err)]
456    #[cfg(feature = "server")]
457    pub(crate) fn from_bytes(mut content: Bytes, cache: &KeyCache) -> Result<Self, Error> {
458        let frame_type = FrameType::from_bytes(&mut content)?;
459        let frame_len = content.len();
460        ensure!(
461            frame_len <= MAX_PACKET_SIZE,
462            Error::FrameTooLarge { frame_len }
463        );
464
465        let res = match frame_type {
466            FrameType::ClientToRelayDatagram | FrameType::ClientToRelayDatagramBatch => {
467                let dst_endpoint_id = cache.key_from_slice(&content[..EndpointId::LENGTH])?;
468                let datagrams = Datagrams::from_bytes(
469                    content.slice(EndpointId::LENGTH..),
470                    frame_type == FrameType::ClientToRelayDatagramBatch,
471                )?;
472                Self::Datagrams {
473                    dst_endpoint_id,
474                    datagrams,
475                }
476            }
477            FrameType::Ping => {
478                ensure!(content.len() == 8, Error::InvalidFrame);
479                let mut data = [0u8; 8];
480                data.copy_from_slice(&content[..8]);
481                Self::Ping(data)
482            }
483            FrameType::Pong => {
484                ensure!(content.len() == 8, Error::InvalidFrame);
485                let mut data = [0u8; 8];
486                data.copy_from_slice(&content[..8]);
487                Self::Pong(data)
488            }
489            _ => {
490                return Err(e!(Error::InvalidFrameType { frame_type }));
491            }
492        };
493        Ok(res)
494    }
495}
496
497#[cfg(test)]
498#[cfg(feature = "server")]
499mod tests {
500    use data_encoding::HEXLOWER;
501    use iroh_base::SecretKey;
502    use n0_error::Result;
503
504    use super::*;
505
506    fn check_expected_bytes(frames: Vec<(Vec<u8>, &str)>) {
507        for (bytes, expected_hex) in frames {
508            let stripped: Vec<u8> = expected_hex
509                .chars()
510                .filter_map(|s| {
511                    if s.is_ascii_whitespace() {
512                        None
513                    } else {
514                        Some(s as u8)
515                    }
516                })
517                .collect();
518            let expected_bytes = HEXLOWER.decode(&stripped).unwrap();
519            assert_eq!(HEXLOWER.encode(&bytes), HEXLOWER.encode(&expected_bytes));
520        }
521    }
522
523    #[test]
524    fn test_server_client_frames_snapshot() -> Result {
525        let client_key = SecretKey::from_bytes(&[42u8; 32]);
526
527        check_expected_bytes(vec![
528            (
529                RelayToClientMsg::Health {
530                    problem: "Hello? Yes this is dog.".into(),
531                }
532                .write_to(Vec::new()),
533                "0b 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73
534                20 69 73 20 64 6f 67 2e",
535            ),
536            (
537                RelayToClientMsg::EndpointGone(client_key.public()).write_to(Vec::new()),
538                "08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e
539                a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d
540                61",
541            ),
542            (
543                RelayToClientMsg::Ping([42u8; 8]).write_to(Vec::new()),
544                "09 2a 2a 2a 2a 2a 2a 2a 2a",
545            ),
546            (
547                RelayToClientMsg::Pong([42u8; 8]).write_to(Vec::new()),
548                "0a 2a 2a 2a 2a 2a 2a 2a 2a",
549            ),
550            (
551                RelayToClientMsg::Datagrams {
552                    remote_endpoint_id: client_key.public(),
553                    datagrams: Datagrams {
554                        ecn: Some(quinn::EcnCodepoint::Ce),
555                        segment_size: NonZeroU16::new(6),
556                        contents: "Hello World!".into(),
557                    },
558                }
559                .write_to(Vec::new()),
560                // frame type
561                // public key first 16 bytes
562                // public key second 16 bytes
563                // ECN byte
564                // segment size
565                // hello world contents bytes
566                "07
567                19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
568                89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
569                03
570                00 06
571                48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
572            ),
573            (
574                RelayToClientMsg::Datagrams {
575                    remote_endpoint_id: client_key.public(),
576                    datagrams: Datagrams {
577                        ecn: Some(quinn::EcnCodepoint::Ce),
578                        segment_size: None,
579                        contents: "Hello World!".into(),
580                    },
581                }
582                .write_to(Vec::new()),
583                // frame type
584                // public key first 16 bytes
585                // public key second 16 bytes
586                // ECN byte
587                // hello world contents bytes
588                "06
589                19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
590                89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
591                03
592                48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
593            ),
594            (
595                RelayToClientMsg::Restarting {
596                    reconnect_in: Duration::from_millis(10),
597                    try_for: Duration::from_millis(20),
598                }
599                .write_to(Vec::new()),
600                "0c 00 00 00 0a 00 00 00 14",
601            ),
602        ]);
603
604        Ok(())
605    }
606
607    #[test]
608    fn test_client_server_frames_snapshot() -> Result {
609        let client_key = SecretKey::from_bytes(&[42u8; 32]);
610
611        check_expected_bytes(vec![
612            (
613                ClientToRelayMsg::Ping([42u8; 8]).write_to(Vec::new()),
614                "09 2a 2a 2a 2a 2a 2a 2a 2a",
615            ),
616            (
617                ClientToRelayMsg::Pong([42u8; 8]).write_to(Vec::new()),
618                "0a 2a 2a 2a 2a 2a 2a 2a 2a",
619            ),
620            (
621                ClientToRelayMsg::Datagrams {
622                    dst_endpoint_id: client_key.public(),
623                    datagrams: Datagrams {
624                        ecn: Some(quinn::EcnCodepoint::Ce),
625                        segment_size: NonZeroU16::new(6),
626                        contents: "Hello World!".into(),
627                    },
628                }
629                .write_to(Vec::new()),
630                // frame type
631                // public key first 16 bytes
632                // public key second 16 bytes
633                // ECN byte
634                // Segment size
635                // hello world contents
636                "05
637                19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
638                89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
639                03
640                00 06
641                48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
642            ),
643            (
644                ClientToRelayMsg::Datagrams {
645                    dst_endpoint_id: client_key.public(),
646                    datagrams: Datagrams {
647                        ecn: Some(quinn::EcnCodepoint::Ce),
648                        segment_size: None,
649                        contents: "Hello World!".into(),
650                    },
651                }
652                .write_to(Vec::new()),
653                // frame type
654                // public key first 16 bytes
655                // public key second 16 bytes
656                // ECN byte
657                // hello world contents
658                "04
659                19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
660                89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
661                03
662                48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
663            ),
664        ]);
665
666        Ok(())
667    }
668}
669
670#[cfg(all(test, feature = "server"))]
671mod proptests {
672    use iroh_base::SecretKey;
673    use proptest::prelude::*;
674
675    use super::*;
676
677    fn secret_key() -> impl Strategy<Value = SecretKey> {
678        prop::array::uniform32(any::<u8>()).prop_map(SecretKey::from)
679    }
680
681    fn key() -> impl Strategy<Value = EndpointId> {
682        secret_key().prop_map(|key| key.public())
683    }
684
685    fn ecn() -> impl Strategy<Value = Option<quinn_proto::EcnCodepoint>> {
686        (0..=3).prop_map(|n| match n {
687            1 => Some(quinn_proto::EcnCodepoint::Ce),
688            2 => Some(quinn_proto::EcnCodepoint::Ect0),
689            3 => Some(quinn_proto::EcnCodepoint::Ect1),
690            _ => None,
691        })
692    }
693
694    fn datagrams() -> impl Strategy<Value = Datagrams> {
695        // The max payload size (conservatively, since with segment_size = 0 we'd have slightly more space)
696        const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - EndpointId::LENGTH - 1 /* ECN bytes */ - 2 /* segment size */;
697        (
698            ecn(),
699            prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE),
700            prop::collection::vec(any::<u8>(), 0..MAX_PAYLOAD_SIZE),
701        )
702            .prop_map(|(ecn, segment_size, data)| Datagrams {
703                ecn,
704                segment_size: segment_size
705                    .map(|ss| std::cmp::min(data.len(), ss) as u16)
706                    .and_then(NonZeroU16::new),
707                contents: Bytes::from(data),
708            })
709    }
710
711    /// Generates a random valid frame
712    fn server_client_frame() -> impl Strategy<Value = RelayToClientMsg> {
713        let recv_packet = (key(), datagrams()).prop_map(|(remote_endpoint_id, datagrams)| {
714            RelayToClientMsg::Datagrams {
715                remote_endpoint_id,
716                datagrams,
717            }
718        });
719        let endpoint_gone = key().prop_map(RelayToClientMsg::EndpointGone);
720        let ping = prop::array::uniform8(any::<u8>()).prop_map(RelayToClientMsg::Ping);
721        let pong = prop::array::uniform8(any::<u8>()).prop_map(RelayToClientMsg::Pong);
722        let health = ".{0,65536}"
723            .prop_filter("exceeds MAX_PACKET_SIZE", |s| {
724                s.len() < MAX_PACKET_SIZE // a single unicode character can match a regex "." but take up multiple bytes
725            })
726            .prop_map(|problem| RelayToClientMsg::Health { problem });
727        let restarting = (any::<u32>(), any::<u32>()).prop_map(|(reconnect_in, try_for)| {
728            RelayToClientMsg::Restarting {
729                reconnect_in: Duration::from_millis(reconnect_in.into()),
730                try_for: Duration::from_millis(try_for.into()),
731            }
732        });
733        prop_oneof![recv_packet, endpoint_gone, ping, pong, health, restarting]
734    }
735
736    fn client_server_frame() -> impl Strategy<Value = ClientToRelayMsg> {
737        let send_packet = (key(), datagrams()).prop_map(|(dst_endpoint_id, datagrams)| {
738            ClientToRelayMsg::Datagrams {
739                dst_endpoint_id,
740                datagrams,
741            }
742        });
743        let ping = prop::array::uniform8(any::<u8>()).prop_map(ClientToRelayMsg::Ping);
744        let pong = prop::array::uniform8(any::<u8>()).prop_map(ClientToRelayMsg::Pong);
745        prop_oneof![send_packet, ping, pong]
746    }
747
748    proptest! {
749        #[test]
750        fn server_client_frame_roundtrip(frame in server_client_frame()) {
751            let encoded = frame.to_bytes().freeze();
752            let decoded = RelayToClientMsg::from_bytes(encoded, &KeyCache::test()).unwrap();
753            prop_assert_eq!(frame, decoded);
754        }
755
756        #[test]
757        fn client_server_frame_roundtrip(frame in client_server_frame()) {
758            let encoded = frame.to_bytes().freeze();
759            let decoded = ClientToRelayMsg::from_bytes(encoded, &KeyCache::test()).unwrap();
760            prop_assert_eq!(frame, decoded);
761        }
762
763        #[test]
764        fn server_client_frame_encoded_len(frame in server_client_frame()) {
765            let claimed_encoded_len = frame.encoded_len();
766            let actual_encoded_len = frame.to_bytes().len();
767            prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
768        }
769
770        #[test]
771        fn client_server_frame_encoded_len(frame in client_server_frame()) {
772            let claimed_encoded_len = frame.encoded_len();
773            let actual_encoded_len = frame.to_bytes().len();
774            prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
775        }
776
777        #[test]
778        fn datagrams_encoded_len(datagrams in datagrams()) {
779            let claimed_encoded_len = datagrams.encoded_len();
780            let actual_encoded_len = datagrams.write_to(Vec::new()).len();
781            prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
782        }
783    }
784}