1use std::{borrow::Cow, str::Utf8Error};
2
3use bytes::Buf;
4
5use crate::{wire_types::*, Message, VarInt};
6
7pub trait Decodable<'de>: Sized {
8 type Wire: WireType;
9
10 fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self>;
11
12 fn merge_from<B: Buf>(&mut self, deserializer: &mut Deserializer<'de, B>) -> Result<()> {
13 self.merge(Self::decode(deserializer)?);
14 Ok(())
15 }
16
17 fn merge(&mut self, other: Self) {
21 *self = other;
22 }
23}
24
25pub trait DecodableMessage<'de>: Sized {
26 type Tag: VarInt;
31
32 fn decode_field<B: Buf>(
36 &mut self,
37 deserializer: &mut Deserializer<'de, B>,
38 tag: Self::Tag,
39 ) -> Result<()>;
40
41 fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self>
42 where
43 Self: Default,
44 {
45 let mut message = Self::default();
46 loop {
47 if !deserializer.has_remaining() {
48 break;
49 }
50 match Self::Tag::read_field_tag(deserializer) {
51 Ok(tag) => message.decode_field(deserializer, tag)?,
52 Err(Ok(wire)) => wire.skip(deserializer)?,
53 Err(Err(e)) => return Err(e),
54 }
55 }
56 Ok(message)
57 }
58}
59
60#[derive(Debug)]
61pub enum DecodingError {
62 Eof,
63 VarIntOverflow,
64 Utf8Error(Utf8Error),
65 UnknownWireType(u8),
66}
67
68impl From<Utf8Error> for DecodingError {
69 fn from(e: Utf8Error) -> Self {
70 Self::Utf8Error(e)
71 }
72}
73
74pub type Result<T, E = DecodingError> = std::result::Result<T, E>;
75
76#[must_use]
77pub struct LimitToken {
78 prev_limit: usize,
79 set_to: usize,
80}
81
82pub struct Deserializer<'de, B> {
83 pub(crate) buf: &'de mut B,
84 limit: usize,
85}
86
87impl<'de, B: Buf> Deserializer<'de, B> {
88 pub fn new(buf: &'de mut B) -> Self {
89 Self {
90 buf,
91 limit: usize::MAX,
92 }
93 }
94
95 pub fn set_limit(&mut self, limit: usize) -> LimitToken {
96 let prev_limit = self.limit;
97 let set_to = limit.min(self.limit);
98 self.limit = set_to;
99
100 LimitToken { prev_limit, set_to }
101 }
102
103 pub fn reset_limit(&mut self, token: LimitToken) {
104 let limit_used = token.set_to - self.limit;
105 self.limit = token.prev_limit - limit_used;
106
107 std::mem::forget(token);
109 }
110
111 pub fn get_u8(&mut self) -> u8 {
113 if self.limit != usize::MAX {
114 self.limit -= 1
115 }
116 self.buf.get_u8()
117 }
118
119 pub fn has_remaining(&self) -> bool {
120 self.buf.remaining() != 0 && self.limit != 0
121 }
122
123 pub fn check_limit<'a, F: FnOnce(&'a mut B) -> Result<V>, V>(
124 &'a mut self,
125 len: usize,
126 f: F,
127 ) -> Result<V> {
128 if self.limit == usize::MAX {
129 f(self.buf)
130 } else if self.limit < len {
131 Err(DecodingError::Eof)
132 } else {
133 self.limit -= len;
134 f(self.buf)
135 }
136 }
137
138 pub fn read_varint<V: VarInt>(&mut self) -> Result<V> {
139 V::read(self)
140 }
141
142 pub fn read_bytes_borrowed<'a>(&mut self, len: usize) -> Result<&'a [u8]> {
143 self.check_limit(len, |buf| {
144 let c = buf.chunk();
145 if c.len() >= len {
146 let c_raw = unsafe { c.get_unchecked(..len) } as *const [u8];
148 buf.advance(len);
149 Ok(unsafe { &*c_raw })
150 } else {
151 Err(DecodingError::Eof)
152 }
153 })
154 }
155
156 pub fn read_bytes<'a>(&mut self, len: usize) -> Result<Cow<'a, [u8]>> {
157 use bytes::BufMut;
158 self.check_limit(len, |buf| {
159 let c = buf.chunk();
160 if buf.remaining() < len {
161 Err(DecodingError::Eof)
162 } else if c.len() >= len {
163 let c_raw = unsafe { c.get_unchecked(..len) } as *const [u8];
165 buf.advance(len);
166 Ok(Cow::Borrowed(unsafe { &*c_raw }))
167 } else {
168 let mut v = Vec::with_capacity(len);
169 (&mut v).put(buf.take(len));
170 Ok(Cow::Owned(v))
171 }
172 })
173 }
174
175 pub fn read_bytes_owned(&mut self, len: usize) -> Result<Box<[u8]>> {
176 use bytes::BufMut;
177 self.check_limit(len, |buf| {
178 if buf.remaining() < len {
179 Err(DecodingError::Eof)
180 } else {
181 let mut v= Vec::with_capacity(len);
182 (&mut v).put(buf.take(len));
183 Ok(v.into_boxed_slice())
184 }
185 })
186 }
187}
188
189impl<'de, B: Buf> From<&'de mut B> for Deserializer<'de, B> {
190 fn from(b: &'de mut B) -> Self {
191 Self::new(b)
192 }
193}
194
195impl<'de> Decodable<'de> for &'de [u8] {
196 type Wire = LengthDelimitedWire;
197
198 fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self> {
199 let len = deserializer.read_varint()?;
200 deserializer.read_bytes_borrowed(len)
201 }
202}
203
204impl Decodable<'_> for Vec<u8> {
205 type Wire = LengthDelimitedWire;
206
207 fn decode<B: Buf>(deserializer: &mut Deserializer<'_, B>) -> Result<Self> {
208 let len = deserializer.read_varint()?;
209 deserializer.read_bytes_owned(len).map(From::from)
210 }
211}
212
213impl Decodable<'_> for Box<[u8]> {
214 type Wire = LengthDelimitedWire;
215
216 fn decode<B: Buf>(deserializer: &mut Deserializer<'_, B>) -> Result<Self> {
217 let len = deserializer.read_varint()?;
218 deserializer.read_bytes_owned(len)
219 }
220}
221
222impl<'de> Decodable<'de> for &'de str {
223 type Wire = LengthDelimitedWire;
224
225 fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self> {
226 let len = deserializer.read_varint()?;
227 let bytes = deserializer.read_bytes_borrowed(len)?;
228 Ok(std::str::from_utf8(bytes)?)
229 }
230}
231
232impl Decodable<'_> for String {
233 type Wire = LengthDelimitedWire;
234
235 fn decode<B: Buf>(deserializer: &mut Deserializer<'_, B>) -> Result<Self> {
236 let len = deserializer.read_varint()?;
237 let bytes = deserializer.read_bytes_owned(len)?;
238 Ok(String::from_utf8(bytes.into()).map_err(|e| e.utf8_error())?)
239 }
240}
241
242impl Decodable<'_> for Box<str> {
243 type Wire = LengthDelimitedWire;
244
245 fn decode<B: Buf>(deserializer: &mut Deserializer<'_, B>) -> Result<Self> {
246 String::decode(deserializer).map(String::into_boxed_str)
247 }
248}
249
250impl<'de, M: DecodableMessage<'de> + Default> Decodable<'de> for Message<M> {
251 type Wire = LengthDelimitedWire;
252
253 fn decode<B: Buf>(deserializer: &mut Deserializer<'de, B>) -> Result<Self> {
254 let len = deserializer.read_varint()?;
255 let tk = deserializer.set_limit(len);
256 let message = M::decode(deserializer);
257 deserializer.reset_limit(tk);
258 Ok(Message(message?))
259 }
260}