1use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
33use bytes::{Buf, BufMut};
34use core::{fmt::Debug, mem::size_of};
35use sealed::{SPrim, UPrim};
36
37const BITS_PER_BYTE: usize = 8;
41
42const DATA_BITS_PER_BYTE: usize = 7;
45
46const DATA_BITS_MASK: u8 = 0x7F;
48
49const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52#[doc(hidden)]
55mod sealed {
56 use super::*;
57 use core::ops::{BitOrAssign, Shl, ShrAssign};
58
59 pub trait UPrim:
61 Copy
62 + From<u8>
63 + Sized
64 + FixedSize
65 + ShrAssign<usize>
66 + Shl<usize, Output = Self>
67 + BitOrAssign<Self>
68 + PartialOrd
69 + Debug
70 {
71 fn leading_zeros(self) -> u32;
73
74 fn as_u8(self) -> u8;
76 }
77
78 macro_rules! impl_uint {
80 ($type:ty) => {
81 impl UPrim for $type {
82 #[inline(always)]
83 fn leading_zeros(self) -> u32 {
84 self.leading_zeros()
85 }
86
87 #[inline(always)]
88 fn as_u8(self) -> u8 {
89 self as u8
90 }
91 }
92 };
93 }
94 impl_uint!(u16);
95 impl_uint!(u32);
96 impl_uint!(u64);
97 impl_uint!(u128);
98
99 pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
106 type UnsignedEquivalent: UPrim;
109
110 fn as_zigzag(&self) -> Self::UnsignedEquivalent;
112
113 fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
115 }
116
117 #[inline(always)]
119 const fn assert_equal_size<T: Sized, U: Sized>() {
120 assert!(
121 size_of::<T>() == size_of::<U>(),
122 "Unsigned integer must be the same size as the signed integer"
123 );
124 }
125
126 macro_rules! impl_sint {
128 ($type:ty, $utype:ty) => {
129 impl SPrim for $type {
130 type UnsignedEquivalent = $utype;
131
132 #[inline]
133 fn as_zigzag(&self) -> $utype {
134 const {
136 assert_equal_size::<$type, $utype>();
137 }
138
139 let shr = size_of::<$utype>() * 8 - 1;
140 ((self << 1) ^ (self >> shr)) as $utype
141 }
142 #[inline]
143 fn un_zigzag(value: $utype) -> Self {
144 const {
146 assert_equal_size::<$type, $utype>();
147 }
148
149 ((value >> 1) as $type) ^ (-((value & 1) as $type))
150 }
151 }
152 };
153 }
154 impl_sint!(i16, u16);
155 impl_sint!(i32, u32);
156 impl_sint!(i64, u64);
157 impl_sint!(i128, u128);
158}
159
160#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
165pub struct UInt<U: UPrim>(pub U);
166
167macro_rules! impl_varuint_into {
170 ($($type:ty),+) => {
171 $(
172 impl From<UInt<$type>> for $type {
173 fn from(val: UInt<$type>) -> Self {
174 val.0
175 }
176 }
177 )+
178 };
179}
180impl_varuint_into!(u16, u32, u64, u128);
181
182impl<U: UPrim> Write for UInt<U> {
183 fn write(&self, buf: &mut impl BufMut) {
184 write(self.0, buf);
185 }
186}
187
188impl<U: UPrim> Read for UInt<U> {
189 type Cfg = ();
190 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
191 read(buf).map(UInt)
192 }
193}
194
195impl<U: UPrim> EncodeSize for UInt<U> {
196 fn encode_size(&self) -> usize {
197 size(self.0)
198 }
199}
200
201#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
204pub struct SInt<S: SPrim>(pub S);
205
206macro_rules! impl_varsint_into {
209 ($($type:ty),+) => {
210 $(
211 impl From<SInt<$type>> for $type {
212 fn from(val: SInt<$type>) -> Self {
213 val.0
214 }
215 }
216 )+
217 };
218}
219impl_varsint_into!(i16, i32, i64, i128);
220
221impl<S: SPrim> Write for SInt<S> {
222 fn write(&self, buf: &mut impl BufMut) {
223 write_signed::<S>(self.0, buf);
224 }
225}
226
227impl<S: SPrim> Read for SInt<S> {
228 type Cfg = ();
229 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
230 read_signed::<S>(buf).map(SInt)
231 }
232}
233
234impl<S: SPrim> EncodeSize for SInt<S> {
235 fn encode_size(&self) -> usize {
236 size_signed::<S>(self.0)
237 }
238}
239
240fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
244 let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
245 if value < continuation_threshold {
246 buf.put_u8(value.as_u8());
249 return;
250 }
251
252 let mut val = value;
253 while val >= continuation_threshold {
254 buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
255 val >>= 7;
256 }
257 buf.put_u8(val.as_u8());
258}
259
260fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
266 let max_bits = T::SIZE * BITS_PER_BYTE;
267 let mut result: T = T::from(0);
268 let mut bits_read = 0;
269
270 loop {
272 let byte = u8::read(buf)?;
274
275 if byte == 0 && bits_read > 0 {
280 return Err(Error::InvalidVarint(T::SIZE));
281 }
282
283 let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
291 if remaining_bits <= DATA_BITS_PER_BYTE {
292 let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
293 if relevant_bits > remaining_bits {
294 return Err(Error::InvalidVarint(T::SIZE));
295 }
296 }
297
298 result |= T::from(byte & DATA_BITS_MASK) << bits_read;
300
301 if byte & CONTINUATION_BIT_MASK == 0 {
303 return Ok(result);
304 }
305
306 bits_read += DATA_BITS_PER_BYTE;
307 }
308}
309
310fn size<T: UPrim>(value: T) -> usize {
312 let total_bits = size_of::<T>() * 8;
313 let leading_zeros = value.leading_zeros() as usize;
314 let data_bits = total_bits - leading_zeros;
315 usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
316}
317
318fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
320 write(value.as_zigzag(), buf);
321}
322
323fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
325 Ok(S::un_zigzag(read(buf)?))
326}
327
328fn size_signed<S: SPrim>(value: S) -> usize {
330 size(value.as_zigzag())
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336 use crate::{error::Error, DecodeExt, Encode};
337 #[cfg(not(feature = "std"))]
338 use alloc::vec::Vec;
339 use bytes::Bytes;
340
341 #[test]
342 fn test_end_of_buffer() {
343 let mut buf: Bytes = Bytes::from_static(&[]);
344 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
345
346 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
347 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
348
349 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
350 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
351 }
352
353 #[test]
354 fn test_overflow() {
355 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
356 assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
357
358 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
359 assert!(matches!(
360 read::<u32>(&mut buf),
361 Err(Error::InvalidVarint(u32::SIZE))
362 ));
363
364 let mut buf =
365 Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
366 assert!(matches!(
367 read::<u64>(&mut buf),
368 Err(Error::InvalidVarint(u64::SIZE))
369 ));
370 }
371
372 #[test]
373 fn test_overcontinuation() {
374 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
375 let result = read::<u32>(&mut buf);
376 assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
377 }
378
379 #[test]
380 fn test_zeroed_byte() {
381 let mut buf = Bytes::from_static(&[0xFF, 0x00]);
382 let result = read::<u64>(&mut buf);
383 assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
384 }
385
386 fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
388 const CASES: &[u128] = &[
389 0,
390 1,
391 127,
392 128,
393 129,
394 0xFF,
395 0x100,
396 0x3FFF,
397 0x4000,
398 0x1_FFFF,
399 0xFF_FFFF,
400 0x1_FF_FF_FF_FF,
401 0xFF_FF_FF_FF_FF_FF,
402 0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
403 u16::MAX as u128,
404 u32::MAX as u128,
405 u64::MAX as u128,
406 u128::MAX,
407 ];
408
409 for &raw in CASES {
410 let Ok(value) = raw.try_into() else { continue };
412 let value: T = value;
413
414 let mut buf = Vec::new();
416 write(value, &mut buf);
417 assert_eq!(buf.len(), size(value));
418
419 let mut slice = &buf[..];
421 let decoded: T = read(&mut slice).unwrap();
422 assert_eq!(decoded, value);
423 assert!(slice.is_empty());
424
425 let encoded = UInt(value).encode();
427 assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
428 }
429 }
430
431 #[test]
432 fn test_varuint() {
433 varuint_round_trip::<u16>();
434 varuint_round_trip::<u32>();
435 varuint_round_trip::<u64>();
436 varuint_round_trip::<u128>();
437 }
438
439 fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
440 const CASES: &[i128] = &[
441 0,
442 1,
443 -1,
444 2,
445 -2,
446 127,
447 -127,
448 128,
449 -128,
450 129,
451 -129,
452 0x7FFFFFFF,
453 -0x7FFFFFFF,
454 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
455 -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
456 i16::MIN as i128,
457 i16::MAX as i128,
458 i32::MIN as i128,
459 i32::MAX as i128,
460 i64::MIN as i128,
461 i64::MAX as i128,
462 ];
463
464 for &raw in CASES {
465 let Ok(value) = raw.try_into() else { continue };
467 let value: T = value;
468
469 let mut buf = Vec::new();
471 write_signed(value, &mut buf);
472 assert_eq!(buf.len(), size_signed(value));
473
474 let mut slice = &buf[..];
476 let decoded: T = read_signed(&mut slice).unwrap();
477 assert_eq!(decoded, value);
478 assert!(slice.is_empty());
479
480 let encoded = SInt(value).encode();
482 assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
483 }
484 }
485
486 #[test]
487 fn test_varsint() {
488 varsint_round_trip::<i16>();
489 varsint_round_trip::<i32>();
490 varsint_round_trip::<i64>();
491 varsint_round_trip::<i128>();
492 }
493
494 #[test]
495 fn test_varuint_into() {
496 let v32: u32 = 0x1_FFFF;
497 let out32: u32 = UInt(v32).into();
498 assert_eq!(v32, out32);
499
500 let v64: u64 = 0x1_FF_FF_FF_FF;
501 let out64: u64 = UInt(v64).into();
502 assert_eq!(v64, out64);
503 }
504
505 #[test]
506 fn test_varsint_into() {
507 let s32: i32 = -123_456;
508 let out32: i32 = SInt(s32).into();
509 assert_eq!(s32, out32);
510
511 let s64: i64 = 987_654_321;
512 let out64: i64 = SInt(s64).into();
513 assert_eq!(s64, out64);
514 }
515
516 #[test]
517 fn test_conformity() {
518 assert_eq!(0usize.encode(), &[0x00][..]);
519 assert_eq!(1usize.encode(), &[0x01][..]);
520 assert_eq!(127usize.encode(), &[0x7F][..]);
521 assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
522 assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
523 assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
524 assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
525 assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
526 assert_eq!(
527 (u32::MAX as usize).encode(),
528 &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
529 );
530 }
531
532 #[test]
533 fn test_all_u16_values() {
534 for i in 0..=u16::MAX {
536 let value = i;
537 let calculated_size = size(value);
538
539 let mut buf = Vec::new();
540 write(value, &mut buf);
541
542 assert_eq!(
543 buf.len(),
544 calculated_size,
545 "Size mismatch for u16 value {value}",
546 );
547
548 let uint = UInt(value);
550 assert_eq!(
551 uint.encode_size(),
552 buf.len(),
553 "UInt encode_size mismatch for value {value}",
554 );
555 }
556 }
557
558 #[test]
559 fn test_all_i16_values() {
560 for i in i16::MIN..=i16::MAX {
562 let value = i;
563 let calculated_size = size_signed(value);
564
565 let mut buf = Vec::new();
566 write_signed(value, &mut buf);
567
568 assert_eq!(
569 buf.len(),
570 calculated_size,
571 "Size mismatch for i16 value {value}",
572 );
573
574 let sint = SInt(value);
576 assert_eq!(
577 sint.encode_size(),
578 buf.len(),
579 "SInt encode_size mismatch for value {value}",
580 );
581
582 let mut slice = &buf[..];
584 let decoded: i16 = read_signed(&mut slice).unwrap();
585 assert_eq!(decoded, value, "Decode mismatch for value {value}");
586 assert!(
587 slice.is_empty(),
588 "Buffer not fully consumed for value {value}",
589 );
590 }
591 }
592
593 #[test]
594 fn test_exact_bit_boundaries() {
595 fn test_exact_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
597 for bits in 1..=128 {
598 let val = if bits == 128 {
601 u128::MAX
602 } else {
603 (1u128 << bits) - 1
604 };
605 let Ok(value) = T::try_from(val) else {
606 continue;
607 };
608
609 let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
611 let calculated_size = size(value);
612 assert_eq!(
613 calculated_size, expected_size,
614 "Size calculation wrong for {val} with {bits} bits",
615 );
616
617 let mut buf = Vec::new();
619 write(value, &mut buf);
620 assert_eq!(
621 buf.len(),
622 expected_size,
623 "Encoded size wrong for {val} with {bits} bits",
624 );
625 }
626 }
627
628 test_exact_bits::<u16>();
629 test_exact_bits::<u32>();
630 test_exact_bits::<u64>();
631 test_exact_bits::<u128>();
632 }
633
634 #[test]
635 fn test_single_bit_boundaries() {
636 fn test_single_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
638 for bit_pos in 0..128 {
639 let val = 1u128 << bit_pos;
641 let Ok(value) = T::try_from(val) else {
642 continue;
643 };
644
645 let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
647 let calculated_size = size(value);
648 assert_eq!(
649 calculated_size, expected_size,
650 "Size wrong for 1<<{bit_pos} = {val}",
651 );
652
653 let mut buf = Vec::new();
655 write(value, &mut buf);
656 assert_eq!(
657 buf.len(),
658 expected_size,
659 "Encoded size wrong for 1<<{bit_pos} = {val}",
660 );
661 }
662 }
663
664 test_single_bits::<u16>();
665 test_single_bits::<u32>();
666 test_single_bits::<u64>();
667 test_single_bits::<u128>();
668 }
669}