iroh_quinn_proto/
packet.rs

1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7    coding::{self, BufExt, BufMutExt},
8    crypto, ConnectionId,
9};
10
11/// Decodes a QUIC packet's invariant header
12///
13/// Due to packet number encryption, it is impossible to fully decode a header
14/// (which includes a variable-length packet number) without crypto context.
15/// The crypto context (represented by the `Crypto` type in Quinn) is usually
16/// part of the `Connection`, or can be derived from the destination CID for
17/// Initial packets.
18///
19/// To cope with this, we decode the invariant header (which should be stable
20/// across QUIC versions), which gives us the destination CID and allows us
21/// to inspect the version and packet type (which depends on the version).
22/// This information allows us to fully decode and decrypt the packet.
23#[cfg_attr(test, derive(Clone))]
24#[derive(Debug)]
25pub struct PartialDecode {
26    plain_header: ProtectedHeader,
27    buf: io::Cursor<BytesMut>,
28}
29
30#[allow(clippy::len_without_is_empty)]
31impl PartialDecode {
32    /// Begin decoding a QUIC packet from `bytes`, returning any trailing data not part of that packet
33    pub fn new(
34        bytes: BytesMut,
35        cid_parser: &(impl ConnectionIdParser + ?Sized),
36        supported_versions: &[u32],
37        grease_quic_bit: bool,
38    ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
39        let mut buf = io::Cursor::new(bytes);
40        let plain_header =
41            ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
42        let dgram_len = buf.get_ref().len();
43        let packet_len = plain_header
44            .payload_len()
45            .map(|len| (buf.position() + len) as usize)
46            .unwrap_or(dgram_len);
47        match dgram_len.cmp(&packet_len) {
48            Ordering::Equal => Ok((Self { plain_header, buf }, None)),
49            Ordering::Less => Err(PacketDecodeError::InvalidHeader(
50                "packet too short to contain payload length",
51            )),
52            Ordering::Greater => {
53                let rest = Some(buf.get_mut().split_off(packet_len));
54                Ok((Self { plain_header, buf }, rest))
55            }
56        }
57    }
58
59    /// The underlying partially-decoded packet data
60    pub(crate) fn data(&self) -> &[u8] {
61        self.buf.get_ref()
62    }
63
64    pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
65        self.plain_header.as_initial()
66    }
67
68    pub(crate) fn has_long_header(&self) -> bool {
69        !matches!(self.plain_header, ProtectedHeader::Short { .. })
70    }
71
72    pub(crate) fn is_initial(&self) -> bool {
73        self.space() == Some(SpaceId::Initial)
74    }
75
76    pub(crate) fn space(&self) -> Option<SpaceId> {
77        use ProtectedHeader::*;
78        match self.plain_header {
79            Initial { .. } => Some(SpaceId::Initial),
80            Long {
81                ty: LongType::Handshake,
82                ..
83            } => Some(SpaceId::Handshake),
84            Long {
85                ty: LongType::ZeroRtt,
86                ..
87            } => Some(SpaceId::Data),
88            Short { .. } => Some(SpaceId::Data),
89            _ => None,
90        }
91    }
92
93    pub(crate) fn is_0rtt(&self) -> bool {
94        match self.plain_header {
95            ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
96            _ => false,
97        }
98    }
99
100    /// The destination connection ID of the packet
101    pub fn dst_cid(&self) -> &ConnectionId {
102        self.plain_header.dst_cid()
103    }
104
105    /// Length of QUIC packet being decoded
106    #[allow(unreachable_pub)] // fuzzing only
107    pub fn len(&self) -> usize {
108        self.buf.get_ref().len()
109    }
110
111    pub(crate) fn finish(
112        self,
113        header_crypto: Option<&dyn crypto::HeaderKey>,
114    ) -> Result<Packet, PacketDecodeError> {
115        use ProtectedHeader::*;
116        let Self {
117            plain_header,
118            mut buf,
119        } = self;
120
121        if let Initial(ProtectedInitialHeader {
122            dst_cid,
123            src_cid,
124            token_pos,
125            version,
126            ..
127        }) = plain_header
128        {
129            let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
130            let header_len = buf.position() as usize;
131            let mut bytes = buf.into_inner();
132
133            let header_data = bytes.split_to(header_len).freeze();
134            let token = header_data.slice(token_pos.start..token_pos.end);
135            return Ok(Packet {
136                header: Header::Initial(InitialHeader {
137                    dst_cid,
138                    src_cid,
139                    token,
140                    number,
141                    version,
142                }),
143                header_data,
144                payload: bytes,
145            });
146        }
147
148        let header = match plain_header {
149            Long {
150                ty,
151                dst_cid,
152                src_cid,
153                version,
154                ..
155            } => Header::Long {
156                ty,
157                dst_cid,
158                src_cid,
159                number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
160                version,
161            },
162            Retry {
163                dst_cid,
164                src_cid,
165                version,
166            } => Header::Retry {
167                dst_cid,
168                src_cid,
169                version,
170            },
171            Short { spin, dst_cid, .. } => {
172                let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
173                let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
174                Header::Short {
175                    spin,
176                    key_phase,
177                    dst_cid,
178                    number,
179                }
180            }
181            VersionNegotiate {
182                random,
183                dst_cid,
184                src_cid,
185            } => Header::VersionNegotiate {
186                random,
187                dst_cid,
188                src_cid,
189            },
190            Initial { .. } => unreachable!(),
191        };
192
193        let header_len = buf.position() as usize;
194        let mut bytes = buf.into_inner();
195        Ok(Packet {
196            header,
197            header_data: bytes.split_to(header_len).freeze(),
198            payload: bytes,
199        })
200    }
201
202    fn decrypt_header(
203        buf: &mut io::Cursor<BytesMut>,
204        header_crypto: &dyn crypto::HeaderKey,
205    ) -> Result<PacketNumber, PacketDecodeError> {
206        let packet_length = buf.get_ref().len();
207        let pn_offset = buf.position() as usize;
208        if packet_length < pn_offset + 4 + header_crypto.sample_size() {
209            return Err(PacketDecodeError::InvalidHeader(
210                "packet too short to extract header protection sample",
211            ));
212        }
213
214        header_crypto.decrypt(pn_offset, buf.get_mut());
215
216        let len = PacketNumber::decode_len(buf.get_ref()[0]);
217        PacketNumber::decode(len, buf)
218    }
219}
220
221pub(crate) struct Packet {
222    pub(crate) header: Header,
223    pub(crate) header_data: Bytes,
224    pub(crate) payload: BytesMut,
225}
226
227impl Packet {
228    pub(crate) fn reserved_bits_valid(&self) -> bool {
229        let mask = match self.header {
230            Header::Short { .. } => SHORT_RESERVED_BITS,
231            _ => LONG_RESERVED_BITS,
232        };
233        self.header_data[0] & mask == 0
234    }
235}
236
237pub(crate) struct InitialPacket {
238    pub(crate) header: InitialHeader,
239    pub(crate) header_data: Bytes,
240    pub(crate) payload: BytesMut,
241}
242
243impl From<InitialPacket> for Packet {
244    fn from(x: InitialPacket) -> Self {
245        Self {
246            header: Header::Initial(x.header),
247            header_data: x.header_data,
248            payload: x.payload,
249        }
250    }
251}
252
253#[cfg_attr(test, derive(Clone))]
254#[derive(Debug)]
255pub(crate) enum Header {
256    Initial(InitialHeader),
257    Long {
258        ty: LongType,
259        dst_cid: ConnectionId,
260        src_cid: ConnectionId,
261        number: PacketNumber,
262        version: u32,
263    },
264    Retry {
265        dst_cid: ConnectionId,
266        src_cid: ConnectionId,
267        version: u32,
268    },
269    Short {
270        spin: bool,
271        key_phase: bool,
272        dst_cid: ConnectionId,
273        number: PacketNumber,
274    },
275    VersionNegotiate {
276        random: u8,
277        src_cid: ConnectionId,
278        dst_cid: ConnectionId,
279    },
280}
281
282impl Header {
283    pub(crate) fn encode(&self, w: &mut Vec<u8>) -> PartialEncode {
284        use Header::*;
285        let start = w.len();
286        match *self {
287            Initial(InitialHeader {
288                ref dst_cid,
289                ref src_cid,
290                ref token,
291                number,
292                version,
293            }) => {
294                w.write(u8::from(LongHeaderType::Initial) | number.tag());
295                w.write(version);
296                dst_cid.encode_long(w);
297                src_cid.encode_long(w);
298                w.write_var(token.len() as u64);
299                w.put_slice(token);
300                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
301                number.encode(w);
302                PartialEncode {
303                    start,
304                    header_len: w.len() - start,
305                    pn: Some((number.len(), true)),
306                }
307            }
308            Long {
309                ty,
310                ref dst_cid,
311                ref src_cid,
312                number,
313                version,
314            } => {
315                w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
316                w.write(version);
317                dst_cid.encode_long(w);
318                src_cid.encode_long(w);
319                w.write::<u16>(0); // Placeholder for payload length; see `set_payload_length`
320                number.encode(w);
321                PartialEncode {
322                    start,
323                    header_len: w.len() - start,
324                    pn: Some((number.len(), true)),
325                }
326            }
327            Retry {
328                ref dst_cid,
329                ref src_cid,
330                version,
331            } => {
332                w.write(u8::from(LongHeaderType::Retry));
333                w.write(version);
334                dst_cid.encode_long(w);
335                src_cid.encode_long(w);
336                PartialEncode {
337                    start,
338                    header_len: w.len() - start,
339                    pn: None,
340                }
341            }
342            Short {
343                spin,
344                key_phase,
345                ref dst_cid,
346                number,
347            } => {
348                w.write(
349                    FIXED_BIT
350                        | if key_phase { KEY_PHASE_BIT } else { 0 }
351                        | if spin { SPIN_BIT } else { 0 }
352                        | number.tag(),
353                );
354                w.put_slice(dst_cid);
355                number.encode(w);
356                PartialEncode {
357                    start,
358                    header_len: w.len() - start,
359                    pn: Some((number.len(), false)),
360                }
361            }
362            VersionNegotiate {
363                ref random,
364                ref dst_cid,
365                ref src_cid,
366            } => {
367                w.write(0x80u8 | random);
368                w.write::<u32>(0);
369                dst_cid.encode_long(w);
370                src_cid.encode_long(w);
371                PartialEncode {
372                    start,
373                    header_len: w.len() - start,
374                    pn: None,
375                }
376            }
377        }
378    }
379
380    /// Whether the packet is encrypted on the wire
381    pub(crate) fn is_protected(&self) -> bool {
382        !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
383    }
384
385    pub(crate) fn number(&self) -> Option<PacketNumber> {
386        use Header::*;
387        Some(match *self {
388            Initial(InitialHeader { number, .. }) => number,
389            Long { number, .. } => number,
390            Short { number, .. } => number,
391            _ => {
392                return None;
393            }
394        })
395    }
396
397    pub(crate) fn space(&self) -> SpaceId {
398        use Header::*;
399        match *self {
400            Short { .. } => SpaceId::Data,
401            Long {
402                ty: LongType::ZeroRtt,
403                ..
404            } => SpaceId::Data,
405            Long {
406                ty: LongType::Handshake,
407                ..
408            } => SpaceId::Handshake,
409            _ => SpaceId::Initial,
410        }
411    }
412
413    pub(crate) fn key_phase(&self) -> bool {
414        match *self {
415            Self::Short { key_phase, .. } => key_phase,
416            _ => false,
417        }
418    }
419
420    pub(crate) fn is_short(&self) -> bool {
421        matches!(*self, Self::Short { .. })
422    }
423
424    pub(crate) fn is_1rtt(&self) -> bool {
425        self.is_short()
426    }
427
428    pub(crate) fn is_0rtt(&self) -> bool {
429        matches!(
430            *self,
431            Self::Long {
432                ty: LongType::ZeroRtt,
433                ..
434            }
435        )
436    }
437
438    pub(crate) fn dst_cid(&self) -> ConnectionId {
439        use Header::*;
440        match *self {
441            Initial(InitialHeader { dst_cid, .. }) => dst_cid,
442            Long { dst_cid, .. } => dst_cid,
443            Retry { dst_cid, .. } => dst_cid,
444            Short { dst_cid, .. } => dst_cid,
445            VersionNegotiate { dst_cid, .. } => dst_cid,
446        }
447    }
448
449    /// Whether the payload of this packet contains QUIC frames
450    pub(crate) fn has_frames(&self) -> bool {
451        use Header::*;
452        match *self {
453            Initial(_) => true,
454            Long { .. } => true,
455            Retry { .. } => false,
456            Short { .. } => true,
457            VersionNegotiate { .. } => false,
458        }
459    }
460}
461
462pub(crate) struct PartialEncode {
463    pub(crate) start: usize,
464    pub(crate) header_len: usize,
465    // Packet number length, payload length needed
466    pn: Option<(usize, bool)>,
467}
468
469impl PartialEncode {
470    pub(crate) fn finish(
471        self,
472        buf: &mut [u8],
473        header_crypto: &dyn crypto::HeaderKey,
474        crypto: Option<(u64, &dyn crypto::PacketKey)>,
475    ) {
476        let Self { header_len, pn, .. } = self;
477        let (pn_len, write_len) = match pn {
478            Some((pn_len, write_len)) => (pn_len, write_len),
479            None => return,
480        };
481
482        let pn_pos = header_len - pn_len;
483        if write_len {
484            let len = buf.len() - header_len + pn_len;
485            assert!(len < 2usize.pow(14)); // Fits in reserved space
486            let mut slice = &mut buf[pn_pos - 2..pn_pos];
487            slice.put_u16(len as u16 | 0b01 << 14);
488        }
489
490        if let Some((number, crypto)) = crypto {
491            crypto.encrypt(number, buf, header_len);
492        }
493
494        debug_assert!(
495            pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
496            "packet must be padded to at least {} bytes for header protection sampling",
497            pn_pos + 4 + header_crypto.sample_size()
498        );
499        header_crypto.encrypt(pn_pos, buf);
500    }
501}
502
503/// Plain packet header
504#[derive(Clone, Debug)]
505pub enum ProtectedHeader {
506    /// An Initial packet header
507    Initial(ProtectedInitialHeader),
508    /// A Long packet header, as used during the handshake
509    Long {
510        /// Type of the Long header packet
511        ty: LongType,
512        /// Destination Connection ID
513        dst_cid: ConnectionId,
514        /// Source Connection ID
515        src_cid: ConnectionId,
516        /// Length of the packet payload
517        len: u64,
518        /// QUIC version
519        version: u32,
520    },
521    /// A Retry packet header
522    Retry {
523        /// Destination Connection ID
524        dst_cid: ConnectionId,
525        /// Source Connection ID
526        src_cid: ConnectionId,
527        /// QUIC version
528        version: u32,
529    },
530    /// A short packet header, as used during the data phase
531    Short {
532        /// Spin bit
533        spin: bool,
534        /// Destination Connection ID
535        dst_cid: ConnectionId,
536    },
537    /// A Version Negotiation packet header
538    VersionNegotiate {
539        /// Random value
540        random: u8,
541        /// Destination Connection ID
542        dst_cid: ConnectionId,
543        /// Source Connection ID
544        src_cid: ConnectionId,
545    },
546}
547
548impl ProtectedHeader {
549    fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
550        match self {
551            Self::Initial(x) => Some(x),
552            _ => None,
553        }
554    }
555
556    /// The destination Connection ID of the packet
557    pub fn dst_cid(&self) -> &ConnectionId {
558        use ProtectedHeader::*;
559        match self {
560            Initial(header) => &header.dst_cid,
561            Long { dst_cid, .. } => dst_cid,
562            Retry { dst_cid, .. } => dst_cid,
563            Short { dst_cid, .. } => dst_cid,
564            VersionNegotiate { dst_cid, .. } => dst_cid,
565        }
566    }
567
568    fn payload_len(&self) -> Option<u64> {
569        use ProtectedHeader::*;
570        match self {
571            Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
572            _ => None,
573        }
574    }
575
576    /// Decode a plain header from given buffer, with given [`ConnectionIdParser`].
577    pub fn decode(
578        buf: &mut io::Cursor<BytesMut>,
579        cid_parser: &(impl ConnectionIdParser + ?Sized),
580        supported_versions: &[u32],
581        grease_quic_bit: bool,
582    ) -> Result<Self, PacketDecodeError> {
583        let first = buf.get::<u8>()?;
584        if !grease_quic_bit && first & FIXED_BIT == 0 {
585            return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
586        }
587        if first & LONG_HEADER_FORM == 0 {
588            let spin = first & SPIN_BIT != 0;
589
590            Ok(Self::Short {
591                spin,
592                dst_cid: cid_parser.parse(buf)?,
593            })
594        } else {
595            let version = buf.get::<u32>()?;
596
597            let dst_cid = ConnectionId::decode_long(buf)
598                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
599            let src_cid = ConnectionId::decode_long(buf)
600                .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
601
602            // TODO: Support long CIDs for compatibility with future QUIC versions
603            if version == 0 {
604                let random = first & !LONG_HEADER_FORM;
605                return Ok(Self::VersionNegotiate {
606                    random,
607                    dst_cid,
608                    src_cid,
609                });
610            }
611
612            if !supported_versions.contains(&version) {
613                return Err(PacketDecodeError::UnsupportedVersion {
614                    src_cid,
615                    dst_cid,
616                    version,
617                });
618            }
619
620            match LongHeaderType::from_byte(first)? {
621                LongHeaderType::Initial => {
622                    let token_len = buf.get_var()? as usize;
623                    let token_start = buf.position() as usize;
624                    if token_len > buf.remaining() {
625                        return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
626                    }
627                    buf.advance(token_len);
628
629                    let len = buf.get_var()?;
630                    Ok(Self::Initial(ProtectedInitialHeader {
631                        dst_cid,
632                        src_cid,
633                        token_pos: token_start..token_start + token_len,
634                        len,
635                        version,
636                    }))
637                }
638                LongHeaderType::Retry => Ok(Self::Retry {
639                    dst_cid,
640                    src_cid,
641                    version,
642                }),
643                LongHeaderType::Standard(ty) => Ok(Self::Long {
644                    ty,
645                    dst_cid,
646                    src_cid,
647                    len: buf.get_var()?,
648                    version,
649                }),
650            }
651        }
652    }
653}
654
655/// Header of an Initial packet, before decryption
656#[derive(Clone, Debug)]
657pub struct ProtectedInitialHeader {
658    /// Destination Connection ID
659    pub dst_cid: ConnectionId,
660    /// Source Connection ID
661    pub src_cid: ConnectionId,
662    /// The position of a token in the packet buffer
663    pub token_pos: Range<usize>,
664    /// Length of the packet payload
665    pub len: u64,
666    /// QUIC version
667    pub version: u32,
668}
669
670#[derive(Clone, Debug)]
671pub(crate) struct InitialHeader {
672    pub(crate) dst_cid: ConnectionId,
673    pub(crate) src_cid: ConnectionId,
674    pub(crate) token: Bytes,
675    pub(crate) number: PacketNumber,
676    pub(crate) version: u32,
677}
678
679// An encoded packet number
680#[derive(Debug, Copy, Clone, Eq, PartialEq)]
681pub(crate) enum PacketNumber {
682    U8(u8),
683    U16(u16),
684    U24(u32),
685    U32(u32),
686}
687
688impl PacketNumber {
689    pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
690        let range = (n - largest_acked) * 2;
691        if range < 1 << 8 {
692            Self::U8(n as u8)
693        } else if range < 1 << 16 {
694            Self::U16(n as u16)
695        } else if range < 1 << 24 {
696            Self::U24(n as u32)
697        } else if range < 1 << 32 {
698            Self::U32(n as u32)
699        } else {
700            panic!("packet number too large to encode")
701        }
702    }
703
704    pub(crate) fn len(self) -> usize {
705        use PacketNumber::*;
706        match self {
707            U8(_) => 1,
708            U16(_) => 2,
709            U24(_) => 3,
710            U32(_) => 4,
711        }
712    }
713
714    pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
715        use PacketNumber::*;
716        match self {
717            U8(x) => w.write(x),
718            U16(x) => w.write(x),
719            U24(x) => w.put_uint(u64::from(x), 3),
720            U32(x) => w.write(x),
721        }
722    }
723
724    pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
725        use PacketNumber::*;
726        let pn = match len {
727            1 => U8(r.get()?),
728            2 => U16(r.get()?),
729            3 => U24(r.get_uint(3) as u32),
730            4 => U32(r.get()?),
731            _ => unreachable!(),
732        };
733        Ok(pn)
734    }
735
736    pub(crate) fn decode_len(tag: u8) -> usize {
737        1 + (tag & 0x03) as usize
738    }
739
740    fn tag(self) -> u8 {
741        use PacketNumber::*;
742        match self {
743            U8(_) => 0b00,
744            U16(_) => 0b01,
745            U24(_) => 0b10,
746            U32(_) => 0b11,
747        }
748    }
749
750    pub(crate) fn expand(self, expected: u64) -> u64 {
751        // From Appendix A
752        use PacketNumber::*;
753        let truncated = match self {
754            U8(x) => u64::from(x),
755            U16(x) => u64::from(x),
756            U24(x) => u64::from(x),
757            U32(x) => u64::from(x),
758        };
759        let nbits = self.len() * 8;
760        let win = 1 << nbits;
761        let hwin = win / 2;
762        let mask = win - 1;
763        // The incoming packet number should be greater than expected - hwin and less than or equal
764        // to expected + hwin
765        //
766        // This means we can't just strip the trailing bits from expected and add the truncated
767        // because that might yield a value outside the window.
768        //
769        // The following code calculates a candidate value and makes sure it's within the packet
770        // number window.
771        let candidate = (expected & !mask) | truncated;
772        if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
773            candidate + win
774        } else if candidate > expected + hwin && candidate > win {
775            candidate - win
776        } else {
777            candidate
778        }
779    }
780}
781
782/// A [`ConnectionIdParser`] implementation that assumes the connection ID is of fixed length
783pub struct FixedLengthConnectionIdParser {
784    expected_len: usize,
785}
786
787impl FixedLengthConnectionIdParser {
788    /// Create a new instance of `FixedLengthConnectionIdParser`
789    pub fn new(expected_len: usize) -> Self {
790        Self { expected_len }
791    }
792}
793
794impl ConnectionIdParser for FixedLengthConnectionIdParser {
795    fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
796        (buffer.remaining() >= self.expected_len)
797            .then(|| ConnectionId::from_buf(buffer, self.expected_len))
798            .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
799    }
800}
801
802/// Parse connection id in short header packet
803pub trait ConnectionIdParser {
804    /// Parse a connection id from given buffer
805    fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
806}
807
808/// Long packet type including non-uniform cases
809#[derive(Clone, Copy, Debug, Eq, PartialEq)]
810pub(crate) enum LongHeaderType {
811    Initial,
812    Retry,
813    Standard(LongType),
814}
815
816impl LongHeaderType {
817    fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
818        use {LongHeaderType::*, LongType::*};
819        debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
820        Ok(match (b & 0x30) >> 4 {
821            0x0 => Initial,
822            0x1 => Standard(ZeroRtt),
823            0x2 => Standard(Handshake),
824            0x3 => Retry,
825            _ => unreachable!(),
826        })
827    }
828}
829
830impl From<LongHeaderType> for u8 {
831    fn from(ty: LongHeaderType) -> Self {
832        use {LongHeaderType::*, LongType::*};
833        match ty {
834            Initial => LONG_HEADER_FORM | FIXED_BIT,
835            Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
836            Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
837            Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
838        }
839    }
840}
841
842/// Long packet types with uniform header structure
843#[derive(Clone, Copy, Debug, Eq, PartialEq)]
844pub enum LongType {
845    /// Handshake packet
846    Handshake,
847    /// 0-RTT packet
848    ZeroRtt,
849}
850
851/// Packet decode error
852#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
853pub enum PacketDecodeError {
854    /// Packet uses a QUIC version that is not supported
855    #[error("unsupported version {version:x}")]
856    UnsupportedVersion {
857        /// Source Connection ID
858        src_cid: ConnectionId,
859        /// Destination Connection ID
860        dst_cid: ConnectionId,
861        /// The version that was unsupported
862        version: u32,
863    },
864    /// The packet header is invalid
865    #[error("invalid header: {0}")]
866    InvalidHeader(&'static str),
867}
868
869impl From<coding::UnexpectedEnd> for PacketDecodeError {
870    fn from(_: coding::UnexpectedEnd) -> Self {
871        Self::InvalidHeader("unexpected end of packet")
872    }
873}
874
875pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
876pub(crate) const FIXED_BIT: u8 = 0x40;
877pub(crate) const SPIN_BIT: u8 = 0x20;
878const SHORT_RESERVED_BITS: u8 = 0x18;
879const LONG_RESERVED_BITS: u8 = 0x0c;
880const KEY_PHASE_BIT: u8 = 0x04;
881
882/// Packet number space identifiers
883#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
884pub enum SpaceId {
885    /// Unprotected packets, used to bootstrap the handshake
886    Initial = 0,
887    Handshake = 1,
888    /// Application data space, used for 0-RTT and post-handshake/1-RTT packets
889    Data = 2,
890}
891
892impl SpaceId {
893    pub fn iter() -> impl Iterator<Item = Self> {
894        [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
895    }
896}
897
898#[cfg(test)]
899mod tests {
900    use super::*;
901    use hex_literal::hex;
902    use std::io;
903
904    fn check_pn(typed: PacketNumber, encoded: &[u8]) {
905        let mut buf = Vec::new();
906        typed.encode(&mut buf);
907        assert_eq!(&buf[..], encoded);
908        let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
909        assert_eq!(typed, decoded);
910    }
911
912    #[test]
913    fn roundtrip_packet_numbers() {
914        check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
915        check_pn(PacketNumber::U16(0x80), &hex!("0080"));
916        check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
917        check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
918        check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
919    }
920
921    #[test]
922    fn pn_encode() {
923        check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
924        check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
925        check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
926    }
927
928    #[test]
929    fn pn_expand_roundtrip() {
930        for expected in 0..1024 {
931            for actual in expected..1024 {
932                assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
933            }
934        }
935    }
936
937    #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
938    #[test]
939    fn header_encoding() {
940        use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
941        use crate::Side;
942        #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
943        use rustls::crypto::aws_lc_rs::default_provider;
944        #[cfg(feature = "rustls-ring")]
945        use rustls::crypto::ring::default_provider;
946        use rustls::quic::Version;
947
948        let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
949        let provider = default_provider();
950
951        let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
952        let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
953        let mut buf = Vec::new();
954        let header = Header::Initial(InitialHeader {
955            number: PacketNumber::U8(0),
956            src_cid: ConnectionId::new(&[]),
957            dst_cid: dcid,
958            token: Bytes::new(),
959            version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
960        });
961        let encode = header.encode(&mut buf);
962        let header_len = buf.len();
963        buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
964        encode.finish(
965            &mut buf,
966            &*client.header.local,
967            Some((0, &*client.packet.local)),
968        );
969
970        for byte in &buf {
971            print!("{byte:02x}");
972        }
973        println!();
974        assert_eq!(
975            buf[..],
976            hex!(
977                "c8000000010806b858ec6f80452b00004021be
978                 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
979            )[..]
980        );
981
982        let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
983        let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
984        let decode = PartialDecode::new(
985            buf.as_slice().into(),
986            &FixedLengthConnectionIdParser::new(0),
987            &supported_versions,
988            false,
989        )
990        .unwrap()
991        .0;
992        let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
993        assert_eq!(
994            packet.header_data[..],
995            hex!("c0000000010806b858ec6f80452b0000402100")[..]
996        );
997        server
998            .packet
999            .remote
1000            .decrypt(0, &packet.header_data, &mut packet.payload)
1001            .unwrap();
1002        assert_eq!(packet.payload[..], [0; 16]);
1003        match packet.header {
1004            Header::Initial(InitialHeader {
1005                number: PacketNumber::U8(0),
1006                ..
1007            }) => {}
1008            _ => {
1009                panic!("unexpected header {:?}", packet.header);
1010            }
1011        }
1012    }
1013}