use std::{fmt, hash, io, str::FromStr};
use bytes::Bytes;
use crate::encode;
use crate::decode::{Constructed, DecodeError, Primitive, Source};
use crate::mode::Mode;
use crate::tag::Tag;
#[derive(Clone, Debug)]
pub struct Oid<T: AsRef<[u8]> = Bytes>(pub T);
pub type ConstOid = Oid<&'static [u8]>;
impl Oid<Bytes> {
pub fn skip_in<S: Source>(
cons: &mut Constructed<S>
) -> Result<(), DecodeError<S::Error>> {
cons.take_primitive_if(Tag::OID, Self::skip_primitive)
}
pub fn skip_opt_in<S: Source>(
cons: &mut Constructed<S>
) -> Result<Option<()>, DecodeError<S::Error>> {
cons.take_opt_primitive_if(Tag::OID, Self::skip_primitive)
}
pub fn take_from<S: Source>(
constructed: &mut Constructed<S>
) -> Result<Self, DecodeError<S::Error>> {
constructed.take_primitive_if(Tag::OID, Self::from_primitive)
}
pub fn take_opt_from<S: Source>(
constructed: &mut Constructed<S>
) -> Result<Option<Self>, DecodeError<S::Error>> {
constructed.take_opt_primitive_if(Tag::OID, Self::from_primitive)
}
pub fn skip_primitive<S: Source>(
prim: &mut Primitive<S>
) -> Result<(), DecodeError<S::Error>> {
prim.with_slice_all(Self::check_content)
}
pub fn from_primitive<S: Source>(
prim: &mut Primitive<S>
) -> Result<Self, DecodeError<S::Error>> {
let content = prim.take_all()?;
Self::check_content(content.as_ref()).map_err(|err| {
prim.content_err(err)
})?;
Ok(Oid(content))
}
fn check_content(content: &[u8]) -> Result<(), &'static str> {
let last = match content.last() {
Some(last) => *last,
None => {
return Err("empty object identifier")
}
};
if last & 0x80 != 0 {
return Err("illegal object identifier")
}
Ok(())
}
}
impl<T: AsRef<[u8]>> Oid<T> {
pub fn skip_if<S: Source>(
&self, constructed: &mut Constructed<S>,
) -> Result<(), DecodeError<S::Error>> {
constructed.take_primitive_if(Tag::OID, |prim| {
prim.with_slice_all(|content| {
if content != self.0.as_ref() {
Err("object identifier mismatch")
}
else {
Ok(())
}
})
})
}
}
impl<T: AsRef<[u8]>> Oid<T> {
pub fn iter(&self) -> Iter {
Iter::new(self.0.as_ref())
}
}
impl<T: AsRef<[u8]>> AsRef<[u8]> for Oid<T> {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl<T: AsRef<[u8]>, U: AsRef<[u8]>> PartialEq<Oid<U>> for Oid<T> {
fn eq(&self, other: &Oid<U>) -> bool {
self.0.as_ref() == other.0.as_ref()
}
}
impl<T: AsRef<[u8]>> Eq for Oid<T> { }
impl<T: AsRef<[u8]>> hash::Hash for Oid<T> {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
self.0.as_ref().hash(state)
}
}
impl<T: AsRef<[u8]>> fmt::Display for Oid<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let mut components = self.iter();
match components.next() {
Some(component) => component.fmt(f)?,
None => { return Ok(()) }
}
components.try_for_each(|item| write!(f, ".{}", item))
}
}
impl<T: AsRef<[u8]>> encode::PrimitiveContent for Oid<T> {
const TAG: Tag = Tag::OID;
fn encoded_len(&self, _: Mode) -> usize {
self.0.as_ref().len()
}
fn write_encoded<W: io::Write>(
&self,
_: Mode,
target: &mut W
) -> Result<(), io::Error> {
target.write_all(self.0.as_ref())
}
}
impl<T: AsRef<[u8]> + From<Vec<u8>>> FromStr for Oid<T> {
type Err = &'static str;
fn from_str(s: &str) -> Result<Self, Self::Err> {
fn from_str(s: &str) -> Result<u32, &'static str> {
u32::from_str(s).map_err(|_| "only integer components allowed")
}
let mut components = s.split('.');
let (first, second) = match (components.next(), components.next()) {
(Some(first), Some(second)) => (first, second),
_ => { return Err("at least two components required"); }
};
let first = from_str(first)?;
if first > 2 {
return Err("first component can only be 0, 1, or 2.")
}
let second = from_str(second)?;
if first < 2 && second >= 40 {
return Err("second component for 0. and 1. must be less than 40");
}
let mut res = vec![40 * first + second];
for item in components {
res.push(from_str(item)?);
}
let mut bytes = vec![];
for item in res {
if item > 0x0FFF_FFFF {
bytes.push(((item >> 28) | 0x80) as u8);
}
if item > 0x001F_FFFF {
bytes.push((((item >> 21) & 0x7F) | 0x80) as u8);
}
if item > 0x0000_3FFF {
bytes.push((((item >> 14) & 0x7F) | 0x80) as u8)
}
if item > 0x0000_007F {
bytes.push((((item >> 7) & 0x7F) | 0x80) as u8);
}
bytes.push((item & 0x7F) as u8);
}
Ok(Oid(bytes.into()))
}
}
#[derive(Clone, Copy, Debug)]
pub struct Component<'a> {
position: Position,
slice: &'a [u8],
}
#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
enum Position {
First,
Second,
Other,
}
impl<'a> Component<'a> {
fn new(position: Position, slice: &'a [u8]) -> Self {
Component { position, slice }
}
pub fn to_u32(self) -> Option<u32> {
if self.slice.len() > 5
|| (self.slice.len() == 4 && self.slice[0] & 0x70 != 0)
{
return None
}
let mut res = 0;
for &ch in self.slice {
res = (res << 7) | u32::from(ch & 0x7F);
}
match self.position {
Position::First => {
if res < 40 {
Some(0)
}
else if res < 80{
Some(1)
}
else {
Some(2)
}
}
Position::Second => {
if res < 80 {
Some(res % 40)
}
else {
Some(res - 80)
}
}
Position::Other => Some(res)
}
}
}
impl PartialEq for Component<'_> {
fn eq(&self, other: &Self) -> bool {
self.position == other.position && self.slice == other.slice
}
}
impl Eq for Component<'_> { }
impl fmt::Display for Component<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self.to_u32() {
Some(val) => val.fmt(f),
None => f.write_str("(very large component)"),
}
}
}
pub struct Iter<'a> {
slice: &'a [u8],
position: Position,
}
impl<'a> Iter<'a> {
fn new(slice: &'a [u8]) -> Self {
Iter {
slice,
position: Position::First
}
}
fn advance_position(&mut self) -> Position {
let res = self.position;
self.position = match res {
Position::First => Position::Second,
_ => Position::Other
};
res
}
}
impl<'a> Iterator for Iter<'a> {
type Item = Component<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.slice.is_empty() {
return None
}
for i in 0..self.slice.len() {
if self.slice[i] & 0x80 == 0 {
let (res, tail) = self.slice.split_at(i + 1);
if self.position != Position::First {
self.slice = tail;
}
return Some(Component::new(self.advance_position(), res));
}
}
panic!("illegal object identifier (last octet has bit 8 set)");
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn display() {
assert_eq!(
"2.5.29.19",
format!("{}", Oid(&[85, 29, 19])).as_str()
);
}
#[test]
fn take_and_skip_primitive() {
fn check(slice: &[u8], is_ok: bool) {
let take = Primitive::decode_slice(
slice, Mode::Der, |prim| Oid::from_primitive(prim)
);
assert_eq!(take.is_ok(), is_ok);
if let Ok(oid) = take {
assert_eq!(oid.0.as_ref(), slice);
}
assert_eq!(
Primitive::decode_slice(
slice, Mode::Der, |prim| Oid::skip_primitive(prim)
).is_ok(),
is_ok
);
}
check(b"", false);
check(b"\x81\x34", true);
check(b"\x81\x34\x03", true);
check(b"\x81\x34\x83\x03", true);
check(b"\x81\x34\x83\x83\x03\x03", true);
check(b"\x81\x34\x83", false);
}
}