ant_quic/
frame.rs

1use std::{
2    fmt::{self, Write},
3    mem,
4    net::SocketAddr,
5    ops::{Range, RangeInclusive},
6};
7
8use bytes::{Buf, BufMut, Bytes};
9use tinyvec::TinyVec;
10
11use crate::{
12    Dir, MAX_CID_SIZE, RESET_TOKEN_SIZE, ResetToken, StreamId, TransportError, TransportErrorCode,
13    VarInt,
14    coding::{self, BufExt, BufMutExt, UnexpectedEnd},
15    range_set::ArrayRangeSet,
16    shared::{ConnectionId, EcnCodepoint},
17};
18
19#[cfg(feature = "arbitrary")]
20use arbitrary::Arbitrary;
21
22/// A QUIC frame type
23#[derive(Copy, Clone, Eq, PartialEq)]
24pub struct FrameType(u64);
25
26impl FrameType {
27    fn stream(self) -> Option<StreamInfo> {
28        if STREAM_TYS.contains(&self.0) {
29            Some(StreamInfo(self.0 as u8))
30        } else {
31            None
32        }
33    }
34    fn datagram(self) -> Option<DatagramInfo> {
35        if DATAGRAM_TYS.contains(&self.0) {
36            Some(DatagramInfo(self.0 as u8))
37        } else {
38            None
39        }
40    }
41}
42
43impl coding::Codec for FrameType {
44    fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
45        Ok(Self(buf.get_var()?))
46    }
47    fn encode<B: BufMut>(&self, buf: &mut B) {
48        buf.write_var(self.0);
49    }
50}
51
52pub(crate) trait FrameStruct {
53    /// Smallest number of bytes this type of frame is guaranteed to fit within.
54    const SIZE_BOUND: usize;
55}
56
57macro_rules! frame_types {
58    {$($name:ident = $val:expr,)*} => {
59        impl FrameType {
60            $(pub(crate) const $name: FrameType = FrameType($val);)*
61        }
62
63        impl fmt::Debug for FrameType {
64            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65                match self.0 {
66                    $($val => f.write_str(stringify!($name)),)*
67                    _ => write!(f, "Type({:02x})", self.0)
68                }
69            }
70        }
71
72        impl fmt::Display for FrameType {
73            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74                match self.0 {
75                    $($val => f.write_str(stringify!($name)),)*
76                    x if STREAM_TYS.contains(&x) => f.write_str("STREAM"),
77                    x if DATAGRAM_TYS.contains(&x) => f.write_str("DATAGRAM"),
78                    _ => write!(f, "<unknown {:02x}>", self.0),
79                }
80            }
81        }
82    }
83}
84
85#[derive(Debug, Copy, Clone, Eq, PartialEq)]
86struct StreamInfo(u8);
87
88impl StreamInfo {
89    fn fin(self) -> bool {
90        self.0 & 0x01 != 0
91    }
92    fn len(self) -> bool {
93        self.0 & 0x02 != 0
94    }
95    fn off(self) -> bool {
96        self.0 & 0x04 != 0
97    }
98}
99
100#[derive(Debug, Copy, Clone, Eq, PartialEq)]
101struct DatagramInfo(u8);
102
103impl DatagramInfo {
104    fn len(self) -> bool {
105        self.0 & 0x01 != 0
106    }
107}
108
109frame_types! {
110    PADDING = 0x00,
111    PING = 0x01,
112    ACK = 0x02,
113    ACK_ECN = 0x03,
114    RESET_STREAM = 0x04,
115    STOP_SENDING = 0x05,
116    CRYPTO = 0x06,
117    NEW_TOKEN = 0x07,
118    // STREAM
119    MAX_DATA = 0x10,
120    MAX_STREAM_DATA = 0x11,
121    MAX_STREAMS_BIDI = 0x12,
122    MAX_STREAMS_UNI = 0x13,
123    DATA_BLOCKED = 0x14,
124    STREAM_DATA_BLOCKED = 0x15,
125    STREAMS_BLOCKED_BIDI = 0x16,
126    STREAMS_BLOCKED_UNI = 0x17,
127    NEW_CONNECTION_ID = 0x18,
128    RETIRE_CONNECTION_ID = 0x19,
129    PATH_CHALLENGE = 0x1a,
130    PATH_RESPONSE = 0x1b,
131    CONNECTION_CLOSE = 0x1c,
132    APPLICATION_CLOSE = 0x1d,
133    HANDSHAKE_DONE = 0x1e,
134    // ACK Frequency
135    ACK_FREQUENCY = 0xaf,
136    IMMEDIATE_ACK = 0x1f,
137    // NAT Traversal Extension
138    ADD_ADDRESS = 0x40,
139    PUNCH_ME_NOW = 0x41,
140    REMOVE_ADDRESS = 0x42,
141    // DATAGRAM
142}
143
144const STREAM_TYS: RangeInclusive<u64> = RangeInclusive::new(0x08, 0x0f);
145const DATAGRAM_TYS: RangeInclusive<u64> = RangeInclusive::new(0x30, 0x31);
146
147#[derive(Debug)]
148pub(crate) enum Frame {
149    Padding,
150    Ping,
151    Ack(Ack),
152    ResetStream(ResetStream),
153    StopSending(StopSending),
154    Crypto(Crypto),
155    NewToken(NewToken),
156    Stream(Stream),
157    MaxData(VarInt),
158    MaxStreamData { id: StreamId, offset: u64 },
159    MaxStreams { dir: Dir, count: u64 },
160    DataBlocked { offset: u64 },
161    StreamDataBlocked { id: StreamId, offset: u64 },
162    StreamsBlocked { dir: Dir, limit: u64 },
163    NewConnectionId(NewConnectionId),
164    RetireConnectionId { sequence: u64 },
165    PathChallenge(u64),
166    PathResponse(u64),
167    Close(Close),
168    Datagram(Datagram),
169    AckFrequency(AckFrequency),
170    ImmediateAck,
171    HandshakeDone,
172    AddAddress(AddAddress),
173    PunchMeNow(PunchMeNow),
174    RemoveAddress(RemoveAddress),
175}
176
177impl Frame {
178    pub(crate) fn ty(&self) -> FrameType {
179        use Frame::*;
180        match *self {
181            Padding => FrameType::PADDING,
182            ResetStream(_) => FrameType::RESET_STREAM,
183            Close(self::Close::Connection(_)) => FrameType::CONNECTION_CLOSE,
184            Close(self::Close::Application(_)) => FrameType::APPLICATION_CLOSE,
185            MaxData(_) => FrameType::MAX_DATA,
186            MaxStreamData { .. } => FrameType::MAX_STREAM_DATA,
187            MaxStreams { dir: Dir::Bi, .. } => FrameType::MAX_STREAMS_BIDI,
188            MaxStreams { dir: Dir::Uni, .. } => FrameType::MAX_STREAMS_UNI,
189            Ping => FrameType::PING,
190            DataBlocked { .. } => FrameType::DATA_BLOCKED,
191            StreamDataBlocked { .. } => FrameType::STREAM_DATA_BLOCKED,
192            StreamsBlocked { dir: Dir::Bi, .. } => FrameType::STREAMS_BLOCKED_BIDI,
193            StreamsBlocked { dir: Dir::Uni, .. } => FrameType::STREAMS_BLOCKED_UNI,
194            StopSending { .. } => FrameType::STOP_SENDING,
195            RetireConnectionId { .. } => FrameType::RETIRE_CONNECTION_ID,
196            Ack(_) => FrameType::ACK,
197            Stream(ref x) => {
198                let mut ty = *STREAM_TYS.start();
199                if x.fin {
200                    ty |= 0x01;
201                }
202                if x.offset != 0 {
203                    ty |= 0x04;
204                }
205                FrameType(ty)
206            }
207            PathChallenge(_) => FrameType::PATH_CHALLENGE,
208            PathResponse(_) => FrameType::PATH_RESPONSE,
209            NewConnectionId { .. } => FrameType::NEW_CONNECTION_ID,
210            Crypto(_) => FrameType::CRYPTO,
211            NewToken(_) => FrameType::NEW_TOKEN,
212            Datagram(_) => FrameType(*DATAGRAM_TYS.start()),
213            AckFrequency(_) => FrameType::ACK_FREQUENCY,
214            ImmediateAck => FrameType::IMMEDIATE_ACK,
215            HandshakeDone => FrameType::HANDSHAKE_DONE,
216            AddAddress(_) => FrameType::ADD_ADDRESS,
217            PunchMeNow(_) => FrameType::PUNCH_ME_NOW,
218            RemoveAddress(_) => FrameType::REMOVE_ADDRESS,
219        }
220    }
221
222    pub(crate) fn is_ack_eliciting(&self) -> bool {
223        !matches!(*self, Self::Ack(_) | Self::Padding | Self::Close(_))
224    }
225}
226
227#[derive(Clone, Debug)]
228pub enum Close {
229    Connection(ConnectionClose),
230    Application(ApplicationClose),
231}
232
233impl Close {
234    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
235        match *self {
236            Self::Connection(ref x) => x.encode(out, max_len),
237            Self::Application(ref x) => x.encode(out, max_len),
238        }
239    }
240
241    pub(crate) fn is_transport_layer(&self) -> bool {
242        matches!(*self, Self::Connection(_))
243    }
244}
245
246impl From<TransportError> for Close {
247    fn from(x: TransportError) -> Self {
248        Self::Connection(x.into())
249    }
250}
251impl From<ConnectionClose> for Close {
252    fn from(x: ConnectionClose) -> Self {
253        Self::Connection(x)
254    }
255}
256impl From<ApplicationClose> for Close {
257    fn from(x: ApplicationClose) -> Self {
258        Self::Application(x)
259    }
260}
261
262/// Reason given by the transport for closing the connection
263#[derive(Debug, Clone, PartialEq, Eq)]
264pub struct ConnectionClose {
265    /// Class of error as encoded in the specification
266    pub error_code: TransportErrorCode,
267    /// Type of frame that caused the close
268    pub frame_type: Option<FrameType>,
269    /// Human-readable reason for the close
270    pub reason: Bytes,
271}
272
273impl fmt::Display for ConnectionClose {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        self.error_code.fmt(f)?;
276        if !self.reason.as_ref().is_empty() {
277            f.write_str(": ")?;
278            f.write_str(&String::from_utf8_lossy(&self.reason))?;
279        }
280        Ok(())
281    }
282}
283
284impl From<TransportError> for ConnectionClose {
285    fn from(x: TransportError) -> Self {
286        Self {
287            error_code: x.code,
288            frame_type: x.frame,
289            reason: x.reason.into(),
290        }
291    }
292}
293
294impl FrameStruct for ConnectionClose {
295    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
296}
297
298impl ConnectionClose {
299    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
300        out.write(FrameType::CONNECTION_CLOSE); // 1 byte
301        out.write(self.error_code); // <= 8 bytes
302        let ty = self.frame_type.map_or(0, |x| x.0);
303        out.write_var(ty); // <= 8 bytes
304        let max_len = max_len
305            - 3
306            - VarInt::from_u64(ty).unwrap().size()
307            - VarInt::from_u64(self.reason.len() as u64).unwrap().size();
308        let actual_len = self.reason.len().min(max_len);
309        out.write_var(actual_len as u64); // <= 8 bytes
310        out.put_slice(&self.reason[0..actual_len]); // whatever's left
311    }
312}
313
314/// Reason given by an application for closing the connection
315#[derive(Debug, Clone, PartialEq, Eq)]
316pub struct ApplicationClose {
317    /// Application-specific reason code
318    pub error_code: VarInt,
319    /// Human-readable reason for the close
320    pub reason: Bytes,
321}
322
323impl fmt::Display for ApplicationClose {
324    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325        if !self.reason.as_ref().is_empty() {
326            f.write_str(&String::from_utf8_lossy(&self.reason))?;
327            f.write_str(" (code ")?;
328            self.error_code.fmt(f)?;
329            f.write_str(")")?;
330        } else {
331            self.error_code.fmt(f)?;
332        }
333        Ok(())
334    }
335}
336
337impl FrameStruct for ApplicationClose {
338    const SIZE_BOUND: usize = 1 + 8 + 8;
339}
340
341impl ApplicationClose {
342    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
343        out.write(FrameType::APPLICATION_CLOSE); // 1 byte
344        out.write(self.error_code); // <= 8 bytes
345        let max_len = max_len - 3 - VarInt::from_u64(self.reason.len() as u64).unwrap().size();
346        let actual_len = self.reason.len().min(max_len);
347        out.write_var(actual_len as u64); // <= 8 bytes
348        out.put_slice(&self.reason[0..actual_len]); // whatever's left
349    }
350}
351
352#[derive(Clone, Eq, PartialEq)]
353pub struct Ack {
354    pub largest: u64,
355    pub delay: u64,
356    pub additional: Bytes,
357    pub ecn: Option<EcnCounts>,
358}
359
360impl fmt::Debug for Ack {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        let mut ranges = "[".to_string();
363        let mut first = true;
364        for range in self.iter() {
365            if !first {
366                ranges.push(',');
367            }
368            write!(ranges, "{range:?}").unwrap();
369            first = false;
370        }
371        ranges.push(']');
372
373        f.debug_struct("Ack")
374            .field("largest", &self.largest)
375            .field("delay", &self.delay)
376            .field("ecn", &self.ecn)
377            .field("ranges", &ranges)
378            .finish()
379    }
380}
381
382impl<'a> IntoIterator for &'a Ack {
383    type Item = RangeInclusive<u64>;
384    type IntoIter = AckIter<'a>;
385
386    fn into_iter(self) -> AckIter<'a> {
387        AckIter::new(self.largest, &self.additional[..])
388    }
389}
390
391impl Ack {
392    pub fn encode<W: BufMut>(
393        delay: u64,
394        ranges: &ArrayRangeSet,
395        ecn: Option<&EcnCounts>,
396        buf: &mut W,
397    ) {
398        let mut rest = ranges.iter().rev();
399        let first = rest.next().unwrap();
400        let largest = first.end - 1;
401        let first_size = first.end - first.start;
402        buf.write(if ecn.is_some() {
403            FrameType::ACK_ECN
404        } else {
405            FrameType::ACK
406        });
407        buf.write_var(largest);
408        buf.write_var(delay);
409        buf.write_var(ranges.len() as u64 - 1);
410        buf.write_var(first_size - 1);
411        let mut prev = first.start;
412        for block in rest {
413            let size = block.end - block.start;
414            buf.write_var(prev - block.end - 1);
415            buf.write_var(size - 1);
416            prev = block.start;
417        }
418        if let Some(x) = ecn {
419            x.encode(buf)
420        }
421    }
422
423    pub fn iter(&self) -> AckIter<'_> {
424        self.into_iter()
425    }
426}
427
428#[derive(Debug, Copy, Clone, Eq, PartialEq)]
429pub struct EcnCounts {
430    pub ect0: u64,
431    pub ect1: u64,
432    pub ce: u64,
433}
434
435impl std::ops::AddAssign<EcnCodepoint> for EcnCounts {
436    fn add_assign(&mut self, rhs: EcnCodepoint) {
437        match rhs {
438            EcnCodepoint::Ect0 => {
439                self.ect0 += 1;
440            }
441            EcnCodepoint::Ect1 => {
442                self.ect1 += 1;
443            }
444            EcnCodepoint::Ce => {
445                self.ce += 1;
446            }
447        }
448    }
449}
450
451impl EcnCounts {
452    pub const ZERO: Self = Self {
453        ect0: 0,
454        ect1: 0,
455        ce: 0,
456    };
457
458    pub fn encode<W: BufMut>(&self, out: &mut W) {
459        out.write_var(self.ect0);
460        out.write_var(self.ect1);
461        out.write_var(self.ce);
462    }
463}
464
465#[derive(Debug, Clone)]
466pub(crate) struct Stream {
467    pub(crate) id: StreamId,
468    pub(crate) offset: u64,
469    pub(crate) fin: bool,
470    pub(crate) data: Bytes,
471}
472
473impl FrameStruct for Stream {
474    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
475}
476
477/// Metadata from a stream frame
478#[derive(Debug, Clone)]
479pub(crate) struct StreamMeta {
480    pub(crate) id: StreamId,
481    pub(crate) offsets: Range<u64>,
482    pub(crate) fin: bool,
483}
484
485// This manual implementation exists because `Default` is not implemented for `StreamId`
486impl Default for StreamMeta {
487    fn default() -> Self {
488        Self {
489            id: StreamId(0),
490            offsets: 0..0,
491            fin: false,
492        }
493    }
494}
495
496impl StreamMeta {
497    pub(crate) fn encode<W: BufMut>(&self, length: bool, out: &mut W) {
498        let mut ty = *STREAM_TYS.start();
499        if self.offsets.start != 0 {
500            ty |= 0x04;
501        }
502        if length {
503            ty |= 0x02;
504        }
505        if self.fin {
506            ty |= 0x01;
507        }
508        out.write_var(ty); // 1 byte
509        out.write(self.id); // <=8 bytes
510        if self.offsets.start != 0 {
511            out.write_var(self.offsets.start); // <=8 bytes
512        }
513        if length {
514            out.write_var(self.offsets.end - self.offsets.start); // <=8 bytes
515        }
516    }
517}
518
519/// A vector of [`StreamMeta`] with optimization for the single element case
520pub(crate) type StreamMetaVec = TinyVec<[StreamMeta; 1]>;
521
522#[derive(Debug, Clone)]
523pub(crate) struct Crypto {
524    pub(crate) offset: u64,
525    pub(crate) data: Bytes,
526}
527
528impl Crypto {
529    pub(crate) const SIZE_BOUND: usize = 17;
530
531    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
532        out.write(FrameType::CRYPTO);
533        out.write_var(self.offset);
534        out.write_var(self.data.len() as u64);
535        out.put_slice(&self.data);
536    }
537}
538
539#[derive(Debug, Clone)]
540pub(crate) struct NewToken {
541    pub(crate) token: Bytes,
542}
543
544impl NewToken {
545    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
546        out.write(FrameType::NEW_TOKEN);
547        out.write_var(self.token.len() as u64);
548        out.put_slice(&self.token);
549    }
550
551    pub(crate) fn size(&self) -> usize {
552        1 + VarInt::from_u64(self.token.len() as u64).unwrap().size() + self.token.len()
553    }
554}
555
556pub(crate) struct Iter {
557    bytes: Bytes,
558    last_ty: Option<FrameType>,
559}
560
561impl Iter {
562    pub(crate) fn new(payload: Bytes) -> Result<Self, TransportError> {
563        if payload.is_empty() {
564            // "An endpoint MUST treat receipt of a packet containing no frames as a
565            // connection error of type PROTOCOL_VIOLATION."
566            // https://www.rfc-editor.org/rfc/rfc9000.html#name-frames-and-frame-types
567            return Err(TransportError::PROTOCOL_VIOLATION(
568                "packet payload is empty",
569            ));
570        }
571
572        Ok(Self {
573            bytes: payload,
574            last_ty: None,
575        })
576    }
577
578    fn take_len(&mut self) -> Result<Bytes, UnexpectedEnd> {
579        let len = self.bytes.get_var()?;
580        if len > self.bytes.remaining() as u64 {
581            return Err(UnexpectedEnd);
582        }
583        Ok(self.bytes.split_to(len as usize))
584    }
585
586    fn try_next(&mut self) -> Result<Frame, IterErr> {
587        let ty = self.bytes.get::<FrameType>()?;
588        self.last_ty = Some(ty);
589        Ok(match ty {
590            FrameType::PADDING => Frame::Padding,
591            FrameType::RESET_STREAM => Frame::ResetStream(ResetStream {
592                id: self.bytes.get()?,
593                error_code: self.bytes.get()?,
594                final_offset: self.bytes.get()?,
595            }),
596            FrameType::CONNECTION_CLOSE => Frame::Close(Close::Connection(ConnectionClose {
597                error_code: self.bytes.get()?,
598                frame_type: {
599                    let x = self.bytes.get_var()?;
600                    if x == 0 { None } else { Some(FrameType(x)) }
601                },
602                reason: self.take_len()?,
603            })),
604            FrameType::APPLICATION_CLOSE => Frame::Close(Close::Application(ApplicationClose {
605                error_code: self.bytes.get()?,
606                reason: self.take_len()?,
607            })),
608            FrameType::MAX_DATA => Frame::MaxData(self.bytes.get()?),
609            FrameType::MAX_STREAM_DATA => Frame::MaxStreamData {
610                id: self.bytes.get()?,
611                offset: self.bytes.get_var()?,
612            },
613            FrameType::MAX_STREAMS_BIDI => Frame::MaxStreams {
614                dir: Dir::Bi,
615                count: self.bytes.get_var()?,
616            },
617            FrameType::MAX_STREAMS_UNI => Frame::MaxStreams {
618                dir: Dir::Uni,
619                count: self.bytes.get_var()?,
620            },
621            FrameType::PING => Frame::Ping,
622            FrameType::DATA_BLOCKED => Frame::DataBlocked {
623                offset: self.bytes.get_var()?,
624            },
625            FrameType::STREAM_DATA_BLOCKED => Frame::StreamDataBlocked {
626                id: self.bytes.get()?,
627                offset: self.bytes.get_var()?,
628            },
629            FrameType::STREAMS_BLOCKED_BIDI => Frame::StreamsBlocked {
630                dir: Dir::Bi,
631                limit: self.bytes.get_var()?,
632            },
633            FrameType::STREAMS_BLOCKED_UNI => Frame::StreamsBlocked {
634                dir: Dir::Uni,
635                limit: self.bytes.get_var()?,
636            },
637            FrameType::STOP_SENDING => Frame::StopSending(StopSending {
638                id: self.bytes.get()?,
639                error_code: self.bytes.get()?,
640            }),
641            FrameType::RETIRE_CONNECTION_ID => Frame::RetireConnectionId {
642                sequence: self.bytes.get_var()?,
643            },
644            FrameType::ACK | FrameType::ACK_ECN => {
645                let largest = self.bytes.get_var()?;
646                let delay = self.bytes.get_var()?;
647                let extra_blocks = self.bytes.get_var()? as usize;
648                let n = scan_ack_blocks(&self.bytes, largest, extra_blocks)?;
649                Frame::Ack(Ack {
650                    delay,
651                    largest,
652                    additional: self.bytes.split_to(n),
653                    ecn: if ty != FrameType::ACK_ECN {
654                        None
655                    } else {
656                        Some(EcnCounts {
657                            ect0: self.bytes.get_var()?,
658                            ect1: self.bytes.get_var()?,
659                            ce: self.bytes.get_var()?,
660                        })
661                    },
662                })
663            }
664            FrameType::PATH_CHALLENGE => Frame::PathChallenge(self.bytes.get()?),
665            FrameType::PATH_RESPONSE => Frame::PathResponse(self.bytes.get()?),
666            FrameType::NEW_CONNECTION_ID => {
667                let sequence = self.bytes.get_var()?;
668                let retire_prior_to = self.bytes.get_var()?;
669                if retire_prior_to > sequence {
670                    return Err(IterErr::Malformed);
671                }
672                let length = self.bytes.get::<u8>()? as usize;
673                if length > MAX_CID_SIZE || length == 0 {
674                    return Err(IterErr::Malformed);
675                }
676                if length > self.bytes.remaining() {
677                    return Err(IterErr::UnexpectedEnd);
678                }
679                let mut stage = [0; MAX_CID_SIZE];
680                self.bytes.copy_to_slice(&mut stage[0..length]);
681                let id = ConnectionId::new(&stage[..length]);
682                if self.bytes.remaining() < 16 {
683                    return Err(IterErr::UnexpectedEnd);
684                }
685                let mut reset_token = [0; RESET_TOKEN_SIZE];
686                self.bytes.copy_to_slice(&mut reset_token);
687                Frame::NewConnectionId(NewConnectionId {
688                    sequence,
689                    retire_prior_to,
690                    id,
691                    reset_token: reset_token.into(),
692                })
693            }
694            FrameType::CRYPTO => Frame::Crypto(Crypto {
695                offset: self.bytes.get_var()?,
696                data: self.take_len()?,
697            }),
698            FrameType::NEW_TOKEN => Frame::NewToken(NewToken {
699                token: self.take_len()?,
700            }),
701            FrameType::HANDSHAKE_DONE => Frame::HandshakeDone,
702            FrameType::ACK_FREQUENCY => Frame::AckFrequency(AckFrequency {
703                sequence: self.bytes.get()?,
704                ack_eliciting_threshold: self.bytes.get()?,
705                request_max_ack_delay: self.bytes.get()?,
706                reordering_threshold: self.bytes.get()?,
707            }),
708            FrameType::IMMEDIATE_ACK => Frame::ImmediateAck,
709            FrameType::ADD_ADDRESS => Frame::AddAddress(AddAddress::decode(&mut self.bytes)?),
710            FrameType::PUNCH_ME_NOW => Frame::PunchMeNow(PunchMeNow::decode(&mut self.bytes)?),
711            FrameType::REMOVE_ADDRESS => Frame::RemoveAddress(RemoveAddress::decode(&mut self.bytes)?),
712            _ => {
713                if let Some(s) = ty.stream() {
714                    Frame::Stream(Stream {
715                        id: self.bytes.get()?,
716                        offset: if s.off() { self.bytes.get_var()? } else { 0 },
717                        fin: s.fin(),
718                        data: if s.len() {
719                            self.take_len()?
720                        } else {
721                            self.take_remaining()
722                        },
723                    })
724                } else if let Some(d) = ty.datagram() {
725                    Frame::Datagram(Datagram {
726                        data: if d.len() {
727                            self.take_len()?
728                        } else {
729                            self.take_remaining()
730                        },
731                    })
732                } else {
733                    return Err(IterErr::InvalidFrameId);
734                }
735            }
736        })
737    }
738
739    fn take_remaining(&mut self) -> Bytes {
740        mem::take(&mut self.bytes)
741    }
742}
743
744impl Iterator for Iter {
745    type Item = Result<Frame, InvalidFrame>;
746    fn next(&mut self) -> Option<Self::Item> {
747        if !self.bytes.has_remaining() {
748            return None;
749        }
750        match self.try_next() {
751            Ok(x) => Some(Ok(x)),
752            Err(e) => {
753                // Corrupt frame, skip it and everything that follows
754                self.bytes.clear();
755                Some(Err(InvalidFrame {
756                    ty: self.last_ty,
757                    reason: e.reason(),
758                }))
759            }
760        }
761    }
762}
763
764#[derive(Debug)]
765pub(crate) struct InvalidFrame {
766    pub(crate) ty: Option<FrameType>,
767    pub(crate) reason: &'static str,
768}
769
770impl From<InvalidFrame> for TransportError {
771    fn from(err: InvalidFrame) -> Self {
772        let mut te = Self::FRAME_ENCODING_ERROR(err.reason);
773        te.frame = err.ty;
774        te
775    }
776}
777
778/// Validate exactly `n` ACK ranges in `buf` and return the number of bytes they cover
779fn scan_ack_blocks(mut buf: &[u8], largest: u64, n: usize) -> Result<usize, IterErr> {
780    let total_len = buf.remaining();
781    let first_block = buf.get_var()?;
782    let mut smallest = largest.checked_sub(first_block).ok_or(IterErr::Malformed)?;
783    for _ in 0..n {
784        let gap = buf.get_var()?;
785        smallest = smallest.checked_sub(gap + 2).ok_or(IterErr::Malformed)?;
786        let block = buf.get_var()?;
787        smallest = smallest.checked_sub(block).ok_or(IterErr::Malformed)?;
788    }
789    Ok(total_len - buf.remaining())
790}
791
792enum IterErr {
793    UnexpectedEnd,
794    InvalidFrameId,
795    Malformed,
796}
797
798impl IterErr {
799    fn reason(&self) -> &'static str {
800        use IterErr::*;
801        match *self {
802            UnexpectedEnd => "unexpected end",
803            InvalidFrameId => "invalid frame ID",
804            Malformed => "malformed",
805        }
806    }
807}
808
809impl From<UnexpectedEnd> for IterErr {
810    fn from(_: UnexpectedEnd) -> Self {
811        Self::UnexpectedEnd
812    }
813}
814
815#[derive(Debug, Clone)]
816pub struct AckIter<'a> {
817    largest: u64,
818    data: &'a [u8],
819}
820
821impl<'a> AckIter<'a> {
822    fn new(largest: u64, data: &'a [u8]) -> Self {
823        Self { largest, data }
824    }
825}
826
827impl Iterator for AckIter<'_> {
828    type Item = RangeInclusive<u64>;
829    fn next(&mut self) -> Option<RangeInclusive<u64>> {
830        if !self.data.has_remaining() {
831            return None;
832        }
833        let block = self.data.get_var().unwrap();
834        let largest = self.largest;
835        if let Ok(gap) = self.data.get_var() {
836            self.largest -= block + gap + 2;
837        }
838        Some(largest - block..=largest)
839    }
840}
841
842#[allow(unreachable_pub)] // fuzzing only
843#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
844#[derive(Debug, Copy, Clone)]
845pub struct ResetStream {
846    pub(crate) id: StreamId,
847    pub(crate) error_code: VarInt,
848    pub(crate) final_offset: VarInt,
849}
850
851impl FrameStruct for ResetStream {
852    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
853}
854
855impl ResetStream {
856    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
857        out.write(FrameType::RESET_STREAM); // 1 byte
858        out.write(self.id); // <= 8 bytes
859        out.write(self.error_code); // <= 8 bytes
860        out.write(self.final_offset); // <= 8 bytes
861    }
862}
863
864#[derive(Debug, Copy, Clone)]
865pub(crate) struct StopSending {
866    pub(crate) id: StreamId,
867    pub(crate) error_code: VarInt,
868}
869
870impl FrameStruct for StopSending {
871    const SIZE_BOUND: usize = 1 + 8 + 8;
872}
873
874impl StopSending {
875    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
876        out.write(FrameType::STOP_SENDING); // 1 byte
877        out.write(self.id); // <= 8 bytes
878        out.write(self.error_code) // <= 8 bytes
879    }
880}
881
882#[derive(Debug, Copy, Clone)]
883pub(crate) struct NewConnectionId {
884    pub(crate) sequence: u64,
885    pub(crate) retire_prior_to: u64,
886    pub(crate) id: ConnectionId,
887    pub(crate) reset_token: ResetToken,
888}
889
890impl NewConnectionId {
891    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
892        out.write(FrameType::NEW_CONNECTION_ID);
893        out.write_var(self.sequence);
894        out.write_var(self.retire_prior_to);
895        out.write(self.id.len() as u8);
896        out.put_slice(&self.id);
897        out.put_slice(&self.reset_token);
898    }
899}
900
901/// Smallest number of bytes this type of frame is guaranteed to fit within.
902pub(crate) const RETIRE_CONNECTION_ID_SIZE_BOUND: usize = 9;
903
904/// An unreliable datagram
905#[derive(Debug, Clone)]
906pub struct Datagram {
907    /// Payload
908    pub data: Bytes,
909}
910
911impl FrameStruct for Datagram {
912    const SIZE_BOUND: usize = 1 + 8;
913}
914
915impl Datagram {
916    pub(crate) fn encode(&self, length: bool, out: &mut Vec<u8>) {
917        out.write(FrameType(*DATAGRAM_TYS.start() | u64::from(length))); // 1 byte
918        if length {
919            // Safe to unwrap because we check length sanity before queueing datagrams
920            out.write(VarInt::from_u64(self.data.len() as u64).unwrap()); // <= 8 bytes
921        }
922        out.extend_from_slice(&self.data);
923    }
924
925    pub(crate) fn size(&self, length: bool) -> usize {
926        1 + if length {
927            VarInt::from_u64(self.data.len() as u64).unwrap().size()
928        } else {
929            0
930        } + self.data.len()
931    }
932}
933
934#[derive(Debug, Copy, Clone, PartialEq, Eq)]
935pub(crate) struct AckFrequency {
936    pub(crate) sequence: VarInt,
937    pub(crate) ack_eliciting_threshold: VarInt,
938    pub(crate) request_max_ack_delay: VarInt,
939    pub(crate) reordering_threshold: VarInt,
940}
941
942impl AckFrequency {
943    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
944        buf.write(FrameType::ACK_FREQUENCY);
945        buf.write(self.sequence);
946        buf.write(self.ack_eliciting_threshold);
947        buf.write(self.request_max_ack_delay);
948        buf.write(self.reordering_threshold);
949    }
950}
951
952/// NAT traversal frame for advertising candidate addresses
953#[derive(Debug, Clone, PartialEq, Eq)]
954pub(crate) struct AddAddress {
955    /// Sequence number for this address advertisement
956    pub(crate) sequence: VarInt,
957    /// Socket address being advertised
958    pub(crate) address: SocketAddr,
959    /// Priority of this address candidate
960    pub(crate) priority: VarInt,
961}
962
963impl AddAddress {
964    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
965        buf.write(FrameType::ADD_ADDRESS);
966        buf.write(self.sequence);
967        buf.write(self.priority);
968        
969        match self.address {
970            SocketAddr::V4(addr) => {
971                buf.put_u8(4); // IPv4 indicator
972                buf.put_slice(&addr.ip().octets());
973                buf.put_u16(addr.port());
974            }
975            SocketAddr::V6(addr) => {
976                buf.put_u8(6); // IPv6 indicator
977                buf.put_slice(&addr.ip().octets());
978                buf.put_u16(addr.port());
979                buf.put_u32(addr.flowinfo());
980                buf.put_u32(addr.scope_id());
981            }
982        }
983    }
984    
985    pub(crate) fn decode<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
986        let sequence = r.get()?;
987        let priority = r.get()?;
988        let ip_version = r.get::<u8>()?;
989        
990        let address = match ip_version {
991            4 => {
992                let mut octets = [0u8; 4];
993                r.copy_to_slice(&mut octets);
994                let port = r.get::<u16>()?;
995                SocketAddr::V4(std::net::SocketAddrV4::new(
996                    std::net::Ipv4Addr::from(octets),
997                    port,
998                ))
999            }
1000            6 => {
1001                let mut octets = [0u8; 16];
1002                r.copy_to_slice(&mut octets);
1003                let port = r.get::<u16>()?;
1004                let flowinfo = r.get::<u32>()?;
1005                let scope_id = r.get::<u32>()?;
1006                SocketAddr::V6(std::net::SocketAddrV6::new(
1007                    std::net::Ipv6Addr::from(octets),
1008                    port,
1009                    flowinfo,
1010                    scope_id,
1011                ))
1012            }
1013            _ => return Err(UnexpectedEnd),
1014        };
1015        
1016        Ok(Self {
1017            sequence,
1018            address,
1019            priority,
1020        })
1021    }
1022}
1023
1024impl FrameStruct for AddAddress {
1025    const SIZE_BOUND: usize = 1 + 9 + 9 + 1 + 16 + 2 + 4 + 4; // Worst case IPv6
1026}
1027
1028/// NAT traversal frame for requesting simultaneous hole punching
1029#[derive(Debug, Clone, PartialEq, Eq)]
1030pub(crate) struct PunchMeNow {
1031    /// Round number for coordination
1032    pub(crate) round: VarInt,
1033    /// Sequence number of the address to punch to (from AddAddress)
1034    pub(crate) target_sequence: VarInt,
1035    /// Local address for this punch attempt
1036    pub(crate) local_address: SocketAddr,
1037    /// Target peer ID for relay by bootstrap nodes (optional)
1038    /// When present, this frame should be relayed to the specified peer
1039    pub(crate) target_peer_id: Option<[u8; 32]>,
1040}
1041
1042impl PunchMeNow {
1043    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
1044        buf.write(FrameType::PUNCH_ME_NOW);
1045        buf.write(self.round);
1046        buf.write(self.target_sequence);
1047        
1048        match self.local_address {
1049            SocketAddr::V4(addr) => {
1050                buf.put_u8(4); // IPv4 indicator
1051                buf.put_slice(&addr.ip().octets());
1052                buf.put_u16(addr.port());
1053            }
1054            SocketAddr::V6(addr) => {
1055                buf.put_u8(6); // IPv6 indicator
1056                buf.put_slice(&addr.ip().octets());
1057                buf.put_u16(addr.port());
1058                buf.put_u32(addr.flowinfo());
1059                buf.put_u32(addr.scope_id());
1060            }
1061        }
1062        
1063        // Encode target_peer_id if present
1064        match &self.target_peer_id {
1065            Some(peer_id) => {
1066                buf.put_u8(1); // Presence indicator
1067                buf.put_slice(peer_id);
1068            }
1069            None => {
1070                buf.put_u8(0); // Absence indicator
1071            }
1072        }
1073    }
1074    
1075    pub(crate) fn decode<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
1076        let round = r.get()?;
1077        let target_sequence = r.get()?;
1078        let ip_version = r.get::<u8>()?;
1079        
1080        let local_address = match ip_version {
1081            4 => {
1082                let mut octets = [0u8; 4];
1083                r.copy_to_slice(&mut octets);
1084                let port = r.get::<u16>()?;
1085                SocketAddr::V4(std::net::SocketAddrV4::new(
1086                    std::net::Ipv4Addr::from(octets),
1087                    port,
1088                ))
1089            }
1090            6 => {
1091                let mut octets = [0u8; 16];
1092                r.copy_to_slice(&mut octets);
1093                let port = r.get::<u16>()?;
1094                let flowinfo = r.get::<u32>()?;
1095                let scope_id = r.get::<u32>()?;
1096                SocketAddr::V6(std::net::SocketAddrV6::new(
1097                    std::net::Ipv6Addr::from(octets),
1098                    port,
1099                    flowinfo,
1100                    scope_id,
1101                ))
1102            }
1103            _ => return Err(UnexpectedEnd),
1104        };
1105        
1106        // Decode target_peer_id if present
1107        let target_peer_id = if r.remaining() > 0 {
1108            let has_peer_id = r.get::<u8>()?;
1109            if has_peer_id == 1 {
1110                let mut peer_id = [0u8; 32];
1111                r.copy_to_slice(&mut peer_id);
1112                Some(peer_id)
1113            } else {
1114                None
1115            }
1116        } else {
1117            None
1118        };
1119        
1120        Ok(Self {
1121            round,
1122            target_sequence,
1123            local_address,
1124            target_peer_id,
1125        })
1126    }
1127}
1128
1129impl FrameStruct for PunchMeNow {
1130    const SIZE_BOUND: usize = 1 + 9 + 9 + 1 + 16 + 2 + 4 + 4 + 1 + 32; // Worst case IPv6 + peer ID
1131}
1132
1133/// NAT traversal frame for removing stale addresses
1134#[derive(Debug, Clone, PartialEq, Eq)]
1135pub(crate) struct RemoveAddress {
1136    /// Sequence number of the address to remove (from AddAddress)
1137    pub(crate) sequence: VarInt,
1138}
1139
1140impl RemoveAddress {
1141    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
1142        buf.write(FrameType::REMOVE_ADDRESS);
1143        buf.write(self.sequence);
1144    }
1145    
1146    pub(crate) fn decode<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
1147        let sequence = r.get()?;
1148        Ok(Self { sequence })
1149    }
1150}
1151
1152impl FrameStruct for RemoveAddress {
1153    const SIZE_BOUND: usize = 1 + 9; // frame type + sequence
1154}
1155
1156#[cfg(test)]
1157mod test {
1158    use super::*;
1159    use crate::coding::Codec;
1160    use assert_matches::assert_matches;
1161
1162    fn frames(buf: Vec<u8>) -> Vec<Frame> {
1163        Iter::new(Bytes::from(buf))
1164            .unwrap()
1165            .collect::<Result<Vec<_>, _>>()
1166            .unwrap()
1167    }
1168
1169    #[test]
1170    fn ack_coding() {
1171        const PACKETS: &[u64] = &[1, 2, 3, 5, 10, 11, 14];
1172        let mut ranges = ArrayRangeSet::new();
1173        for &packet in PACKETS {
1174            ranges.insert(packet..packet + 1);
1175        }
1176        let mut buf = Vec::new();
1177        const ECN: EcnCounts = EcnCounts {
1178            ect0: 42,
1179            ect1: 24,
1180            ce: 12,
1181        };
1182        Ack::encode(42, &ranges, Some(&ECN), &mut buf);
1183        let frames = frames(buf);
1184        assert_eq!(frames.len(), 1);
1185        match frames[0] {
1186            Frame::Ack(ref ack) => {
1187                let mut packets = ack.iter().flatten().collect::<Vec<_>>();
1188                packets.sort_unstable();
1189                assert_eq!(&packets[..], PACKETS);
1190                assert_eq!(ack.ecn, Some(ECN));
1191            }
1192            ref x => panic!("incorrect frame {x:?}"),
1193        }
1194    }
1195
1196    #[test]
1197    fn ack_frequency_coding() {
1198        let mut buf = Vec::new();
1199        let original = AckFrequency {
1200            sequence: VarInt(42),
1201            ack_eliciting_threshold: VarInt(20),
1202            request_max_ack_delay: VarInt(50_000),
1203            reordering_threshold: VarInt(1),
1204        };
1205        original.encode(&mut buf);
1206        let frames = frames(buf);
1207        assert_eq!(frames.len(), 1);
1208        match &frames[0] {
1209            Frame::AckFrequency(decoded) => assert_eq!(decoded, &original),
1210            x => panic!("incorrect frame {x:?}"),
1211        }
1212    }
1213
1214    #[test]
1215    fn immediate_ack_coding() {
1216        let mut buf = Vec::new();
1217        FrameType::IMMEDIATE_ACK.encode(&mut buf);
1218        let frames = frames(buf);
1219        assert_eq!(frames.len(), 1);
1220        assert_matches!(&frames[0], Frame::ImmediateAck);
1221    }
1222
1223    #[test]
1224    fn add_address_ipv4_coding() {
1225        let mut buf = Vec::new();
1226        let addr = SocketAddr::from(([127, 0, 0, 1], 8080));
1227        let original = AddAddress {
1228            sequence: VarInt(42),
1229            address: addr,
1230            priority: VarInt(100),
1231        };
1232        original.encode(&mut buf);
1233        let frames = frames(buf);
1234        assert_eq!(frames.len(), 1);
1235        match &frames[0] {
1236            Frame::AddAddress(decoded) => {
1237                assert_eq!(decoded.sequence, original.sequence);
1238                assert_eq!(decoded.address, original.address);
1239                assert_eq!(decoded.priority, original.priority);
1240            }
1241            x => panic!("incorrect frame {x:?}"),
1242        }
1243    }
1244
1245    #[test]
1246    fn add_address_ipv6_coding() {
1247        let mut buf = Vec::new();
1248        let addr = SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 8080));
1249        let original = AddAddress {
1250            sequence: VarInt(123),
1251            address: addr,
1252            priority: VarInt(200),
1253        };
1254        original.encode(&mut buf);
1255        let frames = frames(buf);
1256        assert_eq!(frames.len(), 1);
1257        match &frames[0] {
1258            Frame::AddAddress(decoded) => {
1259                assert_eq!(decoded.sequence, original.sequence);
1260                assert_eq!(decoded.address, original.address);
1261                assert_eq!(decoded.priority, original.priority);
1262            }
1263            x => panic!("incorrect frame {x:?}"),
1264        }
1265    }
1266
1267    #[test]
1268    fn punch_me_now_ipv4_coding() {
1269        let mut buf = Vec::new();
1270        let addr = SocketAddr::from(([192, 168, 1, 1], 9000));
1271        let original = PunchMeNow {
1272            round: VarInt(1),
1273            target_sequence: VarInt(42),
1274            local_address: addr,
1275            target_peer_id: None,
1276        };
1277        original.encode(&mut buf);
1278        let frames = frames(buf);
1279        assert_eq!(frames.len(), 1);
1280        match &frames[0] {
1281            Frame::PunchMeNow(decoded) => {
1282                assert_eq!(decoded.round, original.round);
1283                assert_eq!(decoded.target_sequence, original.target_sequence);
1284                assert_eq!(decoded.local_address, original.local_address);
1285            }
1286            x => panic!("incorrect frame {x:?}"),
1287        }
1288    }
1289
1290    #[test]
1291    fn punch_me_now_ipv6_coding() {
1292        let mut buf = Vec::new();
1293        let addr = SocketAddr::from(([0xfe80, 0, 0, 0, 0, 0, 0, 1], 9000));
1294        let original = PunchMeNow {
1295            round: VarInt(2),
1296            target_sequence: VarInt(100),
1297            local_address: addr,
1298            target_peer_id: None,
1299        };
1300        original.encode(&mut buf);
1301        let frames = frames(buf);
1302        assert_eq!(frames.len(), 1);
1303        match &frames[0] {
1304            Frame::PunchMeNow(decoded) => {
1305                assert_eq!(decoded.round, original.round);
1306                assert_eq!(decoded.target_sequence, original.target_sequence);
1307                assert_eq!(decoded.local_address, original.local_address);
1308            }
1309            x => panic!("incorrect frame {x:?}"),
1310        }
1311    }
1312
1313    #[test]
1314    fn remove_address_coding() {
1315        let mut buf = Vec::new();
1316        let original = RemoveAddress {
1317            sequence: VarInt(42),
1318        };
1319        original.encode(&mut buf);
1320        let frames = frames(buf);
1321        assert_eq!(frames.len(), 1);
1322        match &frames[0] {
1323            Frame::RemoveAddress(decoded) => {
1324                assert_eq!(decoded.sequence, original.sequence);
1325            }
1326            x => panic!("incorrect frame {x:?}"),
1327        }
1328    }
1329
1330    #[test]
1331    fn nat_traversal_frame_size_bounds() {
1332        // Test that the SIZE_BOUND constants are correct
1333        let mut buf = Vec::new();
1334        
1335        // AddAddress with IPv6 (worst case)
1336        let addr = AddAddress {
1337            sequence: VarInt::MAX,
1338            address: SocketAddr::from(([0xffff; 8], 65535)),
1339            priority: VarInt::MAX,
1340        };
1341        addr.encode(&mut buf);
1342        assert!(buf.len() <= AddAddress::SIZE_BOUND);
1343        buf.clear();
1344        
1345        // PunchMeNow with IPv6 (worst case)
1346        let punch = PunchMeNow {
1347            round: VarInt::MAX,
1348            target_sequence: VarInt::MAX,
1349            local_address: SocketAddr::from(([0xffff; 8], 65535)),
1350            target_peer_id: Some([0xff; 32]),
1351        };
1352        punch.encode(&mut buf);
1353        assert!(buf.len() <= PunchMeNow::SIZE_BOUND);
1354        buf.clear();
1355        
1356        // RemoveAddress
1357        let remove = RemoveAddress {
1358            sequence: VarInt::MAX,
1359        };
1360        remove.encode(&mut buf);
1361        assert!(buf.len() <= RemoveAddress::SIZE_BOUND);
1362    }
1363
1364    #[test]
1365    fn punch_me_now_with_target_peer_id() {
1366        let mut buf = Vec::new();
1367        let target_peer_id = [0x42; 32]; // Test peer ID
1368        let addr = SocketAddr::from(([192, 168, 1, 100], 12345));
1369        let original = PunchMeNow {
1370            round: VarInt(5),
1371            target_sequence: VarInt(999),
1372            local_address: addr,
1373            target_peer_id: Some(target_peer_id),
1374        };
1375        original.encode(&mut buf);
1376        let frames = frames(buf);
1377        assert_eq!(frames.len(), 1);
1378        match &frames[0] {
1379            Frame::PunchMeNow(decoded) => {
1380                assert_eq!(decoded.round, original.round);
1381                assert_eq!(decoded.target_sequence, original.target_sequence);
1382                assert_eq!(decoded.local_address, original.local_address);
1383                assert_eq!(decoded.target_peer_id, Some(target_peer_id));
1384            }
1385            x => panic!("incorrect frame {x:?}"),
1386        }
1387    }
1388
1389    #[test]
1390    fn nat_traversal_frame_edge_cases() {
1391        // Test minimum values
1392        let mut buf = Vec::new();
1393        
1394        // AddAddress with minimum values
1395        let min_addr = AddAddress {
1396            sequence: VarInt(0),
1397            address: SocketAddr::from(([0, 0, 0, 0], 0)),
1398            priority: VarInt(0),
1399        };
1400        min_addr.encode(&mut buf);
1401        let frames1 = frames(buf.clone());
1402        assert_eq!(frames1.len(), 1);
1403        buf.clear();
1404        
1405        // PunchMeNow with minimum values
1406        let min_punch = PunchMeNow {
1407            round: VarInt(0),
1408            target_sequence: VarInt(0),
1409            local_address: SocketAddr::from(([0, 0, 0, 0], 0)),
1410            target_peer_id: None,
1411        };
1412        min_punch.encode(&mut buf);
1413        let frames2 = frames(buf.clone());
1414        assert_eq!(frames2.len(), 1);
1415        buf.clear();
1416        
1417        // RemoveAddress with minimum values
1418        let min_remove = RemoveAddress {
1419            sequence: VarInt(0),
1420        };
1421        min_remove.encode(&mut buf);
1422        let frames3 = frames(buf);
1423        assert_eq!(frames3.len(), 1);
1424    }
1425
1426    #[test]
1427    fn nat_traversal_frame_boundary_values() {
1428        // Test VarInt boundary values
1429        let mut buf = Vec::new();
1430        
1431        // Test VarInt boundary values for AddAddress
1432        let boundary_values = [
1433            VarInt(0),
1434            VarInt(63),          // Maximum 1-byte VarInt
1435            VarInt(64),          // Minimum 2-byte VarInt
1436            VarInt(16383),       // Maximum 2-byte VarInt
1437            VarInt(16384),       // Minimum 4-byte VarInt
1438            VarInt(1073741823),  // Maximum 4-byte VarInt
1439            VarInt(1073741824),  // Minimum 8-byte VarInt
1440        ];
1441        
1442        for &sequence in &boundary_values {
1443            for &priority in &boundary_values {
1444                let addr = AddAddress {
1445                    sequence,
1446                    address: SocketAddr::from(([127, 0, 0, 1], 8080)),
1447                    priority,
1448                };
1449                addr.encode(&mut buf);
1450                let parsed_frames = frames(buf.clone());
1451                assert_eq!(parsed_frames.len(), 1);
1452                match &parsed_frames[0] {
1453                    Frame::AddAddress(decoded) => {
1454                        assert_eq!(decoded.sequence, sequence);
1455                        assert_eq!(decoded.priority, priority);
1456                    }
1457                    x => panic!("incorrect frame {x:?}"),
1458                }
1459                buf.clear();
1460            }
1461        }
1462    }
1463
1464    #[test]
1465    fn nat_traversal_frame_error_handling() {
1466        // Test malformed frame data
1467        let malformed_frames = vec![
1468            // Too short for any NAT traversal frame
1469            vec![0x40], // Just frame type, no data
1470            vec![0x41], // Just frame type, no data
1471            vec![0x42], // Just frame type, no data
1472            
1473            // Incomplete AddAddress frames
1474            vec![0x40, 0x01], // Frame type + partial sequence
1475            vec![0x40, 0x01, 0x04], // Frame type + sequence + incomplete address
1476            
1477            // Incomplete PunchMeNow frames
1478            vec![0x41, 0x01], // Frame type + partial round
1479            vec![0x41, 0x01, 0x02], // Frame type + round + partial target_sequence
1480            
1481            // Incomplete RemoveAddress frames
1482            // RemoveAddress is actually hard to make malformed since it only has sequence
1483            
1484            // Invalid IP address types
1485            vec![0x40, 0x01, 0x99, 0x01, 0x02, 0x03, 0x04], // Invalid address type
1486        ];
1487        
1488        for malformed in malformed_frames {
1489            let result = Iter::new(Bytes::from(malformed)).unwrap().next();
1490            if let Some(frame_result) = result {
1491                // Should either parse successfully (for valid but incomplete data)
1492                // or return an error (for truly malformed data)
1493                match frame_result {
1494                    Ok(_) => {}, // Valid frame parsed
1495                    Err(_) => {}, // Expected error for malformed data
1496                }
1497            }
1498        }
1499    }
1500
1501    #[test]
1502    fn nat_traversal_frame_roundtrip_consistency() {
1503        // Test that encoding and then decoding produces identical frames
1504        
1505        // Test AddAddress frames
1506        let add_test_cases = vec![
1507            AddAddress {
1508                sequence: VarInt(42),
1509                address: SocketAddr::from(([127, 0, 0, 1], 8080)),
1510                priority: VarInt(100),
1511            },
1512            AddAddress {
1513                sequence: VarInt(1000),
1514                address: SocketAddr::from(([0, 0, 0, 0, 0, 0, 0, 1], 443)),
1515                priority: VarInt(255),
1516            },
1517        ];
1518        
1519        for original_add in add_test_cases {
1520            let mut buf = Vec::new();
1521            original_add.encode(&mut buf);
1522            
1523            let decoded_frames = frames(buf);
1524            assert_eq!(decoded_frames.len(), 1);
1525            
1526            match &decoded_frames[0] {
1527                Frame::AddAddress(decoded) => {
1528                    assert_eq!(original_add.sequence, decoded.sequence);
1529                    assert_eq!(original_add.address, decoded.address);
1530                    assert_eq!(original_add.priority, decoded.priority);
1531                }
1532                _ => panic!("Expected AddAddress frame"),
1533            }
1534        }
1535        
1536        // Test PunchMeNow frames
1537        let punch_test_cases = vec![
1538            PunchMeNow {
1539                round: VarInt(1),
1540                target_sequence: VarInt(42),
1541                local_address: SocketAddr::from(([192, 168, 1, 1], 9000)),
1542                target_peer_id: None,
1543            },
1544            PunchMeNow {
1545                round: VarInt(10),
1546                target_sequence: VarInt(500),
1547                local_address: SocketAddr::from(([2001, 0xdb8, 0, 0, 0, 0, 0, 1], 12345)),
1548                target_peer_id: Some([0xaa; 32]),
1549            },
1550        ];
1551        
1552        for original_punch in punch_test_cases {
1553            let mut buf = Vec::new();
1554            original_punch.encode(&mut buf);
1555            
1556            let decoded_frames = frames(buf);
1557            assert_eq!(decoded_frames.len(), 1);
1558            
1559            match &decoded_frames[0] {
1560                Frame::PunchMeNow(decoded) => {
1561                    assert_eq!(original_punch.round, decoded.round);
1562                    assert_eq!(original_punch.target_sequence, decoded.target_sequence);
1563                    assert_eq!(original_punch.local_address, decoded.local_address);
1564                    assert_eq!(original_punch.target_peer_id, decoded.target_peer_id);
1565                }
1566                _ => panic!("Expected PunchMeNow frame"),
1567            }
1568        }
1569        
1570        // Test RemoveAddress frames
1571        let remove_test_cases = vec![
1572            RemoveAddress { sequence: VarInt(123) },
1573            RemoveAddress { sequence: VarInt(0) },
1574        ];
1575        
1576        for original_remove in remove_test_cases {
1577            let mut buf = Vec::new();
1578            original_remove.encode(&mut buf);
1579            
1580            let decoded_frames = frames(buf);
1581            assert_eq!(decoded_frames.len(), 1);
1582            
1583            match &decoded_frames[0] {
1584                Frame::RemoveAddress(decoded) => {
1585                    assert_eq!(original_remove.sequence, decoded.sequence);
1586                }
1587                _ => panic!("Expected RemoveAddress frame"),
1588            }
1589        }
1590    }
1591
1592    #[test]
1593    fn nat_traversal_frame_type_constants() {
1594        // Verify that the frame type constants match the NAT traversal draft specification
1595        assert_eq!(FrameType::ADD_ADDRESS.0, 0x40);
1596        assert_eq!(FrameType::PUNCH_ME_NOW.0, 0x41);
1597        assert_eq!(FrameType::REMOVE_ADDRESS.0, 0x42);
1598    }
1599}