1use std::convert::{TryFrom, TryInto};
6use std::fmt;
7
8use thiserror::Error;
9
10use super::{Decode, DecodeError, Encode, EncodeError};
11
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
14#[error("value out of range")]
15pub struct BoundsExceeded;
16
17#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct VarInt(u64);
25
26impl VarInt {
27 pub const MAX: Self = Self((1 << 62) - 1);
29
30 pub const ZERO: Self = Self(0);
32
33 pub const fn from_u32(x: u32) -> Self {
36 Self(x as u64)
37 }
38
39 pub const fn from_u64(x: u64) -> Option<Self> {
40 if x <= Self::MAX.0 { Some(Self(x)) } else { None }
41 }
42
43 pub const fn from_u128(x: u128) -> Option<Self> {
44 if x <= Self::MAX.0 as u128 {
45 Some(Self(x as u64))
46 } else {
47 None
48 }
49 }
50
51 pub const fn into_inner(self) -> u64 {
53 self.0
54 }
55}
56
57impl From<VarInt> for u64 {
58 fn from(x: VarInt) -> Self {
59 x.0
60 }
61}
62
63impl From<VarInt> for usize {
64 fn from(x: VarInt) -> Self {
65 x.0 as usize
66 }
67}
68
69impl From<VarInt> for u128 {
70 fn from(x: VarInt) -> Self {
71 x.0 as u128
72 }
73}
74
75impl From<u8> for VarInt {
76 fn from(x: u8) -> Self {
77 Self(x.into())
78 }
79}
80
81impl From<u16> for VarInt {
82 fn from(x: u16) -> Self {
83 Self(x.into())
84 }
85}
86
87impl From<u32> for VarInt {
88 fn from(x: u32) -> Self {
89 Self(x.into())
90 }
91}
92
93impl TryFrom<u64> for VarInt {
94 type Error = BoundsExceeded;
95
96 fn try_from(x: u64) -> Result<Self, BoundsExceeded> {
98 let x = Self(x);
99 if x <= Self::MAX { Ok(x) } else { Err(BoundsExceeded) }
100 }
101}
102
103impl TryFrom<u128> for VarInt {
104 type Error = BoundsExceeded;
105
106 fn try_from(x: u128) -> Result<Self, BoundsExceeded> {
108 if x <= Self::MAX.into() {
109 Ok(Self(x as u64))
110 } else {
111 Err(BoundsExceeded)
112 }
113 }
114}
115
116impl TryFrom<usize> for VarInt {
117 type Error = BoundsExceeded;
118
119 fn try_from(x: usize) -> Result<Self, BoundsExceeded> {
121 Self::try_from(x as u64)
122 }
123}
124
125impl TryFrom<VarInt> for u32 {
126 type Error = BoundsExceeded;
127
128 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
130 if x.0 <= u32::MAX.into() {
131 Ok(x.0 as u32)
132 } else {
133 Err(BoundsExceeded)
134 }
135 }
136}
137
138impl TryFrom<VarInt> for u16 {
139 type Error = BoundsExceeded;
140
141 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
143 if x.0 <= u16::MAX.into() {
144 Ok(x.0 as u16)
145 } else {
146 Err(BoundsExceeded)
147 }
148 }
149}
150
151impl TryFrom<VarInt> for u8 {
152 type Error = BoundsExceeded;
153
154 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
156 if x.0 <= u8::MAX.into() {
157 Ok(x.0 as u8)
158 } else {
159 Err(BoundsExceeded)
160 }
161 }
162}
163
164impl fmt::Display for VarInt {
165 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166 self.0.fmt(f)
167 }
168}
169
170impl VarInt {
171 fn decode_quic<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
173 if !r.has_remaining() {
174 return Err(DecodeError::Short);
175 }
176
177 let b = r.get_u8();
178 let tag = b >> 6;
179
180 let mut buf = [0u8; 8];
181 buf[0] = b & 0b0011_1111;
182
183 let x = match tag {
184 0b00 => u64::from(buf[0]),
185 0b01 => {
186 if !r.has_remaining() {
187 return Err(DecodeError::Short);
188 }
189 r.copy_to_slice(buf[1..2].as_mut());
190 u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
191 }
192 0b10 => {
193 if r.remaining() < 3 {
194 return Err(DecodeError::Short);
195 }
196 r.copy_to_slice(buf[1..4].as_mut());
197 u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
198 }
199 0b11 => {
200 if r.remaining() < 7 {
201 return Err(DecodeError::Short);
202 }
203 r.copy_to_slice(buf[1..8].as_mut());
204 u64::from_be_bytes(buf)
205 }
206 _ => unreachable!(),
207 };
208
209 Ok(Self(x))
210 }
211
212 fn encode_quic<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
214 let remaining = w.remaining_mut();
215 if self.0 < (1u64 << 6) {
216 if remaining < 1 {
217 return Err(EncodeError::Short);
218 }
219 w.put_u8(self.0 as u8);
220 } else if self.0 < (1u64 << 14) {
221 if remaining < 2 {
222 return Err(EncodeError::Short);
223 }
224 w.put_u16((0b01 << 14) | self.0 as u16);
225 } else if self.0 < (1u64 << 30) {
226 if remaining < 4 {
227 return Err(EncodeError::Short);
228 }
229 w.put_u32((0b10 << 30) | self.0 as u32);
230 } else if self.0 < (1u64 << 62) {
231 if remaining < 8 {
232 return Err(EncodeError::Short);
233 }
234 w.put_u64((0b11 << 62) | self.0);
235 } else {
236 return Err(BoundsExceeded.into());
237 }
238 Ok(())
239 }
240
241 fn decode_leading_ones<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
253 if !r.has_remaining() {
254 return Err(DecodeError::Short);
255 }
256
257 let b = r.get_u8();
258 let ones = b.leading_ones() as usize;
259
260 match ones {
261 0 => {
262 Ok(Self(u64::from(b)))
264 }
265 1 => {
266 if !r.has_remaining() {
268 return Err(DecodeError::Short);
269 }
270 let hi = u64::from(b & 0x3F);
271 let lo = u64::from(r.get_u8());
272 Ok(Self((hi << 8) | lo))
273 }
274 2 => {
275 if r.remaining() < 2 {
277 return Err(DecodeError::Short);
278 }
279 let hi = u64::from(b & 0x1F);
280 let mut buf = [0u8; 2];
281 r.copy_to_slice(&mut buf);
282 Ok(Self((hi << 16) | u64::from(u16::from_be_bytes(buf))))
283 }
284 3 => {
285 if r.remaining() < 3 {
287 return Err(DecodeError::Short);
288 }
289 let hi = u64::from(b & 0x0F);
290 let mut buf = [0u8; 3];
291 r.copy_to_slice(&mut buf);
292 Ok(Self(
293 (hi << 24) | u64::from(buf[0]) << 16 | u64::from(buf[1]) << 8 | u64::from(buf[2]),
294 ))
295 }
296 4 => {
297 if r.remaining() < 4 {
299 return Err(DecodeError::Short);
300 }
301 let hi = u64::from(b & 0x07);
302 let mut buf = [0u8; 4];
303 r.copy_to_slice(&mut buf);
304 Ok(Self((hi << 32) | u64::from(u32::from_be_bytes(buf))))
305 }
306 5 => {
307 if r.remaining() < 5 {
309 return Err(DecodeError::Short);
310 }
311 let hi = u64::from(b & 0x03);
312 let mut buf = [0u8; 5];
313 r.copy_to_slice(&mut buf);
314 let lo = u64::from(buf[0]) << 32
315 | u64::from(buf[1]) << 24
316 | u64::from(buf[2]) << 16
317 | u64::from(buf[3]) << 8
318 | u64::from(buf[4]);
319 Ok(Self((hi << 40) | lo))
320 }
321 6 => {
322 Err(DecodeError::InvalidValue)?
324 }
325 7 => {
326 if r.remaining() < 7 {
328 return Err(DecodeError::Short);
329 }
330 let mut buf = [0u8; 8];
331 buf[0] = 0;
332 r.copy_to_slice(&mut buf[1..]);
333 Ok(Self(u64::from_be_bytes(buf)))
334 }
335 8 => {
336 if r.remaining() < 8 {
338 return Err(DecodeError::Short);
339 }
340 let mut buf = [0u8; 8];
341 r.copy_to_slice(&mut buf);
342 Ok(Self(u64::from_be_bytes(buf)))
343 }
344 _ => unreachable!(),
345 }
346 }
347
348 fn encode_leading_ones<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
350 let x = self.0;
351 let remaining = w.remaining_mut();
352
353 if x < (1 << 7) {
354 if remaining < 1 {
356 return Err(EncodeError::Short);
357 }
358 w.put_u8(x as u8);
359 } else if x < (1 << 14) {
360 if remaining < 2 {
362 return Err(EncodeError::Short);
363 }
364 w.put_u8(0x80 | (x >> 8) as u8);
365 w.put_u8(x as u8);
366 } else if x < (1 << 21) {
367 if remaining < 3 {
369 return Err(EncodeError::Short);
370 }
371 w.put_u8(0xC0 | (x >> 16) as u8);
372 w.put_u16(x as u16);
373 } else if x < (1 << 28) {
374 if remaining < 4 {
376 return Err(EncodeError::Short);
377 }
378 w.put_u8(0xE0 | (x >> 24) as u8);
379 w.put_u8((x >> 16) as u8);
380 w.put_u16(x as u16);
381 } else if x < (1 << 35) {
382 if remaining < 5 {
384 return Err(EncodeError::Short);
385 }
386 w.put_u8(0xF0 | (x >> 32) as u8);
387 w.put_u32(x as u32);
388 } else if x < (1 << 42) {
389 if remaining < 6 {
391 return Err(EncodeError::Short);
392 }
393 w.put_u8(0xF8 | (x >> 40) as u8);
394 w.put_u8((x >> 32) as u8);
395 w.put_u32(x as u32);
396 } else if x < (1 << 56) {
397 if remaining < 8 {
399 return Err(EncodeError::Short);
400 }
401 w.put_u8(0xFE);
402 w.put_u8((x >> 48) as u8);
404 w.put_u16((x >> 32) as u16);
405 w.put_u32(x as u32);
406 } else {
407 if remaining < 9 {
409 return Err(EncodeError::Short);
410 }
411 w.put_u8(0xFF);
412 w.put_u64(x);
413 }
414
415 Ok(())
416 }
417}
418
419use crate::{Version, ietf, lite};
420
421impl Encode<lite::Version> for VarInt {
423 fn encode<W: bytes::BufMut>(&self, w: &mut W, _: lite::Version) -> Result<(), EncodeError> {
424 self.encode_quic(w)
425 }
426}
427
428impl Decode<lite::Version> for VarInt {
429 fn decode<R: bytes::Buf>(r: &mut R, _: lite::Version) -> Result<Self, DecodeError> {
430 Self::decode_quic(r)
431 }
432}
433
434impl Encode<ietf::Version> for VarInt {
436 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: ietf::Version) -> Result<(), EncodeError> {
437 match version {
438 ietf::Version::Draft14 | ietf::Version::Draft15 | ietf::Version::Draft16 => self.encode_quic(w),
439 ietf::Version::Draft17 => self.encode_leading_ones(w),
440 }
441 }
442}
443
444impl Decode<ietf::Version> for VarInt {
445 fn decode<R: bytes::Buf>(r: &mut R, version: ietf::Version) -> Result<Self, DecodeError> {
446 match version {
447 ietf::Version::Draft14 | ietf::Version::Draft15 | ietf::Version::Draft16 => Self::decode_quic(r),
448 ietf::Version::Draft17 => Self::decode_leading_ones(r),
449 }
450 }
451}
452
453impl Encode<Version> for VarInt {
455 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
456 match version {
457 Version::Lite(v) => self.encode(w, v),
458 Version::Ietf(v) => self.encode(w, v),
459 }
460 }
461}
462
463impl Decode<Version> for VarInt {
464 fn decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
465 match version {
466 Version::Lite(v) => Self::decode(r, v),
467 Version::Ietf(v) => Self::decode(r, v),
468 }
469 }
470}
471
472impl<V: Copy> Encode<V> for u64
474where
475 VarInt: Encode<V>,
476{
477 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
478 VarInt::try_from(*self)?.encode(w, version)
479 }
480}
481
482impl<V: Copy> Decode<V> for u64
483where
484 VarInt: Decode<V>,
485{
486 fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
487 VarInt::decode(r, version).map(|v| v.into_inner())
488 }
489}
490
491impl<V: Copy> Encode<V> for usize
492where
493 VarInt: Encode<V>,
494{
495 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
496 VarInt::try_from(*self)?.encode(w, version)
497 }
498}
499
500impl<V: Copy> Decode<V> for usize
501where
502 VarInt: Decode<V>,
503{
504 fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
505 VarInt::decode(r, version).map(|v| v.into_inner() as usize)
506 }
507}
508
509impl<V: Copy> Encode<V> for u32
510where
511 VarInt: Encode<V>,
512{
513 fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
514 VarInt::from(*self).encode(w, version)
515 }
516}
517
518impl<V: Copy> Decode<V> for u32
519where
520 VarInt: Decode<V>,
521{
522 fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
523 let v = VarInt::decode(r, version)?;
524 let v = v.try_into().map_err(|_| DecodeError::BoundsExceeded)?;
525 Ok(v)
526 }
527}
528
529#[cfg(test)]
530mod tests {
531 use super::{DecodeError, VarInt};
532 use bytes::Bytes;
533
534 #[test]
537 fn leading_ones_spec_examples() {
538 let cases: &[(&[u8], u64)] = &[
539 (&[0x25], 37),
540 (&[0x80, 0x25], 37),
541 (&[0xbb, 0xbd], 15_293),
542 (&[0xfa, 0xa1, 0xa0, 0xe4, 0x03, 0xd8], 2_893_212_287_960),
545 (
546 &[0xfe, 0xfa, 0x31, 0x8f, 0xa8, 0xe3, 0xca, 0x11],
547 70_423_237_261_249_041,
548 ),
549 (
550 &[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
551 18_446_744_073_709_551_615,
552 ),
553 ];
554
555 for (bytes, expected) in cases {
556 let mut buf = Bytes::from(bytes.to_vec());
558 let decoded = VarInt::decode_leading_ones(&mut buf).expect("decode should succeed");
559 assert_eq!(
560 decoded.into_inner(),
561 *expected,
562 "decode mismatch for bytes {bytes:02x?}"
563 );
564 assert_eq!(buf.len(), 0, "all bytes should be consumed for {bytes:02x?}");
565
566 if let Some(varint) = VarInt::from_u64(*expected)
570 && (bytes.len() == 1 || *expected != 37)
571 {
572 let mut encoded = Vec::new();
573 varint.encode_leading_ones(&mut encoded).expect("encode should succeed");
574 assert_eq!(&encoded, bytes, "encode mismatch for value {expected}");
575 }
576 }
577 }
578
579 #[test]
581 fn leading_ones_invalid_0xfc() {
582 let mut buf = Bytes::from_static(&[0xFC]);
583 assert!(
584 matches!(VarInt::decode_leading_ones(&mut buf), Err(DecodeError::InvalidValue)),
585 "0xFC should be rejected as invalid"
586 );
587 }
588
589 #[test]
590 fn leading_ones_boundaries_round_trip() {
591 let cases = [
592 ((1u64 << 7) - 1, 1usize),
593 (1u64 << 7, 2usize),
594 ((1u64 << 14) - 1, 2usize),
595 (1u64 << 14, 3usize),
596 ((1u64 << 56) - 1, 8usize),
597 (1u64 << 56, 9usize),
598 ];
599
600 for (value, expected_len) in cases {
601 let varint = VarInt::from_u64(value).expect("value should be representable as VarInt");
602 let mut encoded = Vec::new();
603 varint
604 .encode_leading_ones(&mut encoded)
605 .expect("leading-ones encode should succeed");
606 assert_eq!(
607 encoded.len(),
608 expected_len,
609 "unexpected encoded length for value {value}"
610 );
611
612 let mut bytes = Bytes::from(encoded);
613 let decoded = VarInt::decode_leading_ones(&mut bytes).expect("leading-ones decode should succeed");
614 assert_eq!(decoded.into_inner(), value, "round-trip mismatch for value {value}");
615 }
616 }
617}