1use std::fmt;
2
3use bytes::{Buf, BufMut};
4
5use crate::proto::coding::{self, BufExt, BufMutExt};
6
7#[derive(Debug, PartialEq)]
8pub enum Error {
9 Overflow,
10 UnexpectedEnd,
11}
12
13impl std::fmt::Display for Error {
14 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> fmt::Result {
15 match self {
16 Error::Overflow => write!(f, "value overflow"),
17 Error::UnexpectedEnd => write!(f, "unexpected end"),
18 }
19 }
20}
21
22pub fn decode<B: Buf>(size: u8, buf: &mut B) -> Result<(u8, usize), Error> {
23 assert!(size <= 8);
24 let mut first = buf.get::<u8>()?;
25
26 let flags = ((first as usize) >> size) as u8;
29 let mask = 0xFF >> (8 - size);
30 first &= mask;
31
32 if first < mask {
34 return Ok((flags, first as usize));
35 }
36
37 let mut value = mask as usize;
38 let mut power = 0usize;
39 loop {
40 let byte = buf.get::<u8>()? as usize;
41 value += (byte & 127) << power;
42 power += 7;
43
44 if byte & 128 == 0 {
45 break;
46 }
47
48 if power >= MAX_POWER {
49 return Err(Error::Overflow);
50 }
51 }
52
53 Ok((flags, value))
54}
55
56pub fn encode<B: BufMut>(size: u8, flags: u8, value: usize, buf: &mut B) {
57 assert!(size <= 8);
58 let mask = !(0xFF << size) as u8;
61 let flags = ((flags as usize) << size) as u8;
62
63 if value < (mask as usize) {
65 buf.write(flags | value as u8);
66 return;
67 }
68
69 buf.write(mask | flags);
70 let mut remaining = value - mask as usize;
71
72 while remaining >= 128 {
73 let rest = (remaining % 128) as u8;
74 buf.write(rest + 128);
75 remaining /= 128;
76 }
77 buf.write(remaining as u8);
78}
79
80#[cfg(target_pointer_width = "64")]
81const MAX_POWER: usize = 10 * 7;
82
83#[cfg(target_pointer_width = "32")]
84const MAX_POWER: usize = 5 * 7;
85
86impl From<coding::UnexpectedEnd> for Error {
87 fn from(_: coding::UnexpectedEnd) -> Self {
88 Error::UnexpectedEnd
89 }
90}
91
92#[cfg(test)]
93mod test {
94 use std::io::Cursor;
95
96 fn check_codec(size: u8, flags: u8, value: usize, data: &[u8]) {
97 let mut buf = Vec::new();
98 super::encode(size, flags, value, &mut buf);
99 assert_eq!(buf, data);
100 let mut read = Cursor::new(&buf);
101 assert_eq!((flags, value), super::decode(size, &mut read).unwrap());
102 }
103
104 #[test]
105 fn codec_5_bits() {
106 check_codec(5, 0b101, 10, &[0b1010_1010]);
107 check_codec(5, 0b101, 0, &[0b1010_0000]);
108 check_codec(5, 0b010, 1337, &[0b0101_1111, 154, 10]);
109 check_codec(5, 0b010, 31, &[0b0101_1111, 0]);
110 check_codec(
111 5,
112 0b010,
113 usize::max_value(),
114 &[95, 224, 255, 255, 255, 255, 255, 255, 255, 255, 1],
115 );
116 }
117
118 #[test]
119 fn codec_8_bits() {
120 check_codec(8, 0, 42, &[0b0010_1010]);
121 check_codec(8, 0, 424_242, &[255, 179, 240, 25]);
122 check_codec(
123 8,
124 0,
125 usize::max_value(),
126 &[255, 128, 254, 255, 255, 255, 255, 255, 255, 255, 1],
127 );
128 }
129
130 #[test]
131 #[should_panic]
132 fn size_too_big_value() {
133 let mut buf = vec![];
134 super::encode(9, 1, 1, &mut buf);
135 }
136
137 #[test]
138 #[should_panic]
139 fn size_too_big_of_size() {
140 let buf = vec![];
141 let mut read = Cursor::new(&buf);
142 super::decode(9, &mut read).unwrap();
143 }
144
145 #[cfg(target_pointer_width = "64")]
146 #[test]
147 fn overflow() {
148 let buf = vec![255, 128, 254, 255, 255, 255, 255, 255, 255, 255, 255, 1];
149 let mut read = Cursor::new(&buf);
150 assert!(super::decode(8, &mut read).is_err());
151 }
152
153 #[test]
154 fn number_never_ends_with_0x80() {
155 check_codec(4, 0b0001, 143, &[31, 128, 1]);
156 }
157}