1use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
33use bytes::{Buf, BufMut};
34use core::{fmt::Debug, mem::size_of};
35use sealed::{SPrim, UPrim};
36
37const BITS_PER_BYTE: usize = 8;
41
42const DATA_BITS_PER_BYTE: usize = 7;
45
46const DATA_BITS_MASK: u8 = 0x7F;
48
49const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52#[doc(hidden)]
55mod sealed {
56 use super::*;
57 use core::ops::{BitOrAssign, Shl, ShrAssign};
58
59 pub trait UPrim:
61 Copy
62 + From<u8>
63 + Sized
64 + FixedSize
65 + ShrAssign<usize>
66 + Shl<usize, Output = Self>
67 + BitOrAssign<Self>
68 + PartialOrd
69 + Debug
70 {
71 fn leading_zeros(self) -> u32;
73
74 fn as_u8(self) -> u8;
76 }
77
78 macro_rules! impl_uint {
80 ($type:ty) => {
81 impl UPrim for $type {
82 #[inline(always)]
83 fn leading_zeros(self) -> u32 {
84 self.leading_zeros()
85 }
86
87 #[inline(always)]
88 fn as_u8(self) -> u8 {
89 self as u8
90 }
91 }
92 };
93 }
94 impl_uint!(u16);
95 impl_uint!(u32);
96 impl_uint!(u64);
97 impl_uint!(u128);
98
99 pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
106 type UnsignedEquivalent: UPrim;
109
110 #[doc(hidden)]
113 const _COMMIT_OP_ASSERT: () =
114 assert!(size_of::<Self>() == size_of::<Self::UnsignedEquivalent>());
115
116 fn as_zigzag(&self) -> Self::UnsignedEquivalent;
118
119 fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
121 }
122
123 macro_rules! impl_sint {
125 ($type:ty, $utype:ty) => {
126 impl SPrim for $type {
127 type UnsignedEquivalent = $utype;
128
129 #[inline]
130 fn as_zigzag(&self) -> $utype {
131 let shr = size_of::<$utype>() * 8 - 1;
132 ((self << 1) ^ (self >> shr)) as $utype
133 }
134 #[inline]
135 fn un_zigzag(value: $utype) -> Self {
136 ((value >> 1) as $type) ^ (-((value & 1) as $type))
137 }
138 }
139 };
140 }
141 impl_sint!(i16, u16);
142 impl_sint!(i32, u32);
143 impl_sint!(i64, u64);
144 impl_sint!(i128, u128);
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
152pub struct UInt<U: UPrim>(pub U);
153
154macro_rules! impl_varuint_into {
157 ($($type:ty),+) => {
158 $(
159 impl From<UInt<$type>> for $type {
160 fn from(val: UInt<$type>) -> Self {
161 val.0
162 }
163 }
164 )+
165 };
166}
167impl_varuint_into!(u16, u32, u64, u128);
168
169impl<U: UPrim> Write for UInt<U> {
170 fn write(&self, buf: &mut impl BufMut) {
171 write(self.0, buf);
172 }
173}
174
175impl<U: UPrim> Read for UInt<U> {
176 type Cfg = ();
177 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
178 read(buf).map(UInt)
179 }
180}
181
182impl<U: UPrim> EncodeSize for UInt<U> {
183 fn encode_size(&self) -> usize {
184 size(self.0)
185 }
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
191pub struct SInt<S: SPrim>(pub S);
192
193macro_rules! impl_varsint_into {
196 ($($type:ty),+) => {
197 $(
198 impl From<SInt<$type>> for $type {
199 fn from(val: SInt<$type>) -> Self {
200 val.0
201 }
202 }
203 )+
204 };
205}
206impl_varsint_into!(i16, i32, i64, i128);
207
208impl<S: SPrim> Write for SInt<S> {
209 fn write(&self, buf: &mut impl BufMut) {
210 write_signed::<S>(self.0, buf);
211 }
212}
213
214impl<S: SPrim> Read for SInt<S> {
215 type Cfg = ();
216 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
217 read_signed::<S>(buf).map(SInt)
218 }
219}
220
221impl<S: SPrim> EncodeSize for SInt<S> {
222 fn encode_size(&self) -> usize {
223 size_signed::<S>(self.0)
224 }
225}
226
227fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
231 let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
232 if value < continuation_threshold {
233 buf.put_u8(value.as_u8());
236 return;
237 }
238
239 let mut val = value;
240 while val >= continuation_threshold {
241 buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
242 val >>= 7;
243 }
244 buf.put_u8(val.as_u8());
245}
246
247fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
253 let max_bits = T::SIZE * BITS_PER_BYTE;
254 let mut result: T = T::from(0);
255 let mut bits_read = 0;
256
257 loop {
259 let byte = u8::read(buf)?;
261
262 if byte == 0 && bits_read > 0 {
267 return Err(Error::InvalidVarint(T::SIZE));
268 }
269
270 let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
278 if remaining_bits <= DATA_BITS_PER_BYTE {
279 let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
280 if relevant_bits > remaining_bits {
281 return Err(Error::InvalidVarint(T::SIZE));
282 }
283 }
284
285 result |= T::from(byte & DATA_BITS_MASK) << bits_read;
287
288 if byte & CONTINUATION_BIT_MASK == 0 {
290 return Ok(result);
291 }
292
293 bits_read += DATA_BITS_PER_BYTE;
294 }
295}
296
297fn size<T: UPrim>(value: T) -> usize {
299 let total_bits = size_of::<T>() * 8;
300 let leading_zeros = value.leading_zeros() as usize;
301 let data_bits = total_bits - leading_zeros;
302 usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
303}
304
305fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
307 write(value.as_zigzag(), buf);
308}
309
310fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
312 Ok(S::un_zigzag(read(buf)?))
313}
314
315fn size_signed<S: SPrim>(value: S) -> usize {
317 size(value.as_zigzag())
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::{error::Error, DecodeExt, Encode};
324 #[cfg(not(feature = "std"))]
325 use alloc::vec::Vec;
326 use bytes::Bytes;
327
328 #[test]
329 fn test_end_of_buffer() {
330 let mut buf: Bytes = Bytes::from_static(&[]);
331 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
332
333 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
334 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
335
336 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
337 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
338 }
339
340 #[test]
341 fn test_overflow() {
342 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
343 assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
344
345 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
346 assert!(matches!(
347 read::<u32>(&mut buf),
348 Err(Error::InvalidVarint(u32::SIZE))
349 ));
350
351 let mut buf =
352 Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
353 assert!(matches!(
354 read::<u64>(&mut buf),
355 Err(Error::InvalidVarint(u64::SIZE))
356 ));
357 }
358
359 #[test]
360 fn test_overcontinuation() {
361 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
362 let result = read::<u32>(&mut buf);
363 assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
364 }
365
366 #[test]
367 fn test_zeroed_byte() {
368 let mut buf = Bytes::from_static(&[0xFF, 0x00]);
369 let result = read::<u64>(&mut buf);
370 assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
371 }
372
373 fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
375 const CASES: &[u128] = &[
376 0,
377 1,
378 127,
379 128,
380 129,
381 0xFF,
382 0x100,
383 0x3FFF,
384 0x4000,
385 0x1_FFFF,
386 0xFF_FFFF,
387 0x1_FF_FF_FF_FF,
388 0xFF_FF_FF_FF_FF_FF,
389 0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
390 u16::MAX as u128,
391 u32::MAX as u128,
392 u64::MAX as u128,
393 u128::MAX,
394 ];
395
396 for &raw in CASES {
397 let Ok(value) = raw.try_into() else { continue };
399 let value: T = value;
400
401 let mut buf = Vec::new();
403 write(value, &mut buf);
404 assert_eq!(buf.len(), size(value));
405
406 let mut slice = &buf[..];
408 let decoded: T = read(&mut slice).unwrap();
409 assert_eq!(decoded, value);
410 assert!(slice.is_empty());
411
412 let encoded = UInt(value).encode();
414 assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
415 }
416 }
417
418 #[test]
419 fn test_varuint() {
420 varuint_round_trip::<u16>();
421 varuint_round_trip::<u32>();
422 varuint_round_trip::<u64>();
423 varuint_round_trip::<u128>();
424 }
425
426 fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
427 const CASES: &[i128] = &[
428 0,
429 1,
430 -1,
431 2,
432 -2,
433 127,
434 -127,
435 128,
436 -128,
437 129,
438 -129,
439 0x7FFFFFFF,
440 -0x7FFFFFFF,
441 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
442 -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
443 i16::MIN as i128,
444 i16::MAX as i128,
445 i32::MIN as i128,
446 i32::MAX as i128,
447 i64::MIN as i128,
448 i64::MAX as i128,
449 ];
450
451 for &raw in CASES {
452 let Ok(value) = raw.try_into() else { continue };
454 let value: T = value;
455
456 let mut buf = Vec::new();
458 write_signed(value, &mut buf);
459 assert_eq!(buf.len(), size_signed(value));
460
461 let mut slice = &buf[..];
463 let decoded: T = read_signed(&mut slice).unwrap();
464 assert_eq!(decoded, value);
465 assert!(slice.is_empty());
466
467 let encoded = SInt(value).encode();
469 assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
470 }
471 }
472
473 #[test]
474 fn test_varsint() {
475 varsint_round_trip::<i16>();
476 varsint_round_trip::<i32>();
477 varsint_round_trip::<i64>();
478 varsint_round_trip::<i128>();
479 }
480
481 #[test]
482 fn test_varuint_into() {
483 let v32: u32 = 0x1_FFFF;
484 let out32: u32 = UInt(v32).into();
485 assert_eq!(v32, out32);
486
487 let v64: u64 = 0x1_FF_FF_FF_FF;
488 let out64: u64 = UInt(v64).into();
489 assert_eq!(v64, out64);
490 }
491
492 #[test]
493 fn test_varsint_into() {
494 let s32: i32 = -123_456;
495 let out32: i32 = SInt(s32).into();
496 assert_eq!(s32, out32);
497
498 let s64: i64 = 987_654_321;
499 let out64: i64 = SInt(s64).into();
500 assert_eq!(s64, out64);
501 }
502
503 #[test]
504 fn test_conformity() {
505 assert_eq!(0usize.encode(), &[0x00][..]);
506 assert_eq!(1usize.encode(), &[0x01][..]);
507 assert_eq!(127usize.encode(), &[0x7F][..]);
508 assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
509 assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
510 assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
511 assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
512 assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
513 assert_eq!(
514 (u32::MAX as usize).encode(),
515 &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
516 );
517 }
518
519 #[test]
520 fn test_all_u16_values() {
521 for i in 0..=u16::MAX {
523 let value = i;
524 let calculated_size = size(value);
525
526 let mut buf = Vec::new();
527 write(value, &mut buf);
528
529 assert_eq!(
530 buf.len(),
531 calculated_size,
532 "Size mismatch for u16 value {value}",
533 );
534
535 let uint = UInt(value);
537 assert_eq!(
538 uint.encode_size(),
539 buf.len(),
540 "UInt encode_size mismatch for value {value}",
541 );
542 }
543 }
544
545 #[test]
546 fn test_all_i16_values() {
547 for i in i16::MIN..=i16::MAX {
549 let value = i;
550 let calculated_size = size_signed(value);
551
552 let mut buf = Vec::new();
553 write_signed(value, &mut buf);
554
555 assert_eq!(
556 buf.len(),
557 calculated_size,
558 "Size mismatch for i16 value {value}",
559 );
560
561 let sint = SInt(value);
563 assert_eq!(
564 sint.encode_size(),
565 buf.len(),
566 "SInt encode_size mismatch for value {value}",
567 );
568
569 let mut slice = &buf[..];
571 let decoded: i16 = read_signed(&mut slice).unwrap();
572 assert_eq!(decoded, value, "Decode mismatch for value {value}");
573 assert!(
574 slice.is_empty(),
575 "Buffer not fully consumed for value {value}",
576 );
577 }
578 }
579
580 #[test]
581 fn test_exact_bit_boundaries() {
582 fn test_exact_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
584 for bits in 1..=128 {
585 let val = if bits == 128 {
588 u128::MAX
589 } else {
590 (1u128 << bits) - 1
591 };
592 let Ok(value) = T::try_from(val) else {
593 continue;
594 };
595
596 let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
598 let calculated_size = size(value);
599 assert_eq!(
600 calculated_size, expected_size,
601 "Size calculation wrong for {val} with {bits} bits",
602 );
603
604 let mut buf = Vec::new();
606 write(value, &mut buf);
607 assert_eq!(
608 buf.len(),
609 expected_size,
610 "Encoded size wrong for {val} with {bits} bits",
611 );
612 }
613 }
614
615 test_exact_bits::<u16>();
616 test_exact_bits::<u32>();
617 test_exact_bits::<u64>();
618 test_exact_bits::<u128>();
619 }
620
621 #[test]
622 fn test_single_bit_boundaries() {
623 fn test_single_bits<T: UPrim + TryFrom<u128> + core::fmt::Display>() {
625 for bit_pos in 0..128 {
626 let val = 1u128 << bit_pos;
628 let Ok(value) = T::try_from(val) else {
629 continue;
630 };
631
632 let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
634 let calculated_size = size(value);
635 assert_eq!(
636 calculated_size, expected_size,
637 "Size wrong for 1<<{bit_pos} = {val}",
638 );
639
640 let mut buf = Vec::new();
642 write(value, &mut buf);
643 assert_eq!(
644 buf.len(),
645 expected_size,
646 "Encoded size wrong for 1<<{bit_pos} = {val}",
647 );
648 }
649 }
650
651 test_single_bits::<u16>();
652 test_single_bits::<u32>();
653 test_single_bits::<u64>();
654 test_single_bits::<u128>();
655 }
656}