use crate::{asn1::*, message, Encodable, Error, ErrorKind, Header, Length, Result, Tag};
use core::convert::{TryFrom, TryInto};
#[derive(Debug)]
pub struct Encoder<'a> {
bytes: Option<&'a mut [u8]>,
position: Length,
}
impl<'a> Encoder<'a> {
pub fn new(bytes: &'a mut [u8]) -> Self {
Self {
bytes: Some(bytes),
position: Length::ZERO,
}
}
pub fn encode<T: Encodable>(&mut self, encodable: &T) -> Result<()> {
if self.is_failed() {
self.error(ErrorKind::Failed)?;
}
encodable.encode(self).map_err(|e| {
self.bytes.take();
e.nested(self.position)
})
}
pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
self.bytes.take();
Err(kind.at(self.position))
}
pub fn value_error(&mut self, tag: Tag) -> Error {
self.bytes.take();
tag.value_error().kind().at(self.position)
}
pub fn is_failed(&self) -> bool {
self.bytes.is_none()
}
pub fn finish(self) -> Result<&'a [u8]> {
let pos = self.position;
let range = ..usize::try_from(self.position)?;
match self.bytes {
Some(bytes) => bytes.get(range).ok_or_else(|| ErrorKind::Truncated.at(pos)),
None => Err(ErrorKind::Failed.at(pos)),
}
}
pub fn bit_string(&mut self, value: impl TryInto<BitString<'a>>) -> Result<()> {
value
.try_into()
.map_err(|_| self.value_error(Tag::BitString))
.and_then(|value| self.encode(&value))
}
pub fn generalized_time(&mut self, value: impl TryInto<GeneralizedTime>) -> Result<()> {
value
.try_into()
.map_err(|_| self.value_error(Tag::GeneralizedTime))
.and_then(|value| self.encode(&value))
}
pub fn ia5_string(&mut self, value: impl TryInto<Ia5String<'a>>) -> Result<()> {
value
.try_into()
.map_err(|_| self.value_error(Tag::Ia5String))
.and_then(|value| self.encode(&value))
}
pub fn message(&mut self, fields: &[&dyn Encodable]) -> Result<()> {
let length = message::encoded_len_inner(fields)?;
self.sequence(length, |nested_encoder| {
for field in fields {
field.encode(nested_encoder)?;
}
Ok(())
})
}
pub fn null(&mut self) -> Result<()> {
self.encode(&Null)
}
pub fn octet_string(&mut self, value: impl TryInto<OctetString<'a>>) -> Result<()> {
value
.try_into()
.map_err(|_| self.value_error(Tag::OctetString))
.and_then(|value| self.encode(&value))
}
#[cfg(feature = "oid")]
#[cfg_attr(docsrs, doc(cfg(feature = "oid")))]
pub fn oid(&mut self, value: impl TryInto<ObjectIdentifier>) -> Result<()> {
value
.try_into()
.map_err(|_| self.value_error(Tag::ObjectIdentifier))
.and_then(|value| self.encode(&value))
}
pub fn printable_string(&mut self, value: impl TryInto<PrintableString<'a>>) -> Result<()> {
value
.try_into()
.map_err(|_| self.value_error(Tag::PrintableString))
.and_then(|value| self.encode(&value))
}
pub fn sequence<F>(&mut self, length: Length, f: F) -> Result<()>
where
F: FnOnce(&mut Encoder<'_>) -> Result<()>,
{
Header::new(Tag::Sequence, length).and_then(|header| header.encode(self))?;
let mut nested_encoder = Encoder::new(self.reserve(length)?);
f(&mut nested_encoder)?;
if nested_encoder.finish()?.len() == length.try_into()? {
Ok(())
} else {
self.error(ErrorKind::Length { tag: Tag::Sequence })
}
}
pub fn utc_time(&mut self, value: impl TryInto<UtcTime>) -> Result<()> {
value
.try_into()
.map_err(|_| self.value_error(Tag::UtcTime))
.and_then(|value| self.encode(&value))
}
pub fn utf8_string(&mut self, value: impl TryInto<Utf8String<'a>>) -> Result<()> {
value
.try_into()
.map_err(|_| self.value_error(Tag::Utf8String))
.and_then(|value| self.encode(&value))
}
fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
let len = len
.try_into()
.or_else(|_| self.error(ErrorKind::Overflow))?;
if len > self.remaining_len()? {
self.error(ErrorKind::Overlength)?;
}
let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
let range = self.position.try_into()?..end.try_into()?;
let position = &mut self.position;
let slice = &mut self.bytes.as_mut().expect("DER encoder tainted")[range];
*position = end;
Ok(slice)
}
pub(crate) fn byte(&mut self, byte: u8) -> Result<()> {
match self.reserve(1u8)?.first_mut() {
Some(b) => {
*b = byte;
Ok(())
}
None => self.error(ErrorKind::Truncated),
}
}
pub(crate) fn bytes(&mut self, slice: &[u8]) -> Result<()> {
self.reserve(slice.len())?.copy_from_slice(slice);
Ok(())
}
fn buffer_len(&self) -> Result<Length> {
self.bytes
.as_ref()
.map(|bytes| bytes.len())
.ok_or_else(|| ErrorKind::Failed.at(self.position))
.and_then(TryInto::try_into)
}
fn remaining_len(&self) -> Result<Length> {
let buffer_len = usize::try_from(self.buffer_len()?)?;
buffer_len
.checked_sub(self.position.try_into()?)
.ok_or_else(|| ErrorKind::Truncated.at(self.position))
.and_then(TryInto::try_into)
}
}
#[cfg(test)]
mod tests {
use super::Encoder;
use crate::{Encodable, ErrorKind, Length};
#[test]
fn overlength_message() {
let mut buffer = [];
let mut encoder = Encoder::new(&mut buffer);
let err = false.encode(&mut encoder).err().unwrap();
assert_eq!(err.kind(), ErrorKind::Overlength);
assert_eq!(err.position(), Some(Length::ZERO));
}
}