1use nom::bytes::complete::take;
2use nom::number::complete::{le_u24, le_u8};
3use serde::de::{self, DeserializeSeed, SeqAccess, Visitor};
4use serde::Deserialize;
5
6use super::event::{Version, VersionInfo};
7use crate::packet_parser::PacketError;
8
9
10pub struct Deserializer<'de> {
11 input: &'de [u8],
12
13 de_version: [u16; 4],
15
16 version_info: VersionInfo,
18
19 skip: bool,
22
23 name: &'static str,
26}
27
28impl<'de> Deserializer<'de> {
29 pub fn from_slice(
30 input: &'de [u8], de_version: [u16; 4], version_info: VersionInfo, name: &'static str,
31 ) -> Self {
32 Deserializer {
33 input,
34 de_version,
35 version_info,
36 name,
37 skip: false,
38 }
39 }
40}
41
42pub fn from_slice<'a, T>(input: &'a [u8], de_version: [u16; 4]) -> Result<T, PacketError>
43where
44 T: Deserialize<'a> + Version,
45{
46 let mut deserializer = Deserializer::from_slice(input, de_version, T::version(), T::name());
47 let t = T::deserialize(&mut deserializer)?;
48
49 if !deserializer.input.is_empty() {
50 return Err(PacketError::UnconsumedInput);
51 }
52
53 Ok(t)
54}
55
56pub fn from_slice_unchecked<'a, T>(
58 input: &'a [u8], de_version: [u16; 4],
59) -> Result<(&'a [u8], T), PacketError>
60where
61 T: Deserialize<'a> + Version,
62{
63 let mut deserializer = Deserializer::from_slice(input, de_version, T::version(), T::name());
64 let t = T::deserialize(&mut deserializer)?;
65
66 Ok((deserializer.input, t))
67}
68
69impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
70 type Error = PacketError;
71
72 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
73 where
74 V: Visitor<'de>,
75 {
76 Err(PacketError::IncorrectUsage)
77 }
78
79 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
80 where
81 V: Visitor<'de>,
82 {
83 let (remaining, result) = le_u8(self.input)?;
84 self.input = remaining;
85 let result = !matches!(result, 0);
86
87 visitor.visit_bool(result)
88 }
89
90 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
91 where
92 V: Visitor<'de>,
93 {
94 use nom::number::complete::le_i8;
95
96 let (remaining, result) = le_i8(self.input)?;
97 self.input = remaining;
98 visitor.visit_i8(result)
99 }
100
101 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
102 where
103 V: Visitor<'de>,
104 {
105 use nom::number::complete::le_i16;
106
107 let (remaining, result) = le_i16(self.input)?;
108 self.input = remaining;
109 visitor.visit_i16(result)
110 }
111
112 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
113 where
114 V: Visitor<'de>,
115 {
116 use nom::number::complete::le_i32;
117
118 let (remaining, result) = le_i32(self.input)?;
119 self.input = remaining;
120 visitor.visit_i32(result)
121 }
122
123 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
124 where
125 V: Visitor<'de>,
126 {
127 use nom::number::complete::le_i64;
128
129 let (remaining, result) = le_i64(self.input)?;
130 self.input = remaining;
131 visitor.visit_i64(result)
132 }
133
134 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
135 where
136 V: Visitor<'de>,
137 {
138 let (remaining, result) = le_u8(self.input)?;
139 self.input = remaining;
140 visitor.visit_u8(result)
141 }
142
143 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
144 where
145 V: Visitor<'de>,
146 {
147 use nom::number::complete::le_u16;
148
149 let (remaining, result) = le_u16(self.input)?;
150 self.input = remaining;
151 visitor.visit_u16(result)
152 }
153
154 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
155 where
156 V: Visitor<'de>,
157 {
158 use nom::number::complete::le_u32;
159
160 let (remaining, result) = le_u32(self.input)?;
161 self.input = remaining;
162 visitor.visit_u32(result)
163 }
164
165 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166 where
167 V: Visitor<'de>,
168 {
169 use nom::number::complete::le_u64;
170
171 let (remaining, result) = le_u64(self.input)?;
172 self.input = remaining;
173 visitor.visit_u64(result)
174 }
175
176 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
177 where
178 V: Visitor<'de>,
179 {
180 use nom::number::complete::le_f32;
181
182 let (remaining, result) = le_f32(self.input)?;
183 self.input = remaining;
184 visitor.visit_f32(result)
185 }
186
187 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
188 where
189 V: Visitor<'de>,
190 {
191 use nom::number::complete::le_f64;
192
193 let (remaining, result) = le_f64(self.input)?;
194 self.input = remaining;
195 visitor.visit_f64(result)
196 }
197
198 fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
199 where
200 V: Visitor<'de>,
201 {
202 unimplemented!()
203 }
204
205 fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
206 where
207 V: Visitor<'de>,
208 {
209 unimplemented!()
210 }
211
212 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
213 where
214 V: Visitor<'de>,
215 {
216 let (remaining, len) = le_u8(self.input)?;
217
218 if (len as usize) > remaining.len() {
219 return Err(PacketError::IncompleteInput("string length is too large".into()));
220 }
221
222 let str_vec = &remaining[..(len as usize)];
223
224 let str = std::str::from_utf8(str_vec)?;
225 self.input = &remaining[(len as usize)..];
226 visitor.visit_string(str.into())
227 }
228
229 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
230 where
231 V: Visitor<'de>,
232 {
233 let (remaining, bytes_array) = parse_byte_array(self.input)?;
234
235 self.input = remaining;
236
237 visitor.visit_borrowed_bytes(bytes_array)
238 }
239
240 fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
241 where
242 V: Visitor<'de>,
243 {
244 unimplemented!()
245 }
246
247 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
248 where
249 V: Visitor<'de>,
250 {
251 if self.skip {
252 self.skip = false;
253 visitor.visit_none()
254 } else {
255 visitor.visit_some(self)
256 }
257 }
258
259 fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
260 where
261 V: Visitor<'de>,
262 {
263 unimplemented!()
264 }
265
266 fn deserialize_unit_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value, Self::Error>
267 where
268 V: Visitor<'de>,
269 {
270 unimplemented!()
271 }
272
273 fn deserialize_newtype_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value, Self::Error>
274 where
275 V: Visitor<'de>,
276 {
277 unimplemented!()
278 }
279
280 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
281 where
282 V: Visitor<'de>,
283 {
284 let (remaining, len) = le_u8(self.input)?;
285 if len == u8::MAX {
286 let (remaining, len) = le_u24(remaining)?;
288
289 self.input = remaining;
290 visitor.visit_seq(SequenceAccess::new(self, len as usize))
291 } else {
292 self.input = remaining;
293 visitor.visit_seq(SequenceAccess::new(self, len as usize))
294 }
295 }
296
297 fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
298 where
299 V: Visitor<'de>,
300 {
301 unimplemented!()
302 }
303
304 fn deserialize_tuple_struct<V>(
305 self, _name: &'static str, _len: usize, _visitor: V,
306 ) -> Result<V::Value, Self::Error>
307 where
308 V: Visitor<'de>,
309 {
310 unimplemented!()
311 }
312
313 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
314 where
315 V: Visitor<'de>,
316 {
317 unimplemented!()
318 }
319
320 fn deserialize_struct<V>(
321 self, name: &'static str, fields: &'static [&'static str], visitor: V,
322 ) -> Result<V::Value, Self::Error>
323 where
324 V: Visitor<'de>,
325 {
326 if name == self.name {
327 if let VersionInfo::Struct(version_info) = self.version_info {
328 assert!(version_info.len() == fields.len());
329 visitor.visit_seq(VersionedSeqAccess::new(self, fields.len(), version_info))
330 } else {
331 panic!("Struct must always have version info of `Struct` variant")
332 }
333 } else {
334 visitor.visit_seq(SequenceAccess::new(self, fields.len()))
336 }
337 }
338
339 fn deserialize_enum<V>(
340 self, _name: &'static str, _variants: &'static [&'static str], _visitor: V,
341 ) -> Result<V::Value, Self::Error>
342 where
343 V: Visitor<'de>,
344 {
345 unimplemented!()
346 }
347
348 fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
349 where
350 V: Visitor<'de>,
351 {
352 unimplemented!()
353 }
354
355 fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
356 where
357 V: Visitor<'de>,
358 {
359 unimplemented!()
360 }
361}
362
363struct SequenceAccess<'a, 'de: 'a> {
364 de: &'a mut Deserializer<'de>,
365 len: usize,
366 curr: usize,
367}
368
369impl<'a, 'de> SequenceAccess<'a, 'de> {
370 fn new(de: &'a mut Deserializer<'de>, len: usize) -> Self {
371 SequenceAccess { de, len, curr: 0 }
372 }
373}
374
375impl<'de, 'a> SeqAccess<'de> for SequenceAccess<'a, 'de> {
376 type Error = PacketError;
377
378 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
379 where
380 T: DeserializeSeed<'de>,
381 {
382 if self.curr == self.len {
383 Ok(None)
384 } else {
385 self.curr += 1;
386 seed.deserialize(&mut *self.de).map(Some)
387 }
388 }
389}
390struct VersionedSeqAccess<'a, 'de: 'a> {
391 de: &'a mut Deserializer<'de>,
392 version_info: &'static [VersionInfo],
393 len: usize,
394 curr: usize,
395}
396
397impl<'a, 'de> VersionedSeqAccess<'a, 'de> {
398 fn new(de: &'a mut Deserializer<'de>, len: usize, version_info: &'static [VersionInfo]) -> Self {
399 VersionedSeqAccess {
400 de,
401 len,
402 version_info,
403 curr: 0,
404 }
405 }
406}
407
408impl<'de, 'a> SeqAccess<'de> for VersionedSeqAccess<'a, 'de> {
409 type Error = PacketError;
410
411 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
412 where
413 T: DeserializeSeed<'de>,
414 {
415 if self.curr == self.len {
416 Ok(None)
417 } else {
418 let version = &self.version_info[self.curr as usize];
420 self.de.version_info = version.clone();
421
422 if !is_correct_version(&self.de.de_version, version) {
423 self.de.skip = true;
424 }
425
426 self.curr += 1;
427 seed.deserialize(&mut *self.de).map(Some)
428 }
429 }
430}
431
432fn is_correct_version(de_version: &[u16; 4], item_version: &VersionInfo) -> bool {
433 match item_version {
434 VersionInfo::Version(version) => {
435 if de_version == &[0, 0, 0, 0] {
436 return true;
437 }
438
439 de_version >= version
440 }
441 VersionInfo::VersionRange((range_begin, range_end)) => {
442 if de_version == &[0, 0, 0, 0] {
443 return true;
444 }
445
446 de_version >= range_begin && de_version <= range_end
447 }
448 _ => true,
449 }
450}
451
452pub fn parse_byte_array(input: &[u8]) -> Result<(&[u8], &[u8]), PacketError> {
454 let (remaining, len) = le_u8(input)?;
455
456 if len == u8::MAX {
457 let (remaining, len) = le_u24(remaining)?;
459 let (remaining, bytes_array) = take(len)(remaining)?;
460
461 Ok((remaining, bytes_array))
462 } else {
463 let (remaining, bytes_array) = take(len)(remaining)?;
464
465 Ok((remaining, bytes_array))
466 }
467}