1use std::{borrow::Cow, cmp::min, convert::TryFrom, fmt, rc::Rc};
4
5use ntex_bytes::{Buf, BufMut, Bytes, BytesMut};
6
7pub const MIN_TAG: u32 = 1;
8pub const MAX_TAG: u32 = (1 << 29) - 1;
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
11#[repr(u8)]
12pub enum WireType {
13 Varint = 0,
14 SixtyFourBit = 1,
15 LengthDelimited = 2,
16 StartGroup = 3,
17 EndGroup = 4,
18 ThirtyTwoBit = 5,
19}
20
21impl TryFrom<u64> for WireType {
22 type Error = DecodeError;
23
24 #[inline]
25 fn try_from(value: u64) -> Result<Self, Self::Error> {
26 match value {
27 0 => Ok(WireType::Varint),
28 1 => Ok(WireType::SixtyFourBit),
29 2 => Ok(WireType::LengthDelimited),
30 3 => Ok(WireType::StartGroup),
31 4 => Ok(WireType::EndGroup),
32 5 => Ok(WireType::ThirtyTwoBit),
33 _ => Err(DecodeError::new(format!(
34 "invalid wire type value: {value}"
35 ))),
36 }
37 }
38}
39
40#[inline]
43pub fn encoded_len_varint(value: u64) -> usize {
44 ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
47}
48
49#[inline]
52pub fn encode_varint(mut value: u64, buf: &mut BytesMut) {
53 loop {
54 if value < 0x80 {
55 buf.put_u8(value as u8);
56 break;
57 }
58 buf.put_u8(((value & 0x7F) | 0x80) as u8);
59 value >>= 7;
60 }
61}
62
63#[inline]
65pub fn decode_varint(buf: &mut Bytes) -> Result<u64, DecodeError> {
66 let bytes = buf.chunk();
67 let len = bytes.len();
68 if len == 0 {
69 return Err(DecodeError::new("invalid varint"));
70 }
71
72 let byte = bytes[0];
73 if byte < 0x80 {
74 buf.advance(1);
75 Ok(u64::from(byte))
76 } else if len > 10 || bytes[len - 1] < 0x80 {
77 let (value, advance) = decode_varint_slice(bytes)?;
78 buf.advance(advance);
79 Ok(value)
80 } else {
81 decode_varint_slow(buf)
82 }
83}
84
85#[inline]
99fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
100 assert!(!bytes.is_empty());
104 assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
105
106 let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
107 let mut part0: u32 = u32::from(b);
108 if b < 0x80 {
109 return Ok((u64::from(part0), 1));
110 }
111 part0 -= 0x80;
112 b = unsafe { *bytes.get_unchecked(1) };
113 part0 += u32::from(b) << 7;
114 if b < 0x80 {
115 return Ok((u64::from(part0), 2));
116 }
117 part0 -= 0x80 << 7;
118 b = unsafe { *bytes.get_unchecked(2) };
119 part0 += u32::from(b) << 14;
120 if b < 0x80 {
121 return Ok((u64::from(part0), 3));
122 }
123 part0 -= 0x80 << 14;
124 b = unsafe { *bytes.get_unchecked(3) };
125 part0 += u32::from(b) << 21;
126 if b < 0x80 {
127 return Ok((u64::from(part0), 4));
128 }
129 part0 -= 0x80 << 21;
130 let value = u64::from(part0);
131
132 b = unsafe { *bytes.get_unchecked(4) };
133 let mut part1: u32 = u32::from(b);
134 if b < 0x80 {
135 return Ok((value + (u64::from(part1) << 28), 5));
136 }
137 part1 -= 0x80;
138 b = unsafe { *bytes.get_unchecked(5) };
139 part1 += u32::from(b) << 7;
140 if b < 0x80 {
141 return Ok((value + (u64::from(part1) << 28), 6));
142 }
143 part1 -= 0x80 << 7;
144 b = unsafe { *bytes.get_unchecked(6) };
145 part1 += u32::from(b) << 14;
146 if b < 0x80 {
147 return Ok((value + (u64::from(part1) << 28), 7));
148 }
149 part1 -= 0x80 << 14;
150 b = unsafe { *bytes.get_unchecked(7) };
151 part1 += u32::from(b) << 21;
152 if b < 0x80 {
153 return Ok((value + (u64::from(part1) << 28), 8));
154 }
155 part1 -= 0x80 << 21;
156 let value = value + ((u64::from(part1)) << 28);
157
158 b = unsafe { *bytes.get_unchecked(8) };
159 let mut part2: u32 = u32::from(b);
160 if b < 0x80 {
161 return Ok((value + (u64::from(part2) << 56), 9));
162 }
163 part2 -= 0x80;
164 b = unsafe { *bytes.get_unchecked(9) };
165 part2 += u32::from(b) << 7;
166 if b < 0x02 {
169 return Ok((value + (u64::from(part2) << 56), 10));
170 }
171
172 Err(DecodeError::new("invalid varint"))
175}
176
177#[inline(never)]
184#[cold]
185fn decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError>
186where
187 B: Buf,
188{
189 let mut value = 0;
190 for count in 0..min(10, buf.remaining()) {
191 let byte = buf.get_u8();
192 value |= u64::from(byte & 0x7F) << (count * 7);
193 if byte <= 0x7F {
194 return if count == 9 && byte >= 0x02 {
197 Err(DecodeError::new("invalid varint"))
198 } else {
199 Ok(value)
200 };
201 }
202 }
203
204 Err(DecodeError::new("invalid varint"))
205}
206
207#[inline]
210pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut BytesMut) {
211 debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
212 let key = (tag << 3) | wire_type as u32;
213 encode_varint(u64::from(key), buf);
214}
215
216#[inline]
219pub fn decode_key(buf: &mut Bytes) -> Result<(u32, WireType), DecodeError> {
220 let key = decode_varint(buf)?;
221 if key > u64::from(u32::MAX) {
222 return Err(DecodeError::new(format!("invalid key value: {key}")));
223 }
224 let wire_type = WireType::try_from(key & 0x07)?;
225 let tag = key as u32 >> 3;
226
227 if tag < MIN_TAG {
228 return Err(DecodeError::new("invalid tag value: 0"));
229 }
230
231 Ok((tag, wire_type))
232}
233
234#[inline]
237pub fn key_len(tag: u32) -> usize {
238 encoded_len_varint(u64::from(tag << 3))
239}
240
241#[inline]
244pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
245 if expected != actual {
246 return Err(DecodeError::new(format!(
247 "invalid wire type: {actual:?} (expected {expected:?})",
248 )));
249 }
250 Ok(())
251}
252
253pub fn skip_field(wire_type: WireType, tag: u32, buf: &mut Bytes) -> Result<(), DecodeError> {
254 let len = match wire_type {
255 WireType::Varint => decode_varint(buf).map(|_| 0)?,
256 WireType::ThirtyTwoBit => 4,
257 WireType::SixtyFourBit => 8,
258 WireType::LengthDelimited => decode_varint(buf)?,
259 WireType::StartGroup => loop {
260 let (inner_tag, inner_wire_type) = decode_key(buf)?;
261 match inner_wire_type {
262 WireType::EndGroup => {
263 if inner_tag != tag {
264 return Err(DecodeError::new("unexpected end group tag"));
265 }
266 break 0;
267 }
268 _ => skip_field(inner_wire_type, inner_tag, buf)?,
269 }
270 },
271 WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
272 };
273
274 buf.split_to_checked(len as usize)
275 .ok_or_else(DecodeError::incomplete)?;
276 Ok(())
277}
278
279#[derive(Clone, PartialEq, Eq)]
281pub struct DecodeError {
282 inner: Rc<Inner>,
283}
284
285#[derive(Clone, PartialEq, Eq)]
286struct Inner {
287 description: Cow<'static, str>,
289 stack: Vec<(&'static str, &'static str)>,
293}
294
295impl DecodeError {
296 #[doc(hidden)]
300 #[cold]
301 pub fn new(description: impl Into<Cow<'static, str>>) -> DecodeError {
302 DecodeError {
303 inner: Rc::new(Inner {
304 description: description.into(),
305 stack: Vec::new(),
306 }),
307 }
308 }
309
310 #[doc(hidden)]
314 #[must_use]
315 pub fn push(mut self, message: &'static str, field: &'static str) -> Self {
316 let inner = if let Some(inner) = Rc::get_mut(&mut self.inner) {
317 inner
318 } else {
319 self.inner = Rc::new(Inner {
320 description: self.inner.description.clone(),
321 stack: self.inner.stack.clone(),
322 });
323 Rc::get_mut(&mut self.inner).unwrap()
324 };
325 inner.stack.push((message, field));
326 self
327 }
328
329 pub(crate) fn incomplete() -> Self {
330 Self::new("Not enough data")
331 }
332}
333
334impl fmt::Debug for DecodeError {
335 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336 f.debug_struct("DecodeError")
337 .field("description", &self.inner.description)
338 .field("stack", &self.inner.stack)
339 .finish()
340 }
341}
342
343impl fmt::Display for DecodeError {
344 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345 f.write_str("failed to decode Protobuf message: ")?;
346 for &(message, field) in &self.inner.stack {
347 write!(f, "{message}.{field}: ")?;
348 }
349 f.write_str(&self.inner.description)
350 }
351}
352
353impl std::error::Error for DecodeError {}