simple_tlv/
encoder.rs

1use core::convert::{TryFrom, TryInto};
2use crate::{Encodable, ErrorKind, header::Header, Length, Result, Tag};
3
4/// SIMPLE-TLV encoder.
5#[derive(Debug)]
6pub struct Encoder<'a> {
7    /// Buffer into which SIMPLE-TLV-encoded message is written
8    bytes: Option<&'a mut [u8]>,
9
10    /// Total number of bytes written to buffer so far
11    position: Length,
12}
13
14impl<'a> Encoder<'a> {
15    /// Create a new encoder with the given byte slice as a backing buffer.
16    pub fn new(bytes: &'a mut [u8]) -> Self {
17        Self {
18            bytes: Some(bytes),
19            position: Length::zero(),
20        }
21    }
22
23    /// Encode a value which impls the [`Encodable`] trait.
24    pub fn encode<T: Encodable>(&mut self, encodable: &T) -> Result<()> {
25        if self.is_failed() {
26            self.error(ErrorKind::Failed)?;
27        }
28
29        encodable.encode(self).map_err(|e| {
30            self.bytes.take();
31            e.nested(self.position)
32        })
33    }
34
35    /// Return an error with the given [`ErrorKind`], annotating it with
36    /// context about where the error occurred.
37    pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
38        self.bytes.take();
39        Err(kind.at(self.position))
40    }
41
42    /// Did the decoding operation fail due to an error?
43    pub fn is_failed(&self) -> bool {
44        self.bytes.is_none()
45    }
46
47    /// Finish encoding to the buffer, returning a slice containing the data
48    /// written to the buffer.
49    pub fn finish(self) -> Result<&'a [u8]> {
50        let position = self.position;
51
52        match self.bytes {
53            Some(bytes) => bytes
54                .get(..self.position.into())
55                .ok_or_else(|| ErrorKind::Truncated.at(position)),
56            None => Err(ErrorKind::Failed.at(position)),
57        }
58    }
59
60    /// Encode a collection of values which impl the [`Encodable`] trait under a given tag.
61    pub fn encode_tagged_collection(&mut self, tag: Tag, encodables: &[&dyn Encodable]) -> Result<()> {
62        let expected_len = Length::try_from(encodables)?;
63        Header::new(tag, expected_len).and_then(|header| header.encode(self))?;
64
65        let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
66
67        for encodable in encodables {
68            encodable.encode(&mut nested_encoder)?;
69        }
70
71        if nested_encoder.finish()?.len() == expected_len.into() {
72            Ok(())
73        } else {
74            self.error(ErrorKind::Length { tag })
75        }
76    }
77
78    /// Encode a collection of values which impl the [`Encodable`] trait under a given tag.
79    pub fn encode_untagged_collection(&mut self, encodables: &[&dyn Encodable]) -> Result<()> {
80        let expected_len = Length::try_from(encodables)?;
81        let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
82
83        for encodable in encodables {
84            encodable.encode(&mut nested_encoder)?;
85        }
86        Ok(())
87    }
88
89    /// Encode a single byte into the backing buffer.
90    pub(crate) fn byte(&mut self, byte: u8) -> Result<()> {
91        match self.reserve(1u8)?.first_mut() {
92            Some(b) => {
93                *b = byte;
94                Ok(())
95            }
96            None => self.error(ErrorKind::Truncated),
97        }
98    }
99
100    /// Encode the provided byte slice into the backing buffer.
101    pub(crate) fn bytes(&mut self, slice: &[u8]) -> Result<()> {
102        self.reserve(slice.len())?.copy_from_slice(slice);
103        Ok(())
104    }
105
106    /// Reserve a portion of the internal buffer, updating the internal cursor
107    /// position and returning a mutable slice.
108    fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
109        let len = len
110            .try_into()
111            .or_else(|_| self.error(ErrorKind::Overflow))?;
112
113        if len > self.remaining_len()? {
114            self.error(ErrorKind::Overlength)?;
115        }
116
117        let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
118        let range = self.position.into()..end.into();
119        let position = &mut self.position;
120
121        // TODO(tarcieri): non-panicking version of this code
122        // We ensure above that the buffer is untainted and there is sufficient
123        // space to perform this slicing operation, however it would be nice to
124        // have fully panic-free code.
125        //
126        // Unfortunately tainting the buffer on error is tricky to do when
127        // potentially holding a reference to the buffer, and failure to taint
128        // it would not uphold the invariant that any errors should taint it.
129        let slice = &mut self.bytes.as_mut().expect("DER encoder tainted")[range];
130        *position = end;
131
132        Ok(slice)
133    }
134
135    /// Get the size of the buffer in bytes.
136    fn buffer_len(&self) -> Result<Length> {
137        self.bytes
138            .as_ref()
139            .map(|bytes| bytes.len())
140            .ok_or_else(|| ErrorKind::Failed.at(self.position))
141            .and_then(TryInto::try_into)
142    }
143
144    /// Get the number of bytes still remaining in the buffer.
145    fn remaining_len(&self) -> Result<Length> {
146        self.buffer_len()?
147            .to_usize()
148            .checked_sub(self.position.into())
149            .ok_or_else(|| ErrorKind::Truncated.at(self.position))
150            .and_then(TryInto::try_into)
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use core::convert::TryFrom;
157    use crate::{Encodable, Tag, TaggedSlice};
158
159    #[test]
160    fn zero_length() {
161        let tv = TaggedSlice::from(Tag::try_from(42).unwrap(), &[]).unwrap();
162        let mut buf = [0u8; 4];
163        assert_eq!(tv.encode_to_slice(&mut buf).unwrap(), &[0x2A, 0x00]);
164    }
165}
166
167//     use super::Encoder;
168//     use crate::{Encodable, ErrorKind, Length};
169
170//     #[test]
171//     fn overlength_message() {
172//         let mut buffer = [];
173//         let mut encoder = Encoder::new(&mut buffer);
174//         let err = false.encode(&mut encoder).err().unwrap();
175//         assert_eq!(err.kind(), ErrorKind::Overlength);
176//         assert_eq!(err.position(), Some(Length::zero()));
177//     }
178// }