1use super::*;
2
3pub fn encode_to_vec(mut n: u128, v: &mut Vec<u8>) {
4 while n >> 7 > 0 {
5 v.push(n.to_le_bytes()[0] | 0b1000_0000);
6 n >>= 7;
7 }
8 v.push(n.to_le_bytes()[0]);
9}
10
11pub fn decode(buffer: &[u8]) -> Result<(u128, usize), Error> {
12 let mut n = 0u128;
13
14 for (i, &byte) in buffer.iter().enumerate() {
15 if i > 18 {
16 return Err(Error::Overlong);
17 }
18
19 let value = u128::from(byte) & 0b0111_1111;
20
21 if i == 18 && value & 0b0111_1100 != 0 {
22 return Err(Error::Overflow);
23 }
24
25 n |= value << (7 * i);
26
27 if byte & 0b1000_0000 == 0 {
28 return Ok((n, i + 1));
29 }
30 }
31
32 Err(Error::Unterminated)
33}
34
35pub fn encode(n: u128) -> Vec<u8> {
36 let mut v = Vec::new();
37 encode_to_vec(n, &mut v);
38 v
39}
40
41#[derive(PartialEq, Debug)]
42pub enum Error {
43 Overlong,
44 Overflow,
45 Unterminated,
46}
47
48impl Display for Error {
49 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
50 match self {
51 Self::Overlong => write!(f, "too long"),
52 Self::Overflow => write!(f, "overflow"),
53 Self::Unterminated => write!(f, "unterminated"),
54 }
55 }
56}
57
58impl std::error::Error for Error {}
59
60#[cfg(test)]
61mod tests {
62 use super::*;
63
64 #[test]
65 fn zero_round_trips_successfully() {
66 let n = 0;
67 let encoded = encode(n);
68 let (decoded, length) = decode(&encoded).unwrap();
69 assert_eq!(decoded, n);
70 assert_eq!(length, encoded.len());
71 }
72
73 #[test]
74 fn u128_max_round_trips_successfully() {
75 let n = u128::MAX;
76 let encoded = encode(n);
77 let (decoded, length) = decode(&encoded).unwrap();
78 assert_eq!(decoded, n);
79 assert_eq!(length, encoded.len());
80 }
81
82 #[test]
83 fn powers_of_two_round_trip_successfully() {
84 for i in 0..128 {
85 let n = 1 << i;
86 let encoded = encode(n);
87 let (decoded, length) = decode(&encoded).unwrap();
88 assert_eq!(decoded, n);
89 assert_eq!(length, encoded.len());
90 }
91 }
92
93 #[test]
94 fn alternating_bit_strings_round_trip_successfully() {
95 let mut n = 0;
96
97 for i in 0..129 {
98 n = n << 1 | (i % 2);
99 let encoded = encode(n);
100 let (decoded, length) = decode(&encoded).unwrap();
101 assert_eq!(decoded, n);
102 assert_eq!(length, encoded.len());
103 }
104 }
105
106 #[test]
107 fn varints_may_not_be_longer_than_19_bytes() {
108 const VALID: [u8; 19] = [
109 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 0,
110 ];
111
112 const INVALID: [u8; 20] = [
113 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
114 128, 0,
115 ];
116
117 assert_eq!(decode(&VALID), Ok((0, 19)));
118 assert_eq!(decode(&INVALID), Err(Error::Overlong));
119 }
120
121 #[test]
122 fn varints_may_not_overflow_u128() {
123 assert_eq!(
124 decode(&[
125 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
126 64,
127 ]),
128 Err(Error::Overflow)
129 );
130 assert_eq!(
131 decode(&[
132 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
133 32,
134 ]),
135 Err(Error::Overflow)
136 );
137 assert_eq!(
138 decode(&[
139 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
140 16,
141 ]),
142 Err(Error::Overflow)
143 );
144 assert_eq!(
145 decode(&[
146 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
147 8,
148 ]),
149 Err(Error::Overflow)
150 );
151 assert_eq!(
152 decode(&[
153 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
154 4,
155 ]),
156 Err(Error::Overflow)
157 );
158 assert_eq!(
159 decode(&[
160 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128,
161 2,
162 ]),
163 Ok((2u128.pow(127), 19))
164 );
165 }
166
167 #[test]
168 fn varints_must_be_terminated() {
169 assert_eq!(decode(&[128]), Err(Error::Unterminated));
170 }
171}