1use crate::{EncodeSize, Error, FixedSize, Read, ReadExt, Write};
33use bytes::{Buf, BufMut};
34use sealed::{SPrim, UPrim};
35use std::fmt::Debug;
36
37const BITS_PER_BYTE: usize = 8;
41
42const DATA_BITS_PER_BYTE: usize = 7;
45
46const DATA_BITS_MASK: u8 = 0x7F;
48
49const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52#[doc(hidden)]
55mod sealed {
56 use super::*;
57 use std::ops::{BitOrAssign, Shl, ShrAssign};
58
59 pub trait UPrim:
61 Copy
62 + From<u8>
63 + Sized
64 + FixedSize
65 + ShrAssign<usize>
66 + Shl<usize, Output = Self>
67 + BitOrAssign<Self>
68 + PartialOrd
69 + Debug
70 {
71 fn leading_zeros(self) -> u32;
73
74 fn as_u8(self) -> u8;
76 }
77
78 macro_rules! impl_uint {
80 ($type:ty) => {
81 impl UPrim for $type {
82 #[inline(always)]
83 fn leading_zeros(self) -> u32 {
84 self.leading_zeros()
85 }
86
87 #[inline(always)]
88 fn as_u8(self) -> u8 {
89 self as u8
90 }
91 }
92 };
93 }
94 impl_uint!(u16);
95 impl_uint!(u32);
96 impl_uint!(u64);
97 impl_uint!(u128);
98
99 pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
106 type UnsignedEquivalent: UPrim;
109
110 #[doc(hidden)]
113 const _COMMIT_OP_ASSERT: () =
114 assert!(std::mem::size_of::<Self>() == std::mem::size_of::<Self::UnsignedEquivalent>());
115
116 fn as_zigzag(&self) -> Self::UnsignedEquivalent;
118
119 fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
121 }
122
123 macro_rules! impl_sint {
125 ($type:ty, $utype:ty) => {
126 impl SPrim for $type {
127 type UnsignedEquivalent = $utype;
128
129 #[inline]
130 fn as_zigzag(&self) -> $utype {
131 let shr = std::mem::size_of::<$utype>() * 8 - 1;
132 ((self << 1) ^ (self >> shr)) as $utype
133 }
134 #[inline]
135 fn un_zigzag(value: $utype) -> Self {
136 ((value >> 1) as $type) ^ (-((value & 1) as $type))
137 }
138 }
139 };
140 }
141 impl_sint!(i16, u16);
142 impl_sint!(i32, u32);
143 impl_sint!(i64, u64);
144 impl_sint!(i128, u128);
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
152pub struct UInt<U: UPrim>(pub U);
153
154macro_rules! impl_varuint_into {
157 ($($type:ty),+) => {
158 $(
159 impl From<UInt<$type>> for $type {
160 fn from(val: UInt<$type>) -> Self {
161 val.0
162 }
163 }
164 )+
165 };
166}
167impl_varuint_into!(u16, u32, u64, u128);
168
169impl<U: UPrim> Write for UInt<U> {
170 fn write(&self, buf: &mut impl BufMut) {
171 write(self.0, buf);
172 }
173}
174
175impl<U: UPrim> Read for UInt<U> {
176 type Cfg = ();
177 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
178 read(buf).map(UInt)
179 }
180}
181
182impl<U: UPrim> EncodeSize for UInt<U> {
183 fn encode_size(&self) -> usize {
184 size(self.0)
185 }
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
191pub struct SInt<S: SPrim>(pub S);
192
193macro_rules! impl_varsint_into {
196 ($($type:ty),+) => {
197 $(
198 impl From<SInt<$type>> for $type {
199 fn from(val: SInt<$type>) -> Self {
200 val.0
201 }
202 }
203 )+
204 };
205}
206impl_varsint_into!(i16, i32, i64, i128);
207
208impl<S: SPrim> Write for SInt<S> {
209 fn write(&self, buf: &mut impl BufMut) {
210 write_signed::<S>(self.0, buf);
211 }
212}
213
214impl<S: SPrim> Read for SInt<S> {
215 type Cfg = ();
216 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
217 read_signed::<S>(buf).map(SInt)
218 }
219}
220
221impl<S: SPrim> EncodeSize for SInt<S> {
222 fn encode_size(&self) -> usize {
223 size_signed::<S>(self.0)
224 }
225}
226
227fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
231 let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
232 if value < continuation_threshold {
233 buf.put_u8(value.as_u8());
236 return;
237 }
238
239 let mut val = value;
240 while val >= continuation_threshold {
241 buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
242 val >>= 7;
243 }
244 buf.put_u8(val.as_u8());
245}
246
247fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
253 let max_bits = T::SIZE * BITS_PER_BYTE;
254 let mut result: T = T::from(0);
255 let mut bits_read = 0;
256
257 loop {
259 let byte = u8::read(buf)?;
261
262 if byte == 0 && bits_read > 0 {
267 return Err(Error::InvalidVarint(T::SIZE));
268 }
269
270 let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
278 if remaining_bits <= DATA_BITS_PER_BYTE {
279 let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
280 if relevant_bits > remaining_bits {
281 return Err(Error::InvalidVarint(T::SIZE));
282 }
283 }
284
285 result |= T::from(byte & DATA_BITS_MASK) << bits_read;
287
288 if byte & CONTINUATION_BIT_MASK == 0 {
290 return Ok(result);
291 }
292
293 bits_read += DATA_BITS_PER_BYTE;
294 }
295}
296
297fn size<T: UPrim>(value: T) -> usize {
299 let total_bits = std::mem::size_of::<T>() * 8;
300 let leading_zeros = value.leading_zeros() as usize;
301 let data_bits = total_bits - leading_zeros;
302 usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
303}
304
305fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
307 write(value.as_zigzag(), buf);
308}
309
310fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
312 Ok(S::un_zigzag(read(buf)?))
313}
314
315fn size_signed<S: SPrim>(value: S) -> usize {
317 size(value.as_zigzag())
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use crate::{error::Error, DecodeExt, Encode};
324 use bytes::Bytes;
325
326 #[test]
327 fn test_end_of_buffer() {
328 let mut buf: Bytes = Bytes::from_static(&[]);
329 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
330
331 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
332 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
333
334 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
335 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
336 }
337
338 #[test]
339 fn test_overflow() {
340 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
341 assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
342
343 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
344 assert!(matches!(
345 read::<u32>(&mut buf),
346 Err(Error::InvalidVarint(u32::SIZE))
347 ));
348
349 let mut buf =
350 Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
351 assert!(matches!(
352 read::<u64>(&mut buf),
353 Err(Error::InvalidVarint(u64::SIZE))
354 ));
355 }
356
357 #[test]
358 fn test_overcontinuation() {
359 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
360 let result = read::<u32>(&mut buf);
361 assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
362 }
363
364 #[test]
365 fn test_zeroed_byte() {
366 let mut buf = Bytes::from_static(&[0xFF, 0x00]);
367 let result = read::<u64>(&mut buf);
368 assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
369 }
370
371 fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
373 const CASES: &[u128] = &[
374 0,
375 1,
376 127,
377 128,
378 129,
379 0xFF,
380 0x100,
381 0x3FFF,
382 0x4000,
383 0x1_FFFF,
384 0xFF_FFFF,
385 0x1_FF_FF_FF_FF,
386 0xFF_FF_FF_FF_FF_FF,
387 0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
388 u16::MAX as u128,
389 u32::MAX as u128,
390 u64::MAX as u128,
391 u128::MAX,
392 ];
393
394 for &raw in CASES {
395 let Ok(value) = raw.try_into() else { continue };
397 let value: T = value;
398
399 let mut buf = Vec::new();
401 write(value, &mut buf);
402 assert_eq!(buf.len(), size(value));
403
404 let mut slice = &buf[..];
406 let decoded: T = read(&mut slice).unwrap();
407 assert_eq!(decoded, value);
408 assert!(slice.is_empty());
409
410 let encoded = UInt(value).encode();
412 assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
413 }
414 }
415
416 #[test]
417 fn test_varuint() {
418 varuint_round_trip::<u16>();
419 varuint_round_trip::<u32>();
420 varuint_round_trip::<u64>();
421 varuint_round_trip::<u128>();
422 }
423
424 fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
425 const CASES: &[i128] = &[
426 0,
427 1,
428 -1,
429 2,
430 -2,
431 127,
432 -127,
433 128,
434 -128,
435 129,
436 -129,
437 0x7FFFFFFF,
438 -0x7FFFFFFF,
439 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
440 -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
441 i16::MIN as i128,
442 i16::MAX as i128,
443 i32::MIN as i128,
444 i32::MAX as i128,
445 i64::MIN as i128,
446 i64::MAX as i128,
447 ];
448
449 for &raw in CASES {
450 let Ok(value) = raw.try_into() else { continue };
452 let value: T = value;
453
454 let mut buf = Vec::new();
456 write_signed(value, &mut buf);
457 assert_eq!(buf.len(), size_signed(value));
458
459 let mut slice = &buf[..];
461 let decoded: T = read_signed(&mut slice).unwrap();
462 assert_eq!(decoded, value);
463 assert!(slice.is_empty());
464
465 let encoded = SInt(value).encode();
467 assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
468 }
469 }
470
471 #[test]
472 fn test_varsint() {
473 varsint_round_trip::<i16>();
474 varsint_round_trip::<i32>();
475 varsint_round_trip::<i64>();
476 varsint_round_trip::<i128>();
477 }
478
479 #[test]
480 fn test_varuint_into() {
481 let v32: u32 = 0x1_FFFF;
482 let out32: u32 = UInt(v32).into();
483 assert_eq!(v32, out32);
484
485 let v64: u64 = 0x1_FF_FF_FF_FF;
486 let out64: u64 = UInt(v64).into();
487 assert_eq!(v64, out64);
488 }
489
490 #[test]
491 fn test_varsint_into() {
492 let s32: i32 = -123_456;
493 let out32: i32 = SInt(s32).into();
494 assert_eq!(s32, out32);
495
496 let s64: i64 = 987_654_321;
497 let out64: i64 = SInt(s64).into();
498 assert_eq!(s64, out64);
499 }
500
501 #[test]
502 fn test_conformity() {
503 assert_eq!(0usize.encode(), &[0x00][..]);
504 assert_eq!(1usize.encode(), &[0x01][..]);
505 assert_eq!(127usize.encode(), &[0x7F][..]);
506 assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
507 assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
508 assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
509 assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
510 assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
511 assert_eq!(
512 (u32::MAX as usize).encode(),
513 &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
514 );
515 }
516
517 #[test]
518 fn test_all_u16_values() {
519 for i in 0..=u16::MAX {
521 let value = i;
522 let calculated_size = size(value);
523
524 let mut buf = Vec::new();
525 write(value, &mut buf);
526
527 assert_eq!(
528 buf.len(),
529 calculated_size,
530 "Size mismatch for u16 value {value}",
531 );
532
533 let uint = UInt(value);
535 assert_eq!(
536 uint.encode_size(),
537 buf.len(),
538 "UInt encode_size mismatch for value {value}",
539 );
540 }
541 }
542
543 #[test]
544 fn test_all_i16_values() {
545 for i in i16::MIN..=i16::MAX {
547 let value = i;
548 let calculated_size = size_signed(value);
549
550 let mut buf = Vec::new();
551 write_signed(value, &mut buf);
552
553 assert_eq!(
554 buf.len(),
555 calculated_size,
556 "Size mismatch for i16 value {value}",
557 );
558
559 let sint = SInt(value);
561 assert_eq!(
562 sint.encode_size(),
563 buf.len(),
564 "SInt encode_size mismatch for value {value}",
565 );
566
567 let mut slice = &buf[..];
569 let decoded: i16 = read_signed(&mut slice).unwrap();
570 assert_eq!(decoded, value, "Decode mismatch for value {value}");
571 assert!(
572 slice.is_empty(),
573 "Buffer not fully consumed for value {value}",
574 );
575 }
576 }
577
578 #[test]
579 fn test_exact_bit_boundaries() {
580 fn test_exact_bits<T: UPrim + TryFrom<u128> + std::fmt::Display>() {
582 for bits in 1..=128 {
583 let val = if bits == 128 {
586 u128::MAX
587 } else {
588 (1u128 << bits) - 1
589 };
590 let Ok(value) = T::try_from(val) else {
591 continue;
592 };
593
594 let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
596 let calculated_size = size(value);
597 assert_eq!(
598 calculated_size, expected_size,
599 "Size calculation wrong for {val} with {bits} bits",
600 );
601
602 let mut buf = Vec::new();
604 write(value, &mut buf);
605 assert_eq!(
606 buf.len(),
607 expected_size,
608 "Encoded size wrong for {val} with {bits} bits",
609 );
610 }
611 }
612
613 test_exact_bits::<u16>();
614 test_exact_bits::<u32>();
615 test_exact_bits::<u64>();
616 test_exact_bits::<u128>();
617 }
618
619 #[test]
620 fn test_single_bit_boundaries() {
621 fn test_single_bits<T: UPrim + TryFrom<u128> + std::fmt::Display>() {
623 for bit_pos in 0..128 {
624 let val = 1u128 << bit_pos;
626 let Ok(value) = T::try_from(val) else {
627 continue;
628 };
629
630 let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
632 let calculated_size = size(value);
633 assert_eq!(
634 calculated_size, expected_size,
635 "Size wrong for 1<<{bit_pos} = {val}",
636 );
637
638 let mut buf = Vec::new();
640 write(value, &mut buf);
641 assert_eq!(
642 buf.len(),
643 expected_size,
644 "Encoded size wrong for 1<<{bit_pos} = {val}",
645 );
646 }
647 }
648
649 test_single_bits::<u16>();
650 test_single_bits::<u32>();
651 test_single_bits::<u64>();
652 test_single_bits::<u128>();
653 }
654}