1use std::convert::{TryFrom, TryInto};
6use std::fmt;
7
8use thiserror::Error;
9
10use super::{Decode, DecodeError, Encode, EncodeError};
11
12#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
13#[error("value out of range")]
14pub struct BoundsExceeded;
15
16#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
22pub struct VarInt(u64);
23
24impl VarInt {
25 pub const MAX: Self = Self((1 << 62) - 1);
27
28 pub const ZERO: Self = Self(0);
30
31 pub const fn from_u32(x: u32) -> Self {
34 Self(x as u64)
35 }
36
37 pub const fn into_inner(self) -> u64 {
39 self.0
40 }
41}
42
43impl From<VarInt> for u64 {
44 fn from(x: VarInt) -> Self {
45 x.0
46 }
47}
48
49impl From<VarInt> for usize {
50 fn from(x: VarInt) -> Self {
51 x.0 as usize
52 }
53}
54
55impl From<VarInt> for u128 {
56 fn from(x: VarInt) -> Self {
57 x.0 as u128
58 }
59}
60
61impl From<u8> for VarInt {
62 fn from(x: u8) -> Self {
63 Self(x.into())
64 }
65}
66
67impl From<u16> for VarInt {
68 fn from(x: u16) -> Self {
69 Self(x.into())
70 }
71}
72
73impl From<u32> for VarInt {
74 fn from(x: u32) -> Self {
75 Self(x.into())
76 }
77}
78
79impl TryFrom<u64> for VarInt {
80 type Error = BoundsExceeded;
81
82 fn try_from(x: u64) -> Result<Self, BoundsExceeded> {
84 let x = Self(x);
85 if x <= Self::MAX {
86 Ok(x)
87 } else {
88 Err(BoundsExceeded)
89 }
90 }
91}
92
93impl TryFrom<u128> for VarInt {
94 type Error = BoundsExceeded;
95
96 fn try_from(x: u128) -> Result<Self, BoundsExceeded> {
98 if x <= Self::MAX.into() {
99 Ok(Self(x as u64))
100 } else {
101 Err(BoundsExceeded)
102 }
103 }
104}
105
106impl TryFrom<usize> for VarInt {
107 type Error = BoundsExceeded;
108
109 fn try_from(x: usize) -> Result<Self, BoundsExceeded> {
111 Self::try_from(x as u64)
112 }
113}
114
115impl TryFrom<VarInt> for u32 {
116 type Error = BoundsExceeded;
117
118 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
120 if x.0 <= u32::MAX.into() {
121 Ok(x.0 as u32)
122 } else {
123 Err(BoundsExceeded)
124 }
125 }
126}
127
128impl TryFrom<VarInt> for u16 {
129 type Error = BoundsExceeded;
130
131 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
133 if x.0 <= u16::MAX.into() {
134 Ok(x.0 as u16)
135 } else {
136 Err(BoundsExceeded)
137 }
138 }
139}
140
141impl TryFrom<VarInt> for u8 {
142 type Error = BoundsExceeded;
143
144 fn try_from(x: VarInt) -> Result<Self, BoundsExceeded> {
146 if x.0 <= u8::MAX.into() {
147 Ok(x.0 as u8)
148 } else {
149 Err(BoundsExceeded)
150 }
151 }
152}
153
154impl fmt::Debug for VarInt {
155 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
156 self.0.fmt(f)
157 }
158}
159
160impl fmt::Display for VarInt {
161 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
162 self.0.fmt(f)
163 }
164}
165
166impl Decode for VarInt {
167 fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
169 Self::decode_remaining(r, 1)?;
170
171 let b = r.get_u8();
172 let tag = b >> 6;
173
174 let mut buf = [0u8; 8];
175 buf[0] = b & 0b0011_1111;
176
177 let x = match tag {
178 0b00 => u64::from(buf[0]),
179 0b01 => {
180 Self::decode_remaining(r, 1)?;
181 r.copy_to_slice(buf[1..2].as_mut());
182 u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
183 }
184 0b10 => {
185 Self::decode_remaining(r, 3)?;
186 r.copy_to_slice(buf[1..4].as_mut());
187 u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
188 }
189 0b11 => {
190 Self::decode_remaining(r, 7)?;
191 r.copy_to_slice(buf[1..8].as_mut());
192 u64::from_be_bytes(buf)
193 }
194 _ => unreachable!(),
195 };
196
197 Ok(Self(x))
198 }
199}
200
201impl Encode for VarInt {
202 fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
204 let x = self.0;
205 if x < 2u64.pow(6) {
206 Self::encode_remaining(w, 1)?;
207 w.put_u8(x as u8)
208 } else if x < 2u64.pow(14) {
209 Self::encode_remaining(w, 2)?;
210 w.put_u16((0b01 << 14) | x as u16)
211 } else if x < 2u64.pow(30) {
212 Self::encode_remaining(w, 4)?;
213 w.put_u32((0b10 << 30) | x as u32)
214 } else if x < 2u64.pow(62) {
215 Self::encode_remaining(w, 8)?;
216 w.put_u64((0b11 << 62) | x)
217 } else {
218 return Err(BoundsExceeded.into());
219 }
220
221 Ok(())
222 }
223}
224
225impl Encode for u64 {
228 fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
230 VarInt::try_from(*self)?.encode(w)
231 }
232}
233
234impl Decode for u64 {
235 fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
236 VarInt::decode(r).map(|v| v.into_inner())
237 }
238}
239
240impl Encode for usize {
244 fn encode<W: bytes::BufMut>(&self, w: &mut W) -> Result<(), EncodeError> {
246 let var = VarInt::try_from(*self)?;
247 var.encode(w)
248 }
249}
250
251impl Decode for usize {
252 fn decode<R: bytes::Buf>(r: &mut R) -> Result<Self, DecodeError> {
253 let var = VarInt::decode(r)?;
254 #[allow(clippy::unnecessary_fallible_conversions)]
256 usize::try_from(var).map_err(|_| DecodeError::BoundsExceeded(BoundsExceeded))
257 }
258}
259
260#[cfg(test)]
261mod tests {
262 use super::*;
263 use bytes::BytesMut;
264
265 #[test]
266 fn encode_decode_usize() {
267 let mut buf = BytesMut::new();
268
269 let i: usize = 123;
270 i.encode(&mut buf).unwrap();
271 assert_eq!(buf.to_vec(), vec![0x40, 0x7b]); let decoded = usize::decode(&mut buf).unwrap();
273 assert_eq!(decoded, i);
274 }
275
276 #[test]
277 fn encode_usize_overflow() {
278 let i: u64 = 4611686018427387904;
279 if i < usize::MAX as u64 {
281 let i = i as usize;
282 let mut buf = BytesMut::new();
283 let encoded = i.encode(&mut buf);
284 assert!(matches!(
285 encoded.unwrap_err(),
286 EncodeError::BoundsExceeded(_)
287 ));
288 }
289 }
290
291 #[test]
292 fn encode_decode_u64() {
293 let mut buf = BytesMut::new();
294
295 let i: u64 = 123;
296 i.encode(&mut buf).unwrap();
297 assert_eq!(buf.to_vec(), vec![0x40, 0x7b]); let decoded = u64::decode(&mut buf).unwrap();
299 assert_eq!(decoded, i);
300 }
301
302 #[test]
303 fn encode_u64_overflow() {
304 let mut buf = BytesMut::new();
305
306 let i: u64 = 4611686018427387904;
307 let encoded = i.encode(&mut buf);
308 assert!(matches!(
309 encoded.unwrap_err(),
310 EncodeError::BoundsExceeded(_)
311 ));
312 }
313
314 #[test]
315 fn encode_decode_varint() {
316 let mut buf = BytesMut::new();
317
318 let i = 0;
320 let vi = VarInt(i);
321 vi.encode(&mut buf).unwrap();
322 assert_eq!(buf.to_vec(), vec![0b0000_0000]); let decoded = VarInt::decode(&mut buf).unwrap();
324 assert_eq!(decoded, vi);
325 assert_eq!(u64::from(decoded), i);
326
327 let i = 63;
329 let vi = VarInt(i);
330 vi.encode(&mut buf).unwrap();
331 assert_eq!(buf.to_vec(), vec![0b0011_1111]); let decoded = VarInt::decode(&mut buf).unwrap();
333 assert_eq!(decoded, vi);
334 assert_eq!(u64::from(decoded), i);
335
336 let i = 64;
338 let vi = VarInt(i);
339 vi.encode(&mut buf).unwrap();
340 assert_eq!(buf.to_vec(), vec![0b0100_0000, 0b0100_0000]); let decoded = VarInt::decode(&mut buf).unwrap();
342 assert_eq!(decoded, vi);
343 assert_eq!(u64::from(decoded), i);
344
345 let i = 16383;
347 let vi = VarInt(i);
348 vi.encode(&mut buf).unwrap();
349 assert_eq!(buf.to_vec(), vec![0b0111_1111, 0xff]); let decoded = VarInt::decode(&mut buf).unwrap();
351 assert_eq!(decoded, vi);
352 assert_eq!(u64::from(decoded), i);
353
354 let i = 16384;
356 let vi = VarInt(i);
357 vi.encode(&mut buf).unwrap();
358 assert_eq!(buf.to_vec(), vec![0b1000_0000, 0x00, 0x40, 0x00]); let decoded = VarInt::decode(&mut buf).unwrap();
360 assert_eq!(decoded, vi);
361 assert_eq!(u64::from(decoded), i);
362
363 let i = 1073741823;
365 let vi = VarInt(i);
366 vi.encode(&mut buf).unwrap();
367 assert_eq!(buf.to_vec(), vec![0b1011_1111, 0xff, 0xff, 0xff]); let decoded = VarInt::decode(&mut buf).unwrap();
369 assert_eq!(decoded, vi);
370 assert_eq!(u64::from(decoded), i);
371
372 let i = 1073741824;
374 let vi = VarInt(i);
375 vi.encode(&mut buf).unwrap();
376 assert_eq!(
377 buf.to_vec(),
378 vec![0b1100_0000, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00, 0x00]
380 );
381 let decoded = VarInt::decode(&mut buf).unwrap();
382 assert_eq!(decoded, vi);
383 assert_eq!(u64::from(decoded), i);
384
385 let i = 4611686018427387903;
387 let vi = VarInt(i);
388 vi.encode(&mut buf).unwrap();
389 assert_eq!(
390 buf.to_vec(),
391 vec![0b1111_1111, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff]
393 );
394 let decoded = VarInt::decode(&mut buf).unwrap();
395 assert_eq!(decoded, vi);
396 assert_eq!(u64::from(decoded), i);
397 }
398
399 #[test]
400 fn overflow() {
401 let mut buf = BytesMut::new();
402
403 let i = 4611686018427387904;
404 let vi = VarInt(i);
405 let decoded = vi.encode(&mut buf);
406 assert!(matches!(
407 decoded.unwrap_err(),
408 EncodeError::BoundsExceeded(_)
409 ));
410 }
411}