1use crate::{EncodeSize, Error, FixedSize, Read, Write};
33use bytes::{Buf, BufMut};
34use sealed::{SPrim, UPrim};
35use std::fmt::Debug;
36
37const BITS_PER_BYTE: usize = 8;
41
42const DATA_BITS_PER_BYTE: usize = 7;
45
46const DATA_BITS_MASK: u8 = 0x7F;
48
49const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52mod sealed {
54 use super::*;
55 use std::ops::{BitOrAssign, Shl, ShrAssign};
56
57 pub trait UPrim:
59 Copy
60 + From<u8>
61 + Sized
62 + FixedSize
63 + ShrAssign<usize>
64 + Shl<usize, Output = Self>
65 + BitOrAssign<Self>
66 + PartialOrd
67 + Debug
68 {
69 fn leading_zeros(self) -> u32;
71
72 fn as_u8(self) -> u8;
74 }
75
76 macro_rules! impl_uint {
78 ($type:ty) => {
79 impl UPrim for $type {
80 #[inline(always)]
81 fn leading_zeros(self) -> u32 {
82 self.leading_zeros()
83 }
84
85 #[inline(always)]
86 fn as_u8(self) -> u8 {
87 self as u8
88 }
89 }
90 };
91 }
92 impl_uint!(u16);
93 impl_uint!(u32);
94 impl_uint!(u64);
95 impl_uint!(u128);
96
97 pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
104 type UnsignedEquivalent: UPrim;
107
108 #[doc(hidden)]
111 const _COMMIT_OP_ASSERT: () =
112 assert!(std::mem::size_of::<Self>() == std::mem::size_of::<Self::UnsignedEquivalent>());
113
114 fn as_zigzag(&self) -> Self::UnsignedEquivalent;
116
117 fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
119 }
120
121 macro_rules! impl_sint {
123 ($type:ty, $utype:ty) => {
124 impl SPrim for $type {
125 type UnsignedEquivalent = $utype;
126
127 #[inline]
128 fn as_zigzag(&self) -> $utype {
129 let shr = std::mem::size_of::<$utype>() * 8 - 1;
130 ((self << 1) ^ (self >> shr)) as $utype
131 }
132 #[inline]
133 fn un_zigzag(value: $utype) -> Self {
134 ((value >> 1) as $type) ^ (-((value & 1) as $type))
135 }
136 }
137 };
138 }
139 impl_sint!(i16, u16);
140 impl_sint!(i32, u32);
141 impl_sint!(i64, u64);
142 impl_sint!(i128, u128);
143}
144
145#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
150pub struct UInt<U: UPrim>(pub U);
151
152macro_rules! impl_varuint_into {
155 ($($type:ty),+) => {
156 $(
157 impl From<UInt<$type>> for $type {
158 fn from(val: UInt<$type>) -> Self {
159 val.0
160 }
161 }
162 )+
163 };
164}
165impl_varuint_into!(u16, u32, u64, u128);
166
167impl<U: UPrim> Write for UInt<U> {
168 fn write(&self, buf: &mut impl BufMut) {
169 write(self.0, buf);
170 }
171}
172
173impl<U: UPrim> Read for UInt<U> {
174 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
175 read(buf).map(UInt)
176 }
177}
178
179impl<U: UPrim> EncodeSize for UInt<U> {
180 fn encode_size(&self) -> usize {
181 size(self.0)
182 }
183}
184
185#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
188pub struct SInt<S: SPrim>(pub S);
189
190macro_rules! impl_varsint_into {
193 ($($type:ty),+) => {
194 $(
195 impl From<SInt<$type>> for $type {
196 fn from(val: SInt<$type>) -> Self {
197 val.0
198 }
199 }
200 )+
201 };
202}
203impl_varsint_into!(i16, i32, i64, i128);
204
205impl<S: SPrim> Write for SInt<S> {
206 fn write(&self, buf: &mut impl BufMut) {
207 write_signed::<S>(self.0, buf);
208 }
209}
210
211impl<S: SPrim> Read for SInt<S> {
212 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
213 read_signed::<S>(buf).map(SInt)
214 }
215}
216
217impl<S: SPrim> EncodeSize for SInt<S> {
218 fn encode_size(&self) -> usize {
219 size_signed::<S>(self.0)
220 }
221}
222
223fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
227 let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
228 if value < continuation_threshold {
229 buf.put_u8(value.as_u8());
232 return;
233 }
234
235 let mut val = value;
236 while val >= continuation_threshold {
237 buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
238 val >>= 7;
239 }
240 buf.put_u8(val.as_u8());
241}
242
243fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
249 let max_bits = T::SIZE * BITS_PER_BYTE;
250 let mut result: T = T::from(0);
251 let mut bits_read = 0;
252
253 loop {
255 if !buf.has_remaining() {
257 return Err(Error::EndOfBuffer);
258 }
259 let byte = buf.get_u8();
260
261 if byte == 0 && bits_read > 0 {
266 return Err(Error::InvalidVarint(T::SIZE));
267 }
268
269 let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
277 if remaining_bits <= DATA_BITS_PER_BYTE {
278 let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
279 if relevant_bits > remaining_bits {
280 return Err(Error::InvalidVarint(T::SIZE));
281 }
282 }
283
284 result |= T::from(byte & DATA_BITS_MASK) << bits_read;
286
287 if byte & CONTINUATION_BIT_MASK == 0 {
289 return Ok(result);
290 }
291
292 bits_read += DATA_BITS_PER_BYTE;
293 }
294}
295
296fn size<T: UPrim>(value: T) -> usize {
298 let total_bits = std::mem::size_of::<T>() * 8;
299 let leading_zeros = value.leading_zeros() as usize;
300 let data_bits = total_bits - leading_zeros;
301 usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
302}
303
304fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
306 write(value.as_zigzag(), buf);
307}
308
309fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
311 Ok(S::un_zigzag(read(buf)?))
312}
313
314fn size_signed<S: SPrim>(value: S) -> usize {
316 size(value.as_zigzag())
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use crate::{error::Error, DecodeExt, Encode};
323 use bytes::Bytes;
324
325 #[test]
326 fn test_end_of_buffer() {
327 let mut buf: Bytes = Bytes::from_static(&[]);
328 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
329
330 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
331 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
332
333 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
334 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
335 }
336
337 #[test]
338 fn test_overflow() {
339 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
340 assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
341
342 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
343 assert!(matches!(
344 read::<u32>(&mut buf),
345 Err(Error::InvalidVarint(u32::SIZE))
346 ));
347
348 let mut buf =
349 Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
350 assert!(matches!(
351 read::<u64>(&mut buf),
352 Err(Error::InvalidVarint(u64::SIZE))
353 ));
354 }
355
356 #[test]
357 fn test_overcontinuation() {
358 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
359 let result = read::<u32>(&mut buf);
360 assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
361 }
362
363 #[test]
364 fn test_zeroed_byte() {
365 let mut buf = Bytes::from_static(&[0xFF, 0x00]);
366 let result = read::<u64>(&mut buf);
367 assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
368 }
369
370 fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
372 const CASES: &[u128] = &[
373 0,
374 1,
375 127,
376 128,
377 129,
378 0xFF,
379 0x100,
380 0x3FFF,
381 0x4000,
382 0x1_FFFF,
383 0xFF_FFFF,
384 0x1_FF_FF_FF_FF,
385 0xFF_FF_FF_FF_FF_FF,
386 0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
387 u16::MAX as u128,
388 u32::MAX as u128,
389 u64::MAX as u128,
390 u128::MAX,
391 ];
392
393 for &raw in CASES {
394 let Ok(value) = raw.try_into() else { continue };
396 let value: T = value;
397
398 let mut buf = Vec::new();
400 write(value, &mut buf);
401 assert_eq!(buf.len(), size(value));
402
403 let mut slice = &buf[..];
405 let decoded: T = read(&mut slice).unwrap();
406 assert_eq!(decoded, value);
407 assert!(slice.is_empty());
408
409 let encoded = UInt(value).encode();
411 assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
412 }
413 }
414
415 #[test]
416 fn test_varuint() {
417 varuint_round_trip::<u16>();
418 varuint_round_trip::<u32>();
419 varuint_round_trip::<u64>();
420 varuint_round_trip::<u128>();
421 }
422
423 fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
424 const CASES: &[i128] = &[
425 0,
426 1,
427 -1,
428 2,
429 -2,
430 127,
431 -127,
432 128,
433 -128,
434 129,
435 -129,
436 0x7FFFFFFF,
437 -0x7FFFFFFF,
438 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
439 -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
440 i16::MIN as i128,
441 i16::MAX as i128,
442 i32::MIN as i128,
443 i32::MAX as i128,
444 i64::MIN as i128,
445 i64::MAX as i128,
446 ];
447
448 for &raw in CASES {
449 let Ok(value) = raw.try_into() else { continue };
451 let value: T = value;
452
453 let mut buf = Vec::new();
455 write_signed(value, &mut buf);
456 assert_eq!(buf.len(), size_signed(value));
457
458 let mut slice = &buf[..];
460 let decoded: T = read_signed(&mut slice).unwrap();
461 assert_eq!(decoded, value);
462 assert!(slice.is_empty());
463
464 let encoded = SInt(value).encode();
466 assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
467 }
468 }
469
470 #[test]
471 fn test_varsint() {
472 varsint_round_trip::<i16>();
473 varsint_round_trip::<i32>();
474 varsint_round_trip::<i64>();
475 varsint_round_trip::<i128>();
476 }
477
478 #[test]
479 fn test_varuint_into() {
480 let v32: u32 = 0x1_FFFF;
481 let out32: u32 = UInt(v32).into();
482 assert_eq!(v32, out32);
483
484 let v64: u64 = 0x1_FF_FF_FF_FF;
485 let out64: u64 = UInt(v64).into();
486 assert_eq!(v64, out64);
487 }
488
489 #[test]
490 fn test_varsint_into() {
491 let s32: i32 = -123_456;
492 let out32: i32 = SInt(s32).into();
493 assert_eq!(s32, out32);
494
495 let s64: i64 = 987_654_321;
496 let out64: i64 = SInt(s64).into();
497 assert_eq!(s64, out64);
498 }
499}