1use std::{cmp::Ordering, io, ops::Range, str};
2
3use bytes::{Buf, BufMut, Bytes, BytesMut};
4use thiserror::Error;
5
6use crate::{
7 coding::{self, BufExt, BufMutExt},
8 crypto, ConnectionId,
9};
10
11#[cfg_attr(test, derive(Clone))]
24#[derive(Debug)]
25pub struct PartialDecode {
26 plain_header: ProtectedHeader,
27 buf: io::Cursor<BytesMut>,
28}
29
30#[allow(clippy::len_without_is_empty)]
31impl PartialDecode {
32 pub fn new(
34 bytes: BytesMut,
35 cid_parser: &(impl ConnectionIdParser + ?Sized),
36 supported_versions: &[u32],
37 grease_quic_bit: bool,
38 ) -> Result<(Self, Option<BytesMut>), PacketDecodeError> {
39 let mut buf = io::Cursor::new(bytes);
40 let plain_header =
41 ProtectedHeader::decode(&mut buf, cid_parser, supported_versions, grease_quic_bit)?;
42 let dgram_len = buf.get_ref().len();
43 let packet_len = plain_header
44 .payload_len()
45 .map(|len| (buf.position() + len) as usize)
46 .unwrap_or(dgram_len);
47 match dgram_len.cmp(&packet_len) {
48 Ordering::Equal => Ok((Self { plain_header, buf }, None)),
49 Ordering::Less => Err(PacketDecodeError::InvalidHeader(
50 "packet too short to contain payload length",
51 )),
52 Ordering::Greater => {
53 let rest = Some(buf.get_mut().split_off(packet_len));
54 Ok((Self { plain_header, buf }, rest))
55 }
56 }
57 }
58
59 pub(crate) fn data(&self) -> &[u8] {
61 self.buf.get_ref()
62 }
63
64 pub(crate) fn initial_header(&self) -> Option<&ProtectedInitialHeader> {
65 self.plain_header.as_initial()
66 }
67
68 pub(crate) fn has_long_header(&self) -> bool {
69 !matches!(self.plain_header, ProtectedHeader::Short { .. })
70 }
71
72 pub(crate) fn is_initial(&self) -> bool {
73 self.space() == Some(SpaceId::Initial)
74 }
75
76 pub(crate) fn space(&self) -> Option<SpaceId> {
77 use ProtectedHeader::*;
78 match self.plain_header {
79 Initial { .. } => Some(SpaceId::Initial),
80 Long {
81 ty: LongType::Handshake,
82 ..
83 } => Some(SpaceId::Handshake),
84 Long {
85 ty: LongType::ZeroRtt,
86 ..
87 } => Some(SpaceId::Data),
88 Short { .. } => Some(SpaceId::Data),
89 _ => None,
90 }
91 }
92
93 pub(crate) fn is_0rtt(&self) -> bool {
94 match self.plain_header {
95 ProtectedHeader::Long { ty, .. } => ty == LongType::ZeroRtt,
96 _ => false,
97 }
98 }
99
100 pub fn dst_cid(&self) -> &ConnectionId {
102 self.plain_header.dst_cid()
103 }
104
105 #[allow(unreachable_pub)] pub fn len(&self) -> usize {
108 self.buf.get_ref().len()
109 }
110
111 pub(crate) fn finish(
112 self,
113 header_crypto: Option<&dyn crypto::HeaderKey>,
114 ) -> Result<Packet, PacketDecodeError> {
115 use ProtectedHeader::*;
116 let Self {
117 plain_header,
118 mut buf,
119 } = self;
120
121 if let Initial(ProtectedInitialHeader {
122 dst_cid,
123 src_cid,
124 token_pos,
125 version,
126 ..
127 }) = plain_header
128 {
129 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
130 let header_len = buf.position() as usize;
131 let mut bytes = buf.into_inner();
132
133 let header_data = bytes.split_to(header_len).freeze();
134 let token = header_data.slice(token_pos.start..token_pos.end);
135 return Ok(Packet {
136 header: Header::Initial(InitialHeader {
137 dst_cid,
138 src_cid,
139 token,
140 number,
141 version,
142 }),
143 header_data,
144 payload: bytes,
145 });
146 }
147
148 let header = match plain_header {
149 Long {
150 ty,
151 dst_cid,
152 src_cid,
153 version,
154 ..
155 } => Header::Long {
156 ty,
157 dst_cid,
158 src_cid,
159 number: Self::decrypt_header(&mut buf, header_crypto.unwrap())?,
160 version,
161 },
162 Retry {
163 dst_cid,
164 src_cid,
165 version,
166 } => Header::Retry {
167 dst_cid,
168 src_cid,
169 version,
170 },
171 Short { spin, dst_cid, .. } => {
172 let number = Self::decrypt_header(&mut buf, header_crypto.unwrap())?;
173 let key_phase = buf.get_ref()[0] & KEY_PHASE_BIT != 0;
174 Header::Short {
175 spin,
176 key_phase,
177 dst_cid,
178 number,
179 }
180 }
181 VersionNegotiate {
182 random,
183 dst_cid,
184 src_cid,
185 } => Header::VersionNegotiate {
186 random,
187 dst_cid,
188 src_cid,
189 },
190 Initial { .. } => unreachable!(),
191 };
192
193 let header_len = buf.position() as usize;
194 let mut bytes = buf.into_inner();
195 Ok(Packet {
196 header,
197 header_data: bytes.split_to(header_len).freeze(),
198 payload: bytes,
199 })
200 }
201
202 fn decrypt_header(
203 buf: &mut io::Cursor<BytesMut>,
204 header_crypto: &dyn crypto::HeaderKey,
205 ) -> Result<PacketNumber, PacketDecodeError> {
206 let packet_length = buf.get_ref().len();
207 let pn_offset = buf.position() as usize;
208 if packet_length < pn_offset + 4 + header_crypto.sample_size() {
209 return Err(PacketDecodeError::InvalidHeader(
210 "packet too short to extract header protection sample",
211 ));
212 }
213
214 header_crypto.decrypt(pn_offset, buf.get_mut());
215
216 let len = PacketNumber::decode_len(buf.get_ref()[0]);
217 PacketNumber::decode(len, buf)
218 }
219}
220
221pub(crate) struct Packet {
222 pub(crate) header: Header,
223 pub(crate) header_data: Bytes,
224 pub(crate) payload: BytesMut,
225}
226
227impl Packet {
228 pub(crate) fn reserved_bits_valid(&self) -> bool {
229 let mask = match self.header {
230 Header::Short { .. } => SHORT_RESERVED_BITS,
231 _ => LONG_RESERVED_BITS,
232 };
233 self.header_data[0] & mask == 0
234 }
235}
236
237pub(crate) struct InitialPacket {
238 pub(crate) header: InitialHeader,
239 pub(crate) header_data: Bytes,
240 pub(crate) payload: BytesMut,
241}
242
243impl From<InitialPacket> for Packet {
244 fn from(x: InitialPacket) -> Self {
245 Self {
246 header: Header::Initial(x.header),
247 header_data: x.header_data,
248 payload: x.payload,
249 }
250 }
251}
252
253#[cfg_attr(test, derive(Clone))]
254#[derive(Debug)]
255pub(crate) enum Header {
256 Initial(InitialHeader),
257 Long {
258 ty: LongType,
259 dst_cid: ConnectionId,
260 src_cid: ConnectionId,
261 number: PacketNumber,
262 version: u32,
263 },
264 Retry {
265 dst_cid: ConnectionId,
266 src_cid: ConnectionId,
267 version: u32,
268 },
269 Short {
270 spin: bool,
271 key_phase: bool,
272 dst_cid: ConnectionId,
273 number: PacketNumber,
274 },
275 VersionNegotiate {
276 random: u8,
277 src_cid: ConnectionId,
278 dst_cid: ConnectionId,
279 },
280}
281
282impl Header {
283 pub(crate) fn encode(&self, w: &mut Vec<u8>) -> PartialEncode {
284 use Header::*;
285 let start = w.len();
286 match *self {
287 Initial(InitialHeader {
288 ref dst_cid,
289 ref src_cid,
290 ref token,
291 number,
292 version,
293 }) => {
294 w.write(u8::from(LongHeaderType::Initial) | number.tag());
295 w.write(version);
296 dst_cid.encode_long(w);
297 src_cid.encode_long(w);
298 w.write_var(token.len() as u64);
299 w.put_slice(token);
300 w.write::<u16>(0); number.encode(w);
302 PartialEncode {
303 start,
304 header_len: w.len() - start,
305 pn: Some((number.len(), true)),
306 }
307 }
308 Long {
309 ty,
310 ref dst_cid,
311 ref src_cid,
312 number,
313 version,
314 } => {
315 w.write(u8::from(LongHeaderType::Standard(ty)) | number.tag());
316 w.write(version);
317 dst_cid.encode_long(w);
318 src_cid.encode_long(w);
319 w.write::<u16>(0); number.encode(w);
321 PartialEncode {
322 start,
323 header_len: w.len() - start,
324 pn: Some((number.len(), true)),
325 }
326 }
327 Retry {
328 ref dst_cid,
329 ref src_cid,
330 version,
331 } => {
332 w.write(u8::from(LongHeaderType::Retry));
333 w.write(version);
334 dst_cid.encode_long(w);
335 src_cid.encode_long(w);
336 PartialEncode {
337 start,
338 header_len: w.len() - start,
339 pn: None,
340 }
341 }
342 Short {
343 spin,
344 key_phase,
345 ref dst_cid,
346 number,
347 } => {
348 w.write(
349 FIXED_BIT
350 | if key_phase { KEY_PHASE_BIT } else { 0 }
351 | if spin { SPIN_BIT } else { 0 }
352 | number.tag(),
353 );
354 w.put_slice(dst_cid);
355 number.encode(w);
356 PartialEncode {
357 start,
358 header_len: w.len() - start,
359 pn: Some((number.len(), false)),
360 }
361 }
362 VersionNegotiate {
363 ref random,
364 ref dst_cid,
365 ref src_cid,
366 } => {
367 w.write(0x80u8 | random);
368 w.write::<u32>(0);
369 dst_cid.encode_long(w);
370 src_cid.encode_long(w);
371 PartialEncode {
372 start,
373 header_len: w.len() - start,
374 pn: None,
375 }
376 }
377 }
378 }
379
380 pub(crate) fn is_protected(&self) -> bool {
382 !matches!(*self, Self::Retry { .. } | Self::VersionNegotiate { .. })
383 }
384
385 pub(crate) fn number(&self) -> Option<PacketNumber> {
386 use Header::*;
387 Some(match *self {
388 Initial(InitialHeader { number, .. }) => number,
389 Long { number, .. } => number,
390 Short { number, .. } => number,
391 _ => {
392 return None;
393 }
394 })
395 }
396
397 pub(crate) fn space(&self) -> SpaceId {
398 use Header::*;
399 match *self {
400 Short { .. } => SpaceId::Data,
401 Long {
402 ty: LongType::ZeroRtt,
403 ..
404 } => SpaceId::Data,
405 Long {
406 ty: LongType::Handshake,
407 ..
408 } => SpaceId::Handshake,
409 _ => SpaceId::Initial,
410 }
411 }
412
413 pub(crate) fn key_phase(&self) -> bool {
414 match *self {
415 Self::Short { key_phase, .. } => key_phase,
416 _ => false,
417 }
418 }
419
420 pub(crate) fn is_short(&self) -> bool {
421 matches!(*self, Self::Short { .. })
422 }
423
424 pub(crate) fn is_1rtt(&self) -> bool {
425 self.is_short()
426 }
427
428 pub(crate) fn is_0rtt(&self) -> bool {
429 matches!(
430 *self,
431 Self::Long {
432 ty: LongType::ZeroRtt,
433 ..
434 }
435 )
436 }
437
438 pub(crate) fn dst_cid(&self) -> ConnectionId {
439 use Header::*;
440 match *self {
441 Initial(InitialHeader { dst_cid, .. }) => dst_cid,
442 Long { dst_cid, .. } => dst_cid,
443 Retry { dst_cid, .. } => dst_cid,
444 Short { dst_cid, .. } => dst_cid,
445 VersionNegotiate { dst_cid, .. } => dst_cid,
446 }
447 }
448
449 pub(crate) fn has_frames(&self) -> bool {
451 use Header::*;
452 match *self {
453 Initial(_) => true,
454 Long { .. } => true,
455 Retry { .. } => false,
456 Short { .. } => true,
457 VersionNegotiate { .. } => false,
458 }
459 }
460}
461
462pub(crate) struct PartialEncode {
463 pub(crate) start: usize,
464 pub(crate) header_len: usize,
465 pn: Option<(usize, bool)>,
467}
468
469impl PartialEncode {
470 pub(crate) fn finish(
471 self,
472 buf: &mut [u8],
473 header_crypto: &dyn crypto::HeaderKey,
474 crypto: Option<(u64, &dyn crypto::PacketKey)>,
475 ) {
476 let Self { header_len, pn, .. } = self;
477 let (pn_len, write_len) = match pn {
478 Some((pn_len, write_len)) => (pn_len, write_len),
479 None => return,
480 };
481
482 let pn_pos = header_len - pn_len;
483 if write_len {
484 let len = buf.len() - header_len + pn_len;
485 assert!(len < 2usize.pow(14)); let mut slice = &mut buf[pn_pos - 2..pn_pos];
487 slice.put_u16(len as u16 | 0b01 << 14);
488 }
489
490 if let Some((number, crypto)) = crypto {
491 crypto.encrypt(number, buf, header_len);
492 }
493
494 debug_assert!(
495 pn_pos + 4 + header_crypto.sample_size() <= buf.len(),
496 "packet must be padded to at least {} bytes for header protection sampling",
497 pn_pos + 4 + header_crypto.sample_size()
498 );
499 header_crypto.encrypt(pn_pos, buf);
500 }
501}
502
503#[derive(Clone, Debug)]
505pub enum ProtectedHeader {
506 Initial(ProtectedInitialHeader),
508 Long {
510 ty: LongType,
512 dst_cid: ConnectionId,
514 src_cid: ConnectionId,
516 len: u64,
518 version: u32,
520 },
521 Retry {
523 dst_cid: ConnectionId,
525 src_cid: ConnectionId,
527 version: u32,
529 },
530 Short {
532 spin: bool,
534 dst_cid: ConnectionId,
536 },
537 VersionNegotiate {
539 random: u8,
541 dst_cid: ConnectionId,
543 src_cid: ConnectionId,
545 },
546}
547
548impl ProtectedHeader {
549 fn as_initial(&self) -> Option<&ProtectedInitialHeader> {
550 match self {
551 Self::Initial(x) => Some(x),
552 _ => None,
553 }
554 }
555
556 pub fn dst_cid(&self) -> &ConnectionId {
558 use ProtectedHeader::*;
559 match self {
560 Initial(header) => &header.dst_cid,
561 Long { dst_cid, .. } => dst_cid,
562 Retry { dst_cid, .. } => dst_cid,
563 Short { dst_cid, .. } => dst_cid,
564 VersionNegotiate { dst_cid, .. } => dst_cid,
565 }
566 }
567
568 fn payload_len(&self) -> Option<u64> {
569 use ProtectedHeader::*;
570 match self {
571 Initial(ProtectedInitialHeader { len, .. }) | Long { len, .. } => Some(*len),
572 _ => None,
573 }
574 }
575
576 pub fn decode(
578 buf: &mut io::Cursor<BytesMut>,
579 cid_parser: &(impl ConnectionIdParser + ?Sized),
580 supported_versions: &[u32],
581 grease_quic_bit: bool,
582 ) -> Result<Self, PacketDecodeError> {
583 let first = buf.get::<u8>()?;
584 if !grease_quic_bit && first & FIXED_BIT == 0 {
585 return Err(PacketDecodeError::InvalidHeader("fixed bit unset"));
586 }
587 if first & LONG_HEADER_FORM == 0 {
588 let spin = first & SPIN_BIT != 0;
589
590 Ok(Self::Short {
591 spin,
592 dst_cid: cid_parser.parse(buf)?,
593 })
594 } else {
595 let version = buf.get::<u32>()?;
596
597 let dst_cid = ConnectionId::decode_long(buf)
598 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
599 let src_cid = ConnectionId::decode_long(buf)
600 .ok_or(PacketDecodeError::InvalidHeader("malformed cid"))?;
601
602 if version == 0 {
604 let random = first & !LONG_HEADER_FORM;
605 return Ok(Self::VersionNegotiate {
606 random,
607 dst_cid,
608 src_cid,
609 });
610 }
611
612 if !supported_versions.contains(&version) {
613 return Err(PacketDecodeError::UnsupportedVersion {
614 src_cid,
615 dst_cid,
616 version,
617 });
618 }
619
620 match LongHeaderType::from_byte(first)? {
621 LongHeaderType::Initial => {
622 let token_len = buf.get_var()? as usize;
623 let token_start = buf.position() as usize;
624 if token_len > buf.remaining() {
625 return Err(PacketDecodeError::InvalidHeader("token out of bounds"));
626 }
627 buf.advance(token_len);
628
629 let len = buf.get_var()?;
630 Ok(Self::Initial(ProtectedInitialHeader {
631 dst_cid,
632 src_cid,
633 token_pos: token_start..token_start + token_len,
634 len,
635 version,
636 }))
637 }
638 LongHeaderType::Retry => Ok(Self::Retry {
639 dst_cid,
640 src_cid,
641 version,
642 }),
643 LongHeaderType::Standard(ty) => Ok(Self::Long {
644 ty,
645 dst_cid,
646 src_cid,
647 len: buf.get_var()?,
648 version,
649 }),
650 }
651 }
652 }
653}
654
655#[derive(Clone, Debug)]
657pub struct ProtectedInitialHeader {
658 pub dst_cid: ConnectionId,
660 pub src_cid: ConnectionId,
662 pub token_pos: Range<usize>,
664 pub len: u64,
666 pub version: u32,
668}
669
670#[derive(Clone, Debug)]
671pub(crate) struct InitialHeader {
672 pub(crate) dst_cid: ConnectionId,
673 pub(crate) src_cid: ConnectionId,
674 pub(crate) token: Bytes,
675 pub(crate) number: PacketNumber,
676 pub(crate) version: u32,
677}
678
679#[derive(Debug, Copy, Clone, Eq, PartialEq)]
681pub(crate) enum PacketNumber {
682 U8(u8),
683 U16(u16),
684 U24(u32),
685 U32(u32),
686}
687
688impl PacketNumber {
689 pub(crate) fn new(n: u64, largest_acked: u64) -> Self {
690 let range = (n - largest_acked) * 2;
691 if range < 1 << 8 {
692 Self::U8(n as u8)
693 } else if range < 1 << 16 {
694 Self::U16(n as u16)
695 } else if range < 1 << 24 {
696 Self::U24(n as u32)
697 } else if range < 1 << 32 {
698 Self::U32(n as u32)
699 } else {
700 panic!("packet number too large to encode")
701 }
702 }
703
704 pub(crate) fn len(self) -> usize {
705 use PacketNumber::*;
706 match self {
707 U8(_) => 1,
708 U16(_) => 2,
709 U24(_) => 3,
710 U32(_) => 4,
711 }
712 }
713
714 pub(crate) fn encode<W: BufMut>(self, w: &mut W) {
715 use PacketNumber::*;
716 match self {
717 U8(x) => w.write(x),
718 U16(x) => w.write(x),
719 U24(x) => w.put_uint(u64::from(x), 3),
720 U32(x) => w.write(x),
721 }
722 }
723
724 pub(crate) fn decode<R: Buf>(len: usize, r: &mut R) -> Result<Self, PacketDecodeError> {
725 use PacketNumber::*;
726 let pn = match len {
727 1 => U8(r.get()?),
728 2 => U16(r.get()?),
729 3 => U24(r.get_uint(3) as u32),
730 4 => U32(r.get()?),
731 _ => unreachable!(),
732 };
733 Ok(pn)
734 }
735
736 pub(crate) fn decode_len(tag: u8) -> usize {
737 1 + (tag & 0x03) as usize
738 }
739
740 fn tag(self) -> u8 {
741 use PacketNumber::*;
742 match self {
743 U8(_) => 0b00,
744 U16(_) => 0b01,
745 U24(_) => 0b10,
746 U32(_) => 0b11,
747 }
748 }
749
750 pub(crate) fn expand(self, expected: u64) -> u64 {
751 use PacketNumber::*;
753 let truncated = match self {
754 U8(x) => u64::from(x),
755 U16(x) => u64::from(x),
756 U24(x) => u64::from(x),
757 U32(x) => u64::from(x),
758 };
759 let nbits = self.len() * 8;
760 let win = 1 << nbits;
761 let hwin = win / 2;
762 let mask = win - 1;
763 let candidate = (expected & !mask) | truncated;
772 if expected.checked_sub(hwin).is_some_and(|x| candidate <= x) {
773 candidate + win
774 } else if candidate > expected + hwin && candidate > win {
775 candidate - win
776 } else {
777 candidate
778 }
779 }
780}
781
782pub struct FixedLengthConnectionIdParser {
784 expected_len: usize,
785}
786
787impl FixedLengthConnectionIdParser {
788 pub fn new(expected_len: usize) -> Self {
790 Self { expected_len }
791 }
792}
793
794impl ConnectionIdParser for FixedLengthConnectionIdParser {
795 fn parse(&self, buffer: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError> {
796 (buffer.remaining() >= self.expected_len)
797 .then(|| ConnectionId::from_buf(buffer, self.expected_len))
798 .ok_or(PacketDecodeError::InvalidHeader("packet too small"))
799 }
800}
801
802pub trait ConnectionIdParser {
804 fn parse(&self, buf: &mut dyn Buf) -> Result<ConnectionId, PacketDecodeError>;
806}
807
808#[derive(Clone, Copy, Debug, Eq, PartialEq)]
810pub(crate) enum LongHeaderType {
811 Initial,
812 Retry,
813 Standard(LongType),
814}
815
816impl LongHeaderType {
817 fn from_byte(b: u8) -> Result<Self, PacketDecodeError> {
818 use {LongHeaderType::*, LongType::*};
819 debug_assert!(b & LONG_HEADER_FORM != 0, "not a long packet");
820 Ok(match (b & 0x30) >> 4 {
821 0x0 => Initial,
822 0x1 => Standard(ZeroRtt),
823 0x2 => Standard(Handshake),
824 0x3 => Retry,
825 _ => unreachable!(),
826 })
827 }
828}
829
830impl From<LongHeaderType> for u8 {
831 fn from(ty: LongHeaderType) -> Self {
832 use {LongHeaderType::*, LongType::*};
833 match ty {
834 Initial => LONG_HEADER_FORM | FIXED_BIT,
835 Standard(ZeroRtt) => LONG_HEADER_FORM | FIXED_BIT | (0x1 << 4),
836 Standard(Handshake) => LONG_HEADER_FORM | FIXED_BIT | (0x2 << 4),
837 Retry => LONG_HEADER_FORM | FIXED_BIT | (0x3 << 4),
838 }
839 }
840}
841
842#[derive(Clone, Copy, Debug, Eq, PartialEq)]
844pub enum LongType {
845 Handshake,
847 ZeroRtt,
849}
850
851#[derive(Debug, Error, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
853pub enum PacketDecodeError {
854 #[error("unsupported version {version:x}")]
856 UnsupportedVersion {
857 src_cid: ConnectionId,
859 dst_cid: ConnectionId,
861 version: u32,
863 },
864 #[error("invalid header: {0}")]
866 InvalidHeader(&'static str),
867}
868
869impl From<coding::UnexpectedEnd> for PacketDecodeError {
870 fn from(_: coding::UnexpectedEnd) -> Self {
871 Self::InvalidHeader("unexpected end of packet")
872 }
873}
874
875pub(crate) const LONG_HEADER_FORM: u8 = 0x80;
876pub(crate) const FIXED_BIT: u8 = 0x40;
877pub(crate) const SPIN_BIT: u8 = 0x20;
878const SHORT_RESERVED_BITS: u8 = 0x18;
879const LONG_RESERVED_BITS: u8 = 0x0c;
880const KEY_PHASE_BIT: u8 = 0x04;
881
882#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd)]
884pub enum SpaceId {
885 Initial = 0,
887 Handshake = 1,
888 Data = 2,
890}
891
892impl SpaceId {
893 pub fn iter() -> impl Iterator<Item = Self> {
894 [Self::Initial, Self::Handshake, Self::Data].iter().cloned()
895 }
896}
897
898#[cfg(test)]
899mod tests {
900 use super::*;
901 use hex_literal::hex;
902 use std::io;
903
904 fn check_pn(typed: PacketNumber, encoded: &[u8]) {
905 let mut buf = Vec::new();
906 typed.encode(&mut buf);
907 assert_eq!(&buf[..], encoded);
908 let decoded = PacketNumber::decode(typed.len(), &mut io::Cursor::new(&buf)).unwrap();
909 assert_eq!(typed, decoded);
910 }
911
912 #[test]
913 fn roundtrip_packet_numbers() {
914 check_pn(PacketNumber::U8(0x7f), &hex!("7f"));
915 check_pn(PacketNumber::U16(0x80), &hex!("0080"));
916 check_pn(PacketNumber::U16(0x3fff), &hex!("3fff"));
917 check_pn(PacketNumber::U32(0x0000_4000), &hex!("0000 4000"));
918 check_pn(PacketNumber::U32(0xffff_ffff), &hex!("ffff ffff"));
919 }
920
921 #[test]
922 fn pn_encode() {
923 check_pn(PacketNumber::new(0x10, 0), &hex!("10"));
924 check_pn(PacketNumber::new(0x100, 0), &hex!("0100"));
925 check_pn(PacketNumber::new(0x10000, 0), &hex!("010000"));
926 }
927
928 #[test]
929 fn pn_expand_roundtrip() {
930 for expected in 0..1024 {
931 for actual in expected..1024 {
932 assert_eq!(actual, PacketNumber::new(actual, expected).expand(expected));
933 }
934 }
935 }
936
937 #[cfg(any(feature = "rustls-aws-lc-rs", feature = "rustls-ring"))]
938 #[test]
939 fn header_encoding() {
940 use crate::crypto::rustls::{initial_keys, initial_suite_from_provider};
941 use crate::Side;
942 #[cfg(all(feature = "rustls-aws-lc-rs", not(feature = "rustls-ring")))]
943 use rustls::crypto::aws_lc_rs::default_provider;
944 #[cfg(feature = "rustls-ring")]
945 use rustls::crypto::ring::default_provider;
946 use rustls::quic::Version;
947
948 let dcid = ConnectionId::new(&hex!("06b858ec6f80452b"));
949 let provider = default_provider();
950
951 let suite = initial_suite_from_provider(&std::sync::Arc::new(provider)).unwrap();
952 let client = initial_keys(Version::V1, dcid, Side::Client, &suite);
953 let mut buf = Vec::new();
954 let header = Header::Initial(InitialHeader {
955 number: PacketNumber::U8(0),
956 src_cid: ConnectionId::new(&[]),
957 dst_cid: dcid,
958 token: Bytes::new(),
959 version: crate::DEFAULT_SUPPORTED_VERSIONS[0],
960 });
961 let encode = header.encode(&mut buf);
962 let header_len = buf.len();
963 buf.resize(header_len + 16 + client.packet.local.tag_len(), 0);
964 encode.finish(
965 &mut buf,
966 &*client.header.local,
967 Some((0, &*client.packet.local)),
968 );
969
970 for byte in &buf {
971 print!("{byte:02x}");
972 }
973 println!();
974 assert_eq!(
975 buf[..],
976 hex!(
977 "c8000000010806b858ec6f80452b00004021be
978 3ef50807b84191a196f760a6dad1e9d1c430c48952cba0148250c21c0a6a70e1"
979 )[..]
980 );
981
982 let server = initial_keys(Version::V1, dcid, Side::Server, &suite);
983 let supported_versions = crate::DEFAULT_SUPPORTED_VERSIONS.to_vec();
984 let decode = PartialDecode::new(
985 buf.as_slice().into(),
986 &FixedLengthConnectionIdParser::new(0),
987 &supported_versions,
988 false,
989 )
990 .unwrap()
991 .0;
992 let mut packet = decode.finish(Some(&*server.header.remote)).unwrap();
993 assert_eq!(
994 packet.header_data[..],
995 hex!("c0000000010806b858ec6f80452b0000402100")[..]
996 );
997 server
998 .packet
999 .remote
1000 .decrypt(0, &packet.header_data, &mut packet.payload)
1001 .unwrap();
1002 assert_eq!(packet.payload[..], [0; 16]);
1003 match packet.header {
1004 Header::Initial(InitialHeader {
1005 number: PacketNumber::U8(0),
1006 ..
1007 }) => {}
1008 _ => {
1009 panic!("unexpected header {:?}", packet.header);
1010 }
1011 }
1012 }
1013}