flexiber/
encoder.rs

1use crate::{header::Header, Encodable, ErrorKind, Length, Result, Tag};
2use core::convert::{TryFrom, TryInto};
3
4/// BER-TLV encoder.
5#[derive(Debug)]
6pub struct Encoder<'a> {
7    /// Buffer into which BER-TLV-encoded message is written
8    bytes: Option<&'a mut [u8]>,
9
10    /// Total number of bytes written to buffer so far
11    position: Length,
12}
13
14impl<'a> Encoder<'a> {
15    /// Create a new encoder with the given byte slice as a backing buffer.
16    pub fn new(bytes: &'a mut [u8]) -> Self {
17        Self {
18            bytes: Some(bytes),
19            position: Length::zero(),
20        }
21    }
22
23    /// Encode a value which impls the [`Encodable`] trait.
24    pub fn encode<T: Encodable>(&mut self, encodable: &T) -> Result<()> {
25        if self.is_failed() {
26            self.error(ErrorKind::Failed)?;
27        }
28
29        encodable.encode(self).map_err(|e| {
30            self.bytes.take();
31            e.nested(self.position)
32        })
33    }
34
35    /// Return an error with the given [`ErrorKind`], annotating it with
36    /// context about where the error occurred.
37    pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
38        self.bytes.take();
39        Err(kind.at(self.position))
40    }
41
42    /// Did the decoding operation fail due to an error?
43    pub fn is_failed(&self) -> bool {
44        self.bytes.is_none()
45    }
46
47    /// Finish encoding to the buffer, returning a slice containing the data
48    /// written to the buffer.
49    pub fn finish(self) -> Result<&'a [u8]> {
50        let position = self.position;
51
52        match self.bytes {
53            Some(bytes) => bytes
54                .get(..self.position.into())
55                .ok_or_else(|| ErrorKind::Truncated.at(position)),
56            None => Err(ErrorKind::Failed.at(position)),
57        }
58    }
59
60    /// Encode a collection of values which impl the [`Encodable`] trait under a given tag.
61    pub fn encode_tagged_collection(
62        &mut self,
63        tag: Tag,
64        encodables: &[&dyn Encodable],
65    ) -> Result<()> {
66        let expected_len = Length::try_from(encodables)?;
67        Header::new(tag, expected_len).and_then(|header| header.encode(self))?;
68
69        let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
70
71        for encodable in encodables {
72            encodable.encode(&mut nested_encoder)?;
73        }
74
75        if nested_encoder.finish()?.len() == expected_len.into() {
76            Ok(())
77        } else {
78            self.error(ErrorKind::Length { tag })
79        }
80    }
81
82    /// Encode a collection of values which impl the [`Encodable`] trait under a given tag.
83    pub fn encode_untagged_collection(&mut self, encodables: &[&dyn Encodable]) -> Result<()> {
84        let expected_len = Length::try_from(encodables)?;
85        let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
86
87        for encodable in encodables {
88            encodable.encode(&mut nested_encoder)?;
89        }
90        Ok(())
91    }
92
93    /// Encode a single byte into the backing buffer.
94    pub(crate) fn byte(&mut self, byte: u8) -> Result<()> {
95        match self.reserve(1u8)?.first_mut() {
96            Some(b) => {
97                *b = byte;
98                Ok(())
99            }
100            None => self.error(ErrorKind::Truncated),
101        }
102    }
103
104    /// Encode the provided byte slice into the backing buffer.
105    pub(crate) fn bytes(&mut self, slice: &[u8]) -> Result<()> {
106        self.reserve(slice.len())?.copy_from_slice(slice);
107        Ok(())
108    }
109
110    /// Reserve a portion of the internal buffer, updating the internal cursor
111    /// position and returning a mutable slice.
112    fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
113        let len = len
114            .try_into()
115            .or_else(|_| self.error(ErrorKind::Overflow))?;
116
117        if len > self.remaining_len()? {
118            self.error(ErrorKind::Overlength)?;
119        }
120
121        let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
122        let range = self.position.into()..end.into();
123        let position = &mut self.position;
124
125        // TODO(tarcieri): non-panicking version of this code
126        // We ensure above that the buffer is untainted and there is sufficient
127        // space to perform this slicing operation, however it would be nice to
128        // have fully panic-free code.
129        //
130        // Unfortunately tainting the buffer on error is tricky to do when
131        // potentially holding a reference to the buffer, and failure to taint
132        // it would not uphold the invariant that any errors should taint it.
133        let slice = &mut self.bytes.as_mut().expect("DER encoder tainted")[range];
134        *position = end;
135
136        Ok(slice)
137    }
138
139    /// Get the size of the buffer in bytes.
140    fn buffer_len(&self) -> Result<Length> {
141        self.bytes
142            .as_ref()
143            .map(|bytes| bytes.len())
144            .ok_or_else(|| ErrorKind::Failed.at(self.position))
145            .and_then(TryInto::try_into)
146    }
147
148    /// Get the number of bytes still remaining in the buffer.
149    fn remaining_len(&self) -> Result<Length> {
150        self.buffer_len()?
151            .to_usize()
152            .checked_sub(self.position.into())
153            .ok_or_else(|| ErrorKind::Truncated.at(self.position))
154            .and_then(TryInto::try_into)
155    }
156}
157
158#[cfg(test)]
159mod tests {
160    use crate::{Encodable, Tag, TaggedSlice};
161
162    #[test]
163    fn zero_length() {
164        let tv = TaggedSlice::from(Tag::universal(5), &[]).unwrap();
165        let mut buf = [0u8; 4];
166        assert_eq!(tv.encode_to_slice(&mut buf).unwrap(), &[0x5, 0x00]);
167
168        let tv = TaggedSlice::from(Tag::application(5).constructed(), &[]).unwrap();
169        let mut buf = [0u8; 4];
170        assert_eq!(
171            tv.encode_to_slice(&mut buf).unwrap(),
172            &[(0b01 << 6) | (1 << 5) | 5, 0x00]
173        );
174    }
175
176    // use super::Encoder;
177    // use crate::{ErrorKind, Length};
178
179    // #[test]
180    // fn overlength_message() {
181    //     let mut buffer = [];
182    //     let mut encoder = Encoder::new(&mut buffer);
183    //     let err = false.encode(&mut encoder).err().unwrap();
184    //     assert_eq!(err.kind(), ErrorKind::Overlength);
185    //     assert_eq!(err.position(), Some(Length::zero()));
186    // }
187}