benu/format/
decoder.rs

1use crate::{Header, Token};
2use crate::format::varint::{READ_MAX, read_varint};
3
4#[cfg_attr(feature = "debug", derive(Debug))]
5pub enum DecoderError {
6    ExpectedPrefix,
7    ExpectedHeader,
8    ExpectedBytes {
9        at: usize,
10        length: usize,
11        header: Header,
12    },
13    InvalidInt {
14        at: usize,
15        header: Header,
16    }
17}
18
19#[cfg(feature = "debug")]
20impl std::fmt::Display for DecoderError {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            DecoderError::ExpectedPrefix => f.write_str("expected prefix (0xBE)"),
24            DecoderError::ExpectedHeader => f.write_str("expected header"),
25            DecoderError::ExpectedBytes { at, length, header } =>
26                f.write_fmt(format_args!("expected {} bytes at {} ({:?}", length, at, header)),
27            DecoderError::InvalidInt { at, header } =>
28                f.write_fmt(format_args!("invalid int at {} ({:?}", at, header)),
29        }
30    }
31}
32#[cfg(feature = "debug")]
33impl std::error::Error for DecoderError {}
34
35
36pub(crate) fn decode<'a>(input: &'a [u8]) -> Result<(Token<'a>, &'a [u8], &'a [u8]), DecoderError> {
37    if input.len() < 1 || input[0] != 0xBE {
38        return Err(DecoderError::ExpectedPrefix)
39    }
40
41    if input.len() < 2 {
42        return Err(DecoderError::ExpectedHeader)
43    }
44
45    let header = Header(input[1]);
46    let mut cursor = 2;
47    let mut token = Token::empty();
48
49
50    if header.compare(Header::TYP) {
51        if input.len() < 3 {
52            return Err(DecoderError::ExpectedBytes {
53                at: cursor,
54                length: 1,
55                header: Header::TYP,
56            });
57        }
58        token.typ = input[2];
59        cursor = 3;
60    };
61
62    if header.compare(Header::SUBJECT) {
63        token.subject = get_data(input, &mut cursor, Header::SUBJECT)?;
64    };
65
66    if header.compare(Header::INCREMENT) {
67        token.increment = get_varint(input, &mut cursor, Header::INCREMENT)?;
68    };
69
70    if header.compare(Header::BEFORE) {
71        token.before = get_varint(input, &mut cursor, Header::BEFORE)?;
72    };
73
74    if header.compare(Header::AFTER) {
75        token.after = get_varint(input, &mut cursor, Header::AFTER)?;
76    };
77
78    if header.compare(Header::DATA) {
79        token.data = get_data(input, &mut cursor, Header::DATA)?;
80    };
81
82    if header.compare(Header::SALT) {
83        token.salt = get_data(input, &mut cursor, Header::SALT)?;
84    };
85
86    token.header = header;
87
88    Ok((token, &input[0..cursor], &input[cursor..]))
89}
90
91fn get_varint(input: &[u8], cursor: &mut usize, header: Header) -> Result<u64, DecoderError> {
92    if input.len() < *cursor + READ_MAX {
93        return Err(DecoderError::ExpectedBytes { at: *cursor, length: READ_MAX, header })
94    }
95    let (count, value) = read_varint(&input[*cursor..] );
96
97    if count == 0 {
98        return Err(DecoderError::InvalidInt {
99            at: *cursor,
100            header,
101        })
102    }
103
104    *cursor += count;
105    Ok(value)
106}
107
108fn get_range<'a>(input: &'a [u8], cursor: &mut usize, length: usize, header: Header) -> Result<&'a [u8], DecoderError> {
109    return if input.len() < *cursor + length {
110        Err(DecoderError::ExpectedBytes { at: *cursor, length, header, })
111    } else {
112        let slice = &input[*cursor..*cursor + length];
113        *cursor += length;
114        Ok(slice)
115    }
116}
117
118fn get_data<'a>(input: &'a [u8], cursor: &mut usize, header: Header) -> Result<&'a [u8], DecoderError> {
119    let size = get_varint(input, cursor, header)? as usize;
120    get_range(input, cursor, size, header)
121}