ant_quic/
frame.rs

1use std::{
2    fmt::{self, Write},
3    mem,
4    net::SocketAddr,
5    ops::{Range, RangeInclusive},
6};
7
8use bytes::{Buf, BufMut, Bytes};
9use tinyvec::TinyVec;
10
11use crate::{
12    Dir, MAX_CID_SIZE, RESET_TOKEN_SIZE, ResetToken, StreamId, TransportError, TransportErrorCode,
13    VarInt,
14    coding::{self, BufExt, BufMutExt, UnexpectedEnd},
15    range_set::ArrayRangeSet,
16    shared::{ConnectionId, EcnCodepoint},
17};
18
19#[cfg(feature = "arbitrary")]
20use arbitrary::Arbitrary;
21
22/// A QUIC frame type
23#[derive(Copy, Clone, Eq, PartialEq)]
24pub struct FrameType(u64);
25
26impl FrameType {
27    fn stream(self) -> Option<StreamInfo> {
28        if STREAM_TYS.contains(&self.0) {
29            Some(StreamInfo(self.0 as u8))
30        } else {
31            None
32        }
33    }
34    fn datagram(self) -> Option<DatagramInfo> {
35        if DATAGRAM_TYS.contains(&self.0) {
36            Some(DatagramInfo(self.0 as u8))
37        } else {
38            None
39        }
40    }
41}
42
43impl coding::Codec for FrameType {
44    fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
45        Ok(Self(buf.get_var()?))
46    }
47    fn encode<B: BufMut>(&self, buf: &mut B) {
48        buf.write_var(self.0);
49    }
50}
51
52pub(crate) trait FrameStruct {
53    /// Smallest number of bytes this type of frame is guaranteed to fit within.
54    const SIZE_BOUND: usize;
55}
56
57macro_rules! frame_types {
58    {$($name:ident = $val:expr,)*} => {
59        impl FrameType {
60            $(pub(crate) const $name: FrameType = FrameType($val);)*
61        }
62
63        impl fmt::Debug for FrameType {
64            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65                match self.0 {
66                    $($val => f.write_str(stringify!($name)),)*
67                    _ => write!(f, "Type({:02x})", self.0)
68                }
69            }
70        }
71
72        impl fmt::Display for FrameType {
73            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74                match self.0 {
75                    $($val => f.write_str(stringify!($name)),)*
76                    x if STREAM_TYS.contains(&x) => f.write_str("STREAM"),
77                    x if DATAGRAM_TYS.contains(&x) => f.write_str("DATAGRAM"),
78                    _ => write!(f, "<unknown {:02x}>", self.0),
79                }
80            }
81        }
82    }
83}
84
85#[derive(Debug, Copy, Clone, Eq, PartialEq)]
86struct StreamInfo(u8);
87
88impl StreamInfo {
89    fn fin(self) -> bool {
90        self.0 & 0x01 != 0
91    }
92    fn len(self) -> bool {
93        self.0 & 0x02 != 0
94    }
95    fn off(self) -> bool {
96        self.0 & 0x04 != 0
97    }
98}
99
100#[derive(Debug, Copy, Clone, Eq, PartialEq)]
101struct DatagramInfo(u8);
102
103impl DatagramInfo {
104    fn len(self) -> bool {
105        self.0 & 0x01 != 0
106    }
107}
108
109frame_types! {
110    PADDING = 0x00,
111    PING = 0x01,
112    ACK = 0x02,
113    ACK_ECN = 0x03,
114    RESET_STREAM = 0x04,
115    STOP_SENDING = 0x05,
116    CRYPTO = 0x06,
117    NEW_TOKEN = 0x07,
118    // STREAM
119    MAX_DATA = 0x10,
120    MAX_STREAM_DATA = 0x11,
121    MAX_STREAMS_BIDI = 0x12,
122    MAX_STREAMS_UNI = 0x13,
123    DATA_BLOCKED = 0x14,
124    STREAM_DATA_BLOCKED = 0x15,
125    STREAMS_BLOCKED_BIDI = 0x16,
126    STREAMS_BLOCKED_UNI = 0x17,
127    NEW_CONNECTION_ID = 0x18,
128    RETIRE_CONNECTION_ID = 0x19,
129    PATH_CHALLENGE = 0x1a,
130    PATH_RESPONSE = 0x1b,
131    CONNECTION_CLOSE = 0x1c,
132    APPLICATION_CLOSE = 0x1d,
133    HANDSHAKE_DONE = 0x1e,
134    // ACK Frequency
135    ACK_FREQUENCY = 0xaf,
136    IMMEDIATE_ACK = 0x1f,
137    // NAT Traversal Extension
138    ADD_ADDRESS = 0x40,
139    PUNCH_ME_NOW = 0x41,
140    REMOVE_ADDRESS = 0x42,
141    // DATAGRAM
142}
143
144const STREAM_TYS: RangeInclusive<u64> = RangeInclusive::new(0x08, 0x0f);
145const DATAGRAM_TYS: RangeInclusive<u64> = RangeInclusive::new(0x30, 0x31);
146
147#[derive(Debug)]
148pub(crate) enum Frame {
149    Padding,
150    Ping,
151    Ack(Ack),
152    ResetStream(ResetStream),
153    StopSending(StopSending),
154    Crypto(Crypto),
155    NewToken(NewToken),
156    Stream(Stream),
157    MaxData(VarInt),
158    MaxStreamData { id: StreamId, offset: u64 },
159    MaxStreams { dir: Dir, count: u64 },
160    DataBlocked { offset: u64 },
161    StreamDataBlocked { id: StreamId, offset: u64 },
162    StreamsBlocked { dir: Dir, limit: u64 },
163    NewConnectionId(NewConnectionId),
164    RetireConnectionId { sequence: u64 },
165    PathChallenge(u64),
166    PathResponse(u64),
167    Close(Close),
168    Datagram(Datagram),
169    AckFrequency(AckFrequency),
170    ImmediateAck,
171    HandshakeDone,
172    AddAddress(AddAddress),
173    PunchMeNow(PunchMeNow),
174    RemoveAddress(RemoveAddress),
175}
176
177impl Frame {
178    pub(crate) fn ty(&self) -> FrameType {
179        use Frame::*;
180        match *self {
181            Padding => FrameType::PADDING,
182            ResetStream(_) => FrameType::RESET_STREAM,
183            Close(self::Close::Connection(_)) => FrameType::CONNECTION_CLOSE,
184            Close(self::Close::Application(_)) => FrameType::APPLICATION_CLOSE,
185            MaxData(_) => FrameType::MAX_DATA,
186            MaxStreamData { .. } => FrameType::MAX_STREAM_DATA,
187            MaxStreams { dir: Dir::Bi, .. } => FrameType::MAX_STREAMS_BIDI,
188            MaxStreams { dir: Dir::Uni, .. } => FrameType::MAX_STREAMS_UNI,
189            Ping => FrameType::PING,
190            DataBlocked { .. } => FrameType::DATA_BLOCKED,
191            StreamDataBlocked { .. } => FrameType::STREAM_DATA_BLOCKED,
192            StreamsBlocked { dir: Dir::Bi, .. } => FrameType::STREAMS_BLOCKED_BIDI,
193            StreamsBlocked { dir: Dir::Uni, .. } => FrameType::STREAMS_BLOCKED_UNI,
194            StopSending { .. } => FrameType::STOP_SENDING,
195            RetireConnectionId { .. } => FrameType::RETIRE_CONNECTION_ID,
196            Ack(_) => FrameType::ACK,
197            Stream(ref x) => {
198                let mut ty = *STREAM_TYS.start();
199                if x.fin {
200                    ty |= 0x01;
201                }
202                if x.offset != 0 {
203                    ty |= 0x04;
204                }
205                FrameType(ty)
206            }
207            PathChallenge(_) => FrameType::PATH_CHALLENGE,
208            PathResponse(_) => FrameType::PATH_RESPONSE,
209            NewConnectionId { .. } => FrameType::NEW_CONNECTION_ID,
210            Crypto(_) => FrameType::CRYPTO,
211            NewToken(_) => FrameType::NEW_TOKEN,
212            Datagram(_) => FrameType(*DATAGRAM_TYS.start()),
213            AckFrequency(_) => FrameType::ACK_FREQUENCY,
214            ImmediateAck => FrameType::IMMEDIATE_ACK,
215            HandshakeDone => FrameType::HANDSHAKE_DONE,
216            AddAddress(_) => FrameType::ADD_ADDRESS,
217            PunchMeNow(_) => FrameType::PUNCH_ME_NOW,
218            RemoveAddress(_) => FrameType::REMOVE_ADDRESS,
219        }
220    }
221
222    pub(crate) fn is_ack_eliciting(&self) -> bool {
223        !matches!(*self, Self::Ack(_) | Self::Padding | Self::Close(_))
224    }
225}
226
227#[derive(Clone, Debug)]
228pub enum Close {
229    Connection(ConnectionClose),
230    Application(ApplicationClose),
231}
232
233impl Close {
234    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
235        match *self {
236            Self::Connection(ref x) => x.encode(out, max_len),
237            Self::Application(ref x) => x.encode(out, max_len),
238        }
239    }
240
241    pub(crate) fn is_transport_layer(&self) -> bool {
242        matches!(*self, Self::Connection(_))
243    }
244}
245
246impl From<TransportError> for Close {
247    fn from(x: TransportError) -> Self {
248        Self::Connection(x.into())
249    }
250}
251impl From<ConnectionClose> for Close {
252    fn from(x: ConnectionClose) -> Self {
253        Self::Connection(x)
254    }
255}
256impl From<ApplicationClose> for Close {
257    fn from(x: ApplicationClose) -> Self {
258        Self::Application(x)
259    }
260}
261
262/// Reason given by the transport for closing the connection
263#[derive(Debug, Clone, PartialEq, Eq)]
264pub struct ConnectionClose {
265    /// Class of error as encoded in the specification
266    pub error_code: TransportErrorCode,
267    /// Type of frame that caused the close
268    pub frame_type: Option<FrameType>,
269    /// Human-readable reason for the close
270    pub reason: Bytes,
271}
272
273impl fmt::Display for ConnectionClose {
274    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
275        self.error_code.fmt(f)?;
276        if !self.reason.as_ref().is_empty() {
277            f.write_str(": ")?;
278            f.write_str(&String::from_utf8_lossy(&self.reason))?;
279        }
280        Ok(())
281    }
282}
283
284impl From<TransportError> for ConnectionClose {
285    fn from(x: TransportError) -> Self {
286        Self {
287            error_code: x.code,
288            frame_type: x.frame,
289            reason: x.reason.into(),
290        }
291    }
292}
293
294impl FrameStruct for ConnectionClose {
295    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
296}
297
298impl ConnectionClose {
299    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
300        out.write(FrameType::CONNECTION_CLOSE); // 1 byte
301        out.write(self.error_code); // <= 8 bytes
302        let ty = self.frame_type.map_or(0, |x| x.0);
303        out.write_var(ty); // <= 8 bytes
304        let max_len = max_len
305            - 3
306            - VarInt::from_u64(ty).unwrap().size()
307            - VarInt::from_u64(self.reason.len() as u64).unwrap().size();
308        let actual_len = self.reason.len().min(max_len);
309        out.write_var(actual_len as u64); // <= 8 bytes
310        out.put_slice(&self.reason[0..actual_len]); // whatever's left
311    }
312}
313
314/// Reason given by an application for closing the connection
315#[derive(Debug, Clone, PartialEq, Eq)]
316pub struct ApplicationClose {
317    /// Application-specific reason code
318    pub error_code: VarInt,
319    /// Human-readable reason for the close
320    pub reason: Bytes,
321}
322
323impl fmt::Display for ApplicationClose {
324    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
325        if !self.reason.as_ref().is_empty() {
326            f.write_str(&String::from_utf8_lossy(&self.reason))?;
327            f.write_str(" (code ")?;
328            self.error_code.fmt(f)?;
329            f.write_str(")")?;
330        } else {
331            self.error_code.fmt(f)?;
332        }
333        Ok(())
334    }
335}
336
337impl FrameStruct for ApplicationClose {
338    const SIZE_BOUND: usize = 1 + 8 + 8;
339}
340
341impl ApplicationClose {
342    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
343        out.write(FrameType::APPLICATION_CLOSE); // 1 byte
344        out.write(self.error_code); // <= 8 bytes
345        let max_len = max_len - 3 - VarInt::from_u64(self.reason.len() as u64).unwrap().size();
346        let actual_len = self.reason.len().min(max_len);
347        out.write_var(actual_len as u64); // <= 8 bytes
348        out.put_slice(&self.reason[0..actual_len]); // whatever's left
349    }
350}
351
352#[derive(Clone, Eq, PartialEq)]
353pub struct Ack {
354    pub largest: u64,
355    pub delay: u64,
356    pub additional: Bytes,
357    pub ecn: Option<EcnCounts>,
358}
359
360impl fmt::Debug for Ack {
361    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362        let mut ranges = "[".to_string();
363        let mut first = true;
364        for range in self.iter() {
365            if !first {
366                ranges.push(',');
367            }
368            write!(ranges, "{range:?}").unwrap();
369            first = false;
370        }
371        ranges.push(']');
372
373        f.debug_struct("Ack")
374            .field("largest", &self.largest)
375            .field("delay", &self.delay)
376            .field("ecn", &self.ecn)
377            .field("ranges", &ranges)
378            .finish()
379    }
380}
381
382impl<'a> IntoIterator for &'a Ack {
383    type Item = RangeInclusive<u64>;
384    type IntoIter = AckIter<'a>;
385
386    fn into_iter(self) -> AckIter<'a> {
387        AckIter::new(self.largest, &self.additional[..])
388    }
389}
390
391impl Ack {
392    pub fn encode<W: BufMut>(
393        delay: u64,
394        ranges: &ArrayRangeSet,
395        ecn: Option<&EcnCounts>,
396        buf: &mut W,
397    ) {
398        let mut rest = ranges.iter().rev();
399        let first = rest.next().unwrap();
400        let largest = first.end - 1;
401        let first_size = first.end - first.start;
402        buf.write(if ecn.is_some() {
403            FrameType::ACK_ECN
404        } else {
405            FrameType::ACK
406        });
407        buf.write_var(largest);
408        buf.write_var(delay);
409        buf.write_var(ranges.len() as u64 - 1);
410        buf.write_var(first_size - 1);
411        let mut prev = first.start;
412        for block in rest {
413            let size = block.end - block.start;
414            buf.write_var(prev - block.end - 1);
415            buf.write_var(size - 1);
416            prev = block.start;
417        }
418        if let Some(x) = ecn {
419            x.encode(buf)
420        }
421    }
422
423    pub fn iter(&self) -> AckIter<'_> {
424        self.into_iter()
425    }
426}
427
428#[derive(Debug, Copy, Clone, Eq, PartialEq)]
429pub struct EcnCounts {
430    pub ect0: u64,
431    pub ect1: u64,
432    pub ce: u64,
433}
434
435impl std::ops::AddAssign<EcnCodepoint> for EcnCounts {
436    fn add_assign(&mut self, rhs: EcnCodepoint) {
437        match rhs {
438            EcnCodepoint::Ect0 => {
439                self.ect0 += 1;
440            }
441            EcnCodepoint::Ect1 => {
442                self.ect1 += 1;
443            }
444            EcnCodepoint::Ce => {
445                self.ce += 1;
446            }
447        }
448    }
449}
450
451impl EcnCounts {
452    pub const ZERO: Self = Self {
453        ect0: 0,
454        ect1: 0,
455        ce: 0,
456    };
457
458    pub fn encode<W: BufMut>(&self, out: &mut W) {
459        out.write_var(self.ect0);
460        out.write_var(self.ect1);
461        out.write_var(self.ce);
462    }
463}
464
465#[derive(Debug, Clone)]
466pub(crate) struct Stream {
467    pub(crate) id: StreamId,
468    pub(crate) offset: u64,
469    pub(crate) fin: bool,
470    pub(crate) data: Bytes,
471}
472
473impl FrameStruct for Stream {
474    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
475}
476
477/// Metadata from a stream frame
478#[derive(Debug, Clone)]
479pub(crate) struct StreamMeta {
480    pub(crate) id: StreamId,
481    pub(crate) offsets: Range<u64>,
482    pub(crate) fin: bool,
483}
484
485// This manual implementation exists because `Default` is not implemented for `StreamId`
486impl Default for StreamMeta {
487    fn default() -> Self {
488        Self {
489            id: StreamId(0),
490            offsets: 0..0,
491            fin: false,
492        }
493    }
494}
495
496impl StreamMeta {
497    pub(crate) fn encode<W: BufMut>(&self, length: bool, out: &mut W) {
498        let mut ty = *STREAM_TYS.start();
499        if self.offsets.start != 0 {
500            ty |= 0x04;
501        }
502        if length {
503            ty |= 0x02;
504        }
505        if self.fin {
506            ty |= 0x01;
507        }
508        out.write_var(ty); // 1 byte
509        out.write(self.id); // <=8 bytes
510        if self.offsets.start != 0 {
511            out.write_var(self.offsets.start); // <=8 bytes
512        }
513        if length {
514            out.write_var(self.offsets.end - self.offsets.start); // <=8 bytes
515        }
516    }
517}
518
519/// A vector of [`StreamMeta`] with optimization for the single element case
520pub(crate) type StreamMetaVec = TinyVec<[StreamMeta; 1]>;
521
522#[derive(Debug, Clone)]
523pub(crate) struct Crypto {
524    pub(crate) offset: u64,
525    pub(crate) data: Bytes,
526}
527
528impl Crypto {
529    pub(crate) const SIZE_BOUND: usize = 17;
530
531    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
532        out.write(FrameType::CRYPTO);
533        out.write_var(self.offset);
534        out.write_var(self.data.len() as u64);
535        out.put_slice(&self.data);
536    }
537}
538
539#[derive(Debug, Clone)]
540pub(crate) struct NewToken {
541    pub(crate) token: Bytes,
542}
543
544impl NewToken {
545    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
546        out.write(FrameType::NEW_TOKEN);
547        out.write_var(self.token.len() as u64);
548        out.put_slice(&self.token);
549    }
550
551    pub(crate) fn size(&self) -> usize {
552        1 + VarInt::from_u64(self.token.len() as u64).unwrap().size() + self.token.len()
553    }
554}
555
556pub(crate) struct Iter {
557    bytes: Bytes,
558    last_ty: Option<FrameType>,
559}
560
561impl Iter {
562    pub(crate) fn new(payload: Bytes) -> Result<Self, TransportError> {
563        if payload.is_empty() {
564            // "An endpoint MUST treat receipt of a packet containing no frames as a
565            // connection error of type PROTOCOL_VIOLATION."
566            // https://www.rfc-editor.org/rfc/rfc9000.html#name-frames-and-frame-types
567            return Err(TransportError::PROTOCOL_VIOLATION(
568                "packet payload is empty",
569            ));
570        }
571
572        Ok(Self {
573            bytes: payload,
574            last_ty: None,
575        })
576    }
577
578    fn take_len(&mut self) -> Result<Bytes, UnexpectedEnd> {
579        let len = self.bytes.get_var()?;
580        if len > self.bytes.remaining() as u64 {
581            return Err(UnexpectedEnd);
582        }
583        Ok(self.bytes.split_to(len as usize))
584    }
585
586    fn try_next(&mut self) -> Result<Frame, IterErr> {
587        let ty = self.bytes.get::<FrameType>()?;
588        self.last_ty = Some(ty);
589        Ok(match ty {
590            FrameType::PADDING => Frame::Padding,
591            FrameType::RESET_STREAM => Frame::ResetStream(ResetStream {
592                id: self.bytes.get()?,
593                error_code: self.bytes.get()?,
594                final_offset: self.bytes.get()?,
595            }),
596            FrameType::CONNECTION_CLOSE => Frame::Close(Close::Connection(ConnectionClose {
597                error_code: self.bytes.get()?,
598                frame_type: {
599                    let x = self.bytes.get_var()?;
600                    if x == 0 { None } else { Some(FrameType(x)) }
601                },
602                reason: self.take_len()?,
603            })),
604            FrameType::APPLICATION_CLOSE => Frame::Close(Close::Application(ApplicationClose {
605                error_code: self.bytes.get()?,
606                reason: self.take_len()?,
607            })),
608            FrameType::MAX_DATA => Frame::MaxData(self.bytes.get()?),
609            FrameType::MAX_STREAM_DATA => Frame::MaxStreamData {
610                id: self.bytes.get()?,
611                offset: self.bytes.get_var()?,
612            },
613            FrameType::MAX_STREAMS_BIDI => Frame::MaxStreams {
614                dir: Dir::Bi,
615                count: self.bytes.get_var()?,
616            },
617            FrameType::MAX_STREAMS_UNI => Frame::MaxStreams {
618                dir: Dir::Uni,
619                count: self.bytes.get_var()?,
620            },
621            FrameType::PING => Frame::Ping,
622            FrameType::DATA_BLOCKED => Frame::DataBlocked {
623                offset: self.bytes.get_var()?,
624            },
625            FrameType::STREAM_DATA_BLOCKED => Frame::StreamDataBlocked {
626                id: self.bytes.get()?,
627                offset: self.bytes.get_var()?,
628            },
629            FrameType::STREAMS_BLOCKED_BIDI => Frame::StreamsBlocked {
630                dir: Dir::Bi,
631                limit: self.bytes.get_var()?,
632            },
633            FrameType::STREAMS_BLOCKED_UNI => Frame::StreamsBlocked {
634                dir: Dir::Uni,
635                limit: self.bytes.get_var()?,
636            },
637            FrameType::STOP_SENDING => Frame::StopSending(StopSending {
638                id: self.bytes.get()?,
639                error_code: self.bytes.get()?,
640            }),
641            FrameType::RETIRE_CONNECTION_ID => Frame::RetireConnectionId {
642                sequence: self.bytes.get_var()?,
643            },
644            FrameType::ACK | FrameType::ACK_ECN => {
645                let largest = self.bytes.get_var()?;
646                let delay = self.bytes.get_var()?;
647                let extra_blocks = self.bytes.get_var()? as usize;
648                let n = scan_ack_blocks(&self.bytes, largest, extra_blocks)?;
649                Frame::Ack(Ack {
650                    delay,
651                    largest,
652                    additional: self.bytes.split_to(n),
653                    ecn: if ty != FrameType::ACK_ECN {
654                        None
655                    } else {
656                        Some(EcnCounts {
657                            ect0: self.bytes.get_var()?,
658                            ect1: self.bytes.get_var()?,
659                            ce: self.bytes.get_var()?,
660                        })
661                    },
662                })
663            }
664            FrameType::PATH_CHALLENGE => Frame::PathChallenge(self.bytes.get()?),
665            FrameType::PATH_RESPONSE => Frame::PathResponse(self.bytes.get()?),
666            FrameType::NEW_CONNECTION_ID => {
667                let sequence = self.bytes.get_var()?;
668                let retire_prior_to = self.bytes.get_var()?;
669                if retire_prior_to > sequence {
670                    return Err(IterErr::Malformed);
671                }
672                let length = self.bytes.get::<u8>()? as usize;
673                if length > MAX_CID_SIZE || length == 0 {
674                    return Err(IterErr::Malformed);
675                }
676                if length > self.bytes.remaining() {
677                    return Err(IterErr::UnexpectedEnd);
678                }
679                let mut stage = [0; MAX_CID_SIZE];
680                self.bytes.copy_to_slice(&mut stage[0..length]);
681                let id = ConnectionId::new(&stage[..length]);
682                if self.bytes.remaining() < 16 {
683                    return Err(IterErr::UnexpectedEnd);
684                }
685                let mut reset_token = [0; RESET_TOKEN_SIZE];
686                self.bytes.copy_to_slice(&mut reset_token);
687                Frame::NewConnectionId(NewConnectionId {
688                    sequence,
689                    retire_prior_to,
690                    id,
691                    reset_token: reset_token.into(),
692                })
693            }
694            FrameType::CRYPTO => Frame::Crypto(Crypto {
695                offset: self.bytes.get_var()?,
696                data: self.take_len()?,
697            }),
698            FrameType::NEW_TOKEN => Frame::NewToken(NewToken {
699                token: self.take_len()?,
700            }),
701            FrameType::HANDSHAKE_DONE => Frame::HandshakeDone,
702            FrameType::ACK_FREQUENCY => Frame::AckFrequency(AckFrequency {
703                sequence: self.bytes.get()?,
704                ack_eliciting_threshold: self.bytes.get()?,
705                request_max_ack_delay: self.bytes.get()?,
706                reordering_threshold: self.bytes.get()?,
707            }),
708            FrameType::IMMEDIATE_ACK => Frame::ImmediateAck,
709            FrameType::ADD_ADDRESS => Frame::AddAddress(AddAddress::decode(&mut self.bytes)?),
710            FrameType::PUNCH_ME_NOW => Frame::PunchMeNow(PunchMeNow::decode(&mut self.bytes)?),
711            FrameType::REMOVE_ADDRESS => Frame::RemoveAddress(RemoveAddress::decode(&mut self.bytes)?),
712            _ => {
713                if let Some(s) = ty.stream() {
714                    Frame::Stream(Stream {
715                        id: self.bytes.get()?,
716                        offset: if s.off() { self.bytes.get_var()? } else { 0 },
717                        fin: s.fin(),
718                        data: if s.len() {
719                            self.take_len()?
720                        } else {
721                            self.take_remaining()
722                        },
723                    })
724                } else if let Some(d) = ty.datagram() {
725                    Frame::Datagram(Datagram {
726                        data: if d.len() {
727                            self.take_len()?
728                        } else {
729                            self.take_remaining()
730                        },
731                    })
732                } else {
733                    return Err(IterErr::InvalidFrameId);
734                }
735            }
736        })
737    }
738
739    fn take_remaining(&mut self) -> Bytes {
740        mem::take(&mut self.bytes)
741    }
742}
743
744impl Iterator for Iter {
745    type Item = Result<Frame, InvalidFrame>;
746    fn next(&mut self) -> Option<Self::Item> {
747        if !self.bytes.has_remaining() {
748            return None;
749        }
750        match self.try_next() {
751            Ok(x) => Some(Ok(x)),
752            Err(e) => {
753                // Corrupt frame, skip it and everything that follows
754                self.bytes.clear();
755                Some(Err(InvalidFrame {
756                    ty: self.last_ty,
757                    reason: e.reason(),
758                }))
759            }
760        }
761    }
762}
763
764#[derive(Debug)]
765pub(crate) struct InvalidFrame {
766    pub(crate) ty: Option<FrameType>,
767    pub(crate) reason: &'static str,
768}
769
770impl From<InvalidFrame> for TransportError {
771    fn from(err: InvalidFrame) -> Self {
772        let mut te = Self::FRAME_ENCODING_ERROR(err.reason);
773        te.frame = err.ty;
774        te
775    }
776}
777
778/// Validate exactly `n` ACK ranges in `buf` and return the number of bytes they cover
779fn scan_ack_blocks(mut buf: &[u8], largest: u64, n: usize) -> Result<usize, IterErr> {
780    let total_len = buf.remaining();
781    let first_block = buf.get_var()?;
782    let mut smallest = largest.checked_sub(first_block).ok_or(IterErr::Malformed)?;
783    for _ in 0..n {
784        let gap = buf.get_var()?;
785        smallest = smallest.checked_sub(gap + 2).ok_or(IterErr::Malformed)?;
786        let block = buf.get_var()?;
787        smallest = smallest.checked_sub(block).ok_or(IterErr::Malformed)?;
788    }
789    Ok(total_len - buf.remaining())
790}
791
792enum IterErr {
793    UnexpectedEnd,
794    InvalidFrameId,
795    Malformed,
796}
797
798impl IterErr {
799    fn reason(&self) -> &'static str {
800        use IterErr::*;
801        match *self {
802            UnexpectedEnd => "unexpected end",
803            InvalidFrameId => "invalid frame ID",
804            Malformed => "malformed",
805        }
806    }
807}
808
809impl From<UnexpectedEnd> for IterErr {
810    fn from(_: UnexpectedEnd) -> Self {
811        Self::UnexpectedEnd
812    }
813}
814
815#[derive(Debug, Clone)]
816pub struct AckIter<'a> {
817    largest: u64,
818    data: &'a [u8],
819}
820
821impl<'a> AckIter<'a> {
822    fn new(largest: u64, data: &'a [u8]) -> Self {
823        Self { largest, data }
824    }
825}
826
827impl Iterator for AckIter<'_> {
828    type Item = RangeInclusive<u64>;
829    fn next(&mut self) -> Option<RangeInclusive<u64>> {
830        if !self.data.has_remaining() {
831            return None;
832        }
833        let block = self.data.get_var().unwrap();
834        let largest = self.largest;
835        if let Ok(gap) = self.data.get_var() {
836            self.largest -= block + gap + 2;
837        }
838        Some(largest - block..=largest)
839    }
840}
841
842#[allow(unreachable_pub)] // fuzzing only
843#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
844#[derive(Debug, Copy, Clone)]
845pub struct ResetStream {
846    pub(crate) id: StreamId,
847    pub(crate) error_code: VarInt,
848    pub(crate) final_offset: VarInt,
849}
850
851impl FrameStruct for ResetStream {
852    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
853}
854
855impl ResetStream {
856    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
857        out.write(FrameType::RESET_STREAM); // 1 byte
858        out.write(self.id); // <= 8 bytes
859        out.write(self.error_code); // <= 8 bytes
860        out.write(self.final_offset); // <= 8 bytes
861    }
862}
863
864#[derive(Debug, Copy, Clone)]
865pub(crate) struct StopSending {
866    pub(crate) id: StreamId,
867    pub(crate) error_code: VarInt,
868}
869
870impl FrameStruct for StopSending {
871    const SIZE_BOUND: usize = 1 + 8 + 8;
872}
873
874impl StopSending {
875    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
876        out.write(FrameType::STOP_SENDING); // 1 byte
877        out.write(self.id); // <= 8 bytes
878        out.write(self.error_code) // <= 8 bytes
879    }
880}
881
882#[derive(Debug, Copy, Clone)]
883pub(crate) struct NewConnectionId {
884    pub(crate) sequence: u64,
885    pub(crate) retire_prior_to: u64,
886    pub(crate) id: ConnectionId,
887    pub(crate) reset_token: ResetToken,
888}
889
890impl NewConnectionId {
891    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
892        out.write(FrameType::NEW_CONNECTION_ID);
893        out.write_var(self.sequence);
894        out.write_var(self.retire_prior_to);
895        out.write(self.id.len() as u8);
896        out.put_slice(&self.id);
897        out.put_slice(&self.reset_token);
898    }
899}
900
901/// Smallest number of bytes this type of frame is guaranteed to fit within.
902pub(crate) const RETIRE_CONNECTION_ID_SIZE_BOUND: usize = 9;
903
904/// An unreliable datagram
905#[derive(Debug, Clone)]
906pub struct Datagram {
907    /// Payload
908    pub data: Bytes,
909}
910
911impl FrameStruct for Datagram {
912    const SIZE_BOUND: usize = 1 + 8;
913}
914
915impl Datagram {
916    pub(crate) fn encode(&self, length: bool, out: &mut Vec<u8>) {
917        out.write(FrameType(*DATAGRAM_TYS.start() | u64::from(length))); // 1 byte
918        if length {
919            // Safe to unwrap because we check length sanity before queueing datagrams
920            out.write(VarInt::from_u64(self.data.len() as u64).unwrap()); // <= 8 bytes
921        }
922        out.extend_from_slice(&self.data);
923    }
924
925    pub(crate) fn size(&self, length: bool) -> usize {
926        1 + if length {
927            VarInt::from_u64(self.data.len() as u64).unwrap().size()
928        } else {
929            0
930        } + self.data.len()
931    }
932}
933
934#[derive(Debug, Copy, Clone, PartialEq, Eq)]
935pub(crate) struct AckFrequency {
936    pub(crate) sequence: VarInt,
937    pub(crate) ack_eliciting_threshold: VarInt,
938    pub(crate) request_max_ack_delay: VarInt,
939    pub(crate) reordering_threshold: VarInt,
940}
941
942impl AckFrequency {
943    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
944        buf.write(FrameType::ACK_FREQUENCY);
945        buf.write(self.sequence);
946        buf.write(self.ack_eliciting_threshold);
947        buf.write(self.request_max_ack_delay);
948        buf.write(self.reordering_threshold);
949    }
950}
951
952/// NAT traversal frame for advertising candidate addresses
953#[derive(Debug, Clone, PartialEq, Eq)]
954pub(crate) struct AddAddress {
955    /// Sequence number for this address advertisement
956    pub(crate) sequence: VarInt,
957    /// Socket address being advertised
958    pub(crate) address: SocketAddr,
959    /// Priority of this address candidate
960    pub(crate) priority: VarInt,
961}
962
963impl AddAddress {
964    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
965        buf.write(FrameType::ADD_ADDRESS);
966        buf.write(self.sequence);
967        buf.write(self.priority);
968        
969        match self.address {
970            SocketAddr::V4(addr) => {
971                buf.put_u8(4); // IPv4 indicator
972                buf.put_slice(&addr.ip().octets());
973                buf.put_u16(addr.port());
974            }
975            SocketAddr::V6(addr) => {
976                buf.put_u8(6); // IPv6 indicator
977                buf.put_slice(&addr.ip().octets());
978                buf.put_u16(addr.port());
979                buf.put_u32(addr.flowinfo());
980                buf.put_u32(addr.scope_id());
981            }
982        }
983    }
984    
985    pub(crate) fn decode<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
986        let sequence = r.get()?;
987        let priority = r.get()?;
988        let ip_version = r.get::<u8>()?;
989        
990        let address = match ip_version {
991            4 => {
992                let mut octets = [0u8; 4];
993                r.copy_to_slice(&mut octets);
994                let port = r.get::<u16>()?;
995                SocketAddr::V4(std::net::SocketAddrV4::new(
996                    std::net::Ipv4Addr::from(octets),
997                    port,
998                ))
999            }
1000            6 => {
1001                let mut octets = [0u8; 16];
1002                r.copy_to_slice(&mut octets);
1003                let port = r.get::<u16>()?;
1004                let flowinfo = r.get::<u32>()?;
1005                let scope_id = r.get::<u32>()?;
1006                SocketAddr::V6(std::net::SocketAddrV6::new(
1007                    std::net::Ipv6Addr::from(octets),
1008                    port,
1009                    flowinfo,
1010                    scope_id,
1011                ))
1012            }
1013            _ => return Err(UnexpectedEnd),
1014        };
1015        
1016        Ok(Self {
1017            sequence,
1018            address,
1019            priority,
1020        })
1021    }
1022}
1023
1024impl FrameStruct for AddAddress {
1025    const SIZE_BOUND: usize = 1 + 8 + 8 + 1 + 16 + 2 + 4 + 4; // Worst case IPv6
1026}
1027
1028/// NAT traversal frame for requesting simultaneous hole punching
1029#[derive(Debug, Clone, PartialEq, Eq)]
1030pub(crate) struct PunchMeNow {
1031    /// Round number for coordination
1032    pub(crate) round: VarInt,
1033    /// Sequence number of the address to punch to (from AddAddress)
1034    pub(crate) target_sequence: VarInt,
1035    /// Local address for this punch attempt
1036    pub(crate) local_address: SocketAddr,
1037}
1038
1039impl PunchMeNow {
1040    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
1041        buf.write(FrameType::PUNCH_ME_NOW);
1042        buf.write(self.round);
1043        buf.write(self.target_sequence);
1044        
1045        match self.local_address {
1046            SocketAddr::V4(addr) => {
1047                buf.put_u8(4); // IPv4 indicator
1048                buf.put_slice(&addr.ip().octets());
1049                buf.put_u16(addr.port());
1050            }
1051            SocketAddr::V6(addr) => {
1052                buf.put_u8(6); // IPv6 indicator
1053                buf.put_slice(&addr.ip().octets());
1054                buf.put_u16(addr.port());
1055                buf.put_u32(addr.flowinfo());
1056                buf.put_u32(addr.scope_id());
1057            }
1058        }
1059    }
1060    
1061    pub(crate) fn decode<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
1062        let round = r.get()?;
1063        let target_sequence = r.get()?;
1064        let ip_version = r.get::<u8>()?;
1065        
1066        let local_address = match ip_version {
1067            4 => {
1068                let mut octets = [0u8; 4];
1069                r.copy_to_slice(&mut octets);
1070                let port = r.get::<u16>()?;
1071                SocketAddr::V4(std::net::SocketAddrV4::new(
1072                    std::net::Ipv4Addr::from(octets),
1073                    port,
1074                ))
1075            }
1076            6 => {
1077                let mut octets = [0u8; 16];
1078                r.copy_to_slice(&mut octets);
1079                let port = r.get::<u16>()?;
1080                let flowinfo = r.get::<u32>()?;
1081                let scope_id = r.get::<u32>()?;
1082                SocketAddr::V6(std::net::SocketAddrV6::new(
1083                    std::net::Ipv6Addr::from(octets),
1084                    port,
1085                    flowinfo,
1086                    scope_id,
1087                ))
1088            }
1089            _ => return Err(UnexpectedEnd),
1090        };
1091        
1092        Ok(Self {
1093            round,
1094            target_sequence,
1095            local_address,
1096        })
1097    }
1098}
1099
1100impl FrameStruct for PunchMeNow {
1101    const SIZE_BOUND: usize = 1 + 8 + 8 + 1 + 16 + 2 + 4 + 4; // Worst case IPv6
1102}
1103
1104/// NAT traversal frame for removing stale addresses
1105#[derive(Debug, Clone, PartialEq, Eq)]
1106pub(crate) struct RemoveAddress {
1107    /// Sequence number of the address to remove (from AddAddress)
1108    pub(crate) sequence: VarInt,
1109}
1110
1111impl RemoveAddress {
1112    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
1113        buf.write(FrameType::REMOVE_ADDRESS);
1114        buf.write(self.sequence);
1115    }
1116    
1117    pub(crate) fn decode<R: Buf>(r: &mut R) -> Result<Self, UnexpectedEnd> {
1118        let sequence = r.get()?;
1119        Ok(Self { sequence })
1120    }
1121}
1122
1123impl FrameStruct for RemoveAddress {
1124    const SIZE_BOUND: usize = 1 + 8; // frame type + sequence
1125}
1126
1127#[cfg(test)]
1128mod test {
1129    use super::*;
1130    use crate::coding::Codec;
1131    use assert_matches::assert_matches;
1132
1133    fn frames(buf: Vec<u8>) -> Vec<Frame> {
1134        Iter::new(Bytes::from(buf))
1135            .unwrap()
1136            .collect::<Result<Vec<_>, _>>()
1137            .unwrap()
1138    }
1139
1140    #[test]
1141    fn ack_coding() {
1142        const PACKETS: &[u64] = &[1, 2, 3, 5, 10, 11, 14];
1143        let mut ranges = ArrayRangeSet::new();
1144        for &packet in PACKETS {
1145            ranges.insert(packet..packet + 1);
1146        }
1147        let mut buf = Vec::new();
1148        const ECN: EcnCounts = EcnCounts {
1149            ect0: 42,
1150            ect1: 24,
1151            ce: 12,
1152        };
1153        Ack::encode(42, &ranges, Some(&ECN), &mut buf);
1154        let frames = frames(buf);
1155        assert_eq!(frames.len(), 1);
1156        match frames[0] {
1157            Frame::Ack(ref ack) => {
1158                let mut packets = ack.iter().flatten().collect::<Vec<_>>();
1159                packets.sort_unstable();
1160                assert_eq!(&packets[..], PACKETS);
1161                assert_eq!(ack.ecn, Some(ECN));
1162            }
1163            ref x => panic!("incorrect frame {x:?}"),
1164        }
1165    }
1166
1167    #[test]
1168    fn ack_frequency_coding() {
1169        let mut buf = Vec::new();
1170        let original = AckFrequency {
1171            sequence: VarInt(42),
1172            ack_eliciting_threshold: VarInt(20),
1173            request_max_ack_delay: VarInt(50_000),
1174            reordering_threshold: VarInt(1),
1175        };
1176        original.encode(&mut buf);
1177        let frames = frames(buf);
1178        assert_eq!(frames.len(), 1);
1179        match &frames[0] {
1180            Frame::AckFrequency(decoded) => assert_eq!(decoded, &original),
1181            x => panic!("incorrect frame {x:?}"),
1182        }
1183    }
1184
1185    #[test]
1186    fn immediate_ack_coding() {
1187        let mut buf = Vec::new();
1188        FrameType::IMMEDIATE_ACK.encode(&mut buf);
1189        let frames = frames(buf);
1190        assert_eq!(frames.len(), 1);
1191        assert_matches!(&frames[0], Frame::ImmediateAck);
1192    }
1193}