use std::io;
use bytes::Bytes;
use crate::{decode, encode};
use crate::decode::{DecodeError, Source};
use crate::length::Length;
use crate::mode::Mode;
use crate::tag::Tag;
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct BitString {
unused: u8,
bits: Bytes,
}
impl BitString {
pub fn new(unused: u8, bits: Bytes) -> Self {
assert!(unused <= 7 && (!bits.is_empty() || unused == 0));
Self { unused, bits}
}
pub fn bit(&self, bit: usize) -> bool {
let idx = bit >> 3;
if self.bits.len() <= idx {
return false
}
let bit = 7 - (bit as u8 & 7);
if self.bits.len() + 1 == idx && self.unused > bit {
return false
}
self.bits[idx] & (1 << bit) != 0
}
pub fn bit_len(&self) -> usize {
(self.bits.len() << 3) - (self.unused as usize)
}
pub fn unused(&self) -> u8 {
self.unused
}
pub fn octet_len(&self) -> usize {
self.bits.len()
}
pub fn octets(&self) -> BitStringIter {
BitStringIter(self.bits.iter())
}
pub fn octet_slice(&self) -> Option<&[u8]> {
Some(self.bits.as_ref())
}
pub fn octet_bytes(&self) -> Bytes {
self.bits.clone()
}
}
impl BitString {
pub fn take_from<S: decode::Source>(
constructed: &mut decode::Constructed<S>
) -> Result<Self, DecodeError<S::Error>> {
constructed.take_value_if(Tag::BIT_STRING, Self::from_content)
}
pub fn skip_in<S: decode::Source>(
cons: &mut decode::Constructed<S>
) -> Result<(), DecodeError<S::Error>> {
cons.take_value_if(Tag::BIT_STRING, Self::skip_content)
}
pub fn from_content<S: decode::Source>(
content: &mut decode::Content<S>
) -> Result<Self, DecodeError<S::Error>> {
match *content {
decode::Content::Primitive(ref mut inner) => {
if inner.mode() == Mode::Cer && inner.remaining() > 1000 {
return Err(content.content_err(
"long bit string component in CER mode"
))
}
let unused = inner.take_u8()?;
if unused > 7 {
return Err(content.content_err(
"invalid bit string with large initial octet"
));
}
if inner.remaining() == 0 && unused > 0 {
return Err(content.content_err(
"invalid bit string \
(non-zero initial with empty bits)"
));
}
let bits = inner.take_all()?;
Ok(BitString { unused, bits })
}
decode::Content::Constructed(ref inner) => {
if inner.mode() == Mode::Der {
Err(content.content_err(
"constructed bit string in DER mode"
))
}
else {
Err(content.content_err(
"constructed bit string not implemented"
))
}
}
}
}
pub fn skip_content<S: decode::Source>(
content: &mut decode::Content<S>
) -> Result<(), DecodeError<S::Error>> {
match *content {
decode::Content::Primitive(ref mut inner) => {
if inner.mode() == Mode::Cer && inner.remaining() > 1000 {
return Err(content.content_err(
"long bit string component in CER mode"
))
}
let unused = inner.take_u8()?;
if unused > 7 {
return Err(content.content_err(
"invalid bit string with large initial octet"
));
}
if inner.remaining() == 0 && unused > 0 {
return Err(content.content_err(
"invalid bit string \
(non-zero initial with empty bits)"
));
}
inner.skip_all()
}
decode::Content::Constructed(ref inner) => {
if inner.mode() == Mode::Der {
Err(content.content_err(
"constructed bit string in DER mode"
))
}
else {
Err(content.content_err(
"constructed bit string not implemented"
))
}
}
}
}
pub fn encode_slice<T>(value: T, unused: u8) -> BitSliceEncoder<T> {
Self::encode_slice_as(value, unused, Tag::BIT_STRING)
}
pub fn encode_slice_as<T>(
value: T,
unused: u8,
tag: Tag
) -> BitSliceEncoder<T> {
BitSliceEncoder::new(value, unused, tag)
}
}
impl encode::PrimitiveContent for BitString {
const TAG: Tag = Tag::BIT_STRING;
fn encoded_len(&self, _: Mode) -> usize {
self.bits.len() + 1
}
fn write_encoded<W: io::Write>(
&self,
_: Mode,
target: &mut W
) -> Result<(), io::Error> {
target.write_all(&[self.unused])?;
target.write_all(self.bits.as_ref())
}
}
#[derive(Clone, Debug)]
pub struct BitStringIter<'a>(::std::slice::Iter<'a, u8>);
impl Iterator for BitStringIter<'_> {
type Item = u8;
fn next(&mut self) -> Option<u8> {
self.0.next().cloned()
}
}
#[derive(Clone, Debug)]
pub struct BitSliceEncoder<T> {
slice: T,
unused: u8,
tag: Tag,
}
impl<T> BitSliceEncoder<T> {
fn new(slice: T, unused: u8, tag: Tag) -> Self {
BitSliceEncoder { slice, unused, tag }
}
}
impl<T: AsRef<[u8]>> encode::Values for BitSliceEncoder<T> {
fn encoded_len(&self, mode: Mode) -> usize {
if mode == Mode::Cer {
unimplemented!()
}
let len = self.slice.as_ref().len() + 1;
self.tag.encoded_len() + Length::Definite(len).encoded_len() + len
}
fn write_encoded<W: io::Write>(
&self,
mode: Mode,
target: &mut W
) -> Result<(), io::Error> {
if mode == Mode::Cer {
unimplemented!()
}
self.tag.write_encoded(false, target)?;
Length::Definite(self.slice.as_ref().len() + 1).write_encoded(target)?;
target.write_all(&[self.unused])?;
target.write_all(self.slice.as_ref())
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::decode::IntoSource;
#[test]
fn bitstring_from_der_content() {
fn check(encoded: &[u8], decoded: Option<(u8, &[u8])>) {
let taken = Mode::Der.decode( encoded.into_source(), |cons| {
BitString::take_from(cons)
});
let mut skip_source = encoded.into_source();
let skipped = Mode::Der.decode(&mut skip_source, |cons| {
BitString::skip_in(cons)
});
match decoded {
Some((unused, bits)) => {
let taken = taken.unwrap();
assert!(skipped.is_ok());
assert!(skip_source.slice().is_empty());
assert_eq!(taken.unused, unused);
assert_eq!(taken.bits.as_ref(), bits);
}
None => {
assert!(taken.is_err());
assert!(skipped.is_err());
}
}
}
check(b"\x03\x07\x04deadb\xd0", Some((4, b"deadb\xd0")));
check(b"\x03\x01\x00", Some((0, b"")));
check(b"\x03\x07\x12deadb\xd0", None);
check(b"\x03\x01\x04", None);
check(b"\x03\x00", None);
}
}