flexiber/
decoder.rs

1use crate::{Decodable, ErrorKind, Length, Result, TagLike};
2use core::convert::TryInto;
3
4/// BER-TLV decoder.
5#[derive(Debug)]
6pub struct Decoder<'a> {
7    /// Byte slice being decoded.
8    ///
9    /// In the event an error was previously encountered this will be set to
10    /// `None` to prevent further decoding while in a bad state.
11    bytes: Option<&'a [u8]>,
12
13    /// Position within the decoded slice.
14    position: Length,
15}
16
17impl<'a> Decoder<'a> {
18    /// Create a new decoder for the given byte slice.
19    pub fn new(bytes: &'a [u8]) -> Self {
20        Self {
21            bytes: Some(bytes),
22            position: Length::zero(),
23        }
24    }
25
26    /// Decode a value which impls the [`Decodable`] trait.
27    pub fn decode<T: Decodable<'a>>(&mut self) -> Result<T> {
28        if self.is_failed() {
29            self.error(ErrorKind::Failed)?;
30        }
31
32        T::decode(self).map_err(|e| {
33            self.bytes.take();
34            e.nested(self.position)
35        })
36    }
37
38    /// Decode a TaggedValue with tag checked to be as expected, returning the value
39    pub fn decode_tagged_value<T: Decodable<'a> + TagLike, V: Decodable<'a>>(
40        &mut self,
41        tag: T,
42    ) -> Result<V> {
43        let tagged: crate::TaggedSlice<T> = self.decode()?;
44        tagged.tag().assert_eq(tag)?;
45        Self::new(tagged.as_bytes()).decode()
46    }
47
48    /// Decode a TaggedSlice with tag checked to be as expected, returning the value
49    pub fn decode_tagged_slice<T: Decodable<'a> + TagLike>(&mut self, tag: T) -> Result<&'a [u8]> {
50        let tagged: crate::TaggedSlice<T> = self.decode()?;
51        tagged.tag().assert_eq(tag)?;
52        Ok(tagged.as_bytes())
53    }
54
55    /// Return an error with the given [`ErrorKind`], annotating it with
56    /// context about where the error occurred.
57    pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
58        self.bytes.take();
59        Err(kind.at(self.position))
60    }
61
62    /// Did the decoding operation fail due to an error?
63    pub fn is_failed(&self) -> bool {
64        self.bytes.is_none()
65    }
66
67    /// Finish decoding, returning the given value if there is no
68    /// remaining data, or an error otherwise
69    pub fn finish<T>(self, value: T) -> Result<T> {
70        if self.is_failed() {
71            Err(ErrorKind::Failed.at(self.position))
72        } else if !self.is_finished() {
73            Err(ErrorKind::TrailingData {
74                decoded: self.position,
75                remaining: self.remaining_len()?,
76            }
77            .at(self.position))
78        } else {
79            Ok(value)
80        }
81    }
82
83    /// Have we decoded all of the bytes in this [`Decoder`]?
84    ///
85    /// Returns `false` if we're not finished decoding or if a fatal error
86    /// has occurred.
87    pub fn is_finished(&self) -> bool {
88        self.remaining().map(|rem| rem.is_empty()).unwrap_or(false)
89    }
90
91    /// Decode a single byte, updating the internal cursor.
92    pub(crate) fn byte(&mut self) -> Result<u8> {
93        match self.bytes(1u8)? {
94            [byte] => Ok(*byte),
95            _ => self.error(ErrorKind::Truncated),
96        }
97    }
98
99    /// Obtain a slice of bytes of the given length from the current cursor
100    /// position, or return an error if we have insufficient data.
101    pub(crate) fn bytes(&mut self, len: impl TryInto<Length>) -> Result<&'a [u8]> {
102        if self.is_failed() {
103            self.error(ErrorKind::Failed)?;
104        }
105
106        let len = len
107            .try_into()
108            .or_else(|_| self.error(ErrorKind::Overflow))?;
109
110        let result = self
111            .remaining()?
112            .get(..len.to_usize())
113            .ok_or(ErrorKind::Truncated)?;
114
115        self.position = (self.position + len)?;
116        Ok(result)
117    }
118
119    /// Peek at the next byte in the decoder without modifying the cursor.
120    pub(crate) fn peek(&self) -> Option<u8> {
121        self.remaining()
122            .ok()
123            .and_then(|bytes| bytes.first().cloned())
124    }
125
126    /// Obtain the remaining bytes in this decoder from the current cursor
127    /// position.
128    fn remaining(&self) -> Result<&'a [u8]> {
129        self.bytes
130            .and_then(|b| b.get(self.position.into()..))
131            .ok_or_else(|| ErrorKind::Truncated.at(self.position))
132    }
133
134    /// Get the number of bytes still remaining in the buffer.
135    fn remaining_len(&self) -> Result<Length> {
136        self.remaining()?.len().try_into()
137    }
138}
139
140impl<'a> From<&'a [u8]> for Decoder<'a> {
141    fn from(bytes: &'a [u8]) -> Decoder<'a> {
142        Decoder::new(bytes)
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use crate::{Decodable, Tag, TaggedSlice};
149
150    #[test]
151    fn zero_length() {
152        let buf: &[u8] = &[0x05, 0x00];
153        let ts = TaggedSlice::from_bytes(buf).unwrap();
154        assert_eq!(ts, TaggedSlice::from(Tag::universal(0x5), &[]).unwrap());
155    }
156}
157// #[cfg(test)]
158// mod tests {
159//     use super::Decoder;
160//     use crate::{Decodable, ErrorKind, Length, Tag};
161
162//     #[test]
163//     fn truncated_message() {
164//         let mut decoder = Decoder::new(&[]);
165//         let err = bool::decode(&mut decoder).err().unwrap();
166//         assert_eq!(ErrorKind::Truncated, err.kind());
167//         assert_eq!(Some(Length::zero()), err.position());
168//     }
169
170//     #[test]
171//     fn invalid_field_length() {
172//         let mut decoder = Decoder::new(&[0x02, 0x01]);
173//         let err = i8::decode(&mut decoder).err().unwrap();
174//         assert_eq!(ErrorKind::Length { tag: Tag::Integer }, err.kind());
175//         assert_eq!(Some(Length::from(2u8)), err.position());
176//     }
177
178//     #[test]
179//     fn trailing_data() {
180//         let mut decoder = Decoder::new(&[0x02, 0x01, 0x2A, 0x00]);
181//         let x = decoder.decode().unwrap();
182//         assert_eq!(42i8, x);
183
184//         let err = decoder.finish(x).err().unwrap();
185//         assert_eq!(
186//             ErrorKind::TrailingData {
187//                 decoded: 3u8.into(),
188//                 remaining: 1u8.into()
189//             },
190//             err.kind()
191//         );
192//         assert_eq!(Some(Length::from(3u8)), err.position());
193//     }
194// }