ant_quic/
packet.rs

1// Copyright 2024 Saorsa Labs Ltd.
2//
3// This Saorsa Network Software is licensed under the General Public License (GPL), version 3.
4// Please see the file LICENSE-GPL, or visit <http://www.gnu.org/licenses/> for the full text.
5//
6// Full details available at https://saorsalabs.com/licenses
7
8use std::{cmp::Ordering, io, ops::Range, str};
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use thiserror::Error;
12
13use crate::{
14    ConnectionId,
15    coding::{self, BufExt, BufMutExt},
16    crypto,
17};
18
19/// Decodes a QUIC packet's invariant header
20///
21/// Due to packet number encryption, it is impossible to fully decode a header
22/// (which includes a variable-length packet number) without crypto context.
23/// The crypto context (represented by the `Crypto` type in Quinn) is usually
24/// part of the `Connection`, or can be derived from the destination CID for
25/// Initial packets.
26///
27/// To cope with this, we decode the invariant header (which should be stable
28/// across QUIC versions), which gives us the destination CID and allows us
29/// to inspect the version and packet type (which depends on the version).
30/// This information allows us to fully decode and decrypt the packet.
31#[cfg_attr(test, derive(Clone))]
32#[derive(Debug)]
33pub struct PartialDecode {
34    plain_header: ProtectedHeader,
35    buf: io::Cursor<BytesMut>,
36}
37
38#[allow(clippy::len_without_is_empty)]
39impl PartialDecode {
40    /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet
41    pub fn new(
42        bytes: BytesMut,
43        cid_parser: &(impl ConnectionIdParser + ?Sized),
44        supported_versions: &[u32],
45        grease_quic_bit: bool,
46    ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
47        let mut buf = io::Cursor::new(bytes);
48        let plain_header =
49            ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
50        let dgram_len = buf.get_ref().len();
51        let packet_len = plain_header
52            .payload_len()
53            .map(|len| (buf.position() + len) as usize)
54            .unwrap_or(dgram_len);
55        match dgram_len.cmp(&packet_len) {
56            Ordering::Equal => Ok((Self { plain_header, buf }, None)),
57            Ordering::Less => Err(PacketDecodeError::InvalidHeader(
58                "packet too short to contain payload length",
59            )),
60            Ordering::Greater => {
61                let rest = Some(buf.get_mut().split_off(packet_len));
62                Ok((Self { plain_header, buf }, rest))
63            }
64        }
65    }
66
67    /// The underlying partially-decoded packet data
68    pub(crate) fn data(&self) -> &[u8] {
69        self.buf.get_ref()
70    }
71
72    pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
73        self.plain_header.as_initial()
74    }
75
76    pub(crate) fn has_long_header(&self) -> bool {
77        !matches!(self.plain_header, ProtectedHeader::Short { .. })
78    }
79
80    pub(crate) fn is_initial(&self) -> bool {
81        self.space() == Some(SpaceId::Initial)
82    }
83
84    pub(crate) fn space(&self) -> Option<SpaceId> {
85        use ProtectedHeader::*;
86        match self.plain_header {
87            Initial { .. } => Some(SpaceId::Initial),
88            Long {
89                ty: LongType::Handshake,
90                ..
91            } => Some(SpaceId::Handshake),
92            Long {
93                ty: LongType::ZeroRtt,
94                ..
95            } => Some(SpaceId::Data),
96            Short { .. } => Some(SpaceId::Data),
97            _ => None,
98        }
99    }
100
101    pub(crate) fn is_0rtt(&self) -> bool {
102        match self.plain_header {
103            ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
104            _ => false,
105        }
106    }
107
108    /// The destination connection ID of the packet
109    pub fn dst_cid(&self) -> &ConnectionId {
110        self.plain_header.dst_cid()
111    }
112
113    /// Length of QUIC packet being decoded
114    pub fn len(&self) -> usize {
115        self.buf.get_ref().len()
116    }
117
118    pub(crate) fn finish(
119        self,
120        header_crypto: Option<&dyn crypto::HeaderKey>,
121    ) -> Result<Packet, PacketDecodeError> {
122        use ProtectedHeader::*;
123        let Self {
124            plain_header,
125            mut buf,
126        } = self;
127
128        if let Initial(ProtectedInitialHeader {
129            dst_cid,
130            src_cid,
131            token_pos,
132            version,
133            ..
134        }) = plain_header
135        {
136            let number = match header_crypto {
137                Some(crypto) => Self::decrypt_header(&mut buf, crypto)?,
138                None => {
139                    return Err(PacketDecodeError::InvalidHeader(
140                        "header crypto should be available".into(),
141                    ));
142                }
143            };
144            let header_len = buf.position() as usize;
145            let mut bytes = buf.into_inner();
146
147            let header_data = bytes.split_to(header_len).freeze();
148            let token = header_data.slice(token_pos.start..token_pos.end);
149            return Ok(Packet {
150                header: Header::Initial(InitialHeader {
151                    dst_cid,
152                    src_cid,
153                    token,
154                    number,
155                    version,
156                }),
157                header_data,
158                payload: bytes,
159            });
160        }
161
162        let header = match plain_header {
163            Long {
164                ty,
165                dst_cid,
166                src_cid,
167                version,
168                ..
169            } => Header::Long {
170                ty,
171                dst_cid,
172                src_cid,
173                number: match header_crypto {
174                    Some(crypto) => Self::decrypt_header(&mut buf, crypto)?,
175                    None => {
176                        return Err(PacketDecodeError::InvalidHeader(
177                            "header crypto should be available for long header packet".into(),
178                        ));
179                    }
180                },
181                version,
182            },
183            Retry {
184                dst_cid,
185                src_cid,
186                version,
187            } => Header::Retry {
188                dst_cid,
189                src_cid,
190                version,
191            },
192            Short { spin, dst_cid, .. } => {
193                let number = match header_crypto {
194                    Some(crypto) => Self::decrypt_header(&mut buf, crypto)?,
195                    None => {
196                        return Err(PacketDecodeError::InvalidHeader(
197                            "header crypto should be available for initial packet".into(),
198                        ));
199                    }
200                };
201                let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
202                Header::Short {
203                    spin,
204                    key_phase,
205                    dst_cid,
206                    number,
207                }
208            }
209            VersionNegotiate {
210                random,
211                dst_cid,
212                src_cid,
213            } => Header::VersionNegotiate {
214                random,
215                dst_cid,
216                src_cid,
217            },
218            Initial { .. } => unreachable!(),
219        };
220
221        let header_len = buf.position() as usize;
222        let mut bytes = buf.into_inner();
223        Ok(Packet {
224            header,
225            header_data: bytes.split_to(header_len).freeze(),
226            payload: bytes,
227        })
228    }
229
230    fn decrypt_header(
231        buf: &mut io::Cursor<BytesMut>,
232        header_crypto: &dyn crypto::HeaderKey,
233    ) -> Result<PacketNumber, PacketDecodeError> {
234        let packet_length = buf.get_ref().len();
235        let pn_offset = buf.position() as usize;
236        if packet_length < pn_offset + 4 + header_crypto.sample_size() {
237            return Err(PacketDecodeError::InvalidHeader(
238                "packet too short to extract header protection sample",
239            ));
240        }
241
242        header_crypto.decrypt(pn_offset, buf.get_mut());
243
244        let len = PacketNumber::decode_len(buf.get_ref()[0]);
245        PacketNumber::decode(len, buf)
246    }
247}
248
249pub(crate) struct Packet {
250    pub(crate) header: Header,
251    pub(crate) header_data: Bytes,
252    pub(crate) payload: BytesMut,
253}
254
255impl Packet {
256    pub(crate) fn reserved_bits_valid(&self) -> bool {
257        let mask = match self.header {
258            Header::Short { .. } => SHORT_RESERVED_BITS,
259            _ => LONG_RESERVED_BITS,
260        };
261        self.header_data[0] & mask == 0
262    }
263}
264
265pub(crate) struct InitialPacket {
266    pub(crate) header: InitialHeader,
267    pub(crate) header_data: Bytes,
268    pub(crate) payload: BytesMut,
269}
270
271impl From<InitialPacket> for Packet {
272    fn from(x: InitialPacket) -> Self {
273        Self {
274            header: Header::Initial(x.header),
275            header_data: x.header_data,
276            payload: x.payload,
277        }
278    }
279}
280
281#[cfg_attr(test, derive(Clone))]
282#[derive(Debug)]
283pub(crate) enum Header {
284    Initial(InitialHeader),
285    Long {
286        ty: LongType,
287        dst_cid: ConnectionId,
288        src_cid: ConnectionId,
289        number: PacketNumber,
290        version: u32,
291    },
292    Retry {
293        dst_cid: ConnectionId,
294        src_cid: ConnectionId,
295        version: u32,
296    },
297    Short {
298        spin: bool,
299        key_phase: bool,
300        dst_cid: ConnectionId,
301        number: PacketNumber,
302    },
303    VersionNegotiate {
304        random: u8,
305        src_cid: ConnectionId,
306        dst_cid: ConnectionId,
307    },
308}
309
310impl Header {
311    pub(crate) fn encode(&self, w: &mut Vec<u8>) -> PartialEncode {
312        use Header::*;
313        let start = w.len();
314        match *self {
315            Initial(InitialHeader {
316                ref dst_cid,
317                ref src_cid,
318                ref token,
319                number,
320                version,
321            }) => {
322                w.write(u8::from(LongHeaderType::Initial) | number.tag());
323                w.write(version);
324                dst_cid.encode_long(w);
325                src_cid.encode_long(w);
326                w.write_var(token.len() as u64);
327                w.put_slice(token);
328                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
329                number.encode(w);
330                PartialEncode {
331                    start,
332                    header_len: w.len() - start,
333                    pn: Some((number.len(), true)),
334                }
335            }
336            Long {
337                ty,
338                ref dst_cid,
339                ref src_cid,
340                number,
341                version,
342            } => {
343                w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
344                w.write(version);
345                dst_cid.encode_long(w);
346                src_cid.encode_long(w);
347                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
348                number.encode(w);
349                PartialEncode {
350                    start,
351                    header_len: w.len() - start,
352                    pn: Some((number.len(), true)),
353                }
354            }
355            Retry {
356                ref dst_cid,
357                ref src_cid,
358                version,
359            } => {
360                w.write(u8::from(LongHeaderType::Retry));
361                w.write(version);
362                dst_cid.encode_long(w);
363                src_cid.encode_long(w);
364                PartialEncode {
365                    start,
366                    header_len: w.len() - start,
367                    pn: None,
368                }
369            }
370            Short {
371                spin,
372                key_phase,
373                ref dst_cid,
374                number,
375            } => {
376                w.write(
377                    FIXED_BIT
378                        | if key_phase { KEY_PHASE_BIT } else { 0 }
379                        | if spin { SPIN_BIT } else { 0 }
380                        | number.tag(),
381                );
382                w.put_slice(dst_cid);
383                number.encode(w);
384                PartialEncode {
385                    start,
386                    header_len: w.len() - start,
387                    pn: Some((number.len(), false)),
388                }
389            }
390            VersionNegotiate {
391                ref random,
392                ref dst_cid,
393                ref src_cid,
394            } => {
395                w.write(0x80u8 | random);
396                w.write::<u32>(0);
397                dst_cid.encode_long(w);
398                src_cid.encode_long(w);
399                PartialEncode {
400                    start,
401                    header_len: w.len() - start,
402                    pn: None,
403                }
404            }
405        }
406    }
407
408    /// Whether the packet is encrypted on the wire
409    pub(crate) fn is_protected(&self) -> bool {
410        !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
411    }
412
413    pub(crate) fn number(&self) -> Option<PacketNumber> {
414        use Header::*;
415        Some(match *self {
416            Initial(InitialHeader { number, .. }) => number,
417            Long { number, .. } => number,
418            Short { number, .. } => number,
419            _ => {
420                return None;
421            }
422        })
423    }
424
425    pub(crate) fn space(&self) -> SpaceId {
426        use Header::*;
427        match *self {
428            Short { .. } => SpaceId::Data,
429            Long {
430                ty: LongType::ZeroRtt,
431                ..
432            } => SpaceId::Data,
433            Long {
434                ty: LongType::Handshake,
435                ..
436            } => SpaceId::Handshake,
437            _ => SpaceId::Initial,
438        }
439    }
440
441    pub(crate) fn key_phase(&self) -> bool {
442        match *self {
443            Self::Short { key_phase, .. } => key_phase,
444            _ => false,
445        }
446    }
447
448    pub(crate) fn is_short(&self) -> bool {
449        matches!(*self, Self::Short { .. })
450    }
451
452    pub(crate) fn is_1rtt(&self) -> bool {
453        self.is_short()
454    }
455
456    pub(crate) fn is_0rtt(&self) -> bool {
457        matches!(
458            *self,
459            Self::Long {
460                ty: LongType::ZeroRtt,
461                ..
462            }
463        )
464    }
465
466    pub(crate) fn dst_cid(&self) -> ConnectionId {
467        use Header::*;
468        match *self {
469            Initial(InitialHeader { dst_cid, .. }) => dst_cid,
470            Long { dst_cid, .. } => dst_cid,
471            Retry { dst_cid, .. } => dst_cid,
472            Short { dst_cid, .. } => dst_cid,
473            VersionNegotiate { dst_cid, .. } => dst_cid,
474        }
475    }
476
477    /// Whether the payload of this packet contains QUIC frames
478    pub(crate) fn has_frames(&self) -> bool {
479        use Header::*;
480        match *self {
481            Initial(_) => true,
482            Long { .. } => true,
483            Retry { .. } => false,
484            Short { .. } => true,
485            VersionNegotiate { .. } => false,
486        }
487    }
488}
489
490pub(crate) struct PartialEncode {
491    pub(crate) start: usize,
492    pub(crate) header_len: usize,
493    // Packet number length, payload length needed
494    pn: Option<(usize, bool)>,
495}
496
497impl PartialEncode {
498    pub(crate) fn finish(
499        self,
500        buf: &mut [u8],
501        header_crypto: &dyn crypto::HeaderKey,
502        crypto: Option<(u64, &dyn crypto::PacketKey)>,
503    ) {
504        let Self { header_len, pn, .. } = self;
505        let (pn_len, write_len) = match pn {
506            Some((pn_len, write_len)) => (pn_len, write_len),
507            None => return,
508        };
509
510        let pn_pos = header_len - pn_len;
511        if write_len {
512            let len = buf.len() - header_len + pn_len;
513            assert!(len < 2usize.pow(14)); // Fits in reserved space
514            let mut slice = &mut buf[pn_pos - 2..pn_pos];
515            slice.put_u16(len as u16 | (0b01 << 14));
516        }
517
518        if let Some((number, crypto)) = crypto {
519            crypto.encrypt(number, buf, header_len);
520        }
521
522        debug_assert!(
523            pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
524            "packet must be padded to at least {} bytes for header protection sampling",
525            pn_pos + 4 + header_crypto.sample_size()
526        );
527        header_crypto.encrypt(pn_pos, buf);
528    }
529}
530
531/// Plain packet header
532#[derive(Clone, Debug)]
533pub enum ProtectedHeader {
534    /// An Initial packet header
535    Initial(ProtectedInitialHeader),
536    /// A Long packet header, as used during the handshake
537    Long {
538        /// Type of the Long header packet
539        ty: LongType,
540        /// Destination Connection ID
541        dst_cid: ConnectionId,
542        /// Source Connection ID
543        src_cid: ConnectionId,
544        /// Length of the packet payload
545        len: u64,
546        /// QUIC version
547        version: u32,
548    },
549    /// A Retry packet header
550    Retry {
551        /// Destination Connection ID
552        dst_cid: ConnectionId,
553        /// Source Connection ID
554        src_cid: ConnectionId,
555        /// QUIC version
556        version: u32,
557    },
558    /// A short packet header, as used during the data phase
559    Short {
560        /// Spin bit
561        spin: bool,
562        /// Destination Connection ID
563        dst_cid: ConnectionId,
564    },
565    /// A Version Negotiation packet header
566    VersionNegotiate {
567        /// Random value
568        random: u8,
569        /// Destination Connection ID
570        dst_cid: ConnectionId,
571        /// Source Connection ID
572        src_cid: ConnectionId,
573    },
574}
575
576impl ProtectedHeader {
577    fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
578        match self {
579            Self::Initial(x) => Some(x),
580            _ => None,
581        }
582    }
583
584    /// The destination Connection ID of the packet
585    pub fn dst_cid(&self) -> &ConnectionId {
586        use ProtectedHeader::*;
587        match self {
588            Initial(header) => &header.dst_cid,
589            Long { dst_cid, .. } => dst_cid,
590            Retry { dst_cid, .. } => dst_cid,
591            Short { dst_cid, .. } => dst_cid,
592            VersionNegotiate { dst_cid, .. } => dst_cid,
593        }
594    }
595
596    fn payload_len(&self) -> Option<u64> {
597        use ProtectedHeader::*;
598        match self {
599            Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
600            _ => None,
601        }
602    }
603
604    /// Decode a plain header from given buffer, with given [`ConnectionIdParser`].
605    pub fn decode(
606        buf: &mut io::Cursor<BytesMut>,
607        cid_parser: &(impl ConnectionIdParser + ?Sized),
608        supported_versions: &[u32],
609        grease_quic_bit: bool,
610    ) -> Result<Self, PacketDecodeError> {
611        let first = buf.get::<u8>()?;
612        if !grease_quic_bit && first & FIXED_BIT == 0 {
613            return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
614        }
615        if first & LONG_HEADER_FORM == 0 {
616            let spin = first & SPIN_BIT != 0;
617
618            Ok(Self::Short {
619                spin,
620                dst_cid: cid_parser.parse(buf)?,
621            })
622        } else {
623            let version = buf.get::<u32>()?;
624
625            let dst_cid = ConnectionId::decode_long(buf)
626                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
627            let src_cid = ConnectionId::decode_long(buf)
628                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
629
630            // TODO: Support long CIDs for compatibility with future QUIC versions
631            if version == 0 {
632                let random = first & !LONG_HEADER_FORM;
633                return Ok(Self::VersionNegotiate {
634                    random,
635                    dst_cid,
636                    src_cid,
637                });
638            }
639
640            if !supported_versions.contains(&version) {
641                return Err(PacketDecodeError::UnsupportedVersion {
642                    src_cid,
643                    dst_cid,
644                    version,
645                });
646            }
647
648            match LongHeaderType::from_byte(first)? {
649                LongHeaderType::Initial => {
650                    let token_len = buf.get_var()? as usize;
651                    let token_start = buf.position() as usize;
652                    if token_len > buf.remaining() {
653                        return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
654                    }
655                    buf.advance(token_len);
656
657                    let len = buf.get_var()?;
658                    Ok(Self::Initial(ProtectedInitialHeader {
659                        dst_cid,
660                        src_cid,
661                        token_pos: token_start..token_start + token_len,
662                        len,
663                        version,
664                    }))
665                }
666                LongHeaderType::Retry => Ok(Self::Retry {
667                    dst_cid,
668                    src_cid,
669                    version,
670                }),
671                LongHeaderType::Standard(ty) => Ok(Self::Long {
672                    ty,
673                    dst_cid,
674                    src_cid,
675                    len: buf.get_var()?,
676                    version,
677                }),
678            }
679        }
680    }
681}
682
683/// Header of an Initial packet, before decryption
684#[derive(Clone, Debug)]
685pub struct ProtectedInitialHeader {
686    /// Destination Connection ID
687    pub dst_cid: ConnectionId,
688    /// Source Connection ID
689    pub src_cid: ConnectionId,
690    /// The position of a token in the packet buffer
691    pub token_pos: Range<usize>,
692    /// Length of the packet payload
693    pub len: u64,
694    /// QUIC version
695    pub version: u32,
696}
697
698#[derive(Clone, Debug)]
699pub(crate) struct InitialHeader {
700    pub(crate) dst_cid: ConnectionId,
701    pub(crate) src_cid: ConnectionId,
702    pub(crate) token: Bytes,
703    pub(crate) number: PacketNumber,
704    pub(crate) version: u32,
705}
706
707// An encoded packet number
708#[derive(Debug, Copy, Clone, Eq, PartialEq)]
709pub(crate) enum PacketNumber {
710    U8(u8),
711    U16(u16),
712    U24(u32),
713    U32(u32),
714}
715
716impl PacketNumber {
717    pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
718        let range = (n - largest_acked) * 2;
719        if range < 1 << 8 {
720            Self::U8(n as u8)
721        } else if range < 1 << 16 {
722            Self::U16(n as u16)
723        } else if range < 1 << 24 {
724            Self::U24(n as u32)
725        } else if range < 1 << 32 {
726            Self::U32(n as u32)
727        } else {
728            // Out-of-range packet number difference; clamp to 32 bits to avoid panic
729            // and let higher layers handle any resulting protocol error.
730            Self::U32(n as u32)
731        }
732    }
733
734    pub(crate) fn len(self) -> usize {
735        use PacketNumber::*;
736        match self {
737            U8(_) => 1,
738            U16(_) => 2,
739            U24(_) => 3,
740            U32(_) => 4,
741        }
742    }
743
744    pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
745        use PacketNumber::*;
746        match self {
747            U8(x) => w.write(x),
748            U16(x) => w.write(x),
749            U24(x) => w.put_uint(u64::from(x), 3),
750            U32(x) => w.write(x),
751        }
752    }
753
754    pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
755        use PacketNumber::*;
756        let pn = match len {
757            1 => U8(r.get()?),
758            2 => U16(r.get()?),
759            3 => U24(r.get_uint(3) as u32),
760            4 => U32(r.get()?),
761            _ => unreachable!(),
762        };
763        Ok(pn)
764    }
765
766    pub(crate) fn decode_len(tag: u8) -> usize {
767        1 + (tag & 0x03) as usize
768    }
769
770    fn tag(self) -> u8 {
771        use PacketNumber::*;
772        match self {
773            U8(_) => 0b00,
774            U16(_) => 0b01,
775            U24(_) => 0b10,
776            U32(_) => 0b11,
777        }
778    }
779
780    pub(crate) fn expand(self, expected: u64) -> u64 {
781        // From Appendix A
782        use PacketNumber::*;
783        let truncated = match self {
784            U8(x) => u64::from(x),
785            U16(x) => u64::from(x),
786            U24(x) => u64::from(x),
787            U32(x) => u64::from(x),
788        };
789        let nbits = self.len() * 8;
790        let win = 1 << nbits;
791        let hwin = win / 2;
792        let mask = win - 1;
793        // The incoming packet number should be greater than expected - hwin and less than or equal
794        // to expected + hwin
795        //
796        // This means we can't just strip the trailing bits from expected and add the truncated
797        // because that might yield a value outside the window.
798        //
799        // The following code calculates a candidate value and makes sure it's within the packet
800        // number window.
801        let candidate = (expected & !mask) | truncated;
802        if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
803            candidate + win
804        } else if candidate > expected + hwin && candidate > win {
805            candidate - win
806        } else {
807            candidate
808        }
809    }
810}
811
812/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length
813pub struct FixedLengthConnectionIdParser {
814    expected_len: usize,
815}
816
817impl FixedLengthConnectionIdParser {
818    /// Create a new instance of `FixedLengthConnectionIdParser`
819    pub fn new(expected_len: usize) -> Self {
820        Self { expected_len }
821    }
822}
823
824impl ConnectionIdParser for FixedLengthConnectionIdParser {
825    fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
826        (buffer.remaining() >= self.expected_len)
827            .then(|| ConnectionId::from_buf(buffer, self.expected_len))
828            .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
829    }
830}
831
832/// Parse connection id in short header packet
833pub trait ConnectionIdParser {
834    /// Parse a connection id from given buffer
835    fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
836}
837
838/// Long packet type including non-uniform cases
839#[derive(Clone, Copy, Debug, Eq, PartialEq)]
840pub(crate) enum LongHeaderType {
841    Initial,
842    Retry,
843    Standard(LongType),
844}
845
846impl LongHeaderType {
847    fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
848        use {LongHeaderType::*, LongType::*};
849        debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
850        Ok(match (b & 0x30) >> 4 {
851            0x0 => Initial,
852            0x1 => Standard(ZeroRtt),
853            0x2 => Standard(Handshake),
854            0x3 => Retry,
855            _ => unreachable!(),
856        })
857    }
858}
859
860impl From<LongHeaderType> for u8 {
861    fn from(ty: LongHeaderType) -> Self {
862        use {LongHeaderType::*, LongType::*};
863        match ty {
864            Initial => LONG_HEADER_FORM | FIXED_BIT,
865            Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
866            Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
867            Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
868        }
869    }
870}
871
872/// Long packet types with uniform header structure
873#[derive(Clone, Copy, Debug, Eq, PartialEq)]
874pub enum LongType {
875    /// Handshake packet
876    Handshake,
877    /// 0-RTT packet
878    ZeroRtt,
879}
880
881/// Packet decode error
882#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
883pub enum PacketDecodeError {
884    /// Packet uses a QUIC version that is not supported
885    #[error("unsupported version {version:x}")]
886    UnsupportedVersion {
887        /// Source Connection ID
888        src_cid: ConnectionId,
889        /// Destination Connection ID
890        dst_cid: ConnectionId,
891        /// The version that was unsupported
892        version: u32,
893    },
894    /// The packet header is invalid
895    #[error("invalid header: {0}")]
896    InvalidHeader(&'static str),
897}
898
899impl From<coding::UnexpectedEnd> for PacketDecodeError {
900    fn from(_: coding::UnexpectedEnd) -> Self {
901        Self::InvalidHeader("unexpected end of packet")
902    }
903}
904
905pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
906pub(crate) const FIXED_BIT: u8 = 0x40;
907pub(crate) const SPIN_BIT: u8 = 0x20;
908const SHORT_RESERVED_BITS: u8 = 0x18;
909const LONG_RESERVED_BITS: u8 = 0x0c;
910const KEY_PHASE_BIT: u8 = 0x04;
911
912/// Packet number space identifiers
913#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
914#[allow(missing_docs)]
915pub enum SpaceId {
916    /// Unprotected packets, used to bootstrap the handshake
917    Initial = 0,
918    Handshake = 1,
919    /// Application data space, used for 0-RTT and post-handshake/1-RTT packets
920    Data = 2,
921}
922
923impl SpaceId {
924    #[allow(missing_docs)]
925    pub fn iter() -> impl Iterator<Item = Self> {
926        [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
927    }
928}
929
930#[cfg(test)]
931mod tests {
932    use super::*;
933    use hex_literal::hex;
934    use std::io;
935
936    fn check_pn(typed: PacketNumber, encoded: &[u8]) {
937        let mut buf = Vec::new();
938        typed.encode(&mut buf);
939        assert_eq!(&buf[..], encoded);
940        let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
941        assert_eq!(typed, decoded);
942    }
943
944    #[test]
945    fn roundtrip_packet_numbers() {
946        check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
947        check_pn(PacketNumber::U16(0x80), &hex!("0080"));
948        check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
949        check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
950        check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
951    }
952
953    #[test]
954    fn pn_encode() {
955        check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
956        check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
957        check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
958    }
959
960    #[test]
961    fn pn_expand_roundtrip() {
962        for expected in 0..1024 {
963            for actual in expected..1024 {
964                assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
965            }
966        }
967    }
968
969    #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
970    #[test]
971    fn header_encoding() {
972        use crate::Side;
973        use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
974        #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
975        use rustls::crypto::aws_lc_rs::default_provider;
976        #[cfg(feature = "rustls-ring")]
977        use rustls::crypto::ring::default_provider;
978        use rustls::quic::Version;
979
980        let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
981        let provider = default_provider();
982
983        let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
984        let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
985        let mut buf = Vec::new();
986        let header = Header::Initial(InitialHeader {
987            number: PacketNumber::U8(0),
988            src_cid: ConnectionId::new(&[]),
989            dst_cid: dcid,
990            token: Bytes::new(),
991            version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
992        });
993        let encode = header.encode(&mut buf);
994        let header_len = buf.len();
995        buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
996        encode.finish(
997            &mut buf,
998            &*client.header.local,
999            Some((0, &*client.packet.local)),
1000        );
1001
1002        for byte in &buf {
1003            print!("{byte:02x}");
1004        }
1005        println!();
1006        assert_eq!(
1007            buf[..],
1008            hex!(
1009                "c8000000010806b858ec6f80452b00004021be
1010                 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1011            )[..]
1012        );
1013
1014        let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1015        let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1016        let decode = PartialDecode::new(
1017            buf.as_slice().into(),
1018            &FixedLengthConnectionIdParser::new(0),
1019            &supported_versions,
1020            false,
1021        )
1022        .unwrap()
1023        .0;
1024        let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1025        assert_eq!(
1026            packet.header_data[..],
1027            hex!("c0000000010806b858ec6f80452b0000402100")[..]
1028        );
1029        server
1030            .packet
1031            .remote
1032            .decrypt(0, &packet.header_data, &mut packet.payload)
1033            .unwrap();
1034        assert_eq!(packet.payload[..], [0; 16]);
1035        match packet.header {
1036            Header::Initial(InitialHeader {
1037                number: PacketNumber::U8(0),
1038                ..
1039            }) => {}
1040            _ => {
1041                panic!("unexpected header {:?}", packet.header);
1042            }
1043        }
1044    }
1045}