1use crate::{header::Header, Encodable, ErrorKind, Length, Result, Tag};
2use core::convert::{TryFrom, TryInto};
3
4#[derive(Debug)]
6pub struct Encoder<'a> {
7 bytes: Option<&'a mut [u8]>,
9
10 position: Length,
12}
13
14impl<'a> Encoder<'a> {
15 pub fn new(bytes: &'a mut [u8]) -> Self {
17 Self {
18 bytes: Some(bytes),
19 position: Length::zero(),
20 }
21 }
22
23 pub fn encode<T: Encodable>(&mut self, encodable: &T) -> Result<()> {
25 if self.is_failed() {
26 self.error(ErrorKind::Failed)?;
27 }
28
29 encodable.encode(self).map_err(|e| {
30 self.bytes.take();
31 e.nested(self.position)
32 })
33 }
34
35 pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
38 self.bytes.take();
39 Err(kind.at(self.position))
40 }
41
42 pub fn is_failed(&self) -> bool {
44 self.bytes.is_none()
45 }
46
47 pub fn finish(self) -> Result<&'a [u8]> {
50 let position = self.position;
51
52 match self.bytes {
53 Some(bytes) => bytes
54 .get(..self.position.into())
55 .ok_or_else(|| ErrorKind::Truncated.at(position)),
56 None => Err(ErrorKind::Failed.at(position)),
57 }
58 }
59
60 pub fn encode_tagged_collection(
62 &mut self,
63 tag: Tag,
64 encodables: &[&dyn Encodable],
65 ) -> Result<()> {
66 let expected_len = Length::try_from(encodables)?;
67 Header::new(tag, expected_len).and_then(|header| header.encode(self))?;
68
69 let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
70
71 for encodable in encodables {
72 encodable.encode(&mut nested_encoder)?;
73 }
74
75 if nested_encoder.finish()?.len() == expected_len.into() {
76 Ok(())
77 } else {
78 self.error(ErrorKind::Length { tag })
79 }
80 }
81
82 pub fn encode_untagged_collection(&mut self, encodables: &[&dyn Encodable]) -> Result<()> {
84 let expected_len = Length::try_from(encodables)?;
85 let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
86
87 for encodable in encodables {
88 encodable.encode(&mut nested_encoder)?;
89 }
90 Ok(())
91 }
92
93 pub(crate) fn byte(&mut self, byte: u8) -> Result<()> {
95 match self.reserve(1u8)?.first_mut() {
96 Some(b) => {
97 *b = byte;
98 Ok(())
99 }
100 None => self.error(ErrorKind::Truncated),
101 }
102 }
103
104 pub(crate) fn bytes(&mut self, slice: &[u8]) -> Result<()> {
106 self.reserve(slice.len())?.copy_from_slice(slice);
107 Ok(())
108 }
109
110 fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
113 let len = len
114 .try_into()
115 .or_else(|_| self.error(ErrorKind::Overflow))?;
116
117 if len > self.remaining_len()? {
118 self.error(ErrorKind::Overlength)?;
119 }
120
121 let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
122 let range = self.position.into()..end.into();
123 let position = &mut self.position;
124
125 let slice = &mut self.bytes.as_mut().expect("DER encoder tainted")[range];
134 *position = end;
135
136 Ok(slice)
137 }
138
139 fn buffer_len(&self) -> Result<Length> {
141 self.bytes
142 .as_ref()
143 .map(|bytes| bytes.len())
144 .ok_or_else(|| ErrorKind::Failed.at(self.position))
145 .and_then(TryInto::try_into)
146 }
147
148 fn remaining_len(&self) -> Result<Length> {
150 self.buffer_len()?
151 .to_usize()
152 .checked_sub(self.position.into())
153 .ok_or_else(|| ErrorKind::Truncated.at(self.position))
154 .and_then(TryInto::try_into)
155 }
156}
157
158#[cfg(test)]
159mod tests {
160 use crate::{Encodable, Tag, TaggedSlice};
161
162 #[test]
163 fn zero_length() {
164 let tv = TaggedSlice::from(Tag::universal(5), &[]).unwrap();
165 let mut buf = [0u8; 4];
166 assert_eq!(tv.encode_to_slice(&mut buf).unwrap(), &[0x5, 0x00]);
167
168 let tv = TaggedSlice::from(Tag::application(5).constructed(), &[]).unwrap();
169 let mut buf = [0u8; 4];
170 assert_eq!(
171 tv.encode_to_slice(&mut buf).unwrap(),
172 &[(0b01 << 6) | (1 << 5) | 5, 0x00]
173 );
174 }
175
176 }