use crate::{
asn1::sequence, BitString, Encodable, ErrorKind, Header, Length, Null, OctetString, Result, Tag,
};
use core::convert::TryInto;
#[cfg(feature = "oid")]
use crate::ObjectIdentifier;
#[derive(Debug)]
pub struct Encoder<'a> {
bytes: Option<&'a mut [u8]>,
position: Length,
}
impl<'a> Encoder<'a> {
pub fn new(bytes: &'a mut [u8]) -> Self {
Self {
bytes: Some(bytes),
position: Length::zero(),
}
}
pub fn encode<T: Encodable>(&mut self, encodable: &T) -> Result<()> {
if self.is_failed() {
self.error(ErrorKind::Failed)?;
}
encodable.encode(self).map_err(|e| {
self.bytes.take();
e.nested(self.position)
})
}
pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
self.bytes.take();
Err(kind.at(self.position))
}
pub fn is_failed(&self) -> bool {
self.bytes.is_none()
}
pub fn finish(self) -> Result<&'a [u8]> {
let position = self.position;
match self.bytes {
Some(bytes) => bytes
.get(..self.position.into())
.ok_or_else(|| ErrorKind::Truncated.at(position)),
None => Err(ErrorKind::Failed.at(position)),
}
}
pub fn bit_string(&mut self, value: impl TryInto<BitString<'a>>) -> Result<()> {
value
.try_into()
.or_else(|_| {
self.error(ErrorKind::Value {
tag: Tag::BitString,
})
})
.and_then(|value| self.encode(&value))
}
pub fn null(&mut self) -> Result<()> {
self.encode(&Null)
}
pub fn octet_string(&mut self, value: impl TryInto<OctetString<'a>>) -> Result<()> {
value
.try_into()
.or_else(|_| {
self.error(ErrorKind::Value {
tag: Tag::OctetString,
})
})
.and_then(|value| self.encode(&value))
}
#[cfg(feature = "oid")]
#[cfg_attr(docsrs, doc(cfg(feature = "oid")))]
pub fn oid(&mut self, oid: impl TryInto<ObjectIdentifier>) -> Result<()> {
let oid: ObjectIdentifier = oid.try_into().or_else(|_| {
self.error(ErrorKind::Value {
tag: Tag::ObjectIdentifier,
})
})?;
let expected_len = oid.ber_len();
Header::new(Tag::ObjectIdentifier, expected_len).and_then(|header| header.encode(self))?;
let buffer = self.reserve(expected_len)?;
if oid.write_ber(buffer)?.len() == expected_len {
Ok(())
} else {
self.error(ErrorKind::Length {
tag: Tag::ObjectIdentifier,
})
}
}
pub fn sequence(&mut self, encodables: &[&dyn Encodable]) -> Result<()> {
let expected_len = sequence::encoded_len_inner(encodables)?;
Header::new(Tag::Sequence, expected_len).and_then(|header| header.encode(self))?;
let mut nested_encoder = Encoder::new(self.reserve(expected_len)?);
for encodable in encodables {
encodable.encode(&mut nested_encoder)?;
}
if nested_encoder.finish()?.len() == expected_len.into() {
Ok(())
} else {
self.error(ErrorKind::Length {
tag: Tag::ObjectIdentifier,
})
}
}
pub(crate) fn byte(&mut self, byte: u8) -> Result<()> {
match self.reserve(1u8)?.first_mut() {
Some(b) => {
*b = byte;
Ok(())
}
None => self.error(ErrorKind::Truncated),
}
}
pub(crate) fn bytes(&mut self, slice: &[u8]) -> Result<()> {
self.reserve(slice.len())?.copy_from_slice(slice);
Ok(())
}
fn reserve(&mut self, len: impl TryInto<Length>) -> Result<&mut [u8]> {
let len = len
.try_into()
.or_else(|_| self.error(ErrorKind::Overflow))?;
if len > self.remaining_len()? {
self.error(ErrorKind::Overlength)?;
}
let end = (self.position + len).or_else(|e| self.error(e.kind()))?;
let range = self.position.into()..end.into();
let position = &mut self.position;
let slice = &mut self.bytes.as_mut().expect("DER encoder tainted")[range];
*position = end;
Ok(slice)
}
fn buffer_len(&self) -> Result<Length> {
self.bytes
.as_ref()
.map(|bytes| bytes.len())
.ok_or_else(|| ErrorKind::Failed.at(self.position))
.and_then(TryInto::try_into)
}
fn remaining_len(&self) -> Result<Length> {
self.buffer_len()?
.to_usize()
.checked_sub(self.position.into())
.ok_or_else(|| ErrorKind::Truncated.at(self.position))
.and_then(TryInto::try_into)
}
}
#[cfg(test)]
mod tests {
use super::Encoder;
use crate::{Encodable, ErrorKind, Length};
#[test]
fn overlength_message() {
let mut buffer = [];
let mut encoder = Encoder::new(&mut buffer);
let err = false.encode(&mut encoder).err().unwrap();
assert_eq!(err.kind(), ErrorKind::Overlength);
assert_eq!(err.position(), Some(Length::zero()));
}
}