Skip to main content

ntex_grpc/
encoding.rs

1/// protobuf encoding utils
2/// cloned from `<https://github.com/hyperium/tonic/>`
3use std::{borrow::Cow, cmp::min, convert::TryFrom, fmt, rc::Rc};
4
5use ntex_bytes::{Buf, BufMut, Bytes, BytesMut};
6
7pub const MIN_TAG: u32 = 1;
8pub const MAX_TAG: u32 = (1 << 29) - 1;
9
10#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd, Hash)]
11#[repr(u8)]
12pub enum WireType {
13    Varint = 0,
14    SixtyFourBit = 1,
15    LengthDelimited = 2,
16    StartGroup = 3,
17    EndGroup = 4,
18    ThirtyTwoBit = 5,
19}
20
21impl TryFrom<u64> for WireType {
22    type Error = DecodeError;
23
24    #[inline]
25    fn try_from(value: u64) -> Result<Self, Self::Error> {
26        match value {
27            0 => Ok(WireType::Varint),
28            1 => Ok(WireType::SixtyFourBit),
29            2 => Ok(WireType::LengthDelimited),
30            3 => Ok(WireType::StartGroup),
31            4 => Ok(WireType::EndGroup),
32            5 => Ok(WireType::ThirtyTwoBit),
33            _ => Err(DecodeError::new(format!(
34                "invalid wire type value: {value}"
35            ))),
36        }
37    }
38}
39
40/// Returns the encoded length of the value in LEB128 variable length format.
41/// The returned value will be between 1 and 10, inclusive.
42#[inline]
43pub fn encoded_len_varint(value: u64) -> usize {
44    // Based on [VarintSize64][1].
45    // [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.h#L1301-L1309
46    ((((value | 1).leading_zeros() ^ 63) * 9 + 73) / 64) as usize
47}
48
49/// Encodes an integer value into LEB128 variable length format, and writes it to the buffer.
50/// The buffer must have enough remaining space (maximum 10 bytes).
51#[inline]
52pub fn encode_varint(mut value: u64, buf: &mut BytesMut) {
53    loop {
54        if value < 0x80 {
55            buf.put_u8(value as u8);
56            break;
57        }
58        buf.put_u8(((value & 0x7F) | 0x80) as u8);
59        value >>= 7;
60    }
61}
62
63/// Decodes a LEB128-encoded variable length integer from the buffer.
64#[inline]
65pub fn decode_varint(buf: &mut Bytes) -> Result<u64, DecodeError> {
66    let bytes = buf.chunk();
67    let len = bytes.len();
68    if len == 0 {
69        return Err(DecodeError::new("invalid varint"));
70    }
71
72    let byte = bytes[0];
73    if byte < 0x80 {
74        buf.advance(1);
75        Ok(u64::from(byte))
76    } else if len > 10 || bytes[len - 1] < 0x80 {
77        let (value, advance) = decode_varint_slice(bytes)?;
78        buf.advance(advance);
79        Ok(value)
80    } else {
81        decode_varint_slow(buf)
82    }
83}
84
85/// Decodes a LEB128-encoded variable length integer from the slice, returning the value and the
86/// number of bytes read.
87///
88/// Based loosely on [`ReadVarint64FromArray`][1] with a varint overflow check from
89/// [`ConsumeVarint`][2].
90///
91/// ## Safety
92///
93/// The caller must ensure that `bytes` is non-empty and either `bytes.len() >= 10` or the last
94/// element in bytes is < `0x80`.
95///
96/// [1]: https://github.com/google/protobuf/blob/3.3.x/src/google/protobuf/io/coded_stream.cc#L365-L406
97/// [2]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
98#[inline]
99fn decode_varint_slice(bytes: &[u8]) -> Result<(u64, usize), DecodeError> {
100    // Fully unrolled varint decoding loop. Splitting into 32-bit pieces gives better performance.
101
102    // Use assertions to ensure memory safety, but it should always be optimized after inline.
103    assert!(!bytes.is_empty());
104    assert!(bytes.len() > 10 || bytes[bytes.len() - 1] < 0x80);
105
106    let mut b: u8 = unsafe { *bytes.get_unchecked(0) };
107    let mut part0: u32 = u32::from(b);
108    if b < 0x80 {
109        return Ok((u64::from(part0), 1));
110    }
111    part0 -= 0x80;
112    b = unsafe { *bytes.get_unchecked(1) };
113    part0 += u32::from(b) << 7;
114    if b < 0x80 {
115        return Ok((u64::from(part0), 2));
116    }
117    part0 -= 0x80 << 7;
118    b = unsafe { *bytes.get_unchecked(2) };
119    part0 += u32::from(b) << 14;
120    if b < 0x80 {
121        return Ok((u64::from(part0), 3));
122    }
123    part0 -= 0x80 << 14;
124    b = unsafe { *bytes.get_unchecked(3) };
125    part0 += u32::from(b) << 21;
126    if b < 0x80 {
127        return Ok((u64::from(part0), 4));
128    }
129    part0 -= 0x80 << 21;
130    let value = u64::from(part0);
131
132    b = unsafe { *bytes.get_unchecked(4) };
133    let mut part1: u32 = u32::from(b);
134    if b < 0x80 {
135        return Ok((value + (u64::from(part1) << 28), 5));
136    }
137    part1 -= 0x80;
138    b = unsafe { *bytes.get_unchecked(5) };
139    part1 += u32::from(b) << 7;
140    if b < 0x80 {
141        return Ok((value + (u64::from(part1) << 28), 6));
142    }
143    part1 -= 0x80 << 7;
144    b = unsafe { *bytes.get_unchecked(6) };
145    part1 += u32::from(b) << 14;
146    if b < 0x80 {
147        return Ok((value + (u64::from(part1) << 28), 7));
148    }
149    part1 -= 0x80 << 14;
150    b = unsafe { *bytes.get_unchecked(7) };
151    part1 += u32::from(b) << 21;
152    if b < 0x80 {
153        return Ok((value + (u64::from(part1) << 28), 8));
154    }
155    part1 -= 0x80 << 21;
156    let value = value + ((u64::from(part1)) << 28);
157
158    b = unsafe { *bytes.get_unchecked(8) };
159    let mut part2: u32 = u32::from(b);
160    if b < 0x80 {
161        return Ok((value + (u64::from(part2) << 56), 9));
162    }
163    part2 -= 0x80;
164    b = unsafe { *bytes.get_unchecked(9) };
165    part2 += u32::from(b) << 7;
166    // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
167    // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
168    if b < 0x02 {
169        return Ok((value + (u64::from(part2) << 56), 10));
170    }
171
172    // We have overrun the maximum size of a varint (10 bytes) or the final byte caused an overflow.
173    // Assume the data is corrupt.
174    Err(DecodeError::new("invalid varint"))
175}
176
177/// Decodes a LEB128-encoded variable length integer from the buffer, advancing the buffer as
178/// necessary.
179///
180/// Contains a varint overflow check from [`ConsumeVarint`][1].
181///
182/// [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
183#[inline(never)]
184#[cold]
185fn decode_varint_slow<B>(buf: &mut B) -> Result<u64, DecodeError>
186where
187    B: Buf,
188{
189    let mut value = 0;
190    for count in 0..min(10, buf.remaining()) {
191        let byte = buf.get_u8();
192        value |= u64::from(byte & 0x7F) << (count * 7);
193        if byte <= 0x7F {
194            // Check for u64::MAX overflow. See [`ConsumeVarint`][1] for details.
195            // [1]: https://github.com/protocolbuffers/protobuf-go/blob/v1.27.1/encoding/protowire/wire.go#L358
196            return if count == 9 && byte >= 0x02 {
197                Err(DecodeError::new("invalid varint"))
198            } else {
199                Ok(value)
200            };
201        }
202    }
203
204    Err(DecodeError::new("invalid varint"))
205}
206
207/// Encodes a Protobuf field key, which consists of a wire type designator and
208/// the field tag.
209#[inline]
210pub fn encode_key(tag: u32, wire_type: WireType, buf: &mut BytesMut) {
211    debug_assert!((MIN_TAG..=MAX_TAG).contains(&tag));
212    let key = (tag << 3) | wire_type as u32;
213    encode_varint(u64::from(key), buf);
214}
215
216/// Decodes a Protobuf field key, which consists of a wire type designator and
217/// the field tag.
218#[inline]
219pub fn decode_key(buf: &mut Bytes) -> Result<(u32, WireType), DecodeError> {
220    let key = decode_varint(buf)?;
221    if key > u64::from(u32::MAX) {
222        return Err(DecodeError::new(format!("invalid key value: {key}")));
223    }
224    let wire_type = WireType::try_from(key & 0x07)?;
225    let tag = key as u32 >> 3;
226
227    if tag < MIN_TAG {
228        return Err(DecodeError::new("invalid tag value: 0"));
229    }
230
231    Ok((tag, wire_type))
232}
233
234/// Returns the width of an encoded Protobuf field key with the given tag.
235/// The returned width will be between 1 and 5 bytes (inclusive).
236#[inline]
237pub fn key_len(tag: u32) -> usize {
238    encoded_len_varint(u64::from(tag << 3))
239}
240
241/// Checks that the expected wire type matches the actual wire type,
242/// or returns an error result.
243#[inline]
244pub fn check_wire_type(expected: WireType, actual: WireType) -> Result<(), DecodeError> {
245    if expected != actual {
246        return Err(DecodeError::new(format!(
247            "invalid wire type: {actual:?} (expected {expected:?})",
248        )));
249    }
250    Ok(())
251}
252
253pub fn skip_field(wire_type: WireType, tag: u32, buf: &mut Bytes) -> Result<(), DecodeError> {
254    let len = match wire_type {
255        WireType::Varint => decode_varint(buf).map(|_| 0)?,
256        WireType::ThirtyTwoBit => 4,
257        WireType::SixtyFourBit => 8,
258        WireType::LengthDelimited => decode_varint(buf)?,
259        WireType::StartGroup => loop {
260            let (inner_tag, inner_wire_type) = decode_key(buf)?;
261            match inner_wire_type {
262                WireType::EndGroup => {
263                    if inner_tag != tag {
264                        return Err(DecodeError::new("unexpected end group tag"));
265                    }
266                    break 0;
267                }
268                _ => skip_field(inner_wire_type, inner_tag, buf)?,
269            }
270        },
271        WireType::EndGroup => return Err(DecodeError::new("unexpected end group tag")),
272    };
273
274    buf.split_to_checked(len as usize)
275        .ok_or_else(DecodeError::incomplete)?;
276    Ok(())
277}
278
279/// A Protobuf message decoding error.
280#[derive(Clone, PartialEq, Eq)]
281pub struct DecodeError {
282    inner: Rc<Inner>,
283}
284
285#[derive(Clone, PartialEq, Eq)]
286struct Inner {
287    /// A 'best effort' root cause description.
288    description: Cow<'static, str>,
289    /// A stack of (message, field) name pairs, which identify the specific
290    /// message type and field where decoding failed. The stack contains an
291    /// entry per level of nesting.
292    stack: Vec<(&'static str, &'static str)>,
293}
294
295impl DecodeError {
296    /// Creates a new `DecodeError` with a 'best effort' root cause description.
297    ///
298    /// Meant to be used only by `Message` implementations.
299    #[doc(hidden)]
300    #[cold]
301    pub fn new(description: impl Into<Cow<'static, str>>) -> DecodeError {
302        DecodeError {
303            inner: Rc::new(Inner {
304                description: description.into(),
305                stack: Vec::new(),
306            }),
307        }
308    }
309
310    /// Pushes a (message, field) name location pair on to the location stack.
311    ///
312    /// Meant to be used only by `Message` implementations.
313    #[doc(hidden)]
314    #[must_use]
315    pub fn push(mut self, message: &'static str, field: &'static str) -> Self {
316        let inner = if let Some(inner) = Rc::get_mut(&mut self.inner) {
317            inner
318        } else {
319            self.inner = Rc::new(Inner {
320                description: self.inner.description.clone(),
321                stack: self.inner.stack.clone(),
322            });
323            Rc::get_mut(&mut self.inner).unwrap()
324        };
325        inner.stack.push((message, field));
326        self
327    }
328
329    pub(crate) fn incomplete() -> Self {
330        Self::new("Not enough data")
331    }
332}
333
334impl fmt::Debug for DecodeError {
335    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
336        f.debug_struct("DecodeError")
337            .field("description", &self.inner.description)
338            .field("stack", &self.inner.stack)
339            .finish()
340    }
341}
342
343impl fmt::Display for DecodeError {
344    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
345        f.write_str("failed to decode Protobuf message: ")?;
346        for &(message, field) in &self.inner.stack {
347            write!(f, "{message}.{field}: ")?;
348        }
349        f.write_str(&self.inner.description)
350    }
351}
352
353impl std::error::Error for DecodeError {}