1use std::{error::Error, fmt, marker::PhantomData, str, sync::Arc};
2
3#[derive(Clone, Debug)]
4pub(crate) struct Decoder<'a> {
5 bytes: &'a [u8],
6 position: usize,
7}
8
9impl<'a> Decoder<'a> {
10 pub(crate) fn new(bytes: &'a [u8]) -> Self {
11 Self { bytes, position: 0 }
12 }
13
14 #[inline]
15 pub(crate) fn is_at_end(&self) -> bool {
16 self.position == self.bytes.len()
17 }
18
19 #[inline]
20 pub(crate) fn read_byte(&mut self) -> Result<u8, DecodeError> {
21 let byte = self
22 .bytes
23 .get(self.position)
24 .copied()
25 .ok_or_else(|| DecodeError::new("unexpected end"))?;
26 self.position += 1;
27 Ok(byte)
28 }
29
30 #[inline]
31 pub(crate) fn read_bytes(&mut self, count: usize) -> Result<&'a [u8], DecodeError> {
32 let bytes = self
33 .bytes
34 .get(self.position..self.position + count)
35 .ok_or_else(|| DecodeError::new("unexpected end"))?;
36 self.position += count;
37 Ok(bytes)
38 }
39
40 #[inline]
41 pub(crate) fn read_bytes_until_end(&mut self) -> &'a [u8] {
42 let bytes = &self.bytes[self.position..];
43 self.position = self.bytes.len();
44 bytes
45 }
46
47 #[inline]
48 pub(crate) fn decode_bytes(&mut self) -> Result<&'a [u8], DecodeError> {
49 let len: u32 = self.decode()?;
50 Ok(self.read_bytes(len as usize)?)
51 }
52
53 #[inline]
54 pub(crate) fn decode_string(&mut self) -> Result<&'a str, DecodeError> {
55 let len: u32 = self.decode()?;
56 Ok(str::from_utf8(self.read_bytes(len as usize)?)
57 .map_err(|_| DecodeError::new("malformed string"))?)
58 }
59
60 pub(crate) fn decode_iter<T>(&mut self) -> Result<DecodeIter<'_, 'a, T>, DecodeError>
61 where
62 T: Decode,
63 {
64 let count: u32 = self.decode()?;
65 Ok(DecodeIter {
66 decoder: self,
67 count: count as usize,
68 phantom: PhantomData,
69 })
70 }
71
72 pub(crate) fn decode_decoder(&mut self) -> Result<Decoder<'a>, DecodeError> {
73 let count: u32 = self.decode()?;
74 Ok(Decoder::new(self.read_bytes(count as usize)?))
75 }
76
77 pub(crate) fn decode<T>(&mut self) -> Result<T, DecodeError>
78 where
79 T: Decode,
80 {
81 T::decode(self)
82 }
83}
84
85#[derive(Debug)]
86pub(crate) struct DecodeIter<'a, 'b, T> {
87 decoder: &'a mut Decoder<'b>,
88 count: usize,
89 phantom: PhantomData<T>,
90}
91
92impl<'a, 'b, T> Iterator for DecodeIter<'a, 'b, T>
93where
94 T: Decode,
95{
96 type Item = Result<T, DecodeError>;
97
98 #[inline]
99 fn next(&mut self) -> Option<Self::Item> {
100 if self.count == 0 {
101 return None;
102 }
103 self.count -= 1;
104 Some(self.decoder.decode())
105 }
106
107 fn size_hint(&self) -> (usize, Option<usize>) {
108 (self.count, Some(self.count))
109 }
110}
111
112pub(crate) trait Decode: Sized {
113 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError>;
114}
115
116impl Decode for i32 {
117 #[inline]
118 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError> {
119 fn decode_i32_tail(decoder: &mut Decoder<'_>, mut val: i32) -> Result<i32, DecodeError> {
120 let mut shift = 7;
121 loop {
122 let byte = decoder.read_byte()?;
123 if shift >= 25 {
124 let bits = (byte << 1) as i8 >> (32 - shift);
125 if byte & 0x80 != 0 || bits != 0 && bits != -1 {
126 return Err(DecodeError::new("malformed i32"));
127 }
128 }
129 val |= ((byte & 0x7F) as i32) << shift;
130 if byte & 0x80 == 0 {
131 break;
132 }
133 shift += 7;
134 }
135 let shift = 25 - shift.min(25);
136 Ok(val << shift >> shift)
137 }
138
139 let byte = decoder.read_byte()?;
140 let val = (byte & 0x7F) as i32;
141 if byte & 0x80 == 0 {
142 Ok(val << 25 >> 25)
143 } else {
144 decode_i32_tail(decoder, val)
145 }
146 }
147}
148
149impl Decode for u32 {
150 #[inline]
151 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError> {
152 fn decode_u32_tail(decoder: &mut Decoder<'_>, mut val: u32) -> Result<u32, DecodeError> {
153 let mut shift = 7;
154 loop {
155 let byte = decoder.read_byte()?;
156 if shift >= 25 && byte >> 32 - shift != 0 {
157 return Err(DecodeError::new("malformed u32"));
158 }
159 val |= ((byte & 0x7F) as u32) << shift;
160 if byte & 0x80 == 0 {
161 break;
162 }
163 shift += 7;
164 }
165 Ok(val)
166 }
167
168 let byte = decoder.read_byte()?;
169 let val = (byte & 0x7F) as u32;
170 if byte & 0x80 == 0 {
171 Ok(val)
172 } else {
173 decode_u32_tail(decoder, val)
174 }
175 }
176}
177
178impl Decode for i64 {
179 #[inline]
180 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError> {
181 fn decode_i64_tail(decoder: &mut Decoder<'_>, mut val: i64) -> Result<i64, DecodeError> {
182 let mut shift = 7;
183 loop {
184 let byte = decoder.read_byte()?;
185 if shift >= 57 {
186 let bits = (byte << 1) as i8 >> (64 - shift);
187 if byte & 0x80 != 0 || bits != 0 && bits != -1 {
188 return Err(DecodeError::new("malformed i64"));
189 }
190 }
191 val |= ((byte & 0x7F) as i64) << shift;
192 if byte & 0x80 == 0 {
193 break;
194 }
195 shift += 7;
196 }
197 let shift = 57 - shift.min(57);
198 Ok(val << shift >> shift)
199 }
200
201 let byte = decoder.read_byte()?;
202 let val = (byte & 0x7F) as i64;
203 if byte & 0x80 == 0 {
204 Ok(val << 57 >> 57)
205 } else {
206 decode_i64_tail(decoder, val)
207 }
208 }
209}
210
211impl Decode for usize {
212 #[inline]
213 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError> {
214 Ok(usize::try_from(decoder.decode::<u32>()?).unwrap())
215 }
216}
217
218impl Decode for f32 {
219 #[inline]
220 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError> {
221 Ok(Self::from_le_bytes(
222 decoder.read_bytes(4)?.try_into().unwrap(),
223 ))
224 }
225}
226
227impl Decode for f64 {
228 #[inline]
229 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError> {
230 Ok(Self::from_le_bytes(
231 decoder.read_bytes(8)?.try_into().unwrap(),
232 ))
233 }
234}
235
236impl Decode for Arc<[u8]> {
237 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError> {
238 Ok(decoder.decode_bytes()?.into())
239 }
240}
241
242impl Decode for Arc<str> {
243 fn decode(decoder: &mut Decoder<'_>) -> Result<Self, DecodeError> {
244 Ok(decoder.decode_string()?.into())
245 }
246}
247
248#[derive(Clone, Debug)]
250pub struct DecodeError {
251 message: Box<str>,
252}
253
254impl DecodeError {
255 pub fn new(message: impl Into<Box<str>>) -> Self {
257 Self {
258 message: message.into(),
259 }
260 }
261}
262
263impl fmt::Display for DecodeError {
264 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
265 self.message.fmt(f)
266 }
267}
268
269impl Error for DecodeError {}