Skip to main content

moq_lite/coding/
varint.rs

1// Based on quinn-proto
2// https://github.com/quinn-rs/quinn/blob/main/quinn-proto/src/varint.rs
3// Licensed via Apache 2.0 and MIT
4
5use std::convert::{TryFrom, TryInto};
6use std::fmt;
7
8use thiserror::Error;
9
10use super::{Decode, DecodeError, Encode, EncodeError};
11
12/// The number is too large to fit in a VarInt (62 bits).
13#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
14#[error("value out of range")]
15pub struct BoundsExceeded;
16
17/// An integer less than 2^62
18///
19/// Values of this type are suitable for encoding as QUIC variable-length integer.
20/// It would be neat if we could express to Rust that the top two bits are available for use as enum
21/// discriminants
22#[derive(Debug, Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24pub struct VarInt(u64);
25
26impl VarInt {
27	/// The largest possible value.
28	pub const MAX: Self = Self((1 << 62) - 1);
29
30	/// The smallest possible value.
31	pub const ZERO: Self = Self(0);
32
33	/// Construct a `VarInt` infallibly using the largest available type.
34	/// Larger values need to use `try_from` instead.
35	pub const fn from_u32(x: u32) -> Self {
36		Self(x as u64)
37	}
38
39	pub const fn from_u64(x: u64) -> Option<Self> {
40		if x <= Self::MAX.0 { Some(Self(x)) } else { None }
41	}
42
43	pub const fn from_u128(x: u128) -> Option<Self> {
44		if x <= Self::MAX.0 as u128 {
45			Some(Self(x as u64))
46		} else {
47			None
48		}
49	}
50
51	/// Extract the integer value
52	pub const fn into_inner(self) -> u64 {
53		self.0
54	}
55}
56
57impl From<VarInt> for u64 {
58	fn from(x: VarInt) -> Self {
59		x.0
60	}
61}
62
63impl From<VarInt> for usize {
64	fn from(x: VarInt) -> Self {
65		x.0 as usize
66	}
67}
68
69impl From<VarInt> for u128 {
70	fn from(x: VarInt) -> Self {
71		x.0 as u128
72	}
73}
74
75impl From<u8> for VarInt {
76	fn from(x: u8) -> Self {
77		Self(x.into())
78	}
79}
80
81impl From<u16> for VarInt {
82	fn from(x: u16) -> Self {
83		Self(x.into())
84	}
85}
86
87impl From<u32> for VarInt {
88	fn from(x: u32) -> Self {
89		Self(x.into())
90	}
91}
92
93impl TryFrom<u64> for VarInt {
94	type Error = BoundsExceeded;
95
96	/// Succeeds iff `x` < 2^62
97	fn try_from(x: u64) -> Result<Self, BoundsExceeded> {
98		let x = Self(x);
99		if x <= Self::MAX { Ok(x) } else { Err(BoundsExceeded) }
100	}
101}
102
103impl TryFrom<u128> for VarInt {
104	type Error = BoundsExceeded;
105
106	/// Succeeds iff `x` < 2^62
107	fn try_from(x: u128) -> Result<Self, BoundsExceeded> {
108		if x <= Self::MAX.into() {
109			Ok(Self(x as u64))
110		} else {
111			Err(BoundsExceeded)
112		}
113	}
114}
115
116impl TryFrom<usize> for VarInt {
117	type Error = BoundsExceeded;
118
119	/// Succeeds iff `x` < 2^62
120	fn try_from(x: usize) -> Result<Self, BoundsExceeded> {
121		Self::try_from(x as u64)
122	}
123}
124
125impl TryFrom<VarInt> for u32 {
126	type Error = BoundsExceeded;
127
128	/// Succeeds iff `x` < 2^32
129	fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
130		if x.0 <= u32::MAX.into() {
131			Ok(x.0 as u32)
132		} else {
133			Err(BoundsExceeded)
134		}
135	}
136}
137
138impl TryFrom<VarInt> for u16 {
139	type Error = BoundsExceeded;
140
141	/// Succeeds iff `x` < 2^16
142	fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
143		if x.0 <= u16::MAX.into() {
144			Ok(x.0 as u16)
145		} else {
146			Err(BoundsExceeded)
147		}
148	}
149}
150
151impl TryFrom<VarInt> for u8 {
152	type Error = BoundsExceeded;
153
154	/// Succeeds iff `x` < 2^8
155	fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
156		if x.0 <= u8::MAX.into() {
157			Ok(x.0 as u8)
158		} else {
159			Err(BoundsExceeded)
160		}
161	}
162}
163
164impl fmt::Display for VarInt {
165	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
166		self.0.fmt(f)
167	}
168}
169
170impl VarInt {
171	/// Decode a QUIC-style varint (2-bit length tag in top bits).
172	fn decode_quic<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
173		if !r.has_remaining() {
174			return Err(DecodeError::Short);
175		}
176
177		let b = r.get_u8();
178		let tag = b >> 6;
179
180		let mut buf = [0u8; 8];
181		buf[0] = b & 0b0011_1111;
182
183		let x = match tag {
184			0b00 => u64::from(buf[0]),
185			0b01 => {
186				if !r.has_remaining() {
187					return Err(DecodeError::Short);
188				}
189				r.copy_to_slice(buf[1..2].as_mut());
190				u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
191			}
192			0b10 => {
193				if r.remaining() < 3 {
194					return Err(DecodeError::Short);
195				}
196				r.copy_to_slice(buf[1..4].as_mut());
197				u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
198			}
199			0b11 => {
200				if r.remaining() < 7 {
201					return Err(DecodeError::Short);
202				}
203				r.copy_to_slice(buf[1..8].as_mut());
204				u64::from_be_bytes(buf)
205			}
206			_ => unreachable!(),
207		};
208
209		Ok(Self(x))
210	}
211
212	/// Encode a QUIC-style varint (2-bit length tag in top bits).
213	fn encode_quic<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
214		let remaining = w.remaining_mut();
215		if self.0 < (1u64 << 6) {
216			if remaining < 1 {
217				return Err(EncodeError::Short);
218			}
219			w.put_u8(self.0 as u8);
220		} else if self.0 < (1u64 << 14) {
221			if remaining < 2 {
222				return Err(EncodeError::Short);
223			}
224			w.put_u16((0b01 << 14) | self.0 as u16);
225		} else if self.0 < (1u64 << 30) {
226			if remaining < 4 {
227				return Err(EncodeError::Short);
228			}
229			w.put_u32((0b10 << 30) | self.0 as u32);
230		} else if self.0 < (1u64 << 62) {
231			if remaining < 8 {
232				return Err(EncodeError::Short);
233			}
234			w.put_u64((0b11 << 62) | self.0);
235		} else {
236			return Err(BoundsExceeded.into());
237		}
238		Ok(())
239	}
240
241	/// Decode a leading-1-bits varint (draft-17 Section 1.4.1).
242	///
243	/// The number of leading 1-bits determines the byte length:
244	/// - `0xxxxxxx` → 1 byte, 7 usable bits
245	/// - `10xxxxxx` → 2 bytes, 14 usable bits
246	/// - `110xxxxx` → 3 bytes, 21 usable bits
247	/// - `1110xxxx` → 4 bytes, 28 usable bits
248	/// - `11110xxx` → 5 bytes, 35 usable bits
249	/// - `111110xx` → 6 bytes, 42 usable bits
250	/// - `11111110` → 8 bytes, 56 usable bits (skips 7)
251	/// - `11111111` → 9 bytes, 64 usable bits
252	fn decode_leading_ones<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
253		if !r.has_remaining() {
254			return Err(DecodeError::Short);
255		}
256
257		let b = r.get_u8();
258		let ones = b.leading_ones() as usize;
259
260		match ones {
261			0 => {
262				// 0xxxxxxx — 7 bits
263				Ok(Self(u64::from(b)))
264			}
265			1 => {
266				// 10xxxxxx + 1 byte — 14 bits
267				if !r.has_remaining() {
268					return Err(DecodeError::Short);
269				}
270				let hi = u64::from(b & 0x3F);
271				let lo = u64::from(r.get_u8());
272				Ok(Self((hi << 8) | lo))
273			}
274			2 => {
275				// 110xxxxx + 2 bytes — 21 bits
276				if r.remaining() < 2 {
277					return Err(DecodeError::Short);
278				}
279				let hi = u64::from(b & 0x1F);
280				let mut buf = [0u8; 2];
281				r.copy_to_slice(&mut buf);
282				Ok(Self((hi << 16) | u64::from(u16::from_be_bytes(buf))))
283			}
284			3 => {
285				// 1110xxxx + 3 bytes — 28 bits
286				if r.remaining() < 3 {
287					return Err(DecodeError::Short);
288				}
289				let hi = u64::from(b & 0x0F);
290				let mut buf = [0u8; 3];
291				r.copy_to_slice(&mut buf);
292				Ok(Self(
293					(hi << 24) | u64::from(buf[0]) << 16 | u64::from(buf[1]) << 8 | u64::from(buf[2]),
294				))
295			}
296			4 => {
297				// 11110xxx + 4 bytes — 35 bits
298				if r.remaining() < 4 {
299					return Err(DecodeError::Short);
300				}
301				let hi = u64::from(b & 0x07);
302				let mut buf = [0u8; 4];
303				r.copy_to_slice(&mut buf);
304				Ok(Self((hi << 32) | u64::from(u32::from_be_bytes(buf))))
305			}
306			5 => {
307				// 111110xx + 5 bytes — 42 bits
308				if r.remaining() < 5 {
309					return Err(DecodeError::Short);
310				}
311				let hi = u64::from(b & 0x03);
312				let mut buf = [0u8; 5];
313				r.copy_to_slice(&mut buf);
314				let lo = u64::from(buf[0]) << 32
315					| u64::from(buf[1]) << 24
316					| u64::from(buf[2]) << 16
317					| u64::from(buf[3]) << 8
318					| u64::from(buf[4]);
319				Ok(Self((hi << 40) | lo))
320			}
321			6 => {
322				// 1111110x — INVALID per draft-17
323				Err(DecodeError::InvalidValue)?
324			}
325			7 => {
326				// 11111110 + 7 bytes — 56 bits
327				if r.remaining() < 7 {
328					return Err(DecodeError::Short);
329				}
330				let mut buf = [0u8; 8];
331				buf[0] = 0;
332				r.copy_to_slice(&mut buf[1..]);
333				Ok(Self(u64::from_be_bytes(buf)))
334			}
335			8 => {
336				// 11111111 + 8 bytes — 64 bits
337				if r.remaining() < 8 {
338					return Err(DecodeError::Short);
339				}
340				let mut buf = [0u8; 8];
341				r.copy_to_slice(&mut buf);
342				Ok(Self(u64::from_be_bytes(buf)))
343			}
344			_ => unreachable!(),
345		}
346	}
347
348	/// Encode a leading-1-bits varint (draft-17 Section 1.4.1).
349	fn encode_leading_ones<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
350		let x = self.0;
351		let remaining = w.remaining_mut();
352
353		if x < (1 << 7) {
354			// 0xxxxxxx — 1 byte
355			if remaining < 1 {
356				return Err(EncodeError::Short);
357			}
358			w.put_u8(x as u8);
359		} else if x < (1 << 14) {
360			// 10xxxxxx — 2 bytes
361			if remaining < 2 {
362				return Err(EncodeError::Short);
363			}
364			w.put_u8(0x80 | (x >> 8) as u8);
365			w.put_u8(x as u8);
366		} else if x < (1 << 21) {
367			// 110xxxxx — 3 bytes
368			if remaining < 3 {
369				return Err(EncodeError::Short);
370			}
371			w.put_u8(0xC0 | (x >> 16) as u8);
372			w.put_u16(x as u16);
373		} else if x < (1 << 28) {
374			// 1110xxxx — 4 bytes
375			if remaining < 4 {
376				return Err(EncodeError::Short);
377			}
378			w.put_u8(0xE0 | (x >> 24) as u8);
379			w.put_u8((x >> 16) as u8);
380			w.put_u16(x as u16);
381		} else if x < (1 << 35) {
382			// 11110xxx — 5 bytes
383			if remaining < 5 {
384				return Err(EncodeError::Short);
385			}
386			w.put_u8(0xF0 | (x >> 32) as u8);
387			w.put_u32(x as u32);
388		} else if x < (1 << 42) {
389			// 111110xx — 6 bytes
390			if remaining < 6 {
391				return Err(EncodeError::Short);
392			}
393			w.put_u8(0xF8 | (x >> 40) as u8);
394			w.put_u8((x >> 32) as u8);
395			w.put_u32(x as u32);
396		} else if x < (1 << 56) {
397			// 11111110 — 8 bytes (skips 7)
398			if remaining < 8 {
399				return Err(EncodeError::Short);
400			}
401			w.put_u8(0xFE);
402			// Write 7 bytes: high byte then low 6 bytes
403			w.put_u8((x >> 48) as u8);
404			w.put_u16((x >> 32) as u16);
405			w.put_u32(x as u32);
406		} else {
407			// 11111111 — 9 bytes
408			if remaining < 9 {
409				return Err(EncodeError::Short);
410			}
411			w.put_u8(0xFF);
412			w.put_u64(x);
413		}
414
415		Ok(())
416	}
417}
418
419use crate::{Version, ietf, lite};
420
421// All lite versions use QUIC-style varint encoding.
422impl Encode<lite::Version> for VarInt {
423	fn encode<W: bytes::BufMut>(&self, w: &mut W, _: lite::Version) -> Result<(), EncodeError> {
424		self.encode_quic(w)
425	}
426}
427
428impl Decode<lite::Version> for VarInt {
429	fn decode<R: bytes::Buf>(r: &mut R, _: lite::Version) -> Result<Self, DecodeError> {
430		Self::decode_quic(r)
431	}
432}
433
434// IETF versions use QUIC-style except Draft17 which uses leading-ones.
435impl Encode<ietf::Version> for VarInt {
436	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: ietf::Version) -> Result<(), EncodeError> {
437		match version {
438			ietf::Version::Draft14 | ietf::Version::Draft15 | ietf::Version::Draft16 => self.encode_quic(w),
439			ietf::Version::Draft17 => self.encode_leading_ones(w),
440		}
441	}
442}
443
444impl Decode<ietf::Version> for VarInt {
445	fn decode<R: bytes::Buf>(r: &mut R, version: ietf::Version) -> Result<Self, DecodeError> {
446		match version {
447			ietf::Version::Draft14 | ietf::Version::Draft15 | ietf::Version::Draft16 => Self::decode_quic(r),
448			ietf::Version::Draft17 => Self::decode_leading_ones(r),
449		}
450	}
451}
452
453// The top-level Version delegates to the sub-version impls.
454impl Encode<Version> for VarInt {
455	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: Version) -> Result<(), EncodeError> {
456		match version {
457			Version::Lite(v) => self.encode(w, v),
458			Version::Ietf(v) => self.encode(w, v),
459		}
460	}
461}
462
463impl Decode<Version> for VarInt {
464	fn decode<R: bytes::Buf>(r: &mut R, version: Version) -> Result<Self, DecodeError> {
465		match version {
466			Version::Lite(v) => Self::decode(r, v),
467			Version::Ietf(v) => Self::decode(r, v),
468		}
469	}
470}
471
472// Blanket impls for integer types that delegate to VarInt.
473impl<V: Copy> Encode<V> for u64
474where
475	VarInt: Encode<V>,
476{
477	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
478		VarInt::try_from(*self)?.encode(w, version)
479	}
480}
481
482impl<V: Copy> Decode<V> for u64
483where
484	VarInt: Decode<V>,
485{
486	fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
487		VarInt::decode(r, version).map(|v| v.into_inner())
488	}
489}
490
491impl<V: Copy> Encode<V> for usize
492where
493	VarInt: Encode<V>,
494{
495	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
496		VarInt::try_from(*self)?.encode(w, version)
497	}
498}
499
500impl<V: Copy> Decode<V> for usize
501where
502	VarInt: Decode<V>,
503{
504	fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
505		VarInt::decode(r, version).map(|v| v.into_inner() as usize)
506	}
507}
508
509impl<V: Copy> Encode<V> for u32
510where
511	VarInt: Encode<V>,
512{
513	fn encode<W: bytes::BufMut>(&self, w: &mut W, version: V) -> Result<(), EncodeError> {
514		VarInt::from(*self).encode(w, version)
515	}
516}
517
518impl<V: Copy> Decode<V> for u32
519where
520	VarInt: Decode<V>,
521{
522	fn decode<R: bytes::Buf>(r: &mut R, version: V) -> Result<Self, DecodeError> {
523		let v = VarInt::decode(r, version)?;
524		let v = v.try_into().map_err(|_| DecodeError::BoundsExceeded)?;
525		Ok(v)
526	}
527}
528
529#[cfg(test)]
530mod tests {
531	use super::{DecodeError, VarInt};
532	use bytes::Bytes;
533
534	/// Test vectors from the draft-17 spec (Table 2: Example Integer Encodings),
535	/// excluding the known-buggy example 4 (0xdd7f3e7d).
536	#[test]
537	fn leading_ones_spec_examples() {
538		let cases: &[(&[u8], u64)] = &[
539			(&[0x25], 37),
540			(&[0x80, 0x25], 37),
541			(&[0xbb, 0xbd], 15_293),
542			// Example 4 (0xdd7f3e7d = 494,878,333) is omitted — the spec has a bug.
543			// See https://github.com/moq-wg/moq-transport/pull/1521
544			(&[0xfa, 0xa1, 0xa0, 0xe4, 0x03, 0xd8], 2_893_212_287_960),
545			(
546				&[0xfe, 0xfa, 0x31, 0x8f, 0xa8, 0xe3, 0xca, 0x11],
547				70_423_237_261_249_041,
548			),
549			(
550				&[0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff],
551				18_446_744_073_709_551_615,
552			),
553		];
554
555		for (bytes, expected) in cases {
556			// Test decoding
557			let mut buf = Bytes::from(bytes.to_vec());
558			let decoded = VarInt::decode_leading_ones(&mut buf).expect("decode should succeed");
559			assert_eq!(
560				decoded.into_inner(),
561				*expected,
562				"decode mismatch for bytes {bytes:02x?}"
563			);
564			assert_eq!(buf.len(), 0, "all bytes should be consumed for {bytes:02x?}");
565
566			// Test round-trip encode:
567			// - Skip non-minimal encoding (0x8025 for 37)
568			// - Skip u64::MAX which exceeds VarInt::MAX (2^62-1) but is decodable
569			if let Some(varint) = VarInt::from_u64(*expected)
570				&& (bytes.len() == 1 || *expected != 37)
571			{
572				let mut encoded = Vec::new();
573				varint.encode_leading_ones(&mut encoded).expect("encode should succeed");
574				assert_eq!(&encoded, bytes, "encode mismatch for value {expected}");
575			}
576		}
577	}
578
579	/// 11111100 (0xFC) is an invalid code point per the spec.
580	#[test]
581	fn leading_ones_invalid_0xfc() {
582		let mut buf = Bytes::from_static(&[0xFC]);
583		assert!(
584			matches!(VarInt::decode_leading_ones(&mut buf), Err(DecodeError::InvalidValue)),
585			"0xFC should be rejected as invalid"
586		);
587	}
588
589	#[test]
590	fn leading_ones_boundaries_round_trip() {
591		let cases = [
592			((1u64 << 7) - 1, 1usize),
593			(1u64 << 7, 2usize),
594			((1u64 << 14) - 1, 2usize),
595			(1u64 << 14, 3usize),
596			((1u64 << 56) - 1, 8usize),
597			(1u64 << 56, 9usize),
598		];
599
600		for (value, expected_len) in cases {
601			let varint = VarInt::from_u64(value).expect("value should be representable as VarInt");
602			let mut encoded = Vec::new();
603			varint
604				.encode_leading_ones(&mut encoded)
605				.expect("leading-ones encode should succeed");
606			assert_eq!(
607				encoded.len(),
608				expected_len,
609				"unexpected encoded length for value {value}"
610			);
611
612			let mut bytes = Bytes::from(encoded);
613			let decoded = VarInt::decode_leading_ones(&mut bytes).expect("leading-ones decode should succeed");
614			assert_eq!(decoded.into_inner(), value, "round-trip mismatch for value {value}");
615		}
616	}
617}