otopr/
decoding.rs

1use std::{borrow::Cow, str::Utf8Error};
2
3use bytes::Buf;
4
5use crate::{wire_types::*, Message, VarInt};
6
7pub trait Decodable<'de>: Sized {
8    type Wire: WireType;
9
10    fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self>;
11
12    fn merge_from<B: Buf>(&mut self, deserializer: &mut Deserializer<'de, B>) -> Result<()> {
13        self.merge(Self::decode(deserializer)?);
14        Ok(())
15    }
16
17    /// If this is a `message`, call `merge()` on all fields,
18    /// if this is `repeated`, extend this with the elements of `other`.
19    /// for all other types simply overwrite this with `other`, which is the default.
20    fn merge(&mut self, other: Self) {
21        *self = other;
22    }
23}
24
25pub trait DecodableMessage<'de>: Sized {
26    /// How big the tag message gets. This is an unsigned varint.
27    ///
28    /// It is not an error if any field tag overflows this type,
29    /// since there can be removed fields exceeding the current storage type.
30    type Tag: VarInt;
31
32    /// Decodes a field with the given tag.
33    ///
34    /// Skips the field if there are no matches for the tag.
35    fn decode_field<B: Buf>(
36        &mut self,
37        deserializer: &mut Deserializer<'de, B>,
38        tag: Self::Tag,
39    ) -> Result<()>;
40
41    fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self>
42    where
43        Self: Default,
44    {
45        let mut message = Self::default();
46        loop {
47            if !deserializer.has_remaining() {
48                break;
49            }
50            match Self::Tag::read_field_tag(deserializer) {
51                Ok(tag) => message.decode_field(deserializer, tag)?,
52                Err(Ok(wire)) => wire.skip(deserializer)?,
53                Err(Err(e)) => return Err(e),
54            }
55        }
56        Ok(message)
57    }
58}
59
60#[derive(Debug)]
61pub enum DecodingError {
62    Eof,
63    VarIntOverflow,
64    Utf8Error(Utf8Error),
65    UnknownWireType(u8),
66}
67
68impl From<Utf8Error> for DecodingError {
69    fn from(e: Utf8Error) -> Self {
70        Self::Utf8Error(e)
71    }
72}
73
74pub type Result<T, E = DecodingError> = std::result::Result<T, E>;
75
76#[must_use]
77pub struct LimitToken {
78    prev_limit: usize,
79    set_to: usize,
80}
81
82pub struct Deserializer<'de, B> {
83    pub(crate) buf: &'de mut B,
84    limit: usize,
85}
86
87impl<'de, B: Buf> Deserializer<'de, B> {
88    pub fn new(buf: &'de mut B) -> Self {
89        Self {
90            buf,
91            limit: usize::MAX,
92        }
93    }
94
95    pub fn set_limit(&mut self, limit: usize) -> LimitToken {
96        let prev_limit = self.limit;
97        let set_to = limit.min(self.limit);
98        self.limit = set_to;
99
100        LimitToken { prev_limit, set_to }
101    }
102
103    pub fn reset_limit(&mut self, token: LimitToken) {
104        let limit_used = token.set_to - self.limit;
105        self.limit = token.prev_limit - limit_used;
106
107        // avoid panicking.
108        std::mem::forget(token);
109    }
110
111    /// get an u8 from the underlying buffer, assuming this is within limits.
112    pub fn get_u8(&mut self) -> u8 {
113        if self.limit != usize::MAX {
114            self.limit -= 1
115        }
116        self.buf.get_u8()
117    }
118
119    pub fn has_remaining(&self) -> bool {
120        self.buf.remaining() != 0 && self.limit != 0
121    }
122
123    pub fn check_limit<'a, F: FnOnce(&'a mut B) -> Result<V>, V>(
124        &'a mut self,
125        len: usize,
126        f: F,
127    ) -> Result<V> {
128        if self.limit == usize::MAX {
129            f(self.buf)
130        } else if self.limit < len {
131            Err(DecodingError::Eof)
132        } else {
133            self.limit -= len;
134            f(self.buf)
135        }
136    }
137
138    pub fn read_varint<V: VarInt>(&mut self) -> Result<V> {
139        V::read(self)
140    }
141
142    pub fn read_bytes_borrowed<'a>(&mut self, len: usize) -> Result<&'a [u8]> {
143        self.check_limit(len, |buf| {
144            let c = buf.chunk();
145            if c.len() >= len {
146                // SAFETY: already checked above
147                let c_raw = unsafe { c.get_unchecked(..len) } as *const [u8];
148                buf.advance(len);
149                Ok(unsafe { &*c_raw })
150            } else {
151                Err(DecodingError::Eof)
152            }
153        })
154    }
155
156    pub fn read_bytes<'a>(&mut self, len: usize) -> Result<Cow<'a, [u8]>> {
157        use bytes::BufMut;
158        self.check_limit(len, |buf| {
159            let c = buf.chunk();
160            if buf.remaining() < len {
161                Err(DecodingError::Eof)
162            } else if c.len() >= len {
163                // SAFETY: already checked above
164                let c_raw = unsafe { c.get_unchecked(..len) } as *const [u8];
165                buf.advance(len);
166                Ok(Cow::Borrowed(unsafe { &*c_raw }))
167            } else {
168                let mut v = Vec::with_capacity(len);
169                (&mut v).put(buf.take(len));
170                Ok(Cow::Owned(v))
171            }
172        })
173    }
174
175    pub fn read_bytes_owned(&mut self, len: usize) -> Result<Box<[u8]>> {
176        use bytes::BufMut;
177        self.check_limit(len, |buf| {
178            if buf.remaining() < len {
179                Err(DecodingError::Eof)
180            } else {
181                let mut v= Vec::with_capacity(len);
182                (&mut v).put(buf.take(len));
183                Ok(v.into_boxed_slice())
184            }
185        })
186    }
187}
188
189impl<'de, B: Buf> From<&'de mut B> for Deserializer<'de, B> {
190    fn from(b: &'de mut B) -> Self {
191        Self::new(b)
192    }
193}
194
195impl<'de> Decodable<'de> for &'de [u8] {
196    type Wire = LengthDelimitedWire;
197
198    fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self> {
199        let len = deserializer.read_varint()?;
200        deserializer.read_bytes_borrowed(len)
201    }
202}
203
204impl Decodable<'_> for Vec<u8> {
205    type Wire = LengthDelimitedWire;
206
207    fn decode<B: Buf>(deserializer: &mut Deserializer<'_, B>) -> Result<Self> {
208        let len = deserializer.read_varint()?;
209        deserializer.read_bytes_owned(len).map(From::from)
210    }
211}
212
213impl Decodable<'_> for Box<[u8]> {
214    type Wire = LengthDelimitedWire;
215
216    fn decode<B: Buf>(deserializer: &mut Deserializer<'_, B>) -> Result<Self> {
217        let len = deserializer.read_varint()?;
218        deserializer.read_bytes_owned(len)
219    }
220}
221
222impl<'de> Decodable<'de> for &'de str {
223    type Wire = LengthDelimitedWire;
224
225    fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self> {
226        let len = deserializer.read_varint()?;
227        let bytes = deserializer.read_bytes_borrowed(len)?;
228        Ok(std::str::from_utf8(bytes)?)
229    }
230}
231
232impl Decodable<'_> for String {
233    type Wire = LengthDelimitedWire;
234
235    fn decode<B: Buf>(deserializer: &mut Deserializer<'_, B>) -> Result<Self> {
236        let len = deserializer.read_varint()?;
237        let bytes = deserializer.read_bytes_owned(len)?;
238        Ok(String::from_utf8(bytes.into()).map_err(|e| e.utf8_error())?)
239    }
240}
241
242impl Decodable<'_> for Box<str> {
243    type Wire = LengthDelimitedWire;
244
245    fn decode<B: Buf>(deserializer: &mut Deserializer<'_, B>) -> Result<Self> {
246        String::decode(deserializer).map(String::into_boxed_str)
247    }
248}
249
250impl<'de, M: DecodableMessage<'de> + Default> Decodable<'de> for Message<M> {
251    type Wire = LengthDelimitedWire;
252
253    fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self> {
254        let len = deserializer.read_varint()?;
255        let tk = deserializer.set_limit(len);
256        let message = M::decode(deserializer);
257        deserializer.reset_limit(tk);
258        Ok(Message(message?))
259    }
260}