open_protocol_codec/
decode.rs

1use thiserror;
2use crate::{FieldNumber, FIELD_NUMBER_LEN};
3use chrono::{DateTime, Local, MappedLocalTime, TimeZone};
4
5#[derive(Debug, Eq, PartialEq, thiserror::Error)]
6pub enum Error {
7    #[error("Invalid character '{0}' (not a digit) on position {1}.")]
8    InvalidDigit(u8, usize),
9    #[error("Cannot parse '{0}' as boolean on position {1}.")]
10    InvalidBoolean(char, usize),
11    #[error("Invalid character {0} on position {1}.")]
12    InvalidCharacter(char, usize),
13    #[error("Invalid number {0} for enum on position {1}.")]
14    InvalidEnumNumber(u16, usize),
15    #[error("Invalid MID {mid} revision {revision}")]
16    InvalidMessage { mid: u16, revision: u16 },
17    #[error("Cannot parse bytes to UTF-8 string")]
18    InvalidArgNumber { wanted: u8, actual: u8 },
19    #[error("This type does not allow for unsized decodes.")]
20    UnsizedDecodeNotAllowed,
21    #[error("Out of bounds, wants {request}, but total size is {size}.")]
22    OutOfRightBound { request: usize, size: usize },
23    #[error("Out of bounds, wants to go back {request} positions, but cursor position is {cursor}.")]
24    OutOfLeftBound { request: usize, cursor: usize },
25    #[error("Invalid timestamp")]
26    InvalidTimestamp,
27    #[error("Expected character '{expected_char}', but got character '{decoded_char}' on position {pos}.")]
28    ExpectedCharacter { decoded_char: char, expected_char: char, pos: usize },
29    #[error("Type {type_name} cannot be decoded with size {requested_size}.")]
30    SizeMismatch { requested_size: usize, type_name: String },
31    #[error("Integer {number} does not fit the type {type_name}.")]
32    IntegerOverflow { type_name: String, number: u128 },
33    #[error("Insufficient bytes to decode message, header indicates {need} bytes but only have {have} bytes.")]
34    InsufficientBytes { have: usize, need: usize },
35
36    #[error("Not implemented")]
37    NotImplemented,
38}
39
40pub type Result<T> = core::result::Result<T, Error>;
41
42pub trait Decode: Sized {
43    /// This will take the decoder and return the data itself, parsed from the decoder's input
44    /// bytes.
45    fn decode(decoder: &mut Decoder) -> Result<Self>;
46
47    fn decode_sized(decoder: &mut Decoder, _size: usize) -> Result<Self> {
48        Self::decode(decoder)
49    }
50}
51
52#[derive(Debug)]
53pub struct Decoder<'a> {
54    bytes: &'a [u8],
55    cursor: usize,
56}
57
58impl<'a> Decoder<'a> {
59    pub fn new(bytes: &'a [u8]) -> Self {
60        Self { bytes, cursor: 0 }
61    }
62
63    pub fn read_byte(&mut self) -> Result<u8> {
64        if self.cursor >= self.bytes.len() {
65            return Err(Error::OutOfRightBound { request: self.cursor + 1, size: self.bytes.len()});
66        }
67
68        let byte = self.bytes[self.cursor];
69
70        self.skip(1)?;
71        Ok(byte)
72    }
73
74    pub fn read_bytes(&mut self, len: usize) -> Result<&'a [u8]> {
75        if (self.cursor + len) > self.bytes.len() {
76            return Err(Error::OutOfRightBound { request: self.cursor + len, size: self.bytes.len()});
77        }
78
79        let bytes = &self.bytes[self.cursor..(self.cursor + len)];
80
81        self.skip(len)?;
82        Ok(bytes)
83    }
84
85    pub fn skip(&mut self, len: usize) -> Result<()> {
86        if (self.cursor + len) > self.bytes.len() {
87            return Err(Error::OutOfRightBound { request: self.cursor + len, size: self.bytes.len()});
88        }
89
90        self.cursor += len;
91        Ok(())
92    }
93
94    pub fn back(&mut self, len: usize) -> Result<()> {
95        if self.cursor < len {
96            return Err(Error::OutOfLeftBound { request: len, cursor: self.cursor });
97        }
98
99        self.cursor -= len;
100        Ok(())
101    }
102
103    pub fn pos(&self) -> usize {
104        self.cursor
105    }
106
107    pub fn len(&self) -> usize {
108        self.bytes.len()
109    }
110
111    pub fn expect_char(&mut self, expected_char: char) -> Result<()> {
112        let decoded_char = char::decode(self)?;
113
114        if decoded_char != expected_char {
115            return Err(Error::ExpectedCharacter { decoded_char, expected_char, pos: self.pos() - 1});
116        }
117
118        Ok(())
119    }
120
121    pub fn read_sized_field<T: Decode>(&mut self, size: usize) -> Result<T> {
122        Ok(T::decode_sized(self, size)?)
123    }
124
125    pub fn read_numbered_sized_field<T: Decode>(
126        &mut self,
127        number: FieldNumber,
128        size: usize,
129    ) -> Result<T> {
130        let decoded_number = FieldNumber::decode_sized(self, FIELD_NUMBER_LEN)?;
131
132        if decoded_number != number {
133            return Err(Error::InvalidArgNumber { wanted: number, actual: decoded_number });
134        }
135
136        self.read_sized_field(size)
137    }
138
139    pub fn read_numbered_sized_optional_field<T: Decode>(
140        &mut self,
141        number: FieldNumber,
142        size: usize,
143    ) -> Result<Option<T>> {
144        let decoded_number = FieldNumber::decode_sized(self, FIELD_NUMBER_LEN)?;
145
146        if decoded_number != number {
147            self.back(FIELD_NUMBER_LEN)?;
148            return Ok(None)
149        }
150
151        Ok(Some(self.read_sized_field(size)?))
152    }
153
154    pub fn read_numbered_field<T: Decode>(&mut self, number: FieldNumber) -> Result<T> {
155        let decoded_number = FieldNumber::decode_sized(self, FIELD_NUMBER_LEN)?;
156
157        if decoded_number != number {
158            return Err(Error::InvalidArgNumber { wanted: number, actual: decoded_number })
159        }
160
161        T::decode(self)
162    }
163
164    pub fn read_sized_list<T: Decode>(
165        &mut self,
166        list_length: usize,
167        item_size: usize,
168    ) -> Result<Vec<T>> {
169        let mut list = Vec::with_capacity(list_length);
170        for _ in 0..list_length {
171            list.push(T::decode_sized(self, item_size)?);
172        }
173        Ok(list)
174    }
175
176    pub fn read_list<T: Decode>(
177        &mut self,
178        list_length: usize
179    ) -> Result<Vec<T>> {
180        let mut list = Vec::with_capacity(list_length);
181        for _ in 0..list_length {
182            list.push(T::decode(self)?);
183        }
184        Ok(list)
185    }
186}
187
188impl<'a> From<&'a str> for Decoder<'a> {
189    fn from(value: &'a str) -> Self {
190        Self { bytes: value.as_bytes(), cursor: 0 }
191    }
192}
193
194impl<'a> From<&'a [u8]> for Decoder<'a> {
195    fn from(value: &'a [u8]) -> Self {
196        Self { bytes: value, cursor: 0 }
197    }
198}
199
200
201pub fn decode<T: Decode>(bytes: &[u8]) -> Result<T> {
202    let mut decoder = Decoder::new(bytes);
203    T::decode(&mut decoder)
204}
205
206
207/// Values ranging 0..256, length 1-3
208impl Decode for u8 {
209    fn decode(_: &mut Decoder) -> Result<Self> {
210        Err(Error::UnsizedDecodeNotAllowed)
211    }
212
213    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
214        if !(1..=3).contains(&size) {
215            return Err(Error::SizeMismatch { requested_size: size, type_name: "u8".into() });
216        }
217
218        let mut result: u8 = 0;
219
220        for _ in 0..size {
221            let raw = decoder.read_byte()?;
222            if raw < b'0' || raw > b'9' {
223                return Err(Error::InvalidDigit(raw, decoder.pos() - 1));
224            }
225            let digit = raw - b'0';
226
227            if result > 25 || (result == 25 && digit > 5) {
228                return Err(Error::IntegerOverflow { number: (result * 10 + digit) as u128, type_name: "u8".into() });
229            }
230            result = result * 10 + digit;
231        }
232
233        Ok(result)
234    }
235}
236
237/// Values ranging 0..65536, length 1-5
238impl Decode for u16 {
239    fn decode(_: &mut Decoder) -> Result<Self> {
240        Err(Error::UnsizedDecodeNotAllowed)
241    }
242
243    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
244        if !(1..=5).contains(&size) {
245            return Err(Error::SizeMismatch { requested_size: size, type_name: "u16".into() });
246        }
247
248        let mut result: u32 = 0;
249
250        for _ in 0..size {
251            let raw = decoder.read_byte()?;
252            if raw < b'0' || raw > b'9' {
253                return Err(Error::InvalidDigit(raw, decoder.pos() - 1));
254            }
255
256            let digit = (raw - b'0') as u32;
257
258            if result > 6553 || (result == 6553 && digit > 5) {
259                return Err(Error::IntegerOverflow { number: (result * 10 + digit) as u128, type_name: "u16".into() });
260            }
261            result = result * 10 + digit;
262        }
263
264        Ok(result as u16)
265    }
266}
267
268/// Values ranging 0..4294967296, length 1-10
269impl Decode for u32 {
270    fn decode(_: &mut Decoder) -> Result<Self> {
271        // If the protocol always requires a size for u32, we keep this an error.
272        Err(Error::UnsizedDecodeNotAllowed)
273    }
274
275    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
276        if !(1..=10).contains(&size) {
277            return Err(Error::SizeMismatch { requested_size: size, type_name: "u32".into() });
278        }
279
280        let mut result: u64 = 0;
281
282        for _ in 0..size {
283            let raw = decoder.read_byte()?;
284            if raw < b'0' || raw > b'9' {
285                return Err(Error::InvalidDigit(raw, decoder.pos() - 1));
286            }
287            let digit = (raw - b'0') as u64;
288
289            if result > 429496729 || (result == 429496729 && digit > 5) {
290                return Err(Error::IntegerOverflow { number: (result * 10 + digit) as u128, type_name: "u32".into() });
291            }
292            result = result * 10 + digit;
293        }
294
295        Ok(result as u32)
296    }
297}
298
299
300/// Values ranging 0..18446744073709551616, length 1-20
301impl Decode for u64 {
302    fn decode(_: &mut Decoder) -> Result<Self> {
303        // If the protocol always requires a size for u64, we keep this an error.
304        Err(Error::UnsizedDecodeNotAllowed)
305    }
306
307    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
308        if !(1..=20).contains(&size) {
309            return Err(Error::SizeMismatch { requested_size: size, type_name: "u64".into() });
310        }
311
312        let mut result: u128 = 0;
313
314        for _ in 0..size {
315            let raw = decoder.read_byte()?;
316            if raw < b'0' || raw > b'9' {
317                return Err(Error::InvalidDigit(raw, decoder.pos() - 1));
318            }
319            let digit = (raw - b'0') as u128;
320
321            if result > 1844674407370955161 || (result == 1844674407370955161 && digit > 5) {
322                return Err(Error::IntegerOverflow { number: result * 10 + digit, type_name: "u64".into() });
323            }
324            result = result * 10 + digit;
325        }
326
327        Ok(result as u64)
328    }
329}
330
331/// Raw ASCII character, length 1
332impl Decode for char {
333    fn decode(decoder: &mut Decoder) -> Result<Self> {
334        let byte = decoder.read_byte()?;
335        Ok(byte.into())
336    }
337
338    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
339        if size != 1 {
340            return Err(Error::SizeMismatch { requested_size: size, type_name: "char".into() });
341        }
342        Self::decode(decoder)
343    }
344}
345
346/// String based on the ASCII characters in the decoder, can be length 0-infinite.
347impl Decode for String {
348    fn decode(decoder: &mut Decoder) -> Result<Self> {
349        let mut chars = Vec::new();
350
351        while decoder.pos() < decoder.len() {
352            let next_char = char::decode(decoder)?;
353
354            if next_char == ' ' || next_char == '\0' {
355                break;
356            }
357
358            chars.push(char::decode(decoder)?);
359        }
360
361        Ok(String::from_iter(chars))
362    }
363
364    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
365        let chars = decoder.read_bytes(size)?;
366
367        let mut end = size;
368        while chars[end - 1] == b' ' {
369            end -= 1;
370        }
371
372        let string = String::from_utf8_lossy(&chars[0..end]).to_string();
373        Ok(string)
374    }
375}
376
377/// Values 0 and 1 only. Length is always 1.
378impl Decode for bool {
379    fn decode(decoder: &mut Decoder) -> Result<Self> {
380        let decoded_char = char::decode(decoder)?;
381
382        match decoded_char {
383            '1' => Ok(true),
384            '0' => Ok(false),
385            _ => Err(Error::InvalidBoolean(decoded_char, decoder.pos() - 1)),
386        }
387    }
388
389    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
390        if size != 1 {
391            return Err(Error::SizeMismatch { requested_size: size, type_name: "bool".into() });
392        }
393        Self::decode(decoder)
394    }
395}
396
397impl<T> Decode for Option<T> where T: Decode {
398    fn decode(_: &mut Decoder) -> Result<Self> {
399        Err(Error::UnsizedDecodeNotAllowed)
400    }
401
402    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
403        let bytes = decoder.read_bytes(size)?;
404
405        if bytes.iter().all(|&byte| byte == b' ') {
406            Ok(None)
407        } else {
408            decoder.back(size)?;
409            Ok(Some(T::decode_sized(decoder, size)?))
410        }
411    }
412}
413
414impl Decode for DateTime<Local> {
415    fn decode(decoder: &mut Decoder) -> Result<Self> {
416        let year = u32::decode_sized(decoder, 4)?;
417        decoder.expect_char('-')?;
418        let month = u32::decode_sized(decoder, 2)?;
419        decoder.expect_char('-')?;
420        let day = u32::decode_sized(decoder, 2)?;
421        decoder.expect_char(':')?;
422        let hour = u32::decode_sized(decoder, 2)?;
423        decoder.expect_char(':')?;
424        let min = u32::decode_sized(decoder, 2)?;
425        decoder.expect_char(':')?;
426        let sec = u32::decode_sized(decoder, 2)?;
427
428        match Local.with_ymd_and_hms(year as i32, month, day, hour, min, sec) {
429            MappedLocalTime::Single(timestamp) => Ok(timestamp),
430            _ => Err(Error::InvalidTimestamp)
431        }
432    }
433
434    fn decode_sized(decoder: &mut Decoder, size: usize) -> Result<Self> {
435        if size != 19 {
436            return Err(Error::SizeMismatch { requested_size: size, type_name: "DateTime".into() });
437        }
438        Self::decode(decoder)
439    }
440}
441
442#[cfg(test)]
443mod tests {
444    use chrono::{DateTime, Local, TimeZone};
445    use crate::decode::Error;
446    use crate::decode::{Decode, Decoder, Result};
447
448    #[test]
449    fn test_read_byte() {
450        let bytes = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8'];
451
452        let mut decoder = Decoder::new(bytes.as_slice());
453
454        assert_eq!(decoder.read_byte(), Ok(b'1'));
455        assert_eq!(decoder.read_byte(), Ok(b'2'));
456        assert_eq!(decoder.read_byte(), Ok(b'3'));
457        assert_eq!(decoder.read_byte(), Ok(b'4'));
458
459        assert_eq!(decoder.pos(), 4usize);
460    }
461
462    #[test]
463    fn test_read_bytes() {
464        let bytes = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8'];
465
466        let mut decoder = Decoder::new(&bytes[..]);
467
468        assert_eq!(
469            decoder.read_bytes(4),
470            Ok([b'1', b'2', b'3', b'4'].as_slice())
471        );
472        assert_eq!(decoder.read_bytes(2), Ok([b'5', b'6'].as_slice()));
473
474        assert_eq!(decoder.pos(), 6usize);
475    }
476
477    #[test]
478    fn test_read_bool() {
479        let bytes = [b'1', b'0', b'1', b'0'];
480        let mut decoder = Decoder::new(&bytes[..]);
481
482        assert_eq!(bool::decode(&mut decoder), Ok(true));
483        assert_eq!(bool::decode(&mut decoder), Ok(false));
484        assert_eq!(bool::decode(&mut decoder), Ok(true));
485        assert_eq!(bool::decode(&mut decoder), Ok(false));
486        assert_eq!(decoder.pos(), 4);
487    }
488
489    #[test]
490    fn test_read_u8() {
491        let bytes = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8'];
492        let mut decoder = Decoder::new(&bytes[..]);
493
494        assert_eq!(u8::decode_sized(&mut decoder, 3), Ok(123));
495        assert_eq!(u8::decode_sized(&mut decoder, 2), Ok(45));
496        assert_eq!(decoder.pos(), 5);
497    }
498
499    #[test]
500    fn test_read_u8_unsized() {
501        let bytes = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8'];
502        let mut decoder = Decoder::new(&bytes[..]);
503
504        assert_eq!(u8::decode(&mut decoder), Err(Error::UnsizedDecodeNotAllowed));
505        assert_eq!(decoder.pos(), 0);
506    }
507
508    #[test]
509    fn test_read_u8_too_large() {
510        let bytes = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8'];
511        let mut decoder = Decoder::new(&bytes[..]);
512
513        assert_eq!(u8::decode_sized(&mut decoder, 5), Err(Error::IntegerOverflow { number: 12345, type_name: "u8".into()}));
514        assert_eq!(decoder.pos(), 0);
515    }
516
517    #[test]
518    fn test_read_u16() {
519        let bytes = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8'];
520        let mut decoder = Decoder::new(&bytes[..]);
521
522        assert_eq!(u16::decode_sized(&mut decoder, 5), Ok(12345));
523        assert_eq!(decoder.pos(), 5);
524    }
525
526    #[test]
527    fn test_read_u32() {
528        let bytes = [b'1', b'2', b'3', b'4', b'5', b'6', b'7', b'8'];
529        let mut decoder = Decoder::new(&bytes[..]);
530
531        assert_eq!(u32::decode_sized(&mut decoder, 8), Ok(12345678));
532        assert_eq!(decoder.pos(), 8);
533    }
534
535    #[test]
536    fn test_read_string() {
537        let bytes = [b'H', b'e', b'l', b'l', b'o', b'6', b'7', b'8'];
538        let mut decoder = Decoder::new(&bytes[..]);
539
540        assert_eq!(String::decode_sized(&mut decoder, 5), Ok("Hello".to_string()));
541        assert_eq!(decoder.pos(), 5);
542    }
543
544    #[test]
545    fn test_read_option_string() {
546        let bytes = [b'H', b'e', b'l', b'l', b'o', b'6', b'7', b'8'];
547        let mut decoder = Decoder::new(&bytes[..]);
548
549        let val: Result<Option<String>> = Option::decode_sized(&mut decoder, 5);
550
551        assert_eq!(val, Ok(Some("Hello".to_string())));
552        assert_eq!(decoder.pos(), 5);
553    }
554
555    #[test]
556    fn test_read_option_none() {
557        let bytes = [b' ', b' ', b' ', b' ', b' ', b'6', b'7', b'8'];
558        let mut decoder = Decoder::new(&bytes[..]);
559
560        let val: Result<Option<String>> = Option::decode_sized(&mut decoder, 5);
561
562        assert_eq!(val, Ok(None));
563        assert_eq!(decoder.pos(), 5);
564    }
565
566    #[test]
567    fn test_read_sized_field() {
568        let bytes = [b'H', b'e', b'l', b'l', b'o', b'6', b'7', b'8'];
569        let mut decoder = Decoder::new(&bytes[..]);
570
571        assert_eq!(decoder.read_sized_field(5), Ok("Hello".to_string()));
572        assert_eq!(decoder.pos(), 5);
573    }
574
575    #[test]
576    fn test_read_numbered_sized_field() {
577        let bytes = [b'0', b'1', b'H', b'e', b'l', b'l', b'o', b'0', b'2', b'1'];
578        let mut decoder = Decoder::new(&bytes[..]);
579
580        assert_eq!(decoder.read_numbered_sized_field(1, 5), Ok("Hello".to_string()));
581        assert_eq!(decoder.read_numbered_sized_field(2, 1), Ok(true));
582        assert_eq!(decoder.pos(), 10);
583    }
584
585    #[test]
586    fn test_read_numbered_sized_field_invalid_number() {
587        let bytes = [b'0', b'1', b'H', b'e', b'l', b'l', b'o', b'0', b'2', b'1'];
588        let mut decoder = Decoder::new(&bytes[..]);
589
590        assert_eq!(decoder.read_numbered_sized_field::<String>(4, 5), Err(Error::InvalidArgNumber { wanted: 4, actual: 1 }));
591        assert_eq!(decoder.pos(), 2);
592    }
593
594    #[test]
595    fn test_read_timestamp() {
596        let str = "2001-12-01:20:12:45000000";
597        let mut decoder = Decoder::new(str.as_bytes());
598        let actual_timestamp = Local.with_ymd_and_hms(2001, 12, 1, 20, 12, 45).unwrap();
599
600        let timestamp_res: Result<DateTime<Local>> = DateTime::decode(&mut decoder);
601
602        assert_eq!(timestamp_res, Ok(actual_timestamp));
603    }
604
605    #[test]
606    fn test_read_invalid_timestamp() {
607        let str = "2001:12:01:20:12:45000000";
608        let mut decoder = Decoder::new(str.as_bytes());
609
610        let timestamp_res: Result<DateTime<Local>> = DateTime::decode(&mut decoder);
611
612        assert_eq!(timestamp_res, Err(Error::ExpectedCharacter { decoded_char: ':', expected_char: '-', pos: 4 }));
613    }
614}