1use std::{cmp::Ordering, io, ops::Range, str};
9
10use bytes::{Buf, BufMut, Bytes, BytesMut};
11use thiserror::Error;
12
13use crate::{
14 ConnectionId,
15 coding::{self, BufExt, BufMutExt},
16 crypto,
17};
18
19#[cfg_attr(test, derive(Clone))]
32#[derive(Debug)]
33pub struct PartialDecode {
34 plain_header: ProtectedHeader,
35 buf: io::Cursor<BytesMut>,
36}
37
38#[allow(clippy::len_without_is_empty)]
39impl PartialDecode {
40 pub fn new(
42 bytes: BytesMut,
43 cid_parser: &(impl ConnectionIdParser + ?Sized),
44 supported_versions: &[u32],
45 grease_quic_bit: bool,
46 ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
47 let mut buf = io::Cursor::new(bytes);
48 let plain_header =
49 ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
50 let dgram_len = buf.get_ref().len();
51 let packet_len = plain_header
52 .payload_len()
53 .map(|len| (buf.position() + len) as usize)
54 .unwrap_or(dgram_len);
55 match dgram_len.cmp(&packet_len) {
56 Ordering::Equal => Ok((Self { plain_header, buf }, None)),
57 Ordering::Less => Err(PacketDecodeError::InvalidHeader(
58 "packet too short to contain payload length",
59 )),
60 Ordering::Greater => {
61 let rest = Some(buf.get_mut().split_off(packet_len));
62 Ok((Self { plain_header, buf }, rest))
63 }
64 }
65 }
66
67 pub(crate) fn data(&self) -> &[u8] {
69 self.buf.get_ref()
70 }
71
72 pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
73 self.plain_header.as_initial()
74 }
75
76 pub(crate) fn has_long_header(&self) -> bool {
77 !matches!(self.plain_header, ProtectedHeader::Short { .. })
78 }
79
80 pub(crate) fn is_initial(&self) -> bool {
81 self.space() == Some(SpaceId::Initial)
82 }
83
84 pub(crate) fn space(&self) -> Option<SpaceId> {
85 use ProtectedHeader::*;
86 match self.plain_header {
87 Initial { .. } => Some(SpaceId::Initial),
88 Long {
89 ty: LongType::Handshake,
90 ..
91 } => Some(SpaceId::Handshake),
92 Long {
93 ty: LongType::ZeroRtt,
94 ..
95 } => Some(SpaceId::Data),
96 Short { .. } => Some(SpaceId::Data),
97 _ => None,
98 }
99 }
100
101 pub(crate) fn is_0rtt(&self) -> bool {
102 match self.plain_header {
103 ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
104 _ => false,
105 }
106 }
107
108 pub fn dst_cid(&self) -> &ConnectionId {
110 self.plain_header.dst_cid()
111 }
112
113 pub fn len(&self) -> usize {
115 self.buf.get_ref().len()
116 }
117
118 pub(crate) fn finish(
119 self,
120 header_crypto: Option<&dyn crypto::HeaderKey>,
121 ) -> Result<Packet, PacketDecodeError> {
122 use ProtectedHeader::*;
123 let Self {
124 plain_header,
125 mut buf,
126 } = self;
127
128 if let Initial(ProtectedInitialHeader {
129 dst_cid,
130 src_cid,
131 token_pos,
132 version,
133 ..
134 }) = plain_header
135 {
136 let number = match header_crypto {
137 Some(crypto) => Self::decrypt_header(&mut buf, crypto)?,
138 None => {
139 return Err(PacketDecodeError::InvalidHeader(
140 "header crypto should be available".into(),
141 ));
142 }
143 };
144 let header_len = buf.position() as usize;
145 let mut bytes = buf.into_inner();
146
147 let header_data = bytes.split_to(header_len).freeze();
148 let token = header_data.slice(token_pos.start..token_pos.end);
149 return Ok(Packet {
150 header: Header::Initial(InitialHeader {
151 dst_cid,
152 src_cid,
153 token,
154 number,
155 version,
156 }),
157 header_data,
158 payload: bytes,
159 });
160 }
161
162 let header = match plain_header {
163 Long {
164 ty,
165 dst_cid,
166 src_cid,
167 version,
168 ..
169 } => Header::Long {
170 ty,
171 dst_cid,
172 src_cid,
173 number: match header_crypto {
174 Some(crypto) => Self::decrypt_header(&mut buf, crypto)?,
175 None => {
176 return Err(PacketDecodeError::InvalidHeader(
177 "header crypto should be available for long header packet".into(),
178 ));
179 }
180 },
181 version,
182 },
183 Retry {
184 dst_cid,
185 src_cid,
186 version,
187 } => Header::Retry {
188 dst_cid,
189 src_cid,
190 version,
191 },
192 Short { spin, dst_cid, .. } => {
193 let number = match header_crypto {
194 Some(crypto) => Self::decrypt_header(&mut buf, crypto)?,
195 None => {
196 return Err(PacketDecodeError::InvalidHeader(
197 "header crypto should be available for initial packet".into(),
198 ));
199 }
200 };
201 let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
202 Header::Short {
203 spin,
204 key_phase,
205 dst_cid,
206 number,
207 }
208 }
209 VersionNegotiate {
210 random,
211 dst_cid,
212 src_cid,
213 } => Header::VersionNegotiate {
214 random,
215 dst_cid,
216 src_cid,
217 },
218 Initial { .. } => unreachable!(),
219 };
220
221 let header_len = buf.position() as usize;
222 let mut bytes = buf.into_inner();
223 Ok(Packet {
224 header,
225 header_data: bytes.split_to(header_len).freeze(),
226 payload: bytes,
227 })
228 }
229
230 fn decrypt_header(
231 buf: &mut io::Cursor<BytesMut>,
232 header_crypto: &dyn crypto::HeaderKey,
233 ) -> Result<PacketNumber, PacketDecodeError> {
234 let packet_length = buf.get_ref().len();
235 let pn_offset = buf.position() as usize;
236 if packet_length < pn_offset + 4 + header_crypto.sample_size() {
237 return Err(PacketDecodeError::InvalidHeader(
238 "packet too short to extract header protection sample",
239 ));
240 }
241
242 header_crypto.decrypt(pn_offset, buf.get_mut());
243
244 let len = PacketNumber::decode_len(buf.get_ref()[0]);
245 PacketNumber::decode(len, buf)
246 }
247}
248
249pub(crate) struct Packet {
250 pub(crate) header: Header,
251 pub(crate) header_data: Bytes,
252 pub(crate) payload: BytesMut,
253}
254
255impl Packet {
256 pub(crate) fn reserved_bits_valid(&self) -> bool {
257 let mask = match self.header {
258 Header::Short { .. } => SHORT_RESERVED_BITS,
259 _ => LONG_RESERVED_BITS,
260 };
261 self.header_data[0] & mask == 0
262 }
263}
264
265pub(crate) struct InitialPacket {
266 pub(crate) header: InitialHeader,
267 pub(crate) header_data: Bytes,
268 pub(crate) payload: BytesMut,
269}
270
271impl From<InitialPacket> for Packet {
272 fn from(x: InitialPacket) -> Self {
273 Self {
274 header: Header::Initial(x.header),
275 header_data: x.header_data,
276 payload: x.payload,
277 }
278 }
279}
280
281#[cfg_attr(test, derive(Clone))]
282#[derive(Debug)]
283pub(crate) enum Header {
284 Initial(InitialHeader),
285 Long {
286 ty: LongType,
287 dst_cid: ConnectionId,
288 src_cid: ConnectionId,
289 number: PacketNumber,
290 version: u32,
291 },
292 Retry {
293 dst_cid: ConnectionId,
294 src_cid: ConnectionId,
295 version: u32,
296 },
297 Short {
298 spin: bool,
299 key_phase: bool,
300 dst_cid: ConnectionId,
301 number: PacketNumber,
302 },
303 VersionNegotiate {
304 random: u8,
305 src_cid: ConnectionId,
306 dst_cid: ConnectionId,
307 },
308}
309
310impl Header {
311 pub(crate) fn encode(&self, w: &mut Vec<u8>) -> PartialEncode {
312 use Header::*;
313 let start = w.len();
314 match *self {
315 Initial(InitialHeader {
316 ref dst_cid,
317 ref src_cid,
318 ref token,
319 number,
320 version,
321 }) => {
322 w.write(u8::from(LongHeaderType::Initial) | number.tag());
323 w.write(version);
324 dst_cid.encode_long(w);
325 src_cid.encode_long(w);
326 w.write_var(token.len() as u64);
327 w.put_slice(token);
328 w.write::<u16>(0); number.encode(w);
330 PartialEncode {
331 start,
332 header_len: w.len() - start,
333 pn: Some((number.len(), true)),
334 }
335 }
336 Long {
337 ty,
338 ref dst_cid,
339 ref src_cid,
340 number,
341 version,
342 } => {
343 w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
344 w.write(version);
345 dst_cid.encode_long(w);
346 src_cid.encode_long(w);
347 w.write::<u16>(0); number.encode(w);
349 PartialEncode {
350 start,
351 header_len: w.len() - start,
352 pn: Some((number.len(), true)),
353 }
354 }
355 Retry {
356 ref dst_cid,
357 ref src_cid,
358 version,
359 } => {
360 w.write(u8::from(LongHeaderType::Retry));
361 w.write(version);
362 dst_cid.encode_long(w);
363 src_cid.encode_long(w);
364 PartialEncode {
365 start,
366 header_len: w.len() - start,
367 pn: None,
368 }
369 }
370 Short {
371 spin,
372 key_phase,
373 ref dst_cid,
374 number,
375 } => {
376 w.write(
377 FIXED_BIT
378 | if key_phase { KEY_PHASE_BIT } else { 0 }
379 | if spin { SPIN_BIT } else { 0 }
380 | number.tag(),
381 );
382 w.put_slice(dst_cid);
383 number.encode(w);
384 PartialEncode {
385 start,
386 header_len: w.len() - start,
387 pn: Some((number.len(), false)),
388 }
389 }
390 VersionNegotiate {
391 ref random,
392 ref dst_cid,
393 ref src_cid,
394 } => {
395 w.write(0x80u8 | random);
396 w.write::<u32>(0);
397 dst_cid.encode_long(w);
398 src_cid.encode_long(w);
399 PartialEncode {
400 start,
401 header_len: w.len() - start,
402 pn: None,
403 }
404 }
405 }
406 }
407
408 pub(crate) fn is_protected(&self) -> bool {
410 !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
411 }
412
413 pub(crate) fn number(&self) -> Option<PacketNumber> {
414 use Header::*;
415 Some(match *self {
416 Initial(InitialHeader { number, .. }) => number,
417 Long { number, .. } => number,
418 Short { number, .. } => number,
419 _ => {
420 return None;
421 }
422 })
423 }
424
425 pub(crate) fn space(&self) -> SpaceId {
426 use Header::*;
427 match *self {
428 Short { .. } => SpaceId::Data,
429 Long {
430 ty: LongType::ZeroRtt,
431 ..
432 } => SpaceId::Data,
433 Long {
434 ty: LongType::Handshake,
435 ..
436 } => SpaceId::Handshake,
437 _ => SpaceId::Initial,
438 }
439 }
440
441 pub(crate) fn key_phase(&self) -> bool {
442 match *self {
443 Self::Short { key_phase, .. } => key_phase,
444 _ => false,
445 }
446 }
447
448 pub(crate) fn is_short(&self) -> bool {
449 matches!(*self, Self::Short { .. })
450 }
451
452 pub(crate) fn is_1rtt(&self) -> bool {
453 self.is_short()
454 }
455
456 pub(crate) fn is_0rtt(&self) -> bool {
457 matches!(
458 *self,
459 Self::Long {
460 ty: LongType::ZeroRtt,
461 ..
462 }
463 )
464 }
465
466 pub(crate) fn dst_cid(&self) -> ConnectionId {
467 use Header::*;
468 match *self {
469 Initial(InitialHeader { dst_cid, .. }) => dst_cid,
470 Long { dst_cid, .. } => dst_cid,
471 Retry { dst_cid, .. } => dst_cid,
472 Short { dst_cid, .. } => dst_cid,
473 VersionNegotiate { dst_cid, .. } => dst_cid,
474 }
475 }
476
477 pub(crate) fn has_frames(&self) -> bool {
479 use Header::*;
480 match *self {
481 Initial(_) => true,
482 Long { .. } => true,
483 Retry { .. } => false,
484 Short { .. } => true,
485 VersionNegotiate { .. } => false,
486 }
487 }
488}
489
490pub(crate) struct PartialEncode {
491 pub(crate) start: usize,
492 pub(crate) header_len: usize,
493 pn: Option<(usize, bool)>,
495}
496
497impl PartialEncode {
498 pub(crate) fn finish(
499 self,
500 buf: &mut [u8],
501 header_crypto: &dyn crypto::HeaderKey,
502 crypto: Option<(u64, &dyn crypto::PacketKey)>,
503 ) {
504 let Self { header_len, pn, .. } = self;
505 let (pn_len, write_len) = match pn {
506 Some((pn_len, write_len)) => (pn_len, write_len),
507 None => return,
508 };
509
510 let pn_pos = header_len - pn_len;
511 if write_len {
512 let len = buf.len() - header_len + pn_len;
513 assert!(len < 2usize.pow(14)); let mut slice = &mut buf[pn_pos - 2..pn_pos];
515 slice.put_u16(len as u16 | (0b01 << 14));
516 }
517
518 if let Some((number, crypto)) = crypto {
519 crypto.encrypt(number, buf, header_len);
520 }
521
522 debug_assert!(
523 pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
524 "packet must be padded to at least {} bytes for header protection sampling",
525 pn_pos + 4 + header_crypto.sample_size()
526 );
527 header_crypto.encrypt(pn_pos, buf);
528 }
529}
530
531#[derive(Clone, Debug)]
533pub enum ProtectedHeader {
534 Initial(ProtectedInitialHeader),
536 Long {
538 ty: LongType,
540 dst_cid: ConnectionId,
542 src_cid: ConnectionId,
544 len: u64,
546 version: u32,
548 },
549 Retry {
551 dst_cid: ConnectionId,
553 src_cid: ConnectionId,
555 version: u32,
557 },
558 Short {
560 spin: bool,
562 dst_cid: ConnectionId,
564 },
565 VersionNegotiate {
567 random: u8,
569 dst_cid: ConnectionId,
571 src_cid: ConnectionId,
573 },
574}
575
576impl ProtectedHeader {
577 fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
578 match self {
579 Self::Initial(x) => Some(x),
580 _ => None,
581 }
582 }
583
584 pub fn dst_cid(&self) -> &ConnectionId {
586 use ProtectedHeader::*;
587 match self {
588 Initial(header) => &header.dst_cid,
589 Long { dst_cid, .. } => dst_cid,
590 Retry { dst_cid, .. } => dst_cid,
591 Short { dst_cid, .. } => dst_cid,
592 VersionNegotiate { dst_cid, .. } => dst_cid,
593 }
594 }
595
596 fn payload_len(&self) -> Option<u64> {
597 use ProtectedHeader::*;
598 match self {
599 Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
600 _ => None,
601 }
602 }
603
604 pub fn decode(
606 buf: &mut io::Cursor<BytesMut>,
607 cid_parser: &(impl ConnectionIdParser + ?Sized),
608 supported_versions: &[u32],
609 grease_quic_bit: bool,
610 ) -> Result<Self, PacketDecodeError> {
611 let first = buf.get::<u8>()?;
612 if !grease_quic_bit && first & FIXED_BIT == 0 {
613 return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
614 }
615 if first & LONG_HEADER_FORM == 0 {
616 let spin = first & SPIN_BIT != 0;
617
618 Ok(Self::Short {
619 spin,
620 dst_cid: cid_parser.parse(buf)?,
621 })
622 } else {
623 let version = buf.get::<u32>()?;
624
625 let dst_cid = ConnectionId::decode_long(buf)
626 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
627 let src_cid = ConnectionId::decode_long(buf)
628 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
629
630 if version == 0 {
632 let random = first & !LONG_HEADER_FORM;
633 return Ok(Self::VersionNegotiate {
634 random,
635 dst_cid,
636 src_cid,
637 });
638 }
639
640 if !supported_versions.contains(&version) {
641 return Err(PacketDecodeError::UnsupportedVersion {
642 src_cid,
643 dst_cid,
644 version,
645 });
646 }
647
648 match LongHeaderType::from_byte(first)? {
649 LongHeaderType::Initial => {
650 let token_len = buf.get_var()? as usize;
651 let token_start = buf.position() as usize;
652 if token_len > buf.remaining() {
653 return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
654 }
655 buf.advance(token_len);
656
657 let len = buf.get_var()?;
658 Ok(Self::Initial(ProtectedInitialHeader {
659 dst_cid,
660 src_cid,
661 token_pos: token_start..token_start + token_len,
662 len,
663 version,
664 }))
665 }
666 LongHeaderType::Retry => Ok(Self::Retry {
667 dst_cid,
668 src_cid,
669 version,
670 }),
671 LongHeaderType::Standard(ty) => Ok(Self::Long {
672 ty,
673 dst_cid,
674 src_cid,
675 len: buf.get_var()?,
676 version,
677 }),
678 }
679 }
680 }
681}
682
683#[derive(Clone, Debug)]
685pub struct ProtectedInitialHeader {
686 pub dst_cid: ConnectionId,
688 pub src_cid: ConnectionId,
690 pub token_pos: Range<usize>,
692 pub len: u64,
694 pub version: u32,
696}
697
698#[derive(Clone, Debug)]
699pub(crate) struct InitialHeader {
700 pub(crate) dst_cid: ConnectionId,
701 pub(crate) src_cid: ConnectionId,
702 pub(crate) token: Bytes,
703 pub(crate) number: PacketNumber,
704 pub(crate) version: u32,
705}
706
707#[derive(Debug, Copy, Clone, Eq, PartialEq)]
709pub(crate) enum PacketNumber {
710 U8(u8),
711 U16(u16),
712 U24(u32),
713 U32(u32),
714}
715
716impl PacketNumber {
717 pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
718 let range = (n - largest_acked) * 2;
719 if range < 1 << 8 {
720 Self::U8(n as u8)
721 } else if range < 1 << 16 {
722 Self::U16(n as u16)
723 } else if range < 1 << 24 {
724 Self::U24(n as u32)
725 } else if range < 1 << 32 {
726 Self::U32(n as u32)
727 } else {
728 Self::U32(n as u32)
731 }
732 }
733
734 pub(crate) fn len(self) -> usize {
735 use PacketNumber::*;
736 match self {
737 U8(_) => 1,
738 U16(_) => 2,
739 U24(_) => 3,
740 U32(_) => 4,
741 }
742 }
743
744 pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
745 use PacketNumber::*;
746 match self {
747 U8(x) => w.write(x),
748 U16(x) => w.write(x),
749 U24(x) => w.put_uint(u64::from(x), 3),
750 U32(x) => w.write(x),
751 }
752 }
753
754 pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
755 use PacketNumber::*;
756 let pn = match len {
757 1 => U8(r.get()?),
758 2 => U16(r.get()?),
759 3 => U24(r.get_uint(3) as u32),
760 4 => U32(r.get()?),
761 _ => unreachable!(),
762 };
763 Ok(pn)
764 }
765
766 pub(crate) fn decode_len(tag: u8) -> usize {
767 1 + (tag & 0x03) as usize
768 }
769
770 fn tag(self) -> u8 {
771 use PacketNumber::*;
772 match self {
773 U8(_) => 0b00,
774 U16(_) => 0b01,
775 U24(_) => 0b10,
776 U32(_) => 0b11,
777 }
778 }
779
780 pub(crate) fn expand(self, expected: u64) -> u64 {
781 use PacketNumber::*;
783 let truncated = match self {
784 U8(x) => u64::from(x),
785 U16(x) => u64::from(x),
786 U24(x) => u64::from(x),
787 U32(x) => u64::from(x),
788 };
789 let nbits = self.len() * 8;
790 let win = 1 << nbits;
791 let hwin = win / 2;
792 let mask = win - 1;
793 let candidate = (expected & !mask) | truncated;
802 if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
803 candidate + win
804 } else if candidate > expected + hwin && candidate > win {
805 candidate - win
806 } else {
807 candidate
808 }
809 }
810}
811
812pub struct FixedLengthConnectionIdParser {
814 expected_len: usize,
815}
816
817impl FixedLengthConnectionIdParser {
818 pub fn new(expected_len: usize) -> Self {
820 Self { expected_len }
821 }
822}
823
824impl ConnectionIdParser for FixedLengthConnectionIdParser {
825 fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
826 (buffer.remaining() >= self.expected_len)
827 .then(|| ConnectionId::from_buf(buffer, self.expected_len))
828 .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
829 }
830}
831
832pub trait ConnectionIdParser {
834 fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
836}
837
838#[derive(Clone, Copy, Debug, Eq, PartialEq)]
840pub(crate) enum LongHeaderType {
841 Initial,
842 Retry,
843 Standard(LongType),
844}
845
846impl LongHeaderType {
847 fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
848 use {LongHeaderType::*, LongType::*};
849 debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
850 Ok(match (b & 0x30) >> 4 {
851 0x0 => Initial,
852 0x1 => Standard(ZeroRtt),
853 0x2 => Standard(Handshake),
854 0x3 => Retry,
855 _ => unreachable!(),
856 })
857 }
858}
859
860impl From<LongHeaderType> for u8 {
861 fn from(ty: LongHeaderType) -> Self {
862 use {LongHeaderType::*, LongType::*};
863 match ty {
864 Initial => LONG_HEADER_FORM | FIXED_BIT,
865 Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
866 Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
867 Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
868 }
869 }
870}
871
872#[derive(Clone, Copy, Debug, Eq, PartialEq)]
874pub enum LongType {
875 Handshake,
877 ZeroRtt,
879}
880
881#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
883pub enum PacketDecodeError {
884 #[error("unsupported version {version:x}")]
886 UnsupportedVersion {
887 src_cid: ConnectionId,
889 dst_cid: ConnectionId,
891 version: u32,
893 },
894 #[error("invalid header: {0}")]
896 InvalidHeader(&'static str),
897}
898
899impl From<coding::UnexpectedEnd> for PacketDecodeError {
900 fn from(_: coding::UnexpectedEnd) -> Self {
901 Self::InvalidHeader("unexpected end of packet")
902 }
903}
904
905pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
906pub(crate) const FIXED_BIT: u8 = 0x40;
907pub(crate) const SPIN_BIT: u8 = 0x20;
908const SHORT_RESERVED_BITS: u8 = 0x18;
909const LONG_RESERVED_BITS: u8 = 0x0c;
910const KEY_PHASE_BIT: u8 = 0x04;
911
912#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
914#[allow(missing_docs)]
915pub enum SpaceId {
916 Initial = 0,
918 Handshake = 1,
919 Data = 2,
921}
922
923impl SpaceId {
924 #[allow(missing_docs)]
925 pub fn iter() -> impl Iterator<Item = Self> {
926 [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
927 }
928}
929
930#[cfg(test)]
931mod tests {
932 use super::*;
933 use hex_literal::hex;
934 use std::io;
935
936 fn check_pn(typed: PacketNumber, encoded: &[u8]) {
937 let mut buf = Vec::new();
938 typed.encode(&mut buf);
939 assert_eq!(&buf[..], encoded);
940 let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
941 assert_eq!(typed, decoded);
942 }
943
944 #[test]
945 fn roundtrip_packet_numbers() {
946 check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
947 check_pn(PacketNumber::U16(0x80), &hex!("0080"));
948 check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
949 check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
950 check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
951 }
952
953 #[test]
954 fn pn_encode() {
955 check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
956 check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
957 check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
958 }
959
960 #[test]
961 fn pn_expand_roundtrip() {
962 for expected in 0..1024 {
963 for actual in expected..1024 {
964 assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
965 }
966 }
967 }
968
969 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
970 #[test]
971 fn header_encoding() {
972 use crate::Side;
973 use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
974 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
975 use rustls::crypto::aws_lc_rs::default_provider;
976 #[cfg(feature = "rustls-ring")]
977 use rustls::crypto::ring::default_provider;
978 use rustls::quic::Version;
979
980 let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
981 let provider = default_provider();
982
983 let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
984 let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
985 let mut buf = Vec::new();
986 let header = Header::Initial(InitialHeader {
987 number: PacketNumber::U8(0),
988 src_cid: ConnectionId::new(&[]),
989 dst_cid: dcid,
990 token: Bytes::new(),
991 version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
992 });
993 let encode = header.encode(&mut buf);
994 let header_len = buf.len();
995 buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
996 encode.finish(
997 &mut buf,
998 &*client.header.local,
999 Some((0, &*client.packet.local)),
1000 );
1001
1002 for byte in &buf {
1003 print!("{byte:02x}");
1004 }
1005 println!();
1006 assert_eq!(
1007 buf[..],
1008 hex!(
1009 "c8000000010806b858ec6f80452b00004021be
1010 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
1011 )[..]
1012 );
1013
1014 let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
1015 let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
1016 let decode = PartialDecode::new(
1017 buf.as_slice().into(),
1018 &FixedLengthConnectionIdParser::new(0),
1019 &supported_versions,
1020 false,
1021 )
1022 .unwrap()
1023 .0;
1024 let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
1025 assert_eq!(
1026 packet.header_data[..],
1027 hex!("c0000000010806b858ec6f80452b0000402100")[..]
1028 );
1029 server
1030 .packet
1031 .remote
1032 .decrypt(0, &packet.header_data, &mut packet.payload)
1033 .unwrap();
1034 assert_eq!(packet.payload[..], [0; 16]);
1035 match packet.header {
1036 Header::Initial(InitialHeader {
1037 number: PacketNumber::U8(0),
1038 ..
1039 }) => {}
1040 _ => {
1041 panic!("unexpected header {:?}", packet.header);
1042 }
1043 }
1044 }
1045}