1use std::convert::{TryFrom, TryInto};
6use std::fmt;
7
8use thiserror::Error;
9
10use super::{Decode, DecodeError, Encode, EncodeError};
11
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
14#[error("value out of range")]
15pub struct BoundsExceeded;
16
17#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct VarInt(u64);
25
26impl VarInt {
27 pub const MAX: Self = Self((1 << 62) - 1);
29
30 pub const ZERO: Self = Self(0);
32
33 pub const fn from_u32(x: u32) -> Self {
36 Self(x as u64)
37 }
38
39 pub const fn from_u64(x: u64) -> Option<Self> {
40 if x <= Self::MAX.0 { Some(Self(x)) } else { None }
41 }
42
43 pub const fn from_u128(x: u128) -> Option<Self> {
44 if x <= Self::MAX.0 as u128 {
45 Some(Self(x as u64))
46 } else {
47 None
48 }
49 }
50
51 pub const fn into_inner(self) -> u64 {
53 self.0
54 }
55}
56
57impl From<VarInt> for u64 {
58 fn from(x: VarInt) -> Self {
59 x.0
60 }
61}
62
63impl From<VarInt> for usize {
64 fn from(x: VarInt) -> Self {
65 x.0 as usize
66 }
67}
68
69impl From<VarInt> for u128 {
70 fn from(x: VarInt) -> Self {
71 x.0 as u128
72 }
73}
74
75impl From<u8> for VarInt {
76 fn from(x: u8) -> Self {
77 Self(x.into())
78 }
79}
80
81impl From<u16> for VarInt {
82 fn from(x: u16) -> Self {
83 Self(x.into())
84 }
85}
86
87impl From<u32> for VarInt {
88 fn from(x: u32) -> Self {
89 Self(x.into())
90 }
91}
92
93impl TryFrom<u64> for VarInt {
94 type Error = BoundsExceeded;
95
96 fn try_from(x: u64) -> Result<Self, BoundsExceeded> {
98 let x = Self(x);
99 if x <= Self::MAX { Ok(x) } else { Err(BoundsExceeded) }
100 }
101}
102
103impl TryFrom<u128> for VarInt {
104 type Error = BoundsExceeded;
105
106 fn try_from(x: u128) -> Result<Self, BoundsExceeded> {
108 if x <= Self::MAX.into() {
109 Ok(Self(x as u64))
110 } else {
111 Err(BoundsExceeded)
112 }
113 }
114}
115
116impl TryFrom<usize> for VarInt {
117 type Error = BoundsExceeded;
118
119 fn try_from(x: usize) -> Result<Self, BoundsExceeded> {
121 Self::try_from(x as u64)
122 }
123}
124
125impl TryFrom<VarInt> for u32 {
126 type Error = BoundsExceeded;
127
128 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
130 if x.0 <= u32::MAX.into() {
131 Ok(x.0 as u32)
132 } else {
133 Err(BoundsExceeded)
134 }
135 }
136}
137
138impl TryFrom<VarInt> for u16 {
139 type Error = BoundsExceeded;
140
141 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
143 if x.0 <= u16::MAX.into() {
144 Ok(x.0 as u16)
145 } else {
146 Err(BoundsExceeded)
147 }
148 }
149}
150
151impl TryFrom<VarInt> for u8 {
152 type Error = BoundsExceeded;
153
154 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
156 if x.0 <= u8::MAX.into() {
157 Ok(x.0 as u8)
158 } else {
159 Err(BoundsExceeded)
160 }
161 }
162}
163
164impl fmt::Display for VarInt {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 self.0.fmt(f)
167 }
168}
169
170impl VarInt {
171 fn decode_quic<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
173 if !r.has_remaining() {
174 return Err(DecodeError::Short);
175 }
176
177 let b = r.get_u8();
178 let tag = b >> 6;
179
180 let mut buf = [0u8; 8];
181 buf[0] = b & 0b0011_1111;
182
183 let x = match tag {
184 0b00 => u64::from(buf[0]),
185 0b01 => {
186 if !r.has_remaining() {
187 return Err(DecodeError::Short);
188 }
189 r.copy_to_slice(buf[1..2].as_mut());
190 u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
191 }
192 0b10 => {
193 if r.remaining() < 3 {
194 return Err(DecodeError::Short);
195 }
196 r.copy_to_slice(buf[1..4].as_mut());
197 u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
198 }
199 0b11 => {
200 if r.remaining() < 7 {
201 return Err(DecodeError::Short);
202 }
203 r.copy_to_slice(buf[1..8].as_mut());
204 u64::from_be_bytes(buf)
205 }
206 _ => unreachable!(),
207 };
208
209 Ok(Self(x))
210 }
211
212 fn encode_quic<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
214 let remaining = w.remaining_mut();
215 if self.0 < (1u64 << 6) {
216 if remaining < 1 {
217 return Err(EncodeError::Short);
218 }
219 w.put_u8(self.0 as u8);
220 } else if self.0 < (1u64 << 14) {
221 if remaining < 2 {
222 return Err(EncodeError::Short);
223 }
224 w.put_u16((0b01 << 14) | self.0 as u16);
225 } else if self.0 < (1u64 << 30) {
226 if remaining < 4 {
227 return Err(EncodeError::Short);
228 }
229 w.put_u32((0b10 << 30) | self.0 as u32);
230 } else if self.0 < (1u64 << 62) {
231 if remaining < 8 {
232 return Err(EncodeError::Short);
233 }
234 w.put_u64((0b11 << 62) | self.0);
235 } else {
236 return Err(BoundsExceeded.into());
237 }
238 Ok(())
239 }
240
241 fn decode_leading_ones<R: bytes::Buf>(r: &mut R, version: ietf::Version) -> Result<Self, DecodeError> {
254 if !r.has_remaining() {
255 return Err(DecodeError::Short);
256 }
257
258 let b = r.get_u8();
259 let ones = b.leading_ones() as usize;
260
261 match ones {
262 0 => {
263 Ok(Self(u64::from(b)))
265 }
266 1 => {
267 if !r.has_remaining() {
269 return Err(DecodeError::Short);
270 }
271 let hi = u64::from(b & 0x3F);
272 let lo = u64::from(r.get_u8());
273 Ok(Self((hi << 8) | lo))
274 }
275 2 => {
276 if r.remaining() < 2 {
278 return Err(DecodeError::Short);
279 }
280 let hi = u64::from(b & 0x1F);
281 let mut buf = [0u8; 2];
282 r.copy_to_slice(&mut buf);
283 Ok(Self((hi << 16) | u64::from(u16::from_be_bytes(buf))))
284 }
285 3 => {
286 if r.remaining() < 3 {
288 return Err(DecodeError::Short);
289 }
290 let hi = u64::from(b & 0x0F);
291 let mut buf = [0u8; 3];
292 r.copy_to_slice(&mut buf);
293 Ok(Self(
294 (hi << 24) | u64::from(buf[0]) << 16 | u64::from(buf[1]) << 8 | u64::from(buf[2]),
295 ))
296 }
297 4 => {
298 if r.remaining() < 4 {
300 return Err(DecodeError::Short);
301 }
302 let hi = u64::from(b & 0x07);
303 let mut buf = [0u8; 4];
304 r.copy_to_slice(&mut buf);
305 Ok(Self((hi << 32) | u64::from(u32::from_be_bytes(buf))))
306 }
307 5 => {
308 if r.remaining() < 5 {
310 return Err(DecodeError::Short);
311 }
312 let hi = u64::from(b & 0x03);
313 let mut buf = [0u8; 5];
314 r.copy_to_slice(&mut buf);
315 let lo = u64::from(buf[0]) << 32
316 | u64::from(buf[1]) << 24
317 | u64::from(buf[2]) << 16
318 | u64::from(buf[3]) << 8
319 | u64::from(buf[4]);
320 Ok(Self((hi << 40) | lo))
321 }
322 6 => {
323 if matches!(version, ietf::Version::Draft17) {
325 return Err(DecodeError::InvalidValue);
326 }
327 if r.remaining() < 6 {
328 return Err(DecodeError::Short);
329 }
330 let hi = u64::from(b & 0x01);
331 let mut buf = [0u8; 8];
332 r.copy_to_slice(&mut buf[2..]);
333 Ok(Self((hi << 48) | u64::from_be_bytes(buf)))
334 }
335 7 => {
336 if r.remaining() < 7 {
338 return Err(DecodeError::Short);
339 }
340 let mut buf = [0u8; 8];
341 buf[0] = 0;
342 r.copy_to_slice(&mut buf[1..]);
343 Ok(Self(u64::from_be_bytes(buf)))
344 }
345 8 => {
346 if r.remaining() < 8 {
348 return Err(DecodeError::Short);
349 }
350 let mut buf = [0u8; 8];
351 r.copy_to_slice(&mut buf);
352 Ok(Self(u64::from_be_bytes(buf)))
353 }
354 _ => unreachable!(),
355 }
356 }
357
358 fn encode_leading_ones<W: bytes::BufMut>(&self, w: &mut W, _version: ietf::Version) -> Result<(), EncodeError> {
364 let x = self.0;
365 let remaining = w.remaining_mut();
366
367 if x < (1 << 7) {
368 if remaining < 1 {
370 return Err(EncodeError::Short);
371 }
372 w.put_u8(x as u8);
373 } else if x < (1 << 14) {
374 if remaining < 2 {
376 return Err(EncodeError::Short);
377 }
378 w.put_u8(0x80 | (x >> 8) as u8);
379 w.put_u8(x as u8);
380 } else if x < (1 << 21) {
381 if remaining < 3 {
383 return Err(EncodeError::Short);
384 }
385 w.put_u8(0xC0 | (x >> 16) as u8);
386 w.put_u16(x as u16);
387 } else if x < (1 << 28) {
388 if remaining < 4 {
390 return Err(EncodeError::Short);
391 }
392 w.put_u8(0xE0 | (x >> 24) as u8);
393 w.put_u8((x >> 16) as u8);
394 w.put_u16(x as u16);
395 } else if x < (1 << 35) {
396 if remaining < 5 {
398 return Err(EncodeError::Short);
399 }
400 w.put_u8(0xF0 | (x >> 32) as u8);
401 w.put_u32(x as u32);
402 } else if x < (1 << 42) {
403 if remaining < 6 {
405 return Err(EncodeError::Short);
406 }
407 w.put_u8(0xF8 | (x >> 40) as u8);
408 w.put_u8((x >> 32) as u8);
409 w.put_u32(x as u32);
410 } else if x < (1 << 56) {
411 if remaining < 8 {
413 return Err(EncodeError::Short);
414 }
415 w.put_u8(0xFE);
416 w.put_u8((x >> 48) as u8);
418 w.put_u16((x >> 32) as u16);
419 w.put_u32(x as u32);
420 } else {
421 if remaining < 9 {
423 return Err(EncodeError::Short);
424 }
425 w.put_u8(0xFF);
426 w.put_u64(x);
427 }
428
429 Ok(())
430 }
431}
432
433use crate::{Version, ietf, lite};
434
435impl Encode<lite::Version> for VarInt {
437 fn encode<W: bytes::BufMut>(&self, w: &mut W, _: lite::Version) -> Result<(), EncodeError> {
438 self.encode_quic(w)
439 }
440}
441
442impl Decode<lite::Version> for VarInt {
443 fn decode<R: bytes::Buf>(r: &mut R, _: lite::Version) -> Result<Self, DecodeError> {
444 Self::decode_quic(r)
445 }
446}
447
448impl Encode<ietf::Version> for VarInt {
450 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: ietf::Version) -> Result<(), EncodeError> {
451 match version {
452 ietf::Version::Draft14 | ietf::Version::Draft15 | ietf::Version::Draft16 => self.encode_quic(w),
453 _ => self.encode_leading_ones(w, version),
454 }
455 }
456}
457
458impl Decode<ietf::Version> for VarInt {
459 fn decode<R: bytes::Buf>(r: &mut R, version: ietf::Version) -> Result<Self, DecodeError> {
460 match version {
461 ietf::Version::Draft14 | ietf::Version::Draft15 | ietf::Version::Draft16 => Self::decode_quic(r),
462 _ => Self::decode_leading_ones(r, version),
463 }
464 }
465}
466
467impl Encode<Version> for VarInt {
469 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
470 match version {
471 Version::Lite(v) => self.encode(w, v),
472 Version::Ietf(v) => self.encode(w, v),
473 }
474 }
475}
476
477impl Decode<Version> for VarInt {
478 fn decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
479 match version {
480 Version::Lite(v) => Self::decode(r, v),
481 Version::Ietf(v) => Self::decode(r, v),
482 }
483 }
484}
485
486impl<V: Copy> Encode<V> for u64
488where
489 VarInt: Encode<V>,
490{
491 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
492 VarInt::try_from(*self)?.encode(w, version)
493 }
494}
495
496impl<V: Copy> Decode<V> for u64
497where
498 VarInt: Decode<V>,
499{
500 fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
501 VarInt::decode(r, version).map(|v| v.into_inner())
502 }
503}
504
505impl<V: Copy> Encode<V> for usize
506where
507 VarInt: Encode<V>,
508{
509 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
510 VarInt::try_from(*self)?.encode(w, version)
511 }
512}
513
514impl<V: Copy> Decode<V> for usize
515where
516 VarInt: Decode<V>,
517{
518 fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
519 VarInt::decode(r, version).map(|v| v.into_inner() as usize)
520 }
521}
522
523impl<V: Copy> Encode<V> for u32
524where
525 VarInt: Encode<V>,
526{
527 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
528 VarInt::from(*self).encode(w, version)
529 }
530}
531
532impl<V: Copy> Decode<V> for u32
533where
534 VarInt: Decode<V>,
535{
536 fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
537 let v = VarInt::decode(r, version)?;
538 let v = v.try_into().map_err(|_| DecodeError::BoundsExceeded)?;
539 Ok(v)
540 }
541}
542
543#[cfg(test)]
544mod tests {
545 use super::{DecodeError, VarInt};
546 use crate::ietf;
547 use bytes::Bytes;
548
549 #[test]
552 fn leading_ones_spec_examples() {
553 let cases: &[(&[u8], u64)] = &[
554 (&[0x25], 37),
555 (&[0x80, 0x25], 37),
556 (&[0xbb, 0xbd], 15_293),
557 (&[0xfa, 0xa1, 0xa0, 0xe4, 0x03, 0xd8], 2_893_212_287_960),
560 (
561 &[0xfe, 0xfa, 0x31, 0x8f, 0xa8, 0xe3, 0xca, 0x11],
562 70_423_237_261_249_041,
563 ),
564 (
565 &[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
566 18_446_744_073_709_551_615,
567 ),
568 ];
569
570 for (bytes, expected) in cases {
571 let mut buf = Bytes::from(bytes.to_vec());
573 let decoded = VarInt::decode_leading_ones(&mut buf, ietf::Version::Draft17).expect("decode should succeed");
574 assert_eq!(
575 decoded.into_inner(),
576 *expected,
577 "decode mismatch for bytes {bytes:02x?}"
578 );
579 assert_eq!(buf.len(), 0, "all bytes should be consumed for {bytes:02x?}");
580
581 if let Some(varint) = VarInt::from_u64(*expected)
585 && (bytes.len() == 1 || *expected != 37)
586 {
587 let mut encoded = Vec::new();
588 varint
589 .encode_leading_ones(&mut encoded, ietf::Version::Draft17)
590 .expect("encode should succeed");
591 assert_eq!(&encoded, bytes, "encode mismatch for value {expected}");
592 }
593 }
594 }
595
596 #[test]
598 fn leading_ones_invalid_0xfc() {
599 let mut buf = Bytes::from_static(&[0xFC]);
600 assert!(
601 matches!(
602 VarInt::decode_leading_ones(&mut buf, ietf::Version::Draft17),
603 Err(DecodeError::InvalidValue)
604 ),
605 "0xFC should be rejected as invalid on draft-17"
606 );
607 }
608
609 #[test]
610 fn leading_ones_boundaries_round_trip() {
611 let cases = [
612 ((1u64 << 7) - 1, 1usize),
613 (1u64 << 7, 2usize),
614 ((1u64 << 14) - 1, 2usize),
615 (1u64 << 14, 3usize),
616 ((1u64 << 56) - 1, 8usize),
617 (1u64 << 56, 9usize),
618 ];
619
620 for (value, expected_len) in cases {
621 let varint = VarInt::from_u64(value).expect("value should be representable as VarInt");
622 let mut encoded = Vec::new();
623 varint
624 .encode_leading_ones(&mut encoded, ietf::Version::Draft17)
625 .expect("leading-ones encode should succeed");
626 assert_eq!(
627 encoded.len(),
628 expected_len,
629 "unexpected encoded length for value {value}"
630 );
631
632 let mut bytes = Bytes::from(encoded);
633 let decoded = VarInt::decode_leading_ones(&mut bytes, ietf::Version::Draft17)
634 .expect("leading-ones decode should succeed");
635 assert_eq!(decoded.into_inner(), value, "round-trip mismatch for value {value}");
636 }
637 }
638
639 #[test]
640 fn draft17_rejects_7_byte_varint() {
641 let bytes = Bytes::from(vec![0xFC, 0, 0, 0, 0, 0, 0]);
643 let mut buf = bytes.clone();
644 let err = VarInt::decode_leading_ones(&mut buf, ietf::Version::Draft17).unwrap_err();
645 assert!(matches!(err, DecodeError::InvalidValue));
646 }
647
648 #[test]
649 fn draft18_accepts_7_byte_varint() {
650 let value: u64 = 0x1234_5678_9ABC;
652 let mut bytes = Vec::new();
653 let hi_bit = ((value >> 48) & 0x01) as u8;
656 bytes.push(0xFC | hi_bit);
657 for shift in (0..48).step_by(8).rev() {
658 bytes.push(((value >> shift) & 0xFF) as u8);
659 }
660 let mut buf = Bytes::from(bytes);
661 let decoded = VarInt::decode_leading_ones(&mut buf, ietf::Version::Draft18).unwrap();
662 assert_eq!(decoded.into_inner(), value);
663 }
664}