1use crate::{EncodeSize, Error, FixedSize, Read, Write};
33use bytes::{Buf, BufMut};
34use sealed::{SPrim, UPrim};
35use std::fmt::Debug;
36
37const BITS_PER_BYTE: usize = 8;
41
42const DATA_BITS_PER_BYTE: usize = 7;
45
46const DATA_BITS_MASK: u8 = 0x7F;
48
49const CONTINUATION_BIT_MASK: u8 = 0x80;
51
52#[doc(hidden)]
55mod sealed {
56 use super::*;
57 use std::ops::{BitOrAssign, Shl, ShrAssign};
58
59 pub trait UPrim:
61 Copy
62 + From<u8>
63 + Sized
64 + FixedSize
65 + ShrAssign<usize>
66 + Shl<usize, Output = Self>
67 + BitOrAssign<Self>
68 + PartialOrd
69 + Debug
70 {
71 fn leading_zeros(self) -> u32;
73
74 fn as_u8(self) -> u8;
76 }
77
78 macro_rules! impl_uint {
80 ($type:ty) => {
81 impl UPrim for $type {
82 #[inline(always)]
83 fn leading_zeros(self) -> u32 {
84 self.leading_zeros()
85 }
86
87 #[inline(always)]
88 fn as_u8(self) -> u8 {
89 self as u8
90 }
91 }
92 };
93 }
94 impl_uint!(u16);
95 impl_uint!(u32);
96 impl_uint!(u64);
97 impl_uint!(u128);
98
99 pub trait SPrim: Copy + Sized + FixedSize + PartialOrd + Debug {
106 type UnsignedEquivalent: UPrim;
109
110 #[doc(hidden)]
113 const _COMMIT_OP_ASSERT: () =
114 assert!(std::mem::size_of::<Self>() == std::mem::size_of::<Self::UnsignedEquivalent>());
115
116 fn as_zigzag(&self) -> Self::UnsignedEquivalent;
118
119 fn un_zigzag(value: Self::UnsignedEquivalent) -> Self;
121 }
122
123 macro_rules! impl_sint {
125 ($type:ty, $utype:ty) => {
126 impl SPrim for $type {
127 type UnsignedEquivalent = $utype;
128
129 #[inline]
130 fn as_zigzag(&self) -> $utype {
131 let shr = std::mem::size_of::<$utype>() * 8 - 1;
132 ((self << 1) ^ (self >> shr)) as $utype
133 }
134 #[inline]
135 fn un_zigzag(value: $utype) -> Self {
136 ((value >> 1) as $type) ^ (-((value & 1) as $type))
137 }
138 }
139 };
140 }
141 impl_sint!(i16, u16);
142 impl_sint!(i32, u32);
143 impl_sint!(i64, u64);
144 impl_sint!(i128, u128);
145}
146
147#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
152pub struct UInt<U: UPrim>(pub U);
153
154macro_rules! impl_varuint_into {
157 ($($type:ty),+) => {
158 $(
159 impl From<UInt<$type>> for $type {
160 fn from(val: UInt<$type>) -> Self {
161 val.0
162 }
163 }
164 )+
165 };
166}
167impl_varuint_into!(u16, u32, u64, u128);
168
169impl<U: UPrim> Write for UInt<U> {
170 fn write(&self, buf: &mut impl BufMut) {
171 write(self.0, buf);
172 }
173}
174
175impl<U: UPrim> Read for UInt<U> {
176 type Cfg = ();
177 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
178 read(buf).map(UInt)
179 }
180}
181
182impl<U: UPrim> EncodeSize for UInt<U> {
183 fn encode_size(&self) -> usize {
184 size(self.0)
185 }
186}
187
188#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
191pub struct SInt<S: SPrim>(pub S);
192
193macro_rules! impl_varsint_into {
196 ($($type:ty),+) => {
197 $(
198 impl From<SInt<$type>> for $type {
199 fn from(val: SInt<$type>) -> Self {
200 val.0
201 }
202 }
203 )+
204 };
205}
206impl_varsint_into!(i16, i32, i64, i128);
207
208impl<S: SPrim> Write for SInt<S> {
209 fn write(&self, buf: &mut impl BufMut) {
210 write_signed::<S>(self.0, buf);
211 }
212}
213
214impl<S: SPrim> Read for SInt<S> {
215 type Cfg = ();
216 fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
217 read_signed::<S>(buf).map(SInt)
218 }
219}
220
221impl<S: SPrim> EncodeSize for SInt<S> {
222 fn encode_size(&self) -> usize {
223 size_signed::<S>(self.0)
224 }
225}
226
227fn write<T: UPrim>(value: T, buf: &mut impl BufMut) {
231 let continuation_threshold = T::from(CONTINUATION_BIT_MASK);
232 if value < continuation_threshold {
233 buf.put_u8(value.as_u8());
236 return;
237 }
238
239 let mut val = value;
240 while val >= continuation_threshold {
241 buf.put_u8((val.as_u8()) | CONTINUATION_BIT_MASK);
242 val >>= 7;
243 }
244 buf.put_u8(val.as_u8());
245}
246
247fn read<T: UPrim>(buf: &mut impl Buf) -> Result<T, Error> {
253 let max_bits = T::SIZE * BITS_PER_BYTE;
254 let mut result: T = T::from(0);
255 let mut bits_read = 0;
256
257 loop {
259 if !buf.has_remaining() {
261 return Err(Error::EndOfBuffer);
262 }
263 let byte = buf.get_u8();
264
265 if byte == 0 && bits_read > 0 {
270 return Err(Error::InvalidVarint(T::SIZE));
271 }
272
273 let remaining_bits = max_bits.checked_sub(bits_read).unwrap();
281 if remaining_bits <= DATA_BITS_PER_BYTE {
282 let relevant_bits = BITS_PER_BYTE - byte.leading_zeros() as usize;
283 if relevant_bits > remaining_bits {
284 return Err(Error::InvalidVarint(T::SIZE));
285 }
286 }
287
288 result |= T::from(byte & DATA_BITS_MASK) << bits_read;
290
291 if byte & CONTINUATION_BIT_MASK == 0 {
293 return Ok(result);
294 }
295
296 bits_read += DATA_BITS_PER_BYTE;
297 }
298}
299
300fn size<T: UPrim>(value: T) -> usize {
302 let total_bits = std::mem::size_of::<T>() * 8;
303 let leading_zeros = value.leading_zeros() as usize;
304 let data_bits = total_bits - leading_zeros;
305 usize::max(1, data_bits.div_ceil(DATA_BITS_PER_BYTE))
306}
307
308fn write_signed<S: SPrim>(value: S, buf: &mut impl BufMut) {
310 write(value.as_zigzag(), buf);
311}
312
313fn read_signed<S: SPrim>(buf: &mut impl Buf) -> Result<S, Error> {
315 Ok(S::un_zigzag(read(buf)?))
316}
317
318fn size_signed<S: SPrim>(value: S) -> usize {
320 size(value.as_zigzag())
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::{error::Error, DecodeExt, Encode};
327 use bytes::Bytes;
328
329 #[test]
330 fn test_end_of_buffer() {
331 let mut buf: Bytes = Bytes::from_static(&[]);
332 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
333
334 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x8F]);
335 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
336
337 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0x8F]);
338 assert!(matches!(read::<u32>(&mut buf), Err(Error::EndOfBuffer)));
339 }
340
341 #[test]
342 fn test_overflow() {
343 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x0F]);
344 assert_eq!(read::<u32>(&mut buf).unwrap(), u32::MAX);
345
346 let mut buf: Bytes = Bytes::from_static(&[0xFF, 0xFF, 0xFF, 0xFF, 0x1F]);
347 assert!(matches!(
348 read::<u32>(&mut buf),
349 Err(Error::InvalidVarint(u32::SIZE))
350 ));
351
352 let mut buf =
353 Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x02]);
354 assert!(matches!(
355 read::<u64>(&mut buf),
356 Err(Error::InvalidVarint(u64::SIZE))
357 ));
358 }
359
360 #[test]
361 fn test_overcontinuation() {
362 let mut buf: Bytes = Bytes::from_static(&[0x80, 0x80, 0x80, 0x80, 0x80]);
363 let result = read::<u32>(&mut buf);
364 assert!(matches!(result, Err(Error::InvalidVarint(u32::SIZE))));
365 }
366
367 #[test]
368 fn test_zeroed_byte() {
369 let mut buf = Bytes::from_static(&[0xFF, 0x00]);
370 let result = read::<u64>(&mut buf);
371 assert!(matches!(result, Err(Error::InvalidVarint(u64::SIZE))));
372 }
373
374 fn varuint_round_trip<T: Copy + UPrim + TryFrom<u128>>() {
376 const CASES: &[u128] = &[
377 0,
378 1,
379 127,
380 128,
381 129,
382 0xFF,
383 0x100,
384 0x3FFF,
385 0x4000,
386 0x1_FFFF,
387 0xFF_FFFF,
388 0x1_FF_FF_FF_FF,
389 0xFF_FF_FF_FF_FF_FF,
390 0x1_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF_FF,
391 u16::MAX as u128,
392 u32::MAX as u128,
393 u64::MAX as u128,
394 u128::MAX,
395 ];
396
397 for &raw in CASES {
398 let Ok(value) = raw.try_into() else { continue };
400 let value: T = value;
401
402 let mut buf = Vec::new();
404 write(value, &mut buf);
405 assert_eq!(buf.len(), size(value));
406
407 let mut slice = &buf[..];
409 let decoded: T = read(&mut slice).unwrap();
410 assert_eq!(decoded, value);
411 assert!(slice.is_empty());
412
413 let encoded = UInt(value).encode();
415 assert_eq!(UInt::<T>::decode(encoded).unwrap(), UInt(value));
416 }
417 }
418
419 #[test]
420 fn test_varuint() {
421 varuint_round_trip::<u16>();
422 varuint_round_trip::<u32>();
423 varuint_round_trip::<u64>();
424 varuint_round_trip::<u128>();
425 }
426
427 fn varsint_round_trip<T: Copy + SPrim + TryFrom<i128>>() {
428 const CASES: &[i128] = &[
429 0,
430 1,
431 -1,
432 2,
433 -2,
434 127,
435 -127,
436 128,
437 -128,
438 129,
439 -129,
440 0x7FFFFFFF,
441 -0x7FFFFFFF,
442 0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
443 -0x7FFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF_FFFF,
444 i16::MIN as i128,
445 i16::MAX as i128,
446 i32::MIN as i128,
447 i32::MAX as i128,
448 i64::MIN as i128,
449 i64::MAX as i128,
450 ];
451
452 for &raw in CASES {
453 let Ok(value) = raw.try_into() else { continue };
455 let value: T = value;
456
457 let mut buf = Vec::new();
459 write_signed(value, &mut buf);
460 assert_eq!(buf.len(), size_signed(value));
461
462 let mut slice = &buf[..];
464 let decoded: T = read_signed(&mut slice).unwrap();
465 assert_eq!(decoded, value);
466 assert!(slice.is_empty());
467
468 let encoded = SInt(value).encode();
470 assert_eq!(SInt::<T>::decode(encoded).unwrap(), SInt(value));
471 }
472 }
473
474 #[test]
475 fn test_varsint() {
476 varsint_round_trip::<i16>();
477 varsint_round_trip::<i32>();
478 varsint_round_trip::<i64>();
479 varsint_round_trip::<i128>();
480 }
481
482 #[test]
483 fn test_varuint_into() {
484 let v32: u32 = 0x1_FFFF;
485 let out32: u32 = UInt(v32).into();
486 assert_eq!(v32, out32);
487
488 let v64: u64 = 0x1_FF_FF_FF_FF;
489 let out64: u64 = UInt(v64).into();
490 assert_eq!(v64, out64);
491 }
492
493 #[test]
494 fn test_varsint_into() {
495 let s32: i32 = -123_456;
496 let out32: i32 = SInt(s32).into();
497 assert_eq!(s32, out32);
498
499 let s64: i64 = 987_654_321;
500 let out64: i64 = SInt(s64).into();
501 assert_eq!(s64, out64);
502 }
503
504 #[test]
505 fn test_conformity() {
506 assert_eq!(0usize.encode(), &[0x00][..]);
507 assert_eq!(1usize.encode(), &[0x01][..]);
508 assert_eq!(127usize.encode(), &[0x7F][..]);
509 assert_eq!(128usize.encode(), &[0x80, 0x01][..]);
510 assert_eq!(16383usize.encode(), &[0xFF, 0x7F][..]);
511 assert_eq!(16384usize.encode(), &[0x80, 0x80, 0x01][..]);
512 assert_eq!(2097151usize.encode(), &[0xFF, 0xFF, 0x7F][..]);
513 assert_eq!(2097152usize.encode(), &[0x80, 0x80, 0x80, 0x01][..]);
514 assert_eq!(
515 (u32::MAX as usize).encode(),
516 &[0xFF, 0xFF, 0xFF, 0xFF, 0x0F][..]
517 );
518 }
519
520 #[test]
521 fn test_all_u16_values() {
522 for i in 0..=u16::MAX {
524 let value = i;
525 let calculated_size = size(value);
526
527 let mut buf = Vec::new();
528 write(value, &mut buf);
529
530 assert_eq!(
531 buf.len(),
532 calculated_size,
533 "Size mismatch for u16 value {value}",
534 );
535
536 let uint = UInt(value);
538 assert_eq!(
539 uint.encode_size(),
540 buf.len(),
541 "UInt encode_size mismatch for value {value}",
542 );
543 }
544 }
545
546 #[test]
547 fn test_all_i16_values() {
548 for i in i16::MIN..=i16::MAX {
550 let value = i;
551 let calculated_size = size_signed(value);
552
553 let mut buf = Vec::new();
554 write_signed(value, &mut buf);
555
556 assert_eq!(
557 buf.len(),
558 calculated_size,
559 "Size mismatch for i16 value {value}",
560 );
561
562 let sint = SInt(value);
564 assert_eq!(
565 sint.encode_size(),
566 buf.len(),
567 "SInt encode_size mismatch for value {value}",
568 );
569
570 let mut slice = &buf[..];
572 let decoded: i16 = read_signed(&mut slice).unwrap();
573 assert_eq!(decoded, value, "Decode mismatch for value {value}");
574 assert!(
575 slice.is_empty(),
576 "Buffer not fully consumed for value {value}",
577 );
578 }
579 }
580
581 #[test]
582 fn test_exact_bit_boundaries() {
583 fn test_exact_bits<T: UPrim + TryFrom<u128> + std::fmt::Display>() {
585 for bits in 1..=128 {
586 let val = if bits == 128 {
589 u128::MAX
590 } else {
591 (1u128 << bits) - 1
592 };
593 let Ok(value) = T::try_from(val) else {
594 continue;
595 };
596
597 let expected_size = (bits as usize).div_ceil(DATA_BITS_PER_BYTE);
599 let calculated_size = size(value);
600 assert_eq!(
601 calculated_size, expected_size,
602 "Size calculation wrong for {val} with {bits} bits",
603 );
604
605 let mut buf = Vec::new();
607 write(value, &mut buf);
608 assert_eq!(
609 buf.len(),
610 expected_size,
611 "Encoded size wrong for {val} with {bits} bits",
612 );
613 }
614 }
615
616 test_exact_bits::<u16>();
617 test_exact_bits::<u32>();
618 test_exact_bits::<u64>();
619 test_exact_bits::<u128>();
620 }
621
622 #[test]
623 fn test_single_bit_boundaries() {
624 fn test_single_bits<T: UPrim + TryFrom<u128> + std::fmt::Display>() {
626 for bit_pos in 0..128 {
627 let val = 1u128 << bit_pos;
629 let Ok(value) = T::try_from(val) else {
630 continue;
631 };
632
633 let expected_size = ((bit_pos + 1) as usize).div_ceil(DATA_BITS_PER_BYTE);
635 let calculated_size = size(value);
636 assert_eq!(
637 calculated_size, expected_size,
638 "Size wrong for 1<<{bit_pos} = {val}",
639 );
640
641 let mut buf = Vec::new();
643 write(value, &mut buf);
644 assert_eq!(
645 buf.len(),
646 expected_size,
647 "Encoded size wrong for 1<<{bit_pos} = {val}",
648 );
649 }
650 }
651
652 test_single_bits::<u16>();
653 test_single_bits::<u32>();
654 test_single_bits::<u64>();
655 test_single_bits::<u128>();
656 }
657}