use std::cmp::Ordering;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq, Error)]
pub enum CanonicalCborError {
#[error("MALFORMED_CBOR: indefinite-length items are not permitted in canonical CBOR: {0}")]
IndefiniteLength(String),
#[error("MALFORMED_CBOR: {0}")]
Malformed(String),
}
impl CanonicalCborError {
pub const MALFORMED_CBOR: &'static str = "MALFORMED_CBOR";
#[must_use]
pub const fn code(&self) -> &'static str {
match self {
CanonicalCborError::IndefiniteLength(_) | CanonicalCborError::Malformed(_) => {
Self::MALFORMED_CBOR
}
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CborValue {
Unsigned(u64),
Negative(u64),
Bytes(Vec<u8>),
Text(String),
Array(Vec<CborValue>),
Map(Vec<(CborValue, CborValue)>),
Bool(bool),
Null,
}
impl CborValue {
#[must_use]
pub fn int(n: i64) -> Self {
if n >= 0 {
CborValue::Unsigned(n as u64)
} else {
CborValue::Negative((-(n + 1)) as u64)
}
}
#[must_use]
pub fn text(s: impl Into<String>) -> Self {
CborValue::Text(s.into())
}
#[must_use]
pub fn bytes(b: impl Into<Vec<u8>>) -> Self {
CborValue::Bytes(b.into())
}
}
pub fn encode_canonical_cbor(value: &CborValue) -> Result<Vec<u8>, CanonicalCborError> {
let mut out = Vec::new();
write_value(value, &mut out)?;
Ok(out)
}
fn write_header(out: &mut Vec<u8>, major: u8, additional: u8) {
out.push((major << 5) | additional);
}
fn write_type_and_argument(out: &mut Vec<u8>, major: u8, argument: u64) {
if argument <= 23 {
write_header(out, major, argument as u8);
} else if argument <= u8::MAX as u64 {
write_header(out, major, 24);
out.push(argument as u8);
} else if argument <= u16::MAX as u64 {
write_header(out, major, 25);
out.extend_from_slice(&(argument as u16).to_be_bytes());
} else if argument <= u32::MAX as u64 {
write_header(out, major, 26);
out.extend_from_slice(&(argument as u32).to_be_bytes());
} else {
write_header(out, major, 27);
out.extend_from_slice(&argument.to_be_bytes());
}
}
fn write_value(value: &CborValue, out: &mut Vec<u8>) -> Result<(), CanonicalCborError> {
match value {
CborValue::Unsigned(n) => write_type_and_argument(out, 0, *n),
CborValue::Negative(m) => write_type_and_argument(out, 1, *m),
CborValue::Bytes(b) => {
write_type_and_argument(out, 2, b.len() as u64);
out.extend_from_slice(b);
}
CborValue::Text(s) => {
write_type_and_argument(out, 3, s.len() as u64);
out.extend_from_slice(s.as_bytes());
}
CborValue::Array(items) => {
write_type_and_argument(out, 4, items.len() as u64);
for item in items {
write_value(item, out)?;
}
}
CborValue::Map(pairs) => write_map(pairs, out)?,
CborValue::Bool(false) => write_header(out, 7, 20),
CborValue::Bool(true) => write_header(out, 7, 21),
CborValue::Null => write_header(out, 7, 22),
}
Ok(())
}
fn write_map(
pairs: &[(CborValue, CborValue)],
out: &mut Vec<u8>,
) -> Result<(), CanonicalCborError> {
let mut encoded: Vec<(Vec<u8>, &CborValue)> = Vec::with_capacity(pairs.len());
for (key, val) in pairs {
let mut key_bytes = Vec::new();
write_value(key, &mut key_bytes)?;
encoded.push((key_bytes, val));
}
encoded.sort_by(|a, b| compare_encoded_keys(&a.0, &b.0));
for window in encoded.windows(2) {
if window[0].0 == window[1].0 {
return Err(CanonicalCborError::Malformed(
"map contains a duplicate key".to_string(),
));
}
}
write_type_and_argument(out, 5, encoded.len() as u64);
for (key_bytes, val) in encoded {
out.extend_from_slice(&key_bytes);
write_value(val, out)?;
}
Ok(())
}
fn compare_encoded_keys(a: &[u8], b: &[u8]) -> Ordering {
a.cmp(b)
}
pub fn decode_canonical_cbor(bytes: &[u8]) -> Result<CborValue, CanonicalCborError> {
let mut decoder = Decoder {
data: bytes,
pos: 0,
};
let value = decoder.read_value()?;
if decoder.pos != bytes.len() {
return Err(CanonicalCborError::Malformed(
"trailing bytes after the top-level item".to_string(),
));
}
Ok(value)
}
struct Decoder<'a> {
data: &'a [u8],
pos: usize,
}
impl Decoder<'_> {
fn next_byte(&mut self) -> Result<u8, CanonicalCborError> {
let b = *self
.data
.get(self.pos)
.ok_or_else(|| CanonicalCborError::Malformed("unexpected end of input".to_string()))?;
self.pos += 1;
Ok(b)
}
fn take(&mut self, len: usize) -> Result<&[u8], CanonicalCborError> {
let end = self.pos.checked_add(len).ok_or_else(|| {
CanonicalCborError::Malformed("length overflows the input".to_string())
})?;
let slice = self
.data
.get(self.pos..end)
.ok_or_else(|| CanonicalCborError::Malformed("truncated item".to_string()))?;
self.pos = end;
Ok(slice)
}
fn read_argument(&mut self, additional: u8) -> Result<u64, CanonicalCborError> {
match additional {
0..=23 => Ok(u64::from(additional)),
24 => {
let v = u64::from(self.next_byte()?);
if v <= 23 {
return Err(CanonicalCborError::Malformed(
"non-shortest integer encoding (1-byte argument < 24)".to_string(),
));
}
Ok(v)
}
25 => {
let raw = self.take(2)?;
let v = u64::from(u16::from_be_bytes([raw[0], raw[1]]));
if v <= u64::from(u8::MAX) {
return Err(CanonicalCborError::Malformed(
"non-shortest integer encoding (2-byte argument fits in fewer bytes)"
.to_string(),
));
}
Ok(v)
}
26 => {
let raw = self.take(4)?;
let v = u64::from(u32::from_be_bytes([raw[0], raw[1], raw[2], raw[3]]));
if v <= u64::from(u16::MAX) {
return Err(CanonicalCborError::Malformed(
"non-shortest integer encoding (4-byte argument fits in fewer bytes)"
.to_string(),
));
}
Ok(v)
}
27 => {
let raw = self.take(8)?;
let v = u64::from_be_bytes([
raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
]);
if v <= u64::from(u32::MAX) {
return Err(CanonicalCborError::Malformed(
"non-shortest integer encoding (8-byte argument fits in fewer bytes)"
.to_string(),
));
}
Ok(v)
}
31 => Err(CanonicalCborError::IndefiniteLength(
"indefinite-length item is not canonical".to_string(),
)),
_ => Err(CanonicalCborError::Malformed(
"reserved additional-information value".to_string(),
)),
}
}
fn read_value(&mut self) -> Result<CborValue, CanonicalCborError> {
let initial = self.next_byte()?;
let major = initial >> 5;
let additional = initial & 0x1F;
match major {
0 => Ok(CborValue::Unsigned(self.read_argument(additional)?)),
1 => Ok(CborValue::Negative(self.read_argument(additional)?)),
2 => {
let len = self.read_length(additional)?;
Ok(CborValue::Bytes(self.take(len)?.to_vec()))
}
3 => {
let len = self.read_length(additional)?;
let raw = self.take(len)?;
let s = std::str::from_utf8(raw).map_err(|_| {
CanonicalCborError::Malformed("text string is not valid UTF-8".to_string())
})?;
Ok(CborValue::Text(s.to_string()))
}
4 => {
let len = self.read_length(additional)?;
let mut items = Vec::with_capacity(len.min(self.data.len()));
for _ in 0..len {
items.push(self.read_value()?);
}
Ok(CborValue::Array(items))
}
5 => self.read_map(additional),
6 => Err(CanonicalCborError::Malformed(
"tags are not permitted in canonical CBOR".to_string(),
)),
7 => self.read_simple(additional),
_ => unreachable!("major type is a 3-bit value"),
}
}
fn read_length(&mut self, additional: u8) -> Result<usize, CanonicalCborError> {
let argument = self.read_argument(additional)?;
usize::try_from(argument)
.map_err(|_| CanonicalCborError::Malformed("length exceeds platform usize".to_string()))
}
fn read_map(&mut self, additional: u8) -> Result<CborValue, CanonicalCborError> {
let len = self.read_length(additional)?;
let mut pairs = Vec::with_capacity(len.min(self.data.len()));
let mut prev_key_bytes: Option<&[u8]> = None;
for _ in 0..len {
let key_start = self.pos;
let key = self.read_value()?;
let key_bytes = &self.data[key_start..self.pos];
if let Some(prev) = prev_key_bytes {
match compare_encoded_keys(prev, key_bytes) {
Ordering::Less => {}
Ordering::Equal => {
return Err(CanonicalCborError::Malformed(
"duplicate map key".to_string(),
))
}
Ordering::Greater => {
return Err(CanonicalCborError::Malformed(
"map keys are not in canonical order".to_string(),
))
}
}
}
prev_key_bytes = Some(key_bytes);
let value = self.read_value()?;
pairs.push((key, value));
}
Ok(CborValue::Map(pairs))
}
fn read_simple(&mut self, additional: u8) -> Result<CborValue, CanonicalCborError> {
match additional {
20 => Ok(CborValue::Bool(false)),
21 => Ok(CborValue::Bool(true)),
22 => Ok(CborValue::Null),
23 => Err(CanonicalCborError::Malformed(
"the `undefined` simple value is not valid in a canonical record".to_string(),
)),
24 => Err(CanonicalCborError::Malformed(
"simple values other than false/true/null are not permitted".to_string(),
)),
25..=27 => Err(CanonicalCborError::Malformed(
"floats are not permitted in a canonical record".to_string(),
)),
31 => Err(CanonicalCborError::IndefiniteLength(
"indefinite-length break is not a value".to_string(),
)),
_ => Err(CanonicalCborError::Malformed(
"unassigned or reserved simple value".to_string(),
)),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum PermissiveValue {
Unsigned(u64),
Negative(u64),
Bytes(Vec<u8>),
Text(String),
Array(Vec<PermissiveValue>),
Map(Vec<(PermissiveValue, PermissiveValue)>),
Tag(u64, Box<PermissiveValue>),
Bool(bool),
Null,
Undefined,
Simple(u8),
Float(f64),
}
pub fn decode_cbor_permissive(bytes: &[u8]) -> Result<PermissiveValue, CanonicalCborError> {
let mut decoder = PermissiveDecoder {
data: bytes,
pos: 0,
};
let value = decoder.read_value()?;
if decoder.pos != bytes.len() {
return Err(CanonicalCborError::Malformed(
"trailing bytes after the top-level item".to_string(),
));
}
Ok(value)
}
enum PermissiveItem {
Value(PermissiveValue),
Break,
}
struct PermissiveDecoder<'a> {
data: &'a [u8],
pos: usize,
}
impl PermissiveDecoder<'_> {
fn next_byte(&mut self) -> Result<u8, CanonicalCborError> {
let b = *self
.data
.get(self.pos)
.ok_or_else(|| CanonicalCborError::Malformed("unexpected end of input".to_string()))?;
self.pos += 1;
Ok(b)
}
fn take(&mut self, len: usize) -> Result<&[u8], CanonicalCborError> {
let end = self.pos.checked_add(len).ok_or_else(|| {
CanonicalCborError::Malformed("length overflows the input".to_string())
})?;
let slice = self
.data
.get(self.pos..end)
.ok_or_else(|| CanonicalCborError::Malformed("truncated item".to_string()))?;
self.pos = end;
Ok(slice)
}
fn read_argument(&mut self, additional: u8) -> Result<u64, CanonicalCborError> {
match additional {
0..=23 => Ok(u64::from(additional)),
24 => Ok(u64::from(self.next_byte()?)),
25 => {
let raw = self.take(2)?;
Ok(u64::from(u16::from_be_bytes([raw[0], raw[1]])))
}
26 => {
let raw = self.take(4)?;
Ok(u64::from(u32::from_be_bytes([
raw[0], raw[1], raw[2], raw[3],
])))
}
27 => {
let raw = self.take(8)?;
Ok(u64::from_be_bytes([
raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
]))
}
_ => Err(CanonicalCborError::Malformed(
"reserved additional-information value".to_string(),
)),
}
}
fn read_value(&mut self) -> Result<PermissiveValue, CanonicalCborError> {
match self.read_item()? {
PermissiveItem::Value(v) => Ok(v),
PermissiveItem::Break => Err(CanonicalCborError::Malformed(
"unexpected indefinite-length break".to_string(),
)),
}
}
fn read_item(&mut self) -> Result<PermissiveItem, CanonicalCborError> {
let initial = self.next_byte()?;
let major = initial >> 5;
let additional = initial & 0x1F;
match major {
0 => Ok(PermissiveItem::Value(PermissiveValue::Unsigned(
self.read_argument(additional)?,
))),
1 => Ok(PermissiveItem::Value(PermissiveValue::Negative(
self.read_argument(additional)?,
))),
2 => self.read_byte_or_text(additional, false),
3 => self.read_byte_or_text(additional, true),
4 => self.read_array(additional),
5 => self.read_map(additional),
6 => {
let tag = self.read_argument(additional)?;
let inner = self.read_value()?;
Ok(PermissiveItem::Value(PermissiveValue::Tag(
tag,
Box::new(inner),
)))
}
7 => self.read_simple(additional),
_ => unreachable!("major type is a 3-bit value"),
}
}
fn read_byte_or_text(
&mut self,
additional: u8,
is_text: bool,
) -> Result<PermissiveItem, CanonicalCborError> {
if additional == 31 {
let mut buf = Vec::new();
loop {
let initial = self.next_byte()?;
if initial == 0xFF {
break;
}
let chunk_major = initial >> 5;
let chunk_additional = initial & 0x1F;
let expected_major = if is_text { 3 } else { 2 };
if chunk_major != expected_major || chunk_additional == 31 {
return Err(CanonicalCborError::Malformed(
"invalid chunk inside an indefinite-length string".to_string(),
));
}
let len = self.read_length(chunk_additional)?;
buf.extend_from_slice(self.take(len)?);
}
return Ok(PermissiveItem::Value(self.finish_string(buf, is_text)?));
}
let len = self.read_length(additional)?;
let raw = self.take(len)?.to_vec();
Ok(PermissiveItem::Value(self.finish_string(raw, is_text)?))
}
fn finish_string(
&self,
buf: Vec<u8>,
is_text: bool,
) -> Result<PermissiveValue, CanonicalCborError> {
if is_text {
let s = String::from_utf8(buf).map_err(|_| {
CanonicalCborError::Malformed("text string is not valid UTF-8".to_string())
})?;
Ok(PermissiveValue::Text(s))
} else {
Ok(PermissiveValue::Bytes(buf))
}
}
fn read_array(&mut self, additional: u8) -> Result<PermissiveItem, CanonicalCborError> {
let mut items = Vec::new();
if additional == 31 {
loop {
match self.read_item()? {
PermissiveItem::Break => break,
PermissiveItem::Value(v) => items.push(v),
}
}
} else {
let len = self.read_length(additional)?;
for _ in 0..len {
items.push(self.read_value()?);
}
}
Ok(PermissiveItem::Value(PermissiveValue::Array(items)))
}
fn read_map(&mut self, additional: u8) -> Result<PermissiveItem, CanonicalCborError> {
let mut pairs = Vec::new();
if additional == 31 {
loop {
let key = match self.read_item()? {
PermissiveItem::Break => break,
PermissiveItem::Value(v) => v,
};
let value = self.read_value()?;
pairs.push((key, value));
}
} else {
let len = self.read_length(additional)?;
for _ in 0..len {
let key = self.read_value()?;
let value = self.read_value()?;
pairs.push((key, value));
}
}
Ok(PermissiveItem::Value(PermissiveValue::Map(pairs)))
}
fn read_simple(&mut self, additional: u8) -> Result<PermissiveItem, CanonicalCborError> {
match additional {
0..=19 => Ok(PermissiveItem::Value(PermissiveValue::Simple(additional))),
20 => Ok(PermissiveItem::Value(PermissiveValue::Bool(false))),
21 => Ok(PermissiveItem::Value(PermissiveValue::Bool(true))),
22 => Ok(PermissiveItem::Value(PermissiveValue::Null)),
23 => Ok(PermissiveItem::Value(PermissiveValue::Undefined)),
24 => {
let v = self.next_byte()?;
Ok(PermissiveItem::Value(PermissiveValue::Simple(v)))
}
25 => {
let raw = self.take(2)?;
Ok(PermissiveItem::Value(PermissiveValue::Float(decode_f16(
u16::from_be_bytes([raw[0], raw[1]]),
))))
}
26 => {
let raw = self.take(4)?;
Ok(PermissiveItem::Value(PermissiveValue::Float(f64::from(
f32::from_be_bytes([raw[0], raw[1], raw[2], raw[3]]),
))))
}
27 => {
let raw = self.take(8)?;
Ok(PermissiveItem::Value(PermissiveValue::Float(
f64::from_be_bytes([
raw[0], raw[1], raw[2], raw[3], raw[4], raw[5], raw[6], raw[7],
]),
)))
}
31 => Ok(PermissiveItem::Break),
_ => Err(CanonicalCborError::Malformed(
"reserved additional-information value".to_string(),
)),
}
}
fn read_length(&mut self, additional: u8) -> Result<usize, CanonicalCborError> {
let argument = self.read_argument(additional)?;
usize::try_from(argument)
.map_err(|_| CanonicalCborError::Malformed("length exceeds platform usize".to_string()))
}
}
fn decode_f16(bits: u16) -> f64 {
let sign = f64::from((bits >> 15) & 0x1);
let exponent = (bits >> 10) & 0x1F;
let mantissa = f64::from(bits & 0x3FF);
let value = if exponent == 0 {
mantissa * 2f64.powi(-24)
} else if exponent == 0x1F {
if mantissa == 0.0 {
f64::INFINITY
} else {
f64::NAN
}
} else {
(1024.0 + mantissa) * 2f64.powi(i32::from(exponent) - 25)
};
if sign == 1.0 {
-value
} else {
value
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn int_constructor_dispatches_major_type() {
assert_eq!(CborValue::int(0), CborValue::Unsigned(0));
assert_eq!(CborValue::int(23), CborValue::Unsigned(23));
assert_eq!(CborValue::int(-1), CborValue::Negative(0));
assert_eq!(CborValue::int(-10), CborValue::Negative(9));
assert_eq!(CborValue::int(-100), CborValue::Negative(99));
assert_eq!(
CborValue::int(i64::MIN),
CborValue::Negative(i64::MAX as u64)
);
}
#[test]
fn shortest_form_integer_boundaries() {
let cases: [(CborValue, &str); 8] = [
(CborValue::Unsigned(0), "00"),
(CborValue::Unsigned(23), "17"),
(CborValue::Unsigned(24), "1818"),
(CborValue::Unsigned(255), "18ff"),
(CborValue::Unsigned(256), "190100"),
(CborValue::Unsigned(65535), "19ffff"),
(CborValue::Unsigned(65536), "1a00010000"),
(CborValue::Unsigned(u64::MAX), "1bffffffffffffffff"),
];
for (value, expected) in cases {
assert_eq!(
crate::hex::encode(&encode_canonical_cbor(&value).unwrap()),
expected
);
}
}
#[test]
fn negative_full_range() {
assert_eq!(
crate::hex::encode(&encode_canonical_cbor(&CborValue::Negative(u64::MAX)).unwrap()),
"3bffffffffffffffff"
);
}
#[test]
fn encode_rejects_duplicate_keys() {
let map = CborValue::Map(vec![
(CborValue::text("a"), CborValue::Unsigned(1)),
(CborValue::text("a"), CborValue::Unsigned(2)),
]);
let err = encode_canonical_cbor(&map).unwrap_err();
assert_eq!(err.code(), CanonicalCborError::MALFORMED_CBOR);
}
#[test]
fn map_keys_sort_length_first() {
let map = CborValue::Map(vec![
(CborValue::text("aead"), CborValue::Unsigned(1)),
(CborValue::text("kem"), CborValue::Unsigned(2)),
]);
let encoded = encode_canonical_cbor(&map).unwrap();
assert_eq!(crate::hex::encode(&encoded), "a2636b656d02646165616401");
}
#[test]
fn decode_round_trips_a_nested_value() {
let value = CborValue::Map(vec![
(CborValue::text("v"), CborValue::Unsigned(1)),
(
CborValue::text("items"),
CborValue::Array(vec![CborValue::Bytes(vec![0xde, 0xad])]),
),
]);
let bytes = encode_canonical_cbor(&value).unwrap();
let decoded = decode_canonical_cbor(&bytes).unwrap();
assert_eq!(encode_canonical_cbor(&decoded).unwrap(), bytes);
}
#[test]
fn decode_rejects_trailing_bytes() {
let err = decode_canonical_cbor(&[0x00, 0x00]).unwrap_err();
assert_eq!(err.code(), CanonicalCborError::MALFORMED_CBOR);
}
#[test]
fn permissive_accepts_indefinite_array() {
let decoded = decode_cbor_permissive(&[0x9f, 0x01, 0x02, 0xff]).unwrap();
assert_eq!(
decoded,
PermissiveValue::Array(vec![
PermissiveValue::Unsigned(1),
PermissiveValue::Unsigned(2)
])
);
}
}