use crate::{CborError, CborResult};
use candid::Principal;
use nom::{
bytes::complete::take,
combinator::{eof, map, peek},
error::{Error, ErrorKind},
multi::{count, fold_many_m_n},
number::complete::{be_u16, be_u32, be_u64, be_u8},
sequence::terminated,
Err, IResult,
};
use std::{collections::HashMap, fmt};
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum CborNegativeInt {
Int8(i8),
Int16(i16),
Int32(i32),
Int64(i64),
}
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
pub enum CborUnsignedInt {
UInt8(u8),
UInt16(u16),
UInt32(u32),
UInt64(u64),
}
impl CborUnsignedInt {
fn to_usize(self) -> usize {
match self {
CborUnsignedInt::UInt8(v) => v as usize,
CborUnsignedInt::UInt16(v) => v as usize,
CborUnsignedInt::UInt32(v) => v as usize,
CborUnsignedInt::UInt64(v) => v as usize,
}
}
fn to_negative(self) -> CborNegativeInt {
match self {
CborUnsignedInt::UInt8(n) => CborNegativeInt::Int8(-1 - (n as i8)),
CborUnsignedInt::UInt16(n) => CborNegativeInt::Int16(-1 - (n as i16)),
CborUnsignedInt::UInt32(n) => CborNegativeInt::Int32(-1 - (n as i32)),
CborUnsignedInt::UInt64(n) => CborNegativeInt::Int64(-1 - (n as i64)),
}
}
fn to_u8(self) -> Result<u8, String> {
Ok(match self {
CborUnsignedInt::UInt8(n) => n,
_ => return Err(String::from("Expected u8")),
})
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum CborHashTree {
Empty,
Fork,
Labelled,
Leaf,
Pruned,
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum CborValue {
Unsigned(CborUnsignedInt),
Signed(CborNegativeInt),
ByteString(Vec<u8>),
Array(Vec<CborValue>),
Map(HashMap<String, CborValue>),
HashTree(CborHashTree),
}
impl fmt::Display for CborValue {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(self, f)
}
}
const fn get_cbor_type(e: u8) -> u8 {
(e & 0b1110_0000) >> 5
}
fn extract_cbor_type(i: &[u8]) -> IResult<&[u8], u8> {
map(be_u8, get_cbor_type)(i)
}
fn peek_cbor_type(i: &[u8]) -> IResult<&[u8], u8> {
peek(extract_cbor_type)(i)
}
const fn get_cbor_info(e: u8) -> u8 {
e & 0b0001_1111
}
fn extract_cbor_info(i: &[u8]) -> IResult<&[u8], u8> {
map(be_u8, get_cbor_info)(i)
}
fn extract_cbor_value(i: &[u8]) -> IResult<&[u8], CborUnsignedInt> {
let (i, cbor_info) = extract_cbor_info(i)?;
match cbor_info {
_n @ 0..=23 => Ok((i, CborUnsignedInt::UInt8(cbor_info))),
24 => map(be_u8, CborUnsignedInt::UInt8)(i),
25 => map(be_u16, CborUnsignedInt::UInt16)(i),
26 => map(be_u32, CborUnsignedInt::UInt32)(i),
27 => map(be_u64, CborUnsignedInt::UInt64)(i),
_ => Err(Err::Error(Error::new(i, ErrorKind::Alt))),
}
}
fn extract_key_val_pair(i: &[u8]) -> IResult<&[u8], (String, CborValue)> {
let (i, key) = parser(i)?;
let key = match key {
CborValue::ByteString(byte_string) => match String::from_utf8(byte_string) {
Ok(str) => Ok(str),
_ => Err(Err::Error(Error::new(i, ErrorKind::Alt))),
},
_ => Err(Err::Error(Error::new(i, ErrorKind::Alt))),
}?;
let (i, val) = parser(i)?;
Ok((i, (key, val)))
}
fn parser(i: &[u8]) -> IResult<&[u8], CborValue> {
let (i, cbor_type) = peek_cbor_type(i)?;
let (i, cbor_value) = extract_cbor_value(i)?;
return match cbor_type {
0 => {
Ok((
i,
match cbor_value.to_u8() {
Ok(0) => CborValue::HashTree(CborHashTree::Empty),
Ok(1) => CborValue::HashTree(CborHashTree::Fork),
Ok(2) => CborValue::HashTree(CborHashTree::Labelled),
Ok(3) => CborValue::HashTree(CborHashTree::Leaf),
Ok(4) => CborValue::HashTree(CborHashTree::Pruned),
_ => CborValue::Unsigned(cbor_value),
},
))
}
1 => Ok((i, CborValue::Signed(cbor_value.to_negative()))),
2 | 3 => {
let data_len = cbor_value.to_usize();
let (i, data) = take(data_len)(i)?;
Ok((i, CborValue::ByteString(data.to_vec())))
}
4 => {
let data_len = cbor_value.to_usize();
let (i, data) = count(parser, data_len)(i)?;
Ok((i, CborValue::Array(data)))
}
5 => {
let data_len = cbor_value.to_usize();
let (i, data) = fold_many_m_n(
0,
data_len,
extract_key_val_pair,
|| HashMap::with_capacity(data_len),
|mut acc, (key, val)| {
acc.insert(key, val);
acc
},
)(i)?;
Ok((i, CborValue::Map(data)))
}
6 => parser(i),
7 => parser(i),
_ => Err(Err::Error(Error::new(i, ErrorKind::Alt))),
};
}
pub fn parse_cbor(i: &[u8]) -> Result<CborValue, nom::Err<Error<&[u8]>>> {
let (_remaining, result) = terminated(parser, eof)(i)?;
Ok(result)
}
pub fn parse_cbor_principals_array(i: &[u8]) -> CborResult<Vec<(Principal, Principal)>> {
let parsed_cbor = parse_cbor(i).map_err(|e| CborError::MalformedCbor(e.to_string()))?;
let CborValue::Array(ranges_entries) = parsed_cbor else {
return Err(CborError::MalformedCborCanisterRanges);
};
ranges_entries
.iter()
.map(|ranges_entry| {
let CborValue::Array(range) = ranges_entry else {
return Err(CborError::MalformedCborCanisterRanges);
};
let (first_principal, second_principal) = match (range.get(0), range.get(1)) {
(Some(CborValue::ByteString(a)), Some(CborValue::ByteString(b))) => (a, b),
_ => return Err(CborError::MalformedCborCanisterRanges),
};
Ok((
Principal::from_slice(first_principal),
Principal::from_slice(second_principal),
))
})
.collect::<Result<_, _>>()
}
pub fn parse_cbor_string_array(i: &[u8]) -> CborResult<Vec<String>> {
let parsed_cbor = parse_cbor(i).map_err(|e| CborError::MalformedCbor(e.to_string()))?;
let CborValue::Array(elems) = parsed_cbor else {
return Err(CborError::UnexpectedCborNodeType {
expected_type: "Array".into(),
found_type: parsed_cbor.to_string()
});
};
elems
.iter()
.map(|elem| {
let CborValue::ByteString(elem) = elem else {
return Err(CborError::UnexpectedCborNodeType {
expected_type: "Array".into(),
found_type: elem.to_string()
});
};
String::from_utf8(elem.to_owned()).map_err(CborError::Utf8ConversionError)
})
.collect::<Result<_, _>>()
}
#[cfg(test)]
mod tests {
use super::*;
use ic_response_verification_test_utils::{cbor_encode, hex_decode};
#[test]
fn decodes_arrays() {
let cbor_hex = "83070809";
let cbor = hex_decode(cbor_hex);
let result = parse_cbor(cbor.as_slice()).unwrap();
assert_eq!(
result,
CborValue::Array(vec![
CborValue::Unsigned(CborUnsignedInt::UInt8(7)),
CborValue::Unsigned(CborUnsignedInt::UInt8(8)),
CborValue::Unsigned(CborUnsignedInt::UInt8(9)),
])
);
}
#[test]
fn decodes_nested_arrays() {
let cbor_hex = "8307820809820A0B";
let cbor = hex_decode(cbor_hex);
let result = parse_cbor(cbor.as_slice()).unwrap();
assert_eq!(
result,
CborValue::Array(vec![
CborValue::Unsigned(CborUnsignedInt::UInt8(7)),
CborValue::Array(vec![
CborValue::Unsigned(CborUnsignedInt::UInt8(8)),
CborValue::Unsigned(CborUnsignedInt::UInt8(9)),
]),
CborValue::Array(vec![
CborValue::Unsigned(CborUnsignedInt::UInt8(10)),
CborValue::Unsigned(CborUnsignedInt::UInt8(11)),
]),
])
);
}
#[test]
fn decodes_array_with_nested_map() {
let cbor_hex = "826161a161626163";
let cbor = hex_decode(cbor_hex);
let result = parse_cbor(cbor.as_slice()).unwrap();
assert_eq!(
result,
CborValue::Array(vec![
CborValue::ByteString(Vec::from("a")),
CborValue::Map(HashMap::from([(
String::from("b"),
CborValue::ByteString(Vec::from("c"))
)])),
])
);
}
#[test]
fn decodes_map_with_nested_array() {
let cbor_hex = "A26161076162820809";
let cbor = hex_decode(cbor_hex);
let result = parse_cbor(cbor.as_slice()).unwrap();
assert_eq!(
result,
CborValue::Map(HashMap::from([
(
String::from("a"),
CborValue::Unsigned(CborUnsignedInt::UInt8(7))
),
(
String::from("b"),
CborValue::Array(vec![
CborValue::Unsigned(CborUnsignedInt::UInt8(8)),
CborValue::Unsigned(CborUnsignedInt::UInt8(9)),
])
),
]))
)
}
#[test]
fn can_parse_cbor_principals_array() {
let expected_cbor = vec![(
Principal::from_slice("rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes()),
Principal::from_slice("rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes()),
)];
assert_eq!(
parse_cbor_principals_array(&cbor_encode(&expected_cbor)).unwrap(),
vec![(
Principal::from_slice("rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes()),
Principal::from_slice("rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes())
)],
)
}
#[test]
fn fails_to_parse_cbor_principals_array() {
let expected_cbor = vec![(
"rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes(),
"rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes(),
)];
assert!(matches!(
parse_cbor_principals_array(&cbor_encode(&expected_cbor)).err(),
Some(CborError::MalformedCborCanisterRanges),
));
}
}