use std::{cmp, hash, io, mem};
use bytes::{BytesMut, Bytes};
use crate::captured::Captured;
use crate::{decode, encode};
use crate::mode::Mode;
use crate::length::Length;
use crate::tag::Tag;
#[derive(Clone, Debug)]
pub struct OctetString(Inner<Bytes, Captured>);
#[derive(Clone, Debug)]
enum Inner<P, C> {
Primitive(P),
Constructed(C),
}
impl OctetString {
pub fn new(bytes: Bytes) -> Self {
OctetString(Inner::Primitive(bytes))
}
pub fn iter(&self) -> OctetStringIter {
match self.0 {
Inner::Primitive(ref inner) => {
OctetStringIter(Inner::Primitive(inner.as_ref()))
}
Inner::Constructed(ref inner) => {
OctetStringIter(Inner::Constructed(inner.as_ref()))
}
}
}
pub fn octets(&self) -> OctetStringOctets {
OctetStringOctets::new(self.iter())
}
pub fn as_slice(&self) -> Option<&[u8]> {
match self.0 {
Inner::Primitive(ref inner) => Some(inner.as_ref()),
Inner::Constructed(_) => None
}
}
pub fn to_bytes(&self) -> Bytes {
if let Inner::Primitive(ref inner) = self.0 {
return inner.clone()
}
let mut res = BytesMut::new();
self.iter().for_each(|x| res.extend_from_slice(x));
res.freeze()
}
pub fn into_bytes(self) -> Bytes {
if let Inner::Primitive(inner) = self.0 {
return inner
}
let mut res = BytesMut::new();
self.iter().for_each(|x| res.extend_from_slice(x));
res.freeze()
}
pub fn len(&self) -> usize {
if let Inner::Primitive(ref inner) = self.0 {
return inner.len()
}
self.iter().fold(0, |len, x| len + x.len())
}
pub fn is_empty(&self) -> bool {
if let Inner::Primitive(ref inner) = self.0 {
return inner.is_empty()
}
!self.iter().any(|s| !s.is_empty())
}
pub fn to_source(&self) -> OctetStringSource {
OctetStringSource::new(self)
}
}
impl OctetString {
pub fn take_from<S: decode::Source>(
cons: &mut decode::Constructed<S>
) -> Result<Self, S::Err> {
cons.take_value_if(Tag::OCTET_STRING, Self::from_content)
}
pub fn from_content<S: decode::Source>(
content: &mut decode::Content<S>
) -> Result<Self, S::Err> {
match *content {
decode::Content::Primitive(ref mut inner) => {
if inner.mode() == Mode::Cer && inner.remaining() > 1000 {
xerr!(return Err(decode::Error::Malformed.into()))
}
Ok(OctetString(Inner::Primitive(inner.take_all()?)))
}
decode::Content::Constructed(ref mut inner) => {
match inner.mode() {
Mode::Ber => Self::take_constructed_ber(inner),
Mode::Cer => Self::take_constructed_cer(inner),
Mode::Der => {
xerr!(Err(decode::Error::Malformed.into()))
}
}
}
}
}
fn take_constructed_ber<S: decode::Source>(
constructed: &mut decode::Constructed<S>
) -> Result<Self, S::Err> {
constructed.capture(|constructed| skip_nested(constructed))
.map(|captured| OctetString(Inner::Constructed(captured)))
}
fn take_constructed_cer<S: decode::Source>(
constructed: &mut decode::Constructed<S>
) -> Result<Self, S::Err> {
let mut short = false;
constructed.capture(|con| {
while let Some(()) = con.take_opt_primitive_if(Tag::OCTET_STRING,
|primitive| {
if primitive.remaining() > 1000 {
xerr!(return Err(decode::Error::Malformed.into()));
}
if primitive.remaining() < 1000 {
if short {
xerr!(return Err(decode::Error::Malformed.into()));
}
short = true
}
primitive.skip_all()
})? { }
Ok(())
}).map(|captured| OctetString(Inner::Constructed(captured)))
}
pub fn encode(self) -> impl encode::Values {
self.encode_as(Tag::OCTET_STRING)
}
pub fn encode_as(self, tag: Tag) -> impl encode::Values {
OctetStringEncoder::new(self, tag)
}
pub fn encode_ref<'a>(&'a self) -> impl encode::Values + 'a {
self.encode_ref_as(Tag::OCTET_STRING)
}
pub fn encode_ref_as<'a>(&'a self, tag: Tag) -> impl encode::Values + 'a {
OctetStringEncoder::new(self, tag)
}
pub fn encode_wrapped<V: encode::Values>(
mode: Mode,
values: V
) -> impl encode::Values {
WrappingOctetStringEncoder::new(mode, values)
}
pub fn encode_slice<T>(value: T) -> OctetSliceEncoder<T> {
Self::encode_slice_as(value, Tag::OCTET_STRING)
}
pub fn encode_slice_as<T>(value: T, tag: Tag) -> OctetSliceEncoder<T> {
OctetSliceEncoder::new(value, tag)
}
}
impl AsRef<OctetString> for OctetString {
fn as_ref(&self) -> &Self {
self
}
}
impl PartialEq for OctetString {
fn eq(&self, other: &OctetString) -> bool {
if let (Some(l), Some(r)) = (self.as_slice(), other.as_slice()) {
return l == r
}
let mut sit = self.iter();
let mut oit = other.iter();
let (mut ssl, mut osl) = match (sit.next(), oit.next()) {
(Some(ssl), Some(osl)) => (ssl, osl),
(None, None) => return true,
_ => return false,
};
loop {
if ssl.is_empty() {
ssl = sit.next().unwrap_or(b"");
}
if osl.is_empty() {
osl = oit.next().unwrap_or(b"");
}
match (ssl.is_empty(), osl.is_empty()) {
(true, true) => return true,
(false, false) => { },
_ => return false,
}
let len = cmp::min(ssl.len(), osl.len());
if ssl[..len] != osl[..len] {
return false
}
ssl = &ssl[len..];
osl = &osl[len..];
}
}
}
impl<T: AsRef<[u8]>> PartialEq<T> for OctetString {
fn eq(&self, other: &T) -> bool {
let mut other = other.as_ref();
if let Some(slice) = self.as_slice() {
return slice == other
}
for part in self.iter() {
if part.len() > other.len() {
return false
}
if part.len() == other.len() {
return part == other
}
if part != &other[..part.len()] {
return false
}
other = &other[part.len()..]
}
false
}
}
impl Eq for OctetString { }
impl PartialOrd for OctetString {
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
Some(self.cmp(other))
}
}
impl<T: AsRef<[u8]>> PartialOrd<T> for OctetString {
fn partial_cmp(&self, other: &T) -> Option<cmp::Ordering> {
let mut other = other.as_ref();
if let Some(slice ) = self.as_slice() {
return slice.partial_cmp(other)
}
for part in self.iter() {
if part.len() >= other.len() {
return Some(part.cmp(other))
}
match part.cmp(&other[..part.len()]) {
cmp::Ordering::Equal => { }
other => return Some(other)
}
other = &other[part.len()..]
}
Some(cmp::Ordering::Less)
}
}
impl Ord for OctetString {
fn cmp(&self, other: &Self) -> cmp::Ordering {
if let (Some(l), Some(r)) = (self.as_slice(), other.as_slice()) {
return l.cmp(&r)
}
let mut siter = self.iter();
let mut oiter = other.iter();
let mut spart = b"".as_ref();
let mut opart = b"".as_ref();
loop {
if spart.is_empty() {
spart = siter.next().unwrap_or(b"");
}
if opart.is_empty() {
opart = oiter.next().unwrap_or(b"");
}
match (spart.is_empty(), opart.is_empty()) {
(true, true) => return cmp::Ordering::Equal,
(true, false) => return cmp::Ordering::Less,
(false, true) => return cmp::Ordering::Greater,
(false, false) => { },
}
let len = cmp::min(spart.len(), opart.len());
match spart[..len].cmp(&opart[..len]) {
cmp::Ordering::Equal => { }
other => return other
}
spart = &spart[len..];
opart = &opart[len..];
}
}
}
impl hash::Hash for OctetString {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
for part in self.iter() {
part.hash(state)
}
}
}
impl<'a> IntoIterator for &'a OctetString {
type Item = &'a [u8];
type IntoIter = OctetStringIter<'a>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct OctetStringSource {
current: Bytes,
remainder: Bytes,
}
impl OctetStringSource {
fn new(from: &OctetString) -> Self {
match from.0 {
Inner::Primitive(ref inner) => {
OctetStringSource {
current: inner.clone(),
remainder: Bytes::new(),
}
}
Inner::Constructed(ref inner) => {
OctetStringSource {
current: Bytes::new(),
remainder: inner.clone().into_bytes()
}
}
}
}
fn next_primitive(&mut self) -> Option<Bytes> {
while !self.remainder.is_empty() {
let (tag, cons) = Tag::take_from(&mut self.remainder).unwrap();
let length = Length::take_from(
&mut self.remainder, Mode::Ber
).unwrap();
match tag {
Tag::OCTET_STRING => {
if cons {
continue
}
let length = match length {
Length::Definite(len) => len,
_ => unreachable!()
};
return Some(self.remainder.split_to(length))
}
Tag::END_OF_VALUE => continue,
_ => unreachable!()
}
}
None
}
}
impl decode::Source for OctetStringSource {
type Err = decode::Error;
fn request(&mut self, len: usize) -> Result<usize, decode::Error> {
if self.current.len() < len && !self.remainder.is_empty() {
let mut current = BytesMut::from(self.current.clone());
while current.len() < len {
if let Some(bytes) = self.next_primitive() {
current.extend_from_slice(bytes.as_ref())
}
else {
break
}
}
self.current = current.freeze()
}
Ok(self.current.len())
}
fn advance(&mut self, mut len: usize) -> Result<(), decode::Error> {
while len > self.current.len() {
len -= self.current.len();
self.current = match self.next_primitive() {
Some(value) => value,
None => {
xerr!(return Err(decode::Error::Malformed))
}
}
}
self.current.advance(len);
Ok(())
}
fn slice(&self) -> &[u8] {
self.current.as_ref()
}
fn bytes(&self, start: usize, end: usize) -> Bytes {
self.current.slice(start, end)
}
}
#[derive(Clone, Debug)]
pub struct OctetStringIter<'a>(Inner<&'a [u8], &'a [u8]>);
impl<'a> Iterator for OctetStringIter<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
match self.0 {
Inner::Primitive(ref mut inner) => {
if inner.is_empty() {
None
}
else {
Some(mem::replace(inner, &b""[..]))
}
}
Inner::Constructed(ref mut inner) => {
while !inner.is_empty() {
let (tag, cons) = Tag::take_from(inner).unwrap();
let length = Length::take_from(inner, Mode::Ber).unwrap();
match tag {
Tag::OCTET_STRING => {
if cons {
continue
}
let length = match length {
Length::Definite(len) => len,
_ => unreachable!()
};
let res = &inner[..length];
*inner = &inner[length..];
return Some(res)
}
Tag::END_OF_VALUE => continue,
_ => unreachable!()
}
}
None
}
}
}
}
#[derive(Clone, Debug)]
pub struct OctetStringOctets<'a> {
cur: &'a [u8],
iter: OctetStringIter<'a>,
}
impl<'a> OctetStringOctets<'a> {
fn new(iter: OctetStringIter<'a>) -> Self {
OctetStringOctets {
cur: b"",
iter
}
}
}
impl<'a> Iterator for OctetStringOctets<'a> {
type Item = u8;
fn next(&mut self) -> Option<u8> {
if self.cur.is_empty() {
let next = match self.iter.next() {
Some(some) => some,
None => return None,
};
self.cur = next;
}
let res = self.cur[0];
self.cur = &self.cur[1..];
Some(res)
}
}
#[derive(Clone, Debug)]
pub struct OctetStringEncoder<T> {
value: T,
tag: Tag,
}
impl<T> OctetStringEncoder<T> {
fn new(value: T, tag: Tag) -> Self {
OctetStringEncoder { value, tag }
}
}
impl<T: AsRef<OctetString>> encode::Values for OctetStringEncoder<T> {
fn encoded_len(&self, mode: Mode) -> usize {
match mode {
Mode::Ber => {
let len = match self.value.as_ref().0 {
Inner::Primitive(ref bytes) => bytes.len(),
Inner::Constructed(ref captured) => captured.len(),
};
self.tag.encoded_len()
+ Length::Definite(len).encoded_len()
+ len
}
Mode::Cer => {
unimplemented!()
}
Mode::Der => {
let len = self.value.as_ref().len();
self.tag.encoded_len()
+ Length::Definite(len).encoded_len()
+ len
}
}
}
fn write_encoded<W: io::Write>(
&self,
mode: Mode,
target: &mut W
) -> Result<(), io::Error> {
match mode {
Mode::Ber => {
match self.value.as_ref().0 {
Inner::Primitive(ref bytes) => {
self.tag.write_encoded(false, target)?;
Length::Definite(bytes.len()).write_encoded(target)?;
target.write_all(bytes.as_ref())
}
Inner::Constructed(ref captured) => {
self.tag.write_encoded(true, target)?;
Length::Definite(captured.len()).write_encoded(target)?;
target.write_all(captured.as_slice())
}
}
}
Mode::Cer => {
unimplemented!()
}
Mode::Der => {
self.tag.write_encoded(false, target)?;
Length::Definite(
self.value.as_ref().len()
).write_encoded(target)?;
for slice in self.value.as_ref().iter() {
target.write_all(slice)?;
}
Ok(())
}
}
}
}
#[derive(Clone, Debug)]
pub struct OctetSliceEncoder<T> {
slice: T,
tag: Tag,
}
impl<T> OctetSliceEncoder<T> {
fn new(slice: T, tag: Tag) -> Self {
OctetSliceEncoder { slice, tag }
}
}
impl<T: AsRef<[u8]>> encode::Values for OctetSliceEncoder<T> {
fn encoded_len(&self, mode: Mode) -> usize {
if mode == Mode::Cer {
unimplemented!()
}
let len = self.slice.as_ref().len();
self.tag.encoded_len() + Length::Definite(len).encoded_len() + len
}
fn write_encoded<W: io::Write>(
&self,
mode: Mode,
target: &mut W
) -> Result<(), io::Error> {
if mode == Mode::Cer {
unimplemented!()
}
self.tag.write_encoded(false, target)?;
Length::Definite(self.slice.as_ref().len()).write_encoded(target)?;
target.write_all(self.slice.as_ref())
}
}
pub struct WrappingOctetStringEncoder<V: encode::Values> {
values: V,
mode: Mode
}
impl<V: encode::Values> WrappingOctetStringEncoder<V> {
fn new(mode: Mode, values: V) -> Self {
WrappingOctetStringEncoder { values, mode }
}
}
impl<V: encode::Values> encode::Values for WrappingOctetStringEncoder<V> {
fn encoded_len(&self, mode: Mode) -> usize {
if mode == Mode::Cer {
unimplemented!()
}
encode::total_encoded_len(
Tag::OCTET_STRING,
self.values.encoded_len(self.mode)
)
}
fn write_encoded<W: io::Write>(
&self,
mode: Mode,
target: &mut W
) -> Result<(), io::Error> {
if mode == Mode::Cer {
unimplemented!()
}
encode::write_header(
target,
Tag::OCTET_STRING,
false,
self.values.encoded_len(self.mode))?;
self.values.write_encoded(self.mode, target)
}
}
fn skip_nested<S>(con: &mut decode::Constructed<S>) -> Result<(), S::Err>
where S: decode::Source {
while let Some(()) = con.take_opt_value_if(Tag::OCTET_STRING, |content| {
match content {
decode::Content::Constructed(ref mut inner) => {
skip_nested(inner)?
}
decode::Content::Primitive(ref mut inner) => {
inner.skip_all()?
}
}
Ok(())
})? { }
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use encode::{Values, PrimitiveContent};
#[test]
fn should_wrap_content_in_octetstring() {
let mut v = Vec::new();
let enc = OctetString::encode_wrapped(Mode::Der, true.encode());
enc.write_encoded(Mode::Der, &mut v).unwrap();
assert_eq!(vec![4, 3, 1, 1, 255], v);
let l = enc.encoded_len(Mode::Der);
assert_eq!(l, v.len());
}
}