wot_replay_parser/packet_parser/
serde_packet.rs

1use nom::bytes::complete::take;
2use nom::number::complete::{le_u24, le_u8};
3use serde::de::{self, DeserializeSeed, SeqAccess, Visitor};
4use serde::Deserialize;
5
6use super::event::{Version, VersionInfo};
7use crate::packet_parser::PacketError;
8
9
10pub struct Deserializer<'de> {
11    input: &'de [u8],
12
13    /// The version the `Deserializer` expect the data format to be
14    de_version: [u16; 4],
15
16    /// Versions of each field. (only used when deserialzing to a struct)
17    version_info: VersionInfo,
18
19    /// Whether to skip deserialzing current item. This flag is set by `VersionedSeqAccess`.
20    /// When set, the current item is deserialized to `None` and the flag will be unset
21    skip: bool,
22
23    /// Name of struct we are deserialzing into. We use this to make sure we call the correct
24    /// visitor for children of this struct who are also structs
25    name: &'static str,
26}
27
28impl<'de> Deserializer<'de> {
29    pub fn from_slice(
30        input: &'de [u8], de_version: [u16; 4], version_info: VersionInfo, name: &'static str,
31    ) -> Self {
32        Deserializer {
33            input,
34            de_version,
35            version_info,
36            name,
37            skip: false,
38        }
39    }
40}
41
42pub fn from_slice<'a, T>(input: &'a [u8], de_version: [u16; 4]) -> Result<T, PacketError>
43where
44    T: Deserialize<'a> + Version,
45{
46    let mut deserializer = Deserializer::from_slice(input, de_version, T::version(), T::name());
47    let t = T::deserialize(&mut deserializer)?;
48
49    if !deserializer.input.is_empty() {
50        return Err(PacketError::UnconsumedInput);
51    }
52
53    Ok(t)
54}
55
56/// Does not check if the input was fully consumed
57pub fn from_slice_unchecked<'a, T>(
58    input: &'a [u8], de_version: [u16; 4],
59) -> Result<(&'a [u8], T), PacketError>
60where
61    T: Deserialize<'a> + Version,
62{
63    let mut deserializer = Deserializer::from_slice(input, de_version, T::version(), T::name());
64    let t = T::deserialize(&mut deserializer)?;
65
66    Ok((deserializer.input, t))
67}
68
69impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
70    type Error = PacketError;
71
72    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
73    where
74        V: Visitor<'de>,
75    {
76        Err(PacketError::IncorrectUsage)
77    }
78
79    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
80    where
81        V: Visitor<'de>,
82    {
83        let (remaining, result) = le_u8(self.input)?;
84        self.input = remaining;
85        let result = !matches!(result, 0);
86
87        visitor.visit_bool(result)
88    }
89
90    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
91    where
92        V: Visitor<'de>,
93    {
94        use nom::number::complete::le_i8;
95
96        let (remaining, result) = le_i8(self.input)?;
97        self.input = remaining;
98        visitor.visit_i8(result)
99    }
100
101    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
102    where
103        V: Visitor<'de>,
104    {
105        use nom::number::complete::le_i16;
106
107        let (remaining, result) = le_i16(self.input)?;
108        self.input = remaining;
109        visitor.visit_i16(result)
110    }
111
112    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
113    where
114        V: Visitor<'de>,
115    {
116        use nom::number::complete::le_i32;
117
118        let (remaining, result) = le_i32(self.input)?;
119        self.input = remaining;
120        visitor.visit_i32(result)
121    }
122
123    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
124    where
125        V: Visitor<'de>,
126    {
127        use nom::number::complete::le_i64;
128
129        let (remaining, result) = le_i64(self.input)?;
130        self.input = remaining;
131        visitor.visit_i64(result)
132    }
133
134    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
135    where
136        V: Visitor<'de>,
137    {
138        let (remaining, result) = le_u8(self.input)?;
139        self.input = remaining;
140        visitor.visit_u8(result)
141    }
142
143    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
144    where
145        V: Visitor<'de>,
146    {
147        use nom::number::complete::le_u16;
148
149        let (remaining, result) = le_u16(self.input)?;
150        self.input = remaining;
151        visitor.visit_u16(result)
152    }
153
154    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
155    where
156        V: Visitor<'de>,
157    {
158        use nom::number::complete::le_u32;
159
160        let (remaining, result) = le_u32(self.input)?;
161        self.input = remaining;
162        visitor.visit_u32(result)
163    }
164
165    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166    where
167        V: Visitor<'de>,
168    {
169        use nom::number::complete::le_u64;
170
171        let (remaining, result) = le_u64(self.input)?;
172        self.input = remaining;
173        visitor.visit_u64(result)
174    }
175
176    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
177    where
178        V: Visitor<'de>,
179    {
180        use nom::number::complete::le_f32;
181
182        let (remaining, result) = le_f32(self.input)?;
183        self.input = remaining;
184        visitor.visit_f32(result)
185    }
186
187    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
188    where
189        V: Visitor<'de>,
190    {
191        use nom::number::complete::le_f64;
192
193        let (remaining, result) = le_f64(self.input)?;
194        self.input = remaining;
195        visitor.visit_f64(result)
196    }
197
198    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
199    where
200        V: Visitor<'de>,
201    {
202        unimplemented!()
203    }
204
205    fn deserialize_str<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
206    where
207        V: Visitor<'de>,
208    {
209        unimplemented!()
210    }
211
212    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
213    where
214        V: Visitor<'de>,
215    {
216        let (remaining, len) = le_u8(self.input)?;
217
218        if (len as usize) > remaining.len() {
219            return Err(PacketError::IncompleteInput("string length is too large".into()));
220        }
221
222        let str_vec = &remaining[..(len as usize)];
223
224        let str = std::str::from_utf8(str_vec)?;
225        self.input = &remaining[(len as usize)..];
226        visitor.visit_string(str.into())
227    }
228
229    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
230    where
231        V: Visitor<'de>,
232    {
233        let (remaining, bytes_array) = parse_byte_array(self.input)?;
234
235        self.input = remaining;
236
237        visitor.visit_borrowed_bytes(bytes_array)
238    }
239
240    fn deserialize_byte_buf<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
241    where
242        V: Visitor<'de>,
243    {
244        unimplemented!()
245    }
246
247    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
248    where
249        V: Visitor<'de>,
250    {
251        if self.skip {
252            self.skip = false;
253            visitor.visit_none()
254        } else {
255            visitor.visit_some(self)
256        }
257    }
258
259    fn deserialize_unit<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
260    where
261        V: Visitor<'de>,
262    {
263        unimplemented!()
264    }
265
266    fn deserialize_unit_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value, Self::Error>
267    where
268        V: Visitor<'de>,
269    {
270        unimplemented!()
271    }
272
273    fn deserialize_newtype_struct<V>(self, _name: &'static str, _visitor: V) -> Result<V::Value, Self::Error>
274    where
275        V: Visitor<'de>,
276    {
277        unimplemented!()
278    }
279
280    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
281    where
282        V: Visitor<'de>,
283    {
284        let (remaining, len) = le_u8(self.input)?;
285        if len == u8::MAX {
286            // This is a packed int spanning 3 bytes
287            let (remaining, len) = le_u24(remaining)?;
288
289            self.input = remaining;
290            visitor.visit_seq(SequenceAccess::new(self, len as usize))
291        } else {
292            self.input = remaining;
293            visitor.visit_seq(SequenceAccess::new(self, len as usize))
294        }
295    }
296
297    fn deserialize_tuple<V>(self, _len: usize, _visitor: V) -> Result<V::Value, Self::Error>
298    where
299        V: Visitor<'de>,
300    {
301        unimplemented!()
302    }
303
304    fn deserialize_tuple_struct<V>(
305        self, _name: &'static str, _len: usize, _visitor: V,
306    ) -> Result<V::Value, Self::Error>
307    where
308        V: Visitor<'de>,
309    {
310        unimplemented!()
311    }
312
313    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
314    where
315        V: Visitor<'de>,
316    {
317        unimplemented!()
318    }
319
320    fn deserialize_struct<V>(
321        self, name: &'static str, fields: &'static [&'static str], visitor: V,
322    ) -> Result<V::Value, Self::Error>
323    where
324        V: Visitor<'de>,
325    {
326        if name == self.name {
327            if let VersionInfo::Struct(version_info) = self.version_info {
328                assert!(version_info.len() == fields.len());
329                visitor.visit_seq(VersionedSeqAccess::new(self, fields.len(), version_info))
330            } else {
331                panic!("Struct must always have version info of `Struct` variant")
332            }
333        } else {
334            // This is for children structs of the main struct. We do not support versioning for those
335            visitor.visit_seq(SequenceAccess::new(self, fields.len()))
336        }
337    }
338
339    fn deserialize_enum<V>(
340        self, _name: &'static str, _variants: &'static [&'static str], _visitor: V,
341    ) -> Result<V::Value, Self::Error>
342    where
343        V: Visitor<'de>,
344    {
345        unimplemented!()
346    }
347
348    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
349    where
350        V: Visitor<'de>,
351    {
352        unimplemented!()
353    }
354
355    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
356    where
357        V: Visitor<'de>,
358    {
359        unimplemented!()
360    }
361}
362
363struct SequenceAccess<'a, 'de: 'a> {
364    de:   &'a mut Deserializer<'de>,
365    len:  usize,
366    curr: usize,
367}
368
369impl<'a, 'de> SequenceAccess<'a, 'de> {
370    fn new(de: &'a mut Deserializer<'de>, len: usize) -> Self {
371        SequenceAccess { de, len, curr: 0 }
372    }
373}
374
375impl<'de, 'a> SeqAccess<'de> for SequenceAccess<'a, 'de> {
376    type Error = PacketError;
377
378    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
379    where
380        T: DeserializeSeed<'de>,
381    {
382        if self.curr == self.len {
383            Ok(None)
384        } else {
385            self.curr += 1;
386            seed.deserialize(&mut *self.de).map(Some)
387        }
388    }
389}
390struct VersionedSeqAccess<'a, 'de: 'a> {
391    de:           &'a mut Deserializer<'de>,
392    version_info: &'static [VersionInfo],
393    len:          usize,
394    curr:         usize,
395}
396
397impl<'a, 'de> VersionedSeqAccess<'a, 'de> {
398    fn new(de: &'a mut Deserializer<'de>, len: usize, version_info: &'static [VersionInfo]) -> Self {
399        VersionedSeqAccess {
400            de,
401            len,
402            version_info,
403            curr: 0,
404        }
405    }
406}
407
408impl<'de, 'a> SeqAccess<'de> for VersionedSeqAccess<'a, 'de> {
409    type Error = PacketError;
410
411    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
412    where
413        T: DeserializeSeed<'de>,
414    {
415        if self.curr == self.len {
416            Ok(None)
417        } else {
418            // Version Check
419            let version = &self.version_info[self.curr as usize];
420            self.de.version_info = version.clone();
421
422            if !is_correct_version(&self.de.de_version, version) {
423                self.de.skip = true;
424            }
425
426            self.curr += 1;
427            seed.deserialize(&mut *self.de).map(Some)
428        }
429    }
430}
431
432fn is_correct_version(de_version: &[u16; 4], item_version: &VersionInfo) -> bool {
433    match item_version {
434        VersionInfo::Version(version) => {
435            if de_version == &[0, 0, 0, 0] {
436                return true;
437            }
438
439            de_version >= version
440        }
441        VersionInfo::VersionRange((range_begin, range_end)) => {
442            if de_version == &[0, 0, 0, 0] {
443                return true;
444            }
445
446            de_version >= range_begin && de_version <= range_end
447        }
448        _ => true,
449    }
450}
451
452/// Return the remaining input and the byte_array that was parsed
453pub fn parse_byte_array(input: &[u8]) -> Result<(&[u8], &[u8]), PacketError> {
454    let (remaining, len) = le_u8(input)?;
455
456    if len == u8::MAX {
457        // This is a packed int spanning 3 bytes Ex: 0xFF080100
458        let (remaining, len) = le_u24(remaining)?;
459        let (remaining, bytes_array) = take(len)(remaining)?;
460
461        Ok((remaining, bytes_array))
462    } else {
463        let (remaining, bytes_array) = take(len)(remaining)?;
464
465        Ok((remaining, bytes_array))
466    }
467}