Skip to main content

iroh_relay/protos/
relay.rs

1//! This module implements the send/recv relaying protocol.
2//!
3//! Protocol flow:
4//!  * server occasionally sends [`FrameType::Ping`]
5//!  * client responds to any [`FrameType::Ping`] with a [`FrameType::Pong`]
6//!  * clients sends [`FrameType::ClientToRelayDatagram`] or [`FrameType::ClientToRelayDatagramBatch`]
7//!  * server then sends [`FrameType::RelayToClientDatagram`] or [`FrameType::RelayToClientDatagramBatch`] to recipient
8//!  * server sends [`FrameType::EndpointGone`] when the other client disconnects
9
10use std::num::NonZeroU16;
11
12use bytes::{Buf, BufMut, Bytes, BytesMut};
13use iroh_base::{EndpointId, KeyParsingError};
14use n0_error::{e, ensure, stack_error};
15use n0_future::time::Duration;
16
17use super::common::{FrameType, FrameTypeError};
18use crate::{KeyCache, http::ProtocolVersion};
19
20/// The maximum size of a packet sent over relay.
21/// (This only includes the data bytes visible to the socket, not
22/// including its on-wire framing overhead)
23pub const MAX_PACKET_SIZE: usize = 64 * 1024;
24
25/// The maximum frame size.
26///
27/// This is also the minimum burst size that a rate-limiter has to accept.
28#[cfg(not(wasm_browser))]
29pub(crate) const MAX_FRAME_SIZE: usize = 1024 * 1024;
30
31/// Interval in which we ping the relay server to ensure the connection is alive.
32///
33/// The default QUIC max_idle_timeout is 30s, so setting that to half this time gives some
34/// chance of recovering.
35#[cfg(feature = "server")]
36pub(crate) const PING_INTERVAL: Duration = Duration::from_secs(15);
37
38/// The number of packets buffered for sending per client
39#[cfg(feature = "server")]
40pub const PER_CLIENT_SEND_QUEUE_DEPTH: usize = 512;
41
42/// Protocol send errors.
43#[stack_error(derive, add_meta, from_sources)]
44#[allow(missing_docs)]
45#[non_exhaustive]
46pub enum Error {
47    #[error("unexpected frame: got {got:?}, expected {expected:?}")]
48    UnexpectedFrame { got: FrameType, expected: FrameType },
49    #[error("Frame is too large, has {frame_len} bytes")]
50    FrameTooLarge { frame_len: usize },
51    #[error(transparent)]
52    SerDe {
53        #[error(std_err)]
54        source: postcard::Error,
55    },
56    #[error(transparent)]
57    FrameTypeError { source: FrameTypeError },
58    #[error("Invalid public key")]
59    InvalidPublicKey { source: KeyParsingError },
60    #[error("Invalid frame encoding")]
61    InvalidFrame {},
62    #[error("Invalid frame type: {frame_type:?}")]
63    InvalidFrameType { frame_type: FrameType },
64    #[error("Invalid protocol message encoding")]
65    InvalidProtocolMessageEncoding {
66        #[error(std_err)]
67        source: std::str::Utf8Error,
68    },
69    #[error("Received a frame not allowed in this protocol version.")]
70    FrameNotAllowedInVersion,
71    #[error("Too few bytes")]
72    TooSmall {},
73}
74
75/// The messages that a relay sends to clients or the clients receive from the relay.
76#[derive(Debug, Clone, PartialEq, Eq, strum::Display)]
77#[non_exhaustive]
78pub enum RelayToClientMsg {
79    /// Represents datagrams sent from relays (originally sent to them by another client).
80    Datagrams {
81        /// The [`EndpointId`] of the original sender.
82        remote_endpoint_id: EndpointId,
83        /// The datagrams and related metadata.
84        datagrams: Datagrams,
85    },
86    /// Indicates that the client identified by the underlying public key had previously sent you a
87    /// packet but has now disconnected from the relay.
88    EndpointGone(EndpointId),
89    /// A one-way message from relay to client, declaring the connection health state.
90    Status(Status),
91    /// A one-way message from relay to client, advertising that the relay is restarting.
92    Restarting {
93        /// An advisory duration that the client should wait before attempting to reconnect.
94        /// It might be zero. It exists for the relay to smear out the reconnects.
95        reconnect_in: Duration,
96        /// An advisory duration for how long the client should attempt to reconnect
97        /// before giving up and proceeding with its normal connection failure logic. The interval
98        /// between retries is undefined for now. A relay should not send a `try_for` duration more
99        /// than a few seconds.
100        try_for: Duration,
101    },
102    /// Request from the relay to reply to the
103    /// other side with a [`ClientToRelayMsg::Pong`] with the given payload.
104    Ping([u8; 8]),
105    /// Reply to a [`ClientToRelayMsg::Ping`] from a client
106    /// with the payload sent previously in the ping.
107    Pong([u8; 8]),
108
109    // -- Deprecated variants --
110    // We don't use `#[deprecated]` because this would throw warnings for the derived serde impls.
111    /// Removed since relay-protocol-v2:
112    /// A one-way message from relay to client, declaring the connection health state.
113    ///
114    /// Use [`Self::Status`] instead.
115    Health {
116        /// Description of why the connection is unhealthy.
117        ///
118        /// The default condition is healthy, so the relay doesn't broadcast a [`RelayToClientMsg::Health`]
119        /// until a problem exists.
120        problem: String,
121    },
122}
123
124/// One-way message from server to client indicating issues with the relay connection.
125#[derive(Debug, Clone, PartialEq, Eq, derive_more::Display)]
126#[non_exhaustive]
127pub enum Status {
128    /// The connection is healthy and recovered from previous problems.
129    #[display("The connection is healthy and has recovered from previous problems")]
130    Healthy,
131    /// Another endpoint connected with the same endpoint id. No more messages will be received.
132    #[display(
133        "Another endpoint connected with the same endpoint id. No more messages will be received."
134    )]
135    SameEndpointIdConnected,
136    /// Placeholder for backwards-compatibility for future new health status variants.
137    #[display("Unsupported health message ({_0})")]
138    Unknown(u8),
139}
140
141impl Status {
142    #[cfg(feature = "server")]
143    fn write_to<O: BufMut>(&self, mut dst: O) -> O {
144        match self {
145            Status::Healthy => dst.put_u8(0),
146            Status::SameEndpointIdConnected => dst.put_u8(1),
147            Status::Unknown(discriminant) => dst.put_u8(*discriminant),
148        }
149        dst
150    }
151
152    #[cfg(feature = "server")]
153    fn encoded_len(&self) -> usize {
154        1
155    }
156
157    fn from_bytes(mut bytes: Bytes) -> Result<Self, Error> {
158        ensure!(!bytes.is_empty(), Error::InvalidFrame);
159        let discriminant = bytes.get_u8();
160        match discriminant {
161            0 => Ok(Self::Healthy),
162            1 => Ok(Self::SameEndpointIdConnected),
163            n => Ok(Self::Unknown(n)),
164        }
165    }
166}
167
168/// Messages that clients send to relays.
169#[derive(Debug, Clone, PartialEq, Eq)]
170#[non_exhaustive]
171pub enum ClientToRelayMsg {
172    /// Request from the client to the server to reply to the
173    /// other side with a [`RelayToClientMsg::Pong`] with the given payload.
174    Ping([u8; 8]),
175    /// Reply to a [`RelayToClientMsg::Ping`] from a server
176    /// with the payload sent previously in the ping.
177    Pong([u8; 8]),
178    /// Request from the client to relay datagrams to given remote endpoint.
179    Datagrams {
180        /// The remote endpoint to relay to.
181        dst_endpoint_id: EndpointId,
182        /// The datagrams and related metadata to relay.
183        datagrams: Datagrams,
184    },
185}
186
187/// One or multiple datagrams being transferred via the relay.
188///
189/// This type is modeled after [`noq_proto::Transmit`]
190/// (or even more similarly `noq_udp::Transmit`, but we don't depend on that library here).
191#[derive(derive_more::Debug, Clone, PartialEq, Eq)]
192pub struct Datagrams {
193    /// Explicit congestion notification bits
194    pub ecn: Option<noq_proto::EcnCodepoint>,
195    /// The segment size if this transmission contains multiple datagrams.
196    /// This is `None` if the transmit only contains a single datagram
197    pub segment_size: Option<NonZeroU16>,
198    /// The contents of the datagram(s)
199    #[debug(skip)]
200    pub contents: Bytes,
201}
202
203impl<T: AsRef<[u8]>> From<T> for Datagrams {
204    fn from(bytes: T) -> Self {
205        Self {
206            ecn: None,
207            segment_size: None,
208            contents: Bytes::copy_from_slice(bytes.as_ref()),
209        }
210    }
211}
212
213impl Datagrams {
214    /// Splits the current datagram into at maximum `num_segments` segments, returning
215    /// the batch with at most `num_segments` and leaving only the rest in `self`.
216    ///
217    /// Calling this on a datagram batch that only contains a single datagram (`segment_size == None`)
218    /// will result in returning essentially a clone of `self`, while making `self` empty afterwards.
219    ///
220    /// Calling this on a datagram batch with e.g. 15 datagrams with `num_segments == 10` will
221    /// result in returning a datagram batch that contains the first 10 datagrams and leave `self`
222    /// containing the remaining 5 datagrams.
223    ///
224    /// Calling this on a datagram batch with less than `num_segments` datagrams will result in
225    /// making `self` empty and returning essentially a clone of `self`.
226    pub fn take_segments(&mut self, num_segments: usize) -> Datagrams {
227        let Some(segment_size) = self.segment_size else {
228            let contents = std::mem::take(&mut self.contents);
229            return Datagrams {
230                ecn: self.ecn,
231                segment_size: None,
232                contents,
233            };
234        };
235
236        let usize_segment_size = usize::from(u16::from(segment_size));
237        let max_content_len = num_segments * usize_segment_size;
238        let contents = self
239            .contents
240            .split_to(std::cmp::min(max_content_len, self.contents.len()));
241
242        let is_datagram_batch = num_segments > 1 && usize_segment_size < contents.len();
243
244        // If this left our batch with only one more datagram, then remove the segment size
245        // to uphold the invariant that single-datagram batches don't have a segment size set.
246        if self.contents.len() <= usize_segment_size {
247            self.segment_size = None;
248        }
249
250        Datagrams {
251            ecn: self.ecn,
252            segment_size: is_datagram_batch.then_some(segment_size),
253            contents,
254        }
255    }
256
257    fn write_to<O: BufMut>(&self, mut dst: O) -> O {
258        let ecn = self.ecn.map_or(0, |ecn| ecn as u8);
259        dst.put_u8(ecn);
260        if let Some(segment_size) = self.segment_size {
261            dst.put_u16(segment_size.into());
262        }
263        dst.put(self.contents.as_ref());
264        dst
265    }
266
267    fn encoded_len(&self) -> usize {
268        1 // ECN byte
269        + self.segment_size.map_or(0, |_| 2) // segment size, when None, then a packed representation is assumed
270        + self.contents.len()
271    }
272
273    #[allow(clippy::len_zero, clippy::result_large_err)]
274    fn from_bytes(mut bytes: Bytes, is_batch: bool) -> Result<Self, Error> {
275        if is_batch {
276            // 1 bytes ECN, 2 bytes segment size
277            ensure!(bytes.len() >= 3, Error::InvalidFrame);
278        } else {
279            ensure!(bytes.len() >= 1, Error::InvalidFrame);
280        }
281
282        let ecn_byte = bytes.get_u8();
283        let ecn = noq_proto::EcnCodepoint::from_bits(ecn_byte);
284
285        let segment_size = if is_batch {
286            let segment_size = bytes.get_u16(); // length checked above
287            NonZeroU16::new(segment_size)
288        } else {
289            None
290        };
291
292        Ok(Self {
293            ecn,
294            segment_size,
295            contents: bytes,
296        })
297    }
298}
299
300impl RelayToClientMsg {
301    /// Returns this frame's corresponding frame type.
302    pub fn typ(&self) -> FrameType {
303        match self {
304            Self::Datagrams { datagrams, .. } => {
305                if datagrams.segment_size.is_some() {
306                    FrameType::RelayToClientDatagramBatch
307                } else {
308                    FrameType::RelayToClientDatagram
309                }
310            }
311            Self::EndpointGone { .. } => FrameType::EndpointGone,
312            Self::Ping { .. } => FrameType::Ping,
313            Self::Pong { .. } => FrameType::Pong,
314            Self::Status { .. } => FrameType::Status,
315            Self::Restarting { .. } => FrameType::Restarting,
316            Self::Health { .. } => FrameType::Health,
317        }
318    }
319
320    #[cfg(feature = "server")]
321    pub(crate) fn to_bytes(&self) -> BytesMut {
322        self.write_to(BytesMut::with_capacity(self.encoded_len()))
323    }
324
325    /// Encodes this frame for sending over websockets.
326    ///
327    /// Specifically meant for being put into a binary websocket message frame.
328    #[cfg(feature = "server")]
329    pub(crate) fn write_to<O: BufMut>(&self, mut dst: O) -> O {
330        dst = self.typ().write_to(dst);
331        match self {
332            Self::Datagrams {
333                remote_endpoint_id,
334                datagrams,
335            } => {
336                dst.put(remote_endpoint_id.as_ref());
337                dst = datagrams.write_to(dst);
338            }
339            Self::EndpointGone(endpoint_id) => {
340                dst.put(endpoint_id.as_ref());
341            }
342            Self::Ping(data) => {
343                dst.put(&data[..]);
344            }
345            Self::Pong(data) => {
346                dst.put(&data[..]);
347            }
348            Self::Health { problem } => {
349                dst.put(problem.as_ref());
350            }
351            Self::Restarting {
352                reconnect_in,
353                try_for,
354            } => {
355                dst.put_u32(reconnect_in.as_millis() as u32);
356                dst.put_u32(try_for.as_millis() as u32);
357            }
358            Self::Status(status) => {
359                dst = status.write_to(dst);
360            }
361        }
362        dst
363    }
364
365    #[cfg(feature = "server")]
366    pub(crate) fn encoded_len(&self) -> usize {
367        let payload_len = match self {
368            Self::Datagrams { datagrams, .. } => {
369                32 // endpointid
370                + datagrams.encoded_len()
371            }
372            Self::EndpointGone(_) => 32,
373            Self::Ping(_) | Self::Pong(_) => 8,
374            Self::Status(status) => status.encoded_len(),
375            Self::Restarting { .. } => {
376                4 // u32
377                + 4 // u32
378            }
379            Self::Health { problem } => problem.len(),
380        };
381        self.typ().encoded_len() + payload_len
382    }
383
384    /// Tries to decode a frame received over websockets.
385    ///
386    /// Specifically, bytes received from a binary websocket message frame.
387    ///
388    /// `protocol_version` is the negotiated protocol version for this connection.
389    #[allow(clippy::result_large_err)]
390    pub(crate) fn from_bytes(
391        mut content: Bytes,
392        cache: &KeyCache,
393        protocol_version: ProtocolVersion,
394    ) -> Result<Self, Error> {
395        let frame_type = FrameType::from_bytes(&mut content)?;
396        let frame_len = content.len();
397        ensure!(
398            frame_len <= MAX_PACKET_SIZE,
399            Error::FrameTooLarge { frame_len }
400        );
401
402        let res = match frame_type {
403            FrameType::RelayToClientDatagram | FrameType::RelayToClientDatagramBatch => {
404                ensure!(content.len() >= EndpointId::LENGTH, Error::InvalidFrame);
405
406                let remote_endpoint_id = cache.key_from_slice(&content[..EndpointId::LENGTH])?;
407                let datagrams = Datagrams::from_bytes(
408                    content.slice(EndpointId::LENGTH..),
409                    frame_type == FrameType::RelayToClientDatagramBatch,
410                )?;
411                Self::Datagrams {
412                    remote_endpoint_id,
413                    datagrams,
414                }
415            }
416            FrameType::EndpointGone => {
417                ensure!(content.len() == EndpointId::LENGTH, Error::InvalidFrame);
418                let endpoint_id = cache.key_from_slice(content.as_ref())?;
419                Self::EndpointGone(endpoint_id)
420            }
421            FrameType::Ping => {
422                ensure!(content.len() == 8, Error::InvalidFrame);
423                let mut data = [0u8; 8];
424                data.copy_from_slice(&content[..8]);
425                Self::Ping(data)
426            }
427            FrameType::Pong => {
428                ensure!(content.len() == 8, Error::InvalidFrame);
429                let mut data = [0u8; 8];
430                data.copy_from_slice(&content[..8]);
431                Self::Pong(data)
432            }
433            FrameType::Health => {
434                ensure!(
435                    protocol_version == ProtocolVersion::V1,
436                    Error::FrameNotAllowedInVersion
437                );
438                let problem = std::str::from_utf8(&content)?.to_owned();
439                Self::Health { problem }
440            }
441            FrameType::Restarting => {
442                ensure!(content.len() == 4 + 4, Error::InvalidFrame);
443                let reconnect_in = u32::from_be_bytes(
444                    content[..4]
445                        .try_into()
446                        .map_err(|_| e!(Error::InvalidFrame))?,
447                );
448                let try_for = u32::from_be_bytes(
449                    content[4..]
450                        .try_into()
451                        .map_err(|_| e!(Error::InvalidFrame))?,
452                );
453                let reconnect_in = Duration::from_millis(reconnect_in as u64);
454                let try_for = Duration::from_millis(try_for as u64);
455                Self::Restarting {
456                    reconnect_in,
457                    try_for,
458                }
459            }
460            FrameType::Status => {
461                ensure!(
462                    protocol_version >= ProtocolVersion::V2,
463                    Error::FrameNotAllowedInVersion
464                );
465                let status = Status::from_bytes(content)?;
466                Self::Status(status)
467            }
468            _ => {
469                return Err(e!(Error::InvalidFrameType { frame_type }));
470            }
471        };
472        Ok(res)
473    }
474}
475
476impl ClientToRelayMsg {
477    pub(crate) fn typ(&self) -> FrameType {
478        match self {
479            Self::Datagrams { datagrams, .. } => {
480                if datagrams.segment_size.is_some() {
481                    FrameType::ClientToRelayDatagramBatch
482                } else {
483                    FrameType::ClientToRelayDatagram
484                }
485            }
486            Self::Ping { .. } => FrameType::Ping,
487            Self::Pong { .. } => FrameType::Pong,
488        }
489    }
490
491    pub(crate) fn to_bytes(&self) -> BytesMut {
492        self.write_to(BytesMut::with_capacity(self.encoded_len()))
493    }
494
495    /// Encodes this frame for sending over websockets.
496    ///
497    /// Specifically meant for being put into a binary websocket message frame.
498    pub(crate) fn write_to<O: BufMut>(&self, mut dst: O) -> O {
499        dst = self.typ().write_to(dst);
500        match self {
501            Self::Datagrams {
502                dst_endpoint_id,
503                datagrams,
504            } => {
505                dst.put(dst_endpoint_id.as_ref());
506                dst = datagrams.write_to(dst);
507            }
508            Self::Ping(data) => {
509                dst.put(&data[..]);
510            }
511            Self::Pong(data) => {
512                dst.put(&data[..]);
513            }
514        }
515        dst
516    }
517
518    pub(crate) fn encoded_len(&self) -> usize {
519        let payload_len = match self {
520            Self::Ping(_) | Self::Pong(_) => 8,
521            Self::Datagrams { datagrams, .. } => {
522                32 // endpoint id
523                + datagrams.encoded_len()
524            }
525        };
526        self.typ().encoded_len() + payload_len
527    }
528
529    /// Tries to decode a frame received over websockets.
530    ///
531    /// Specifically, bytes received from a binary websocket message frame.
532    #[allow(clippy::result_large_err)]
533    #[cfg(feature = "server")]
534    pub(crate) fn from_bytes(mut content: Bytes, cache: &KeyCache) -> Result<Self, Error> {
535        let frame_type = FrameType::from_bytes(&mut content)?;
536        let frame_len = content.len();
537        ensure!(
538            frame_len <= MAX_PACKET_SIZE,
539            Error::FrameTooLarge { frame_len }
540        );
541
542        let res = match frame_type {
543            FrameType::ClientToRelayDatagram | FrameType::ClientToRelayDatagramBatch => {
544                let dst_endpoint_id = cache.key_from_slice(&content[..EndpointId::LENGTH])?;
545                let datagrams = Datagrams::from_bytes(
546                    content.slice(EndpointId::LENGTH..),
547                    frame_type == FrameType::ClientToRelayDatagramBatch,
548                )?;
549                Self::Datagrams {
550                    dst_endpoint_id,
551                    datagrams,
552                }
553            }
554            FrameType::Ping => {
555                ensure!(content.len() == 8, Error::InvalidFrame);
556                let mut data = [0u8; 8];
557                data.copy_from_slice(&content[..8]);
558                Self::Ping(data)
559            }
560            FrameType::Pong => {
561                ensure!(content.len() == 8, Error::InvalidFrame);
562                let mut data = [0u8; 8];
563                data.copy_from_slice(&content[..8]);
564                Self::Pong(data)
565            }
566            _ => {
567                return Err(e!(Error::InvalidFrameType { frame_type }));
568            }
569        };
570        Ok(res)
571    }
572}
573
574#[cfg(test)]
575#[cfg(feature = "server")]
576mod tests {
577    use data_encoding::HEXLOWER;
578    use iroh_base::SecretKey;
579    use n0_error::Result;
580
581    use super::*;
582
583    fn check_expected_bytes(frames: Vec<(Vec<u8>, &str)>) {
584        for (bytes, expected_hex) in frames {
585            let stripped: Vec<u8> = expected_hex
586                .chars()
587                .filter_map(|s| {
588                    if s.is_ascii_whitespace() {
589                        None
590                    } else {
591                        Some(s as u8)
592                    }
593                })
594                .collect();
595            let expected_bytes = HEXLOWER.decode(&stripped).unwrap();
596            assert_eq!(HEXLOWER.encode(&bytes), HEXLOWER.encode(&expected_bytes));
597        }
598    }
599
600    #[test]
601    fn test_server_client_frames_snapshot() -> Result {
602        let client_key = SecretKey::from_bytes(&[42u8; 32]);
603
604        check_expected_bytes(vec![
605            (
606                RelayToClientMsg::Health {
607                    problem: "Hello? Yes this is dog.".into(),
608                }
609                .write_to(Vec::new()),
610                "0b 48 65 6c 6c 6f 3f 20 59 65 73 20 74 68 69 73
611                20 69 73 20 64 6f 67 2e",
612            ),
613            (
614                RelayToClientMsg::EndpointGone(client_key.public()).write_to(Vec::new()),
615                "08 19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e
616                a7 89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d
617                61",
618            ),
619            (
620                RelayToClientMsg::Ping([42u8; 8]).write_to(Vec::new()),
621                "09 2a 2a 2a 2a 2a 2a 2a 2a",
622            ),
623            (
624                RelayToClientMsg::Pong([42u8; 8]).write_to(Vec::new()),
625                "0a 2a 2a 2a 2a 2a 2a 2a 2a",
626            ),
627            (
628                RelayToClientMsg::Datagrams {
629                    remote_endpoint_id: client_key.public(),
630                    datagrams: Datagrams {
631                        ecn: Some(noq::EcnCodepoint::Ce),
632                        segment_size: NonZeroU16::new(6),
633                        contents: "Hello World!".into(),
634                    },
635                }
636                .write_to(Vec::new()),
637                // frame type
638                // public key first 16 bytes
639                // public key second 16 bytes
640                // ECN byte
641                // segment size
642                // hello world contents bytes
643                "07
644                19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
645                89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
646                03
647                00 06
648                48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
649            ),
650            (
651                RelayToClientMsg::Datagrams {
652                    remote_endpoint_id: client_key.public(),
653                    datagrams: Datagrams {
654                        ecn: Some(noq::EcnCodepoint::Ce),
655                        segment_size: None,
656                        contents: "Hello World!".into(),
657                    },
658                }
659                .write_to(Vec::new()),
660                // frame type
661                // public key first 16 bytes
662                // public key second 16 bytes
663                // ECN byte
664                // hello world contents bytes
665                "06
666                19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
667                89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
668                03
669                48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
670            ),
671            (
672                RelayToClientMsg::Restarting {
673                    reconnect_in: Duration::from_millis(10),
674                    try_for: Duration::from_millis(20),
675                }
676                .write_to(Vec::new()),
677                "0c 00 00 00 0a 00 00 00 14",
678            ),
679            (
680                RelayToClientMsg::Status(Status::SameEndpointIdConnected).write_to(Vec::new()),
681                "0d 01",
682            ),
683        ]);
684
685        Ok(())
686    }
687
688    #[test]
689    fn test_client_server_frames_snapshot() -> Result {
690        let client_key = SecretKey::from_bytes(&[42u8; 32]);
691
692        check_expected_bytes(vec![
693            (
694                ClientToRelayMsg::Ping([42u8; 8]).write_to(Vec::new()),
695                "09 2a 2a 2a 2a 2a 2a 2a 2a",
696            ),
697            (
698                ClientToRelayMsg::Pong([42u8; 8]).write_to(Vec::new()),
699                "0a 2a 2a 2a 2a 2a 2a 2a 2a",
700            ),
701            (
702                ClientToRelayMsg::Datagrams {
703                    dst_endpoint_id: client_key.public(),
704                    datagrams: Datagrams {
705                        ecn: Some(noq::EcnCodepoint::Ce),
706                        segment_size: NonZeroU16::new(6),
707                        contents: "Hello World!".into(),
708                    },
709                }
710                .write_to(Vec::new()),
711                // frame type
712                // public key first 16 bytes
713                // public key second 16 bytes
714                // ECN byte
715                // Segment size
716                // hello world contents
717                "05
718                19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
719                89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
720                03
721                00 06
722                48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
723            ),
724            (
725                ClientToRelayMsg::Datagrams {
726                    dst_endpoint_id: client_key.public(),
727                    datagrams: Datagrams {
728                        ecn: Some(noq::EcnCodepoint::Ce),
729                        segment_size: None,
730                        contents: "Hello World!".into(),
731                    },
732                }
733                .write_to(Vec::new()),
734                // frame type
735                // public key first 16 bytes
736                // public key second 16 bytes
737                // ECN byte
738                // hello world contents
739                "04
740                19 7f 6b 23 e1 6c 85 32 c6 ab c8 38 fa cd 5e a7
741                89 be 0c 76 b2 92 03 34 03 9b fa 8b 3d 36 8d 61
742                03
743                48 65 6c 6c 6f 20 57 6f 72 6c 64 21",
744            ),
745        ]);
746
747        Ok(())
748    }
749}
750
751#[cfg(all(test, feature = "server"))]
752mod proptests {
753    use iroh_base::SecretKey;
754    use proptest::prelude::*;
755
756    use super::*;
757
758    fn secret_key() -> impl Strategy<Value = SecretKey> {
759        prop::array::uniform32(any::<u8>()).prop_map(SecretKey::from)
760    }
761
762    fn key() -> impl Strategy<Value = EndpointId> {
763        secret_key().prop_map(|key| key.public())
764    }
765
766    fn ecn() -> impl Strategy<Value = Option<noq_proto::EcnCodepoint>> {
767        (0..=3).prop_map(|n| match n {
768            1 => Some(noq_proto::EcnCodepoint::Ce),
769            2 => Some(noq_proto::EcnCodepoint::Ect0),
770            3 => Some(noq_proto::EcnCodepoint::Ect1),
771            _ => None,
772        })
773    }
774
775    fn datagrams() -> impl Strategy<Value = Datagrams> {
776        // The max payload size (conservatively, since with segment_size = 0 we'd have slightly more space)
777        const MAX_PAYLOAD_SIZE: usize = MAX_PACKET_SIZE - EndpointId::LENGTH - 1 /* ECN bytes */ - 2 /* segment size */;
778        (
779            ecn(),
780            prop::option::of(MAX_PAYLOAD_SIZE / 20..MAX_PAYLOAD_SIZE),
781            prop::collection::vec(any::<u8>(), 0..MAX_PAYLOAD_SIZE),
782        )
783            .prop_map(|(ecn, segment_size, data)| Datagrams {
784                ecn,
785                segment_size: segment_size
786                    .map(|ss| std::cmp::min(data.len(), ss) as u16)
787                    .and_then(NonZeroU16::new),
788                contents: Bytes::from(data),
789            })
790    }
791
792    /// Generates a random valid frame
793    fn server_client_frame() -> impl Strategy<Value = RelayToClientMsg> {
794        let recv_packet = (key(), datagrams()).prop_map(|(remote_endpoint_id, datagrams)| {
795            RelayToClientMsg::Datagrams {
796                remote_endpoint_id,
797                datagrams,
798            }
799        });
800        let endpoint_gone = key().prop_map(RelayToClientMsg::EndpointGone);
801        let ping = prop::array::uniform8(any::<u8>()).prop_map(RelayToClientMsg::Ping);
802        let pong = prop::array::uniform8(any::<u8>()).prop_map(RelayToClientMsg::Pong);
803        let v1health = ".{0,65536}"
804            .prop_filter("exceeds MAX_PACKET_SIZE", |s| {
805                s.len() < MAX_PACKET_SIZE // a single unicode character can match a regex "." but take up multiple bytes
806            })
807            .prop_map(|problem| RelayToClientMsg::Health { problem });
808        let health = Just(Status::SameEndpointIdConnected).prop_map(RelayToClientMsg::Status);
809        let restarting = (any::<u32>(), any::<u32>()).prop_map(|(reconnect_in, try_for)| {
810            RelayToClientMsg::Restarting {
811                reconnect_in: Duration::from_millis(reconnect_in.into()),
812                try_for: Duration::from_millis(try_for.into()),
813            }
814        });
815        prop_oneof![
816            recv_packet,
817            endpoint_gone,
818            ping,
819            pong,
820            v1health,
821            restarting,
822            health
823        ]
824    }
825
826    fn client_server_frame() -> impl Strategy<Value = ClientToRelayMsg> {
827        let send_packet = (key(), datagrams()).prop_map(|(dst_endpoint_id, datagrams)| {
828            ClientToRelayMsg::Datagrams {
829                dst_endpoint_id,
830                datagrams,
831            }
832        });
833        let ping = prop::array::uniform8(any::<u8>()).prop_map(ClientToRelayMsg::Ping);
834        let pong = prop::array::uniform8(any::<u8>()).prop_map(ClientToRelayMsg::Pong);
835        prop_oneof![send_packet, ping, pong]
836    }
837
838    /// The earliest protocol version in which `frame` is allowed.
839    fn allowed_version(frame: &RelayToClientMsg) -> ProtocolVersion {
840        match frame {
841            RelayToClientMsg::Health { .. } => ProtocolVersion::V1,
842            _ => ProtocolVersion::V2,
843        }
844    }
845
846    #[test]
847    fn v1health_rejected_in_v2() {
848        let frame = RelayToClientMsg::Health {
849            problem: "test".into(),
850        };
851        let encoded = frame.to_bytes().freeze();
852        let result = RelayToClientMsg::from_bytes(encoded, &KeyCache::test(), ProtocolVersion::V2);
853        assert!(matches!(
854            result,
855            Err(Error::FrameNotAllowedInVersion { .. })
856        ));
857    }
858
859    #[test]
860    fn status_rejected_in_v1() {
861        let frame = RelayToClientMsg::Status(Status::SameEndpointIdConnected);
862        let encoded = frame.to_bytes().freeze();
863        let result = RelayToClientMsg::from_bytes(encoded, &KeyCache::test(), ProtocolVersion::V1);
864        assert!(matches!(
865            result,
866            Err(Error::FrameNotAllowedInVersion { .. })
867        ));
868    }
869
870    proptest! {
871        #[test]
872        fn server_client_frame_roundtrip(frame in server_client_frame()) {
873            let version = allowed_version(&frame);
874            let encoded = frame.to_bytes().freeze();
875            let decoded = RelayToClientMsg::from_bytes(encoded, &KeyCache::test(), version).unwrap();
876            prop_assert_eq!(frame, decoded);
877        }
878
879        #[test]
880        fn client_server_frame_roundtrip(frame in client_server_frame()) {
881            let encoded = frame.to_bytes().freeze();
882            let decoded = ClientToRelayMsg::from_bytes(encoded, &KeyCache::test()).unwrap();
883            prop_assert_eq!(frame, decoded);
884        }
885
886        #[test]
887        fn server_client_frame_encoded_len(frame in server_client_frame()) {
888            let claimed_encoded_len = frame.encoded_len();
889            let actual_encoded_len = frame.to_bytes().len();
890            prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
891        }
892
893        #[test]
894        fn client_server_frame_encoded_len(frame in client_server_frame()) {
895            let claimed_encoded_len = frame.encoded_len();
896            let actual_encoded_len = frame.to_bytes().len();
897            prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
898        }
899
900        #[test]
901        fn datagrams_encoded_len(datagrams in datagrams()) {
902            let claimed_encoded_len = datagrams.encoded_len();
903            let actual_encoded_len = datagrams.write_to(Vec::new()).len();
904            prop_assert_eq!(claimed_encoded_len, actual_encoded_len);
905        }
906    }
907}