iroh_quinn_proto/
frame.rs

1use std::{
2    fmt::{self, Write},
3    mem,
4    net::{IpAddr, SocketAddr},
5    ops::{Range, RangeInclusive},
6};
7
8use bytes::{Buf, BufMut, Bytes};
9use tinyvec::TinyVec;
10
11use crate::{
12    coding::{self, BufExt, BufMutExt, UnexpectedEnd},
13    range_set::ArrayRangeSet,
14    shared::{ConnectionId, EcnCodepoint},
15    Dir, ResetToken, StreamId, TransportError, TransportErrorCode, VarInt, MAX_CID_SIZE,
16    RESET_TOKEN_SIZE,
17};
18
19#[cfg(feature = "arbitrary")]
20use arbitrary::Arbitrary;
21
22/// A QUIC frame type
23#[derive(Copy, Clone, Eq, PartialEq)]
24pub struct FrameType(u64);
25
26impl FrameType {
27    fn stream(self) -> Option<StreamInfo> {
28        if STREAM_TYS.contains(&self.0) {
29            Some(StreamInfo(self.0 as u8))
30        } else {
31            None
32        }
33    }
34    fn datagram(self) -> Option<DatagramInfo> {
35        if DATAGRAM_TYS.contains(&self.0) {
36            Some(DatagramInfo(self.0 as u8))
37        } else {
38            None
39        }
40    }
41}
42
43impl coding::Codec for FrameType {
44    fn decode<B: Buf>(buf: &mut B) -> coding::Result<Self> {
45        Ok(Self(buf.get_var()?))
46    }
47    fn encode<B: BufMut>(&self, buf: &mut B) {
48        buf.write_var(self.0);
49    }
50}
51
52pub(crate) trait FrameStruct {
53    /// Smallest number of bytes this type of frame is guaranteed to fit within.
54    const SIZE_BOUND: usize;
55}
56
57macro_rules! frame_types {
58    {$($name:ident = $val:expr,)*} => {
59        impl FrameType {
60            $(pub(crate) const $name: FrameType = FrameType($val);)*
61        }
62
63        impl fmt::Debug for FrameType {
64            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
65                match self.0 {
66                    $($val => f.write_str(stringify!($name)),)*
67                    _ => write!(f, "Type({:02x})", self.0)
68                }
69            }
70        }
71
72        impl fmt::Display for FrameType {
73            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
74                match self.0 {
75                    $($val => f.write_str(stringify!($name)),)*
76                    x if STREAM_TYS.contains(&x) => f.write_str("STREAM"),
77                    x if DATAGRAM_TYS.contains(&x) => f.write_str("DATAGRAM"),
78                    _ => write!(f, "<unknown {:02x}>", self.0),
79                }
80            }
81        }
82    }
83}
84
85#[derive(Debug, Copy, Clone, Eq, PartialEq)]
86struct StreamInfo(u8);
87
88impl StreamInfo {
89    fn fin(self) -> bool {
90        self.0 & 0x01 != 0
91    }
92    fn len(self) -> bool {
93        self.0 & 0x02 != 0
94    }
95    fn off(self) -> bool {
96        self.0 & 0x04 != 0
97    }
98}
99
100#[derive(Debug, Copy, Clone, Eq, PartialEq)]
101struct DatagramInfo(u8);
102
103impl DatagramInfo {
104    fn len(self) -> bool {
105        self.0 & 0x01 != 0
106    }
107}
108
109frame_types! {
110    PADDING = 0x00,
111    PING = 0x01,
112    ACK = 0x02,
113    ACK_ECN = 0x03,
114    RESET_STREAM = 0x04,
115    STOP_SENDING = 0x05,
116    CRYPTO = 0x06,
117    NEW_TOKEN = 0x07,
118    // STREAM
119    MAX_DATA = 0x10,
120    MAX_STREAM_DATA = 0x11,
121    MAX_STREAMS_BIDI = 0x12,
122    MAX_STREAMS_UNI = 0x13,
123    DATA_BLOCKED = 0x14,
124    STREAM_DATA_BLOCKED = 0x15,
125    STREAMS_BLOCKED_BIDI = 0x16,
126    STREAMS_BLOCKED_UNI = 0x17,
127    NEW_CONNECTION_ID = 0x18,
128    RETIRE_CONNECTION_ID = 0x19,
129    PATH_CHALLENGE = 0x1a,
130    PATH_RESPONSE = 0x1b,
131    CONNECTION_CLOSE = 0x1c,
132    APPLICATION_CLOSE = 0x1d,
133    HANDSHAKE_DONE = 0x1e,
134    // ACK Frequency
135    ACK_FREQUENCY = 0xaf,
136    IMMEDIATE_ACK = 0x1f,
137    // DATAGRAM
138    // ADDRESS DISCOVERY REPORT
139    OBSERVED_IPV4_ADDR = 0x9f81a6,
140    OBSERVED_IPV6_ADDR = 0x9f81a7,
141}
142
143const STREAM_TYS: RangeInclusive<u64> = RangeInclusive::new(0x08, 0x0f);
144const DATAGRAM_TYS: RangeInclusive<u64> = RangeInclusive::new(0x30, 0x31);
145
146#[derive(Debug)]
147pub(crate) enum Frame {
148    Padding,
149    Ping,
150    Ack(Ack),
151    ResetStream(ResetStream),
152    StopSending(StopSending),
153    Crypto(Crypto),
154    NewToken(NewToken),
155    Stream(Stream),
156    MaxData(VarInt),
157    MaxStreamData { id: StreamId, offset: u64 },
158    MaxStreams { dir: Dir, count: u64 },
159    DataBlocked { offset: u64 },
160    StreamDataBlocked { id: StreamId, offset: u64 },
161    StreamsBlocked { dir: Dir, limit: u64 },
162    NewConnectionId(NewConnectionId),
163    RetireConnectionId { sequence: u64 },
164    PathChallenge(u64),
165    PathResponse(u64),
166    Close(Close),
167    Datagram(Datagram),
168    AckFrequency(AckFrequency),
169    ImmediateAck,
170    HandshakeDone,
171    ObservedAddr(ObservedAddr),
172}
173
174impl Frame {
175    pub(crate) fn ty(&self) -> FrameType {
176        use Frame::*;
177        match *self {
178            Padding => FrameType::PADDING,
179            ResetStream(_) => FrameType::RESET_STREAM,
180            Close(self::Close::Connection(_)) => FrameType::CONNECTION_CLOSE,
181            Close(self::Close::Application(_)) => FrameType::APPLICATION_CLOSE,
182            MaxData(_) => FrameType::MAX_DATA,
183            MaxStreamData { .. } => FrameType::MAX_STREAM_DATA,
184            MaxStreams { dir: Dir::Bi, .. } => FrameType::MAX_STREAMS_BIDI,
185            MaxStreams { dir: Dir::Uni, .. } => FrameType::MAX_STREAMS_UNI,
186            Ping => FrameType::PING,
187            DataBlocked { .. } => FrameType::DATA_BLOCKED,
188            StreamDataBlocked { .. } => FrameType::STREAM_DATA_BLOCKED,
189            StreamsBlocked { dir: Dir::Bi, .. } => FrameType::STREAMS_BLOCKED_BIDI,
190            StreamsBlocked { dir: Dir::Uni, .. } => FrameType::STREAMS_BLOCKED_UNI,
191            StopSending { .. } => FrameType::STOP_SENDING,
192            RetireConnectionId { .. } => FrameType::RETIRE_CONNECTION_ID,
193            Ack(_) => FrameType::ACK,
194            Stream(ref x) => {
195                let mut ty = *STREAM_TYS.start();
196                if x.fin {
197                    ty |= 0x01;
198                }
199                if x.offset != 0 {
200                    ty |= 0x04;
201                }
202                FrameType(ty)
203            }
204            PathChallenge(_) => FrameType::PATH_CHALLENGE,
205            PathResponse(_) => FrameType::PATH_RESPONSE,
206            NewConnectionId { .. } => FrameType::NEW_CONNECTION_ID,
207            Crypto(_) => FrameType::CRYPTO,
208            NewToken(_) => FrameType::NEW_TOKEN,
209            Datagram(_) => FrameType(*DATAGRAM_TYS.start()),
210            AckFrequency(_) => FrameType::ACK_FREQUENCY,
211            ImmediateAck => FrameType::IMMEDIATE_ACK,
212            HandshakeDone => FrameType::HANDSHAKE_DONE,
213            ObservedAddr(ref observed) => observed.get_type(),
214        }
215    }
216
217    pub(crate) fn is_ack_eliciting(&self) -> bool {
218        !matches!(*self, Self::Ack(_) | Self::Padding | Self::Close(_))
219    }
220}
221
222#[derive(Clone, Debug)]
223pub enum Close {
224    Connection(ConnectionClose),
225    Application(ApplicationClose),
226}
227
228impl Close {
229    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
230        match *self {
231            Self::Connection(ref x) => x.encode(out, max_len),
232            Self::Application(ref x) => x.encode(out, max_len),
233        }
234    }
235
236    pub(crate) fn is_transport_layer(&self) -> bool {
237        matches!(*self, Self::Connection(_))
238    }
239}
240
241impl From<TransportError> for Close {
242    fn from(x: TransportError) -> Self {
243        Self::Connection(x.into())
244    }
245}
246impl From<ConnectionClose> for Close {
247    fn from(x: ConnectionClose) -> Self {
248        Self::Connection(x)
249    }
250}
251impl From<ApplicationClose> for Close {
252    fn from(x: ApplicationClose) -> Self {
253        Self::Application(x)
254    }
255}
256
257/// Reason given by the transport for closing the connection
258#[derive(Debug, Clone, PartialEq, Eq)]
259pub struct ConnectionClose {
260    /// Class of error as encoded in the specification
261    pub error_code: TransportErrorCode,
262    /// Type of frame that caused the close
263    pub frame_type: Option<FrameType>,
264    /// Human-readable reason for the close
265    pub reason: Bytes,
266}
267
268impl fmt::Display for ConnectionClose {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        self.error_code.fmt(f)?;
271        if !self.reason.as_ref().is_empty() {
272            f.write_str(": ")?;
273            f.write_str(&String::from_utf8_lossy(&self.reason))?;
274        }
275        Ok(())
276    }
277}
278
279impl From<TransportError> for ConnectionClose {
280    fn from(x: TransportError) -> Self {
281        Self {
282            error_code: x.code,
283            frame_type: x.frame,
284            reason: x.reason.into(),
285        }
286    }
287}
288
289impl FrameStruct for ConnectionClose {
290    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
291}
292
293impl ConnectionClose {
294    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
295        out.write(FrameType::CONNECTION_CLOSE); // 1 byte
296        out.write(self.error_code); // <= 8 bytes
297        let ty = self.frame_type.map_or(0, |x| x.0);
298        out.write_var(ty); // <= 8 bytes
299        let max_len = max_len
300            - 3
301            - VarInt::from_u64(ty).unwrap().size()
302            - VarInt::from_u64(self.reason.len() as u64).unwrap().size();
303        let actual_len = self.reason.len().min(max_len);
304        out.write_var(actual_len as u64); // <= 8 bytes
305        out.put_slice(&self.reason[0..actual_len]); // whatever's left
306    }
307}
308
309/// Reason given by an application for closing the connection
310#[derive(Debug, Clone, PartialEq, Eq)]
311pub struct ApplicationClose {
312    /// Application-specific reason code
313    pub error_code: VarInt,
314    /// Human-readable reason for the close
315    pub reason: Bytes,
316}
317
318impl fmt::Display for ApplicationClose {
319    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
320        if !self.reason.as_ref().is_empty() {
321            f.write_str(&String::from_utf8_lossy(&self.reason))?;
322            f.write_str(" (code ")?;
323            self.error_code.fmt(f)?;
324            f.write_str(")")?;
325        } else {
326            self.error_code.fmt(f)?;
327        }
328        Ok(())
329    }
330}
331
332impl FrameStruct for ApplicationClose {
333    const SIZE_BOUND: usize = 1 + 8 + 8;
334}
335
336impl ApplicationClose {
337    pub(crate) fn encode<W: BufMut>(&self, out: &mut W, max_len: usize) {
338        out.write(FrameType::APPLICATION_CLOSE); // 1 byte
339        out.write(self.error_code); // <= 8 bytes
340        let max_len = max_len - 3 - VarInt::from_u64(self.reason.len() as u64).unwrap().size();
341        let actual_len = self.reason.len().min(max_len);
342        out.write_var(actual_len as u64); // <= 8 bytes
343        out.put_slice(&self.reason[0..actual_len]); // whatever's left
344    }
345}
346
347#[derive(Clone, Eq, PartialEq)]
348pub struct Ack {
349    pub largest: u64,
350    pub delay: u64,
351    pub additional: Bytes,
352    pub ecn: Option<EcnCounts>,
353}
354
355impl fmt::Debug for Ack {
356    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
357        let mut ranges = "[".to_string();
358        let mut first = true;
359        for range in self.iter() {
360            if !first {
361                ranges.push(',');
362            }
363            write!(ranges, "{range:?}").unwrap();
364            first = false;
365        }
366        ranges.push(']');
367
368        f.debug_struct("Ack")
369            .field("largest", &self.largest)
370            .field("delay", &self.delay)
371            .field("ecn", &self.ecn)
372            .field("ranges", &ranges)
373            .finish()
374    }
375}
376
377impl<'a> IntoIterator for &'a Ack {
378    type Item = RangeInclusive<u64>;
379    type IntoIter = AckIter<'a>;
380
381    fn into_iter(self) -> AckIter<'a> {
382        AckIter::new(self.largest, &self.additional[..])
383    }
384}
385
386impl Ack {
387    pub fn encode<W: BufMut>(
388        delay: u64,
389        ranges: &ArrayRangeSet,
390        ecn: Option<&EcnCounts>,
391        buf: &mut W,
392    ) {
393        let mut rest = ranges.iter().rev();
394        let first = rest.next().unwrap();
395        let largest = first.end - 1;
396        let first_size = first.end - first.start;
397        buf.write(if ecn.is_some() {
398            FrameType::ACK_ECN
399        } else {
400            FrameType::ACK
401        });
402        buf.write_var(largest);
403        buf.write_var(delay);
404        buf.write_var(ranges.len() as u64 - 1);
405        buf.write_var(first_size - 1);
406        let mut prev = first.start;
407        for block in rest {
408            let size = block.end - block.start;
409            buf.write_var(prev - block.end - 1);
410            buf.write_var(size - 1);
411            prev = block.start;
412        }
413        if let Some(x) = ecn {
414            x.encode(buf)
415        }
416    }
417
418    pub fn iter(&self) -> AckIter<'_> {
419        self.into_iter()
420    }
421}
422
423#[derive(Debug, Copy, Clone, Eq, PartialEq)]
424pub struct EcnCounts {
425    pub ect0: u64,
426    pub ect1: u64,
427    pub ce: u64,
428}
429
430impl std::ops::AddAssign<EcnCodepoint> for EcnCounts {
431    fn add_assign(&mut self, rhs: EcnCodepoint) {
432        match rhs {
433            EcnCodepoint::Ect0 => {
434                self.ect0 += 1;
435            }
436            EcnCodepoint::Ect1 => {
437                self.ect1 += 1;
438            }
439            EcnCodepoint::Ce => {
440                self.ce += 1;
441            }
442        }
443    }
444}
445
446impl EcnCounts {
447    pub const ZERO: Self = Self {
448        ect0: 0,
449        ect1: 0,
450        ce: 0,
451    };
452
453    pub fn encode<W: BufMut>(&self, out: &mut W) {
454        out.write_var(self.ect0);
455        out.write_var(self.ect1);
456        out.write_var(self.ce);
457    }
458}
459
460#[derive(Debug, Clone)]
461pub(crate) struct Stream {
462    pub(crate) id: StreamId,
463    pub(crate) offset: u64,
464    pub(crate) fin: bool,
465    pub(crate) data: Bytes,
466}
467
468impl FrameStruct for Stream {
469    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
470}
471
472/// Metadata from a stream frame
473#[derive(Debug, Clone)]
474pub(crate) struct StreamMeta {
475    pub(crate) id: StreamId,
476    pub(crate) offsets: Range<u64>,
477    pub(crate) fin: bool,
478}
479
480// This manual implementation exists because `Default` is not implemented for `StreamId`
481impl Default for StreamMeta {
482    fn default() -> Self {
483        Self {
484            id: StreamId(0),
485            offsets: 0..0,
486            fin: false,
487        }
488    }
489}
490
491impl StreamMeta {
492    pub(crate) fn encode<W: BufMut>(&self, length: bool, out: &mut W) {
493        let mut ty = *STREAM_TYS.start();
494        if self.offsets.start != 0 {
495            ty |= 0x04;
496        }
497        if length {
498            ty |= 0x02;
499        }
500        if self.fin {
501            ty |= 0x01;
502        }
503        out.write_var(ty); // 1 byte
504        out.write(self.id); // <=8 bytes
505        if self.offsets.start != 0 {
506            out.write_var(self.offsets.start); // <=8 bytes
507        }
508        if length {
509            out.write_var(self.offsets.end - self.offsets.start); // <=8 bytes
510        }
511    }
512}
513
514/// A vector of [`StreamMeta`] with optimization for the single element case
515pub(crate) type StreamMetaVec = TinyVec<[StreamMeta; 1]>;
516
517#[derive(Debug, Clone)]
518pub(crate) struct Crypto {
519    pub(crate) offset: u64,
520    pub(crate) data: Bytes,
521}
522
523impl Crypto {
524    pub(crate) const SIZE_BOUND: usize = 17;
525
526    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
527        out.write(FrameType::CRYPTO);
528        out.write_var(self.offset);
529        out.write_var(self.data.len() as u64);
530        out.put_slice(&self.data);
531    }
532}
533
534#[derive(Debug, Clone)]
535pub(crate) struct NewToken {
536    pub(crate) token: Bytes,
537}
538
539impl NewToken {
540    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
541        out.write(FrameType::NEW_TOKEN);
542        out.write_var(self.token.len() as u64);
543        out.put_slice(&self.token);
544    }
545
546    pub(crate) fn size(&self) -> usize {
547        1 + VarInt::from_u64(self.token.len() as u64).unwrap().size() + self.token.len()
548    }
549}
550
551pub(crate) struct Iter {
552    bytes: Bytes,
553    last_ty: Option<FrameType>,
554}
555
556impl Iter {
557    pub(crate) fn new(payload: Bytes) -> Result<Self, TransportError> {
558        if payload.is_empty() {
559            // "An endpoint MUST treat receipt of a packet containing no frames as a
560            // connection error of type PROTOCOL_VIOLATION."
561            // https://www.rfc-editor.org/rfc/rfc9000.html#name-frames-and-frame-types
562            return Err(TransportError::PROTOCOL_VIOLATION(
563                "packet payload is empty",
564            ));
565        }
566
567        Ok(Self {
568            bytes: payload,
569            last_ty: None,
570        })
571    }
572
573    fn take_len(&mut self) -> Result<Bytes, UnexpectedEnd> {
574        let len = self.bytes.get_var()?;
575        if len > self.bytes.remaining() as u64 {
576            return Err(UnexpectedEnd);
577        }
578        Ok(self.bytes.split_to(len as usize))
579    }
580
581    fn try_next(&mut self) -> Result<Frame, IterErr> {
582        let ty = self.bytes.get::<FrameType>()?;
583        self.last_ty = Some(ty);
584        Ok(match ty {
585            FrameType::PADDING => Frame::Padding,
586            FrameType::RESET_STREAM => Frame::ResetStream(ResetStream {
587                id: self.bytes.get()?,
588                error_code: self.bytes.get()?,
589                final_offset: self.bytes.get()?,
590            }),
591            FrameType::CONNECTION_CLOSE => Frame::Close(Close::Connection(ConnectionClose {
592                error_code: self.bytes.get()?,
593                frame_type: {
594                    let x = self.bytes.get_var()?;
595                    if x == 0 {
596                        None
597                    } else {
598                        Some(FrameType(x))
599                    }
600                },
601                reason: self.take_len()?,
602            })),
603            FrameType::APPLICATION_CLOSE => Frame::Close(Close::Application(ApplicationClose {
604                error_code: self.bytes.get()?,
605                reason: self.take_len()?,
606            })),
607            FrameType::MAX_DATA => Frame::MaxData(self.bytes.get()?),
608            FrameType::MAX_STREAM_DATA => Frame::MaxStreamData {
609                id: self.bytes.get()?,
610                offset: self.bytes.get_var()?,
611            },
612            FrameType::MAX_STREAMS_BIDI => Frame::MaxStreams {
613                dir: Dir::Bi,
614                count: self.bytes.get_var()?,
615            },
616            FrameType::MAX_STREAMS_UNI => Frame::MaxStreams {
617                dir: Dir::Uni,
618                count: self.bytes.get_var()?,
619            },
620            FrameType::PING => Frame::Ping,
621            FrameType::DATA_BLOCKED => Frame::DataBlocked {
622                offset: self.bytes.get_var()?,
623            },
624            FrameType::STREAM_DATA_BLOCKED => Frame::StreamDataBlocked {
625                id: self.bytes.get()?,
626                offset: self.bytes.get_var()?,
627            },
628            FrameType::STREAMS_BLOCKED_BIDI => Frame::StreamsBlocked {
629                dir: Dir::Bi,
630                limit: self.bytes.get_var()?,
631            },
632            FrameType::STREAMS_BLOCKED_UNI => Frame::StreamsBlocked {
633                dir: Dir::Uni,
634                limit: self.bytes.get_var()?,
635            },
636            FrameType::STOP_SENDING => Frame::StopSending(StopSending {
637                id: self.bytes.get()?,
638                error_code: self.bytes.get()?,
639            }),
640            FrameType::RETIRE_CONNECTION_ID => Frame::RetireConnectionId {
641                sequence: self.bytes.get_var()?,
642            },
643            FrameType::ACK | FrameType::ACK_ECN => {
644                let largest = self.bytes.get_var()?;
645                let delay = self.bytes.get_var()?;
646                let extra_blocks = self.bytes.get_var()? as usize;
647                let n = scan_ack_blocks(&self.bytes, largest, extra_blocks)?;
648                Frame::Ack(Ack {
649                    delay,
650                    largest,
651                    additional: self.bytes.split_to(n),
652                    ecn: if ty != FrameType::ACK_ECN {
653                        None
654                    } else {
655                        Some(EcnCounts {
656                            ect0: self.bytes.get_var()?,
657                            ect1: self.bytes.get_var()?,
658                            ce: self.bytes.get_var()?,
659                        })
660                    },
661                })
662            }
663            FrameType::PATH_CHALLENGE => Frame::PathChallenge(self.bytes.get()?),
664            FrameType::PATH_RESPONSE => Frame::PathResponse(self.bytes.get()?),
665            FrameType::NEW_CONNECTION_ID => {
666                let sequence = self.bytes.get_var()?;
667                let retire_prior_to = self.bytes.get_var()?;
668                if retire_prior_to > sequence {
669                    return Err(IterErr::Malformed);
670                }
671                let length = self.bytes.get::<u8>()? as usize;
672                if length > MAX_CID_SIZE || length == 0 {
673                    return Err(IterErr::Malformed);
674                }
675                if length > self.bytes.remaining() {
676                    return Err(IterErr::UnexpectedEnd);
677                }
678                let mut stage = [0; MAX_CID_SIZE];
679                self.bytes.copy_to_slice(&mut stage[0..length]);
680                let id = ConnectionId::new(&stage[..length]);
681                if self.bytes.remaining() < 16 {
682                    return Err(IterErr::UnexpectedEnd);
683                }
684                let mut reset_token = [0; RESET_TOKEN_SIZE];
685                self.bytes.copy_to_slice(&mut reset_token);
686                Frame::NewConnectionId(NewConnectionId {
687                    sequence,
688                    retire_prior_to,
689                    id,
690                    reset_token: reset_token.into(),
691                })
692            }
693            FrameType::CRYPTO => Frame::Crypto(Crypto {
694                offset: self.bytes.get_var()?,
695                data: self.take_len()?,
696            }),
697            FrameType::NEW_TOKEN => Frame::NewToken(NewToken {
698                token: self.take_len()?,
699            }),
700            FrameType::HANDSHAKE_DONE => Frame::HandshakeDone,
701            FrameType::ACK_FREQUENCY => Frame::AckFrequency(AckFrequency {
702                sequence: self.bytes.get()?,
703                ack_eliciting_threshold: self.bytes.get()?,
704                request_max_ack_delay: self.bytes.get()?,
705                reordering_threshold: self.bytes.get()?,
706            }),
707            FrameType::IMMEDIATE_ACK => Frame::ImmediateAck,
708            FrameType::OBSERVED_IPV4_ADDR | FrameType::OBSERVED_IPV6_ADDR => {
709                let is_ipv6 = ty == FrameType::OBSERVED_IPV6_ADDR;
710                let observed = ObservedAddr::read(&mut self.bytes, is_ipv6)?;
711                Frame::ObservedAddr(observed)
712            }
713            _ => {
714                if let Some(s) = ty.stream() {
715                    Frame::Stream(Stream {
716                        id: self.bytes.get()?,
717                        offset: if s.off() { self.bytes.get_var()? } else { 0 },
718                        fin: s.fin(),
719                        data: if s.len() {
720                            self.take_len()?
721                        } else {
722                            self.take_remaining()
723                        },
724                    })
725                } else if let Some(d) = ty.datagram() {
726                    Frame::Datagram(Datagram {
727                        data: if d.len() {
728                            self.take_len()?
729                        } else {
730                            self.take_remaining()
731                        },
732                    })
733                } else {
734                    return Err(IterErr::InvalidFrameId);
735                }
736            }
737        })
738    }
739
740    fn take_remaining(&mut self) -> Bytes {
741        mem::take(&mut self.bytes)
742    }
743}
744
745impl Iterator for Iter {
746    type Item = Result<Frame, InvalidFrame>;
747    fn next(&mut self) -> Option<Self::Item> {
748        if !self.bytes.has_remaining() {
749            return None;
750        }
751        match self.try_next() {
752            Ok(x) => Some(Ok(x)),
753            Err(e) => {
754                // Corrupt frame, skip it and everything that follows
755                self.bytes.clear();
756                Some(Err(InvalidFrame {
757                    ty: self.last_ty,
758                    reason: e.reason(),
759                }))
760            }
761        }
762    }
763}
764
765#[derive(Debug)]
766pub(crate) struct InvalidFrame {
767    pub(crate) ty: Option<FrameType>,
768    pub(crate) reason: &'static str,
769}
770
771impl From<InvalidFrame> for TransportError {
772    fn from(err: InvalidFrame) -> Self {
773        let mut te = Self::FRAME_ENCODING_ERROR(err.reason);
774        te.frame = err.ty;
775        te
776    }
777}
778
779/// Validate exactly `n` ACK ranges in `buf` and return the number of bytes they cover
780fn scan_ack_blocks(mut buf: &[u8], largest: u64, n: usize) -> Result<usize, IterErr> {
781    let total_len = buf.remaining();
782    let first_block = buf.get_var()?;
783    let mut smallest = largest.checked_sub(first_block).ok_or(IterErr::Malformed)?;
784    for _ in 0..n {
785        let gap = buf.get_var()?;
786        smallest = smallest.checked_sub(gap + 2).ok_or(IterErr::Malformed)?;
787        let block = buf.get_var()?;
788        smallest = smallest.checked_sub(block).ok_or(IterErr::Malformed)?;
789    }
790    Ok(total_len - buf.remaining())
791}
792
793enum IterErr {
794    UnexpectedEnd,
795    InvalidFrameId,
796    Malformed,
797}
798
799impl IterErr {
800    fn reason(&self) -> &'static str {
801        use IterErr::*;
802        match *self {
803            UnexpectedEnd => "unexpected end",
804            InvalidFrameId => "invalid frame ID",
805            Malformed => "malformed",
806        }
807    }
808}
809
810impl From<UnexpectedEnd> for IterErr {
811    fn from(_: UnexpectedEnd) -> Self {
812        Self::UnexpectedEnd
813    }
814}
815
816#[derive(Debug, Clone)]
817pub struct AckIter<'a> {
818    largest: u64,
819    data: &'a [u8],
820}
821
822impl<'a> AckIter<'a> {
823    fn new(largest: u64, data: &'a [u8]) -> Self {
824        Self { largest, data }
825    }
826}
827
828impl Iterator for AckIter<'_> {
829    type Item = RangeInclusive<u64>;
830    fn next(&mut self) -> Option<RangeInclusive<u64>> {
831        if !self.data.has_remaining() {
832            return None;
833        }
834        let block = self.data.get_var().unwrap();
835        let largest = self.largest;
836        if let Ok(gap) = self.data.get_var() {
837            self.largest -= block + gap + 2;
838        }
839        Some(largest - block..=largest)
840    }
841}
842
843#[allow(unreachable_pub)] // fuzzing only
844#[cfg_attr(feature = "arbitrary", derive(Arbitrary))]
845#[derive(Debug, Copy, Clone)]
846pub struct ResetStream {
847    pub(crate) id: StreamId,
848    pub(crate) error_code: VarInt,
849    pub(crate) final_offset: VarInt,
850}
851
852impl FrameStruct for ResetStream {
853    const SIZE_BOUND: usize = 1 + 8 + 8 + 8;
854}
855
856impl ResetStream {
857    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
858        out.write(FrameType::RESET_STREAM); // 1 byte
859        out.write(self.id); // <= 8 bytes
860        out.write(self.error_code); // <= 8 bytes
861        out.write(self.final_offset); // <= 8 bytes
862    }
863}
864
865#[derive(Debug, Copy, Clone)]
866pub(crate) struct StopSending {
867    pub(crate) id: StreamId,
868    pub(crate) error_code: VarInt,
869}
870
871impl FrameStruct for StopSending {
872    const SIZE_BOUND: usize = 1 + 8 + 8;
873}
874
875impl StopSending {
876    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
877        out.write(FrameType::STOP_SENDING); // 1 byte
878        out.write(self.id); // <= 8 bytes
879        out.write(self.error_code) // <= 8 bytes
880    }
881}
882
883#[derive(Debug, Copy, Clone)]
884pub(crate) struct NewConnectionId {
885    pub(crate) sequence: u64,
886    pub(crate) retire_prior_to: u64,
887    pub(crate) id: ConnectionId,
888    pub(crate) reset_token: ResetToken,
889}
890
891impl NewConnectionId {
892    pub(crate) fn encode<W: BufMut>(&self, out: &mut W) {
893        out.write(FrameType::NEW_CONNECTION_ID);
894        out.write_var(self.sequence);
895        out.write_var(self.retire_prior_to);
896        out.write(self.id.len() as u8);
897        out.put_slice(&self.id);
898        out.put_slice(&self.reset_token);
899    }
900}
901
902/// Smallest number of bytes this type of frame is guaranteed to fit within.
903pub(crate) const RETIRE_CONNECTION_ID_SIZE_BOUND: usize = 9;
904
905/// An unreliable datagram
906#[derive(Debug, Clone)]
907pub struct Datagram {
908    /// Payload
909    pub data: Bytes,
910}
911
912impl FrameStruct for Datagram {
913    const SIZE_BOUND: usize = 1 + 8;
914}
915
916impl Datagram {
917    pub(crate) fn encode(&self, length: bool, out: &mut Vec<u8>) {
918        out.write(FrameType(*DATAGRAM_TYS.start() | u64::from(length))); // 1 byte
919        if length {
920            // Safe to unwrap because we check length sanity before queueing datagrams
921            out.write(VarInt::from_u64(self.data.len() as u64).unwrap()); // <= 8 bytes
922        }
923        out.extend_from_slice(&self.data);
924    }
925
926    pub(crate) fn size(&self, length: bool) -> usize {
927        1 + if length {
928            VarInt::from_u64(self.data.len() as u64).unwrap().size()
929        } else {
930            0
931        } + self.data.len()
932    }
933}
934
935#[derive(Debug, Copy, Clone, PartialEq, Eq)]
936pub(crate) struct AckFrequency {
937    pub(crate) sequence: VarInt,
938    pub(crate) ack_eliciting_threshold: VarInt,
939    pub(crate) request_max_ack_delay: VarInt,
940    pub(crate) reordering_threshold: VarInt,
941}
942
943impl AckFrequency {
944    pub(crate) fn encode<W: BufMut>(&self, buf: &mut W) {
945        buf.write(FrameType::ACK_FREQUENCY);
946        buf.write(self.sequence);
947        buf.write(self.ack_eliciting_threshold);
948        buf.write(self.request_max_ack_delay);
949        buf.write(self.reordering_threshold);
950    }
951}
952
953/* Address Discovery https://datatracker.ietf.org/doc/draft-seemann-quic-address-discovery/ */
954
955/// Conjuction of the information contained in the address discovery frames
956/// ([`FrameType::OBSERVED_IPV4_ADDR`], [`FrameType::OBSERVED_IPV6_ADDR`]).
957#[derive(Debug, PartialEq, Eq, Clone)]
958pub(crate) struct ObservedAddr {
959    /// Monotonically increasing integer within the same connection.
960    pub(crate) seq_no: VarInt,
961    /// Reported observed address.
962    pub(crate) ip: IpAddr,
963    /// Reported observed port.
964    pub(crate) port: u16,
965}
966
967impl ObservedAddr {
968    pub(crate) fn new<N: Into<VarInt>>(remote: std::net::SocketAddr, seq_no: N) -> Self {
969        Self {
970            ip: remote.ip(),
971            port: remote.port(),
972            seq_no: seq_no.into(),
973        }
974    }
975
976    /// Get the [`FrameType`] for this frame.
977    pub(crate) fn get_type(&self) -> FrameType {
978        if self.ip.is_ipv6() {
979            FrameType::OBSERVED_IPV6_ADDR
980        } else {
981            FrameType::OBSERVED_IPV4_ADDR
982        }
983    }
984
985    /// Compute the number of bytes needed to encode the frame.
986    pub(crate) fn size(&self) -> usize {
987        let type_size = VarInt(self.get_type().0).size();
988        let req_id_bytes = self.seq_no.size();
989        let ip_bytes = if self.ip.is_ipv6() { 16 } else { 4 };
990        let port_bytes = 2;
991        type_size + req_id_bytes + ip_bytes + port_bytes
992    }
993
994    /// Unconditionally write this frame to `buf`.
995    pub(crate) fn write<W: BufMut>(&self, buf: &mut W) {
996        buf.write(self.get_type());
997        buf.write(self.seq_no);
998        match self.ip {
999            IpAddr::V4(ipv4_addr) => {
1000                buf.write(ipv4_addr);
1001            }
1002            IpAddr::V6(ipv6_addr) => {
1003                buf.write(ipv6_addr);
1004            }
1005        }
1006        buf.write::<u16>(self.port);
1007    }
1008
1009    /// Reads the frame contents from the buffer.
1010    ///
1011    /// Should only be called when the fram type has been identified as
1012    /// [`FrameType::OBSERVED_IPV4_ADDR`] or [`FrameType::OBSERVED_IPV6_ADDR`].
1013    pub(crate) fn read<R: Buf>(bytes: &mut R, is_ipv6: bool) -> coding::Result<Self> {
1014        let seq_no = bytes.get()?;
1015        let ip = if is_ipv6 {
1016            IpAddr::V6(bytes.get()?)
1017        } else {
1018            IpAddr::V4(bytes.get()?)
1019        };
1020        let port = bytes.get()?;
1021        Ok(Self { seq_no, ip, port })
1022    }
1023
1024    /// Gives the [`SocketAddr`] reported in the frame.
1025    pub(crate) fn socket_addr(&self) -> SocketAddr {
1026        (self.ip, self.port).into()
1027    }
1028}
1029
1030#[cfg(test)]
1031mod test {
1032
1033    use super::*;
1034    use crate::coding::Codec;
1035    use assert_matches::assert_matches;
1036
1037    fn frames(buf: Vec<u8>) -> Vec<Frame> {
1038        Iter::new(Bytes::from(buf))
1039            .unwrap()
1040            .collect::<Result<Vec<_>, _>>()
1041            .unwrap()
1042    }
1043
1044    #[test]
1045    fn ack_coding() {
1046        const PACKETS: &[u64] = &[1, 2, 3, 5, 10, 11, 14];
1047        let mut ranges = ArrayRangeSet::new();
1048        for &packet in PACKETS {
1049            ranges.insert(packet..packet + 1);
1050        }
1051        let mut buf = Vec::new();
1052        const ECN: EcnCounts = EcnCounts {
1053            ect0: 42,
1054            ect1: 24,
1055            ce: 12,
1056        };
1057        Ack::encode(42, &ranges, Some(&ECN), &mut buf);
1058        let frames = frames(buf);
1059        assert_eq!(frames.len(), 1);
1060        match frames[0] {
1061            Frame::Ack(ref ack) => {
1062                let mut packets = ack.iter().flatten().collect::<Vec<_>>();
1063                packets.sort_unstable();
1064                assert_eq!(&packets[..], PACKETS);
1065                assert_eq!(ack.ecn, Some(ECN));
1066            }
1067            ref x => panic!("incorrect frame {x:?}"),
1068        }
1069    }
1070
1071    #[test]
1072    fn ack_frequency_coding() {
1073        let mut buf = Vec::new();
1074        let original = AckFrequency {
1075            sequence: VarInt(42),
1076            ack_eliciting_threshold: VarInt(20),
1077            request_max_ack_delay: VarInt(50_000),
1078            reordering_threshold: VarInt(1),
1079        };
1080        original.encode(&mut buf);
1081        let frames = frames(buf);
1082        assert_eq!(frames.len(), 1);
1083        match &frames[0] {
1084            Frame::AckFrequency(decoded) => assert_eq!(decoded, &original),
1085            x => panic!("incorrect frame {x:?}"),
1086        }
1087    }
1088
1089    #[test]
1090    fn immediate_ack_coding() {
1091        let mut buf = Vec::new();
1092        FrameType::IMMEDIATE_ACK.encode(&mut buf);
1093        let frames = frames(buf);
1094        assert_eq!(frames.len(), 1);
1095        assert_matches!(&frames[0], Frame::ImmediateAck);
1096    }
1097
1098    /// Test that encoding and decoding [`ObservedAddr`] produces the same result.
1099    #[test]
1100    fn test_observed_addr_roundrip() {
1101        let observed_addr = ObservedAddr {
1102            seq_no: VarInt(42),
1103            ip: std::net::Ipv4Addr::LOCALHOST.into(),
1104            port: 4242,
1105        };
1106        let mut buf = Vec::with_capacity(observed_addr.size());
1107        observed_addr.write(&mut buf);
1108
1109        assert_eq!(
1110            observed_addr.size(),
1111            buf.len(),
1112            "expected written bytes and actual size differ"
1113        );
1114
1115        let mut decoded = frames(buf);
1116        assert_eq!(decoded.len(), 1);
1117        match decoded.pop().expect("non empty") {
1118            Frame::ObservedAddr(decoded) => assert_eq!(decoded, observed_addr),
1119            x => panic!("incorrect frame {x:?}"),
1120        }
1121    }
1122}