use core::convert::{TryFrom, TryInto};
use crate::{Encodable, ErrorKind, header::Header, Length, Result, Tag};
#[derive(Debug)]
pub struct Encoder<'a> {
bytes: Option<&'a mut [u8]>,
position: Length,
}
impl<'a> Encoder<'a> {
pub fn new(bytes: &'a mut [u8]) -> Self {
Self {
bytes: Some(bytes),
position: Length::zero(),
}
}
pub fn encode<T: Encodable>(&mut self, encodable: &T) -> Result<()> {
if self.is_failed() {
self.error(ErrorKind::Failed)?;
}
encodable.encode(self).map_err(|e| {
self.bytes.take();
e.nested(self.position)
})
}
pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
self.bytes.take();
Err(kind.at(self.position))
}
pub fn is_failed(&self) -> bool {
self.bytes.is_none()
}
pub fn finish(self) -> Result<&'a [u8]> {
let position = self.position;
match self.bytes {
Some(bytes) => bytes
.get(..self.position.into())
.ok_or_else(|| ErrorKind::Truncated.at(position)),
None => Err(ErrorKind::Failed.at(position)),
}
}
pub fn encode_tagged_collection(&mut self, tag: Tag, encodables: &[&dyn Encodable]) -> Result<()> {
let expected_len = Length::try_from(encodables)?;
Header::new(tag, expected_len).and_then(|header| header.encode(self))?;
let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
for encodable in encodables {
encodable.encode(&mut nested_encoder)?;
}
if nested_encoder.finish()?.len() == expected_len.into() {
Ok(())
} else {
self.error(ErrorKind::Length { tag })
}
}
pub fn encode_untagged_collection(&mut self, encodables: &[&dyn Encodable]) -> Result<()> {
let expected_len = Length::try_from(encodables)?;
let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
for encodable in encodables {
encodable.encode(&mut nested_encoder)?;
}
Ok(())
}
pub(crate) fn byte(&mut self, byte: u8) -> Result<()> {
match self.reserve(1u8)?.first_mut() {
Some(b) => {
*b = byte;
Ok(())
}
None => self.error(ErrorKind::Truncated),
}
}
pub(crate) fn bytes(&mut self, slice: &[u8]) -> Result<()> {
self.reserve(slice.len())?.copy_from_slice(slice);
Ok(())
}
fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
let len = len
.try_into()
.or_else(|_| self.error(ErrorKind::Overflow))?;
if len > self.remaining_len()? {
self.error(ErrorKind::Overlength)?;
}
let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
let range = self.position.into()..end.into();
let position = &mut self.position;
let slice = &mut self.bytes.as_mut().expect("DER encoder tainted")[range];
*position = end;
Ok(slice)
}
fn buffer_len(&self) -> Result<Length> {
self.bytes
.as_ref()
.map(|bytes| bytes.len())
.ok_or_else(|| ErrorKind::Failed.at(self.position))
.and_then(TryInto::try_into)
}
fn remaining_len(&self) -> Result<Length> {
self.buffer_len()?
.to_usize()
.checked_sub(self.position.into())
.ok_or_else(|| ErrorKind::Truncated.at(self.position))
.and_then(TryInto::try_into)
}
}
#[cfg(test)]
mod tests {
use crate::{Encodable, Tag, TaggedSlice};
#[test]
fn zero_length() {
let tv = TaggedSlice::from(Tag::universal(5), &[]).unwrap();
let mut buf = [0u8; 4];
assert_eq!(tv.encode_to_slice(&mut buf).unwrap(), &[0x5, 0x00]);
let tv = TaggedSlice::from(Tag::application(5).constructed(), &[]).unwrap();
let mut buf = [0u8; 4];
assert_eq!(tv.encode_to_slice(&mut buf).unwrap(), &[(0b01 << 6) | (1 << 5) | 5, 0x00]);
}
}