1use crate::{CborError, CborResult};
2use candid::Principal;
3use nom::{
4 bytes::complete::take,
5 combinator::{eof, map, peek},
6 error::{Error, ErrorKind},
7 multi::{count, fold_many_m_n},
8 number::complete::{be_u16, be_u32, be_u64, be_u8},
9 sequence::terminated,
10 Err, IResult,
11};
12use std::{collections::HashMap, fmt};
13
14#[derive(Debug, Copy, Clone, Eq, PartialEq)]
15pub enum CborNegativeInt {
16 Int8(i8),
17 Int16(i16),
18 Int32(i32),
19 Int64(i64),
20}
21
22#[derive(Debug, Copy, Clone, Eq, PartialEq)]
23pub enum CborUnsignedInt {
24 UInt8(u8),
25 UInt16(u16),
26 UInt32(u32),
27 UInt64(u64),
28}
29
30impl CborUnsignedInt {
31 fn to_usize(self) -> usize {
32 match self {
33 CborUnsignedInt::UInt8(v) => v as usize,
34 CborUnsignedInt::UInt16(v) => v as usize,
35 CborUnsignedInt::UInt32(v) => v as usize,
36 CborUnsignedInt::UInt64(v) => v as usize,
37 }
38 }
39
40 fn to_negative(self) -> CborNegativeInt {
46 match self {
47 CborUnsignedInt::UInt8(n) => CborNegativeInt::Int8(-1 - (n as i8)),
48 CborUnsignedInt::UInt16(n) => CborNegativeInt::Int16(-1 - (n as i16)),
49 CborUnsignedInt::UInt32(n) => CborNegativeInt::Int32(-1 - (n as i32)),
50 CborUnsignedInt::UInt64(n) => CborNegativeInt::Int64(-1 - (n as i64)),
51 }
52 }
53
54 fn to_u8(self) -> Result<u8, String> {
55 Ok(match self {
56 CborUnsignedInt::UInt8(n) => n,
57 _ => return Err(String::from("Expected u8")),
58 })
59 }
60}
61
62#[derive(Debug, Clone, Eq, PartialEq)]
63pub enum CborHashTree {
64 Empty,
65 Fork,
66 Labelled,
67 Leaf,
68 Pruned,
69}
70
71#[derive(Debug, Clone, Eq, PartialEq)]
72pub enum CborValue {
73 Unsigned(CborUnsignedInt),
74 Signed(CborNegativeInt),
75 ByteString(Vec<u8>),
76 Array(Vec<CborValue>),
77 Map(HashMap<String, CborValue>),
78 HashTree(CborHashTree),
79}
80
81impl fmt::Display for CborValue {
82 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
83 fmt::Debug::fmt(self, f)
84 }
85}
86
87const fn get_cbor_type(e: u8) -> u8 {
89 (e & 0b1110_0000) >> 5
90}
91
92fn extract_cbor_type(i: &[u8]) -> IResult<&[u8], u8> {
93 map(be_u8, get_cbor_type)(i)
94}
95
96fn peek_cbor_type(i: &[u8]) -> IResult<&[u8], u8> {
97 peek(extract_cbor_type)(i)
98}
99
100const fn get_cbor_info(e: u8) -> u8 {
104 e & 0b0001_1111
105}
106
107fn extract_cbor_info(i: &[u8]) -> IResult<&[u8], u8> {
108 map(be_u8, get_cbor_info)(i)
109}
110
111fn extract_cbor_value(i: &[u8]) -> IResult<&[u8], CborUnsignedInt> {
112 let (i, cbor_info) = extract_cbor_info(i)?;
113
114 match cbor_info {
115 _n @ 0..=23 => Ok((i, CborUnsignedInt::UInt8(cbor_info))),
116 24 => map(be_u8, CborUnsignedInt::UInt8)(i),
117 25 => map(be_u16, CborUnsignedInt::UInt16)(i),
118 26 => map(be_u32, CborUnsignedInt::UInt32)(i),
119 27 => map(be_u64, CborUnsignedInt::UInt64)(i),
120 _ => Err(Err::Error(Error::new(i, ErrorKind::Alt))),
121 }
122}
123
124fn extract_key_val_pair(i: &[u8]) -> IResult<&[u8], (String, CborValue)> {
125 let (i, key) = parser(i)?;
126
127 let key = match key {
128 CborValue::ByteString(byte_string) => match String::from_utf8(byte_string) {
129 Ok(str) => Ok(str),
130 _ => Err(Err::Error(Error::new(i, ErrorKind::Alt))),
131 },
132 _ => Err(Err::Error(Error::new(i, ErrorKind::Alt))),
133 }?;
134
135 let (i, val) = parser(i)?;
136
137 Ok((i, (key, val)))
138}
139
140fn parser(i: &[u8]) -> IResult<&[u8], CborValue> {
141 let (i, cbor_type) = peek_cbor_type(i)?;
142 let (i, cbor_value) = extract_cbor_value(i)?;
143
144 return match cbor_type {
145 0 => {
146 Ok((
149 i,
150 match cbor_value.to_u8() {
151 Ok(0) => CborValue::HashTree(CborHashTree::Empty),
152 Ok(1) => CborValue::HashTree(CborHashTree::Fork),
153 Ok(2) => CborValue::HashTree(CborHashTree::Labelled),
154 Ok(3) => CborValue::HashTree(CborHashTree::Leaf),
155 Ok(4) => CborValue::HashTree(CborHashTree::Pruned),
156 _ => CborValue::Unsigned(cbor_value),
157 },
158 ))
159 }
160
161 1 => Ok((i, CborValue::Signed(cbor_value.to_negative()))),
162
163 2 | 3 => {
164 let data_len = cbor_value.to_usize();
165 let (i, data) = take(data_len)(i)?;
166
167 Ok((i, CborValue::ByteString(data.to_vec())))
168 }
169
170 4 => {
171 let data_len = cbor_value.to_usize();
172 let (i, data) = count(parser, data_len)(i)?;
173
174 Ok((i, CborValue::Array(data)))
175 }
176
177 5 => {
178 let data_len = cbor_value.to_usize();
179 let (i, data) = fold_many_m_n(
180 0,
181 data_len,
182 extract_key_val_pair,
183 || HashMap::with_capacity(data_len),
184 |mut acc, (key, val)| {
185 acc.insert(key, val);
186 acc
187 },
188 )(i)?;
189
190 Ok((i, CborValue::Map(data)))
191 }
192
193 6 => parser(i),
195 7 => parser(i),
196
197 _ => Err(Err::Error(Error::new(i, ErrorKind::Alt))),
198 };
199}
200
201pub fn parse_cbor(i: &[u8]) -> Result<CborValue, nom::Err<Error<&[u8]>>> {
202 let (_remaining, result) = terminated(parser, eof)(i)?;
203
204 Ok(result)
205}
206
207pub fn parse_cbor_principals_array(i: &[u8]) -> CborResult<Vec<(Principal, Principal)>> {
208 let parsed_cbor = parse_cbor(i).map_err(|e| CborError::MalformedCbor(e.to_string()))?;
209
210 let CborValue::Array(ranges_entries) = parsed_cbor else {
211 return Err(CborError::MalformedCborCanisterRanges);
212 };
213
214 ranges_entries
215 .iter()
216 .map(|ranges_entry| {
217 let CborValue::Array(range) = ranges_entry else {
218 return Err(CborError::MalformedCborCanisterRanges);
219 };
220
221 let (first_principal, second_principal) = match (range.first(), range.get(1)) {
222 (Some(CborValue::ByteString(a)), Some(CborValue::ByteString(b))) => (a, b),
223 _ => return Err(CborError::MalformedCborCanisterRanges),
224 };
225
226 Ok((
227 Principal::from_slice(first_principal),
228 Principal::from_slice(second_principal),
229 ))
230 })
231 .collect::<Result<_, _>>()
232}
233
234pub fn parse_cbor_string_array(i: &[u8]) -> CborResult<Vec<String>> {
235 let parsed_cbor = parse_cbor(i).map_err(|e| CborError::MalformedCbor(e.to_string()))?;
236
237 let CborValue::Array(elems) = parsed_cbor else {
238 return Err(CborError::UnexpectedCborNodeType {
239 expected_type: "Array".into(),
240 found_type: parsed_cbor.to_string(),
241 });
242 };
243
244 elems
245 .iter()
246 .map(|elem| {
247 let CborValue::ByteString(elem) = elem else {
248 return Err(CborError::UnexpectedCborNodeType {
249 expected_type: "ByteString".into(),
250 found_type: elem.to_string(),
251 });
252 };
253
254 String::from_utf8(elem.to_owned()).map_err(CborError::Utf8ConversionError)
255 })
256 .collect::<Result<_, _>>()
257}
258
259#[cfg(test)]
261mod tests {
262 use super::*;
263 use ic_response_verification_test_utils::{cbor_encode, hex_decode};
264
265 #[test]
266 fn decodes_arrays() {
267 let cbor_hex = "83070809";
268 let cbor = hex_decode(cbor_hex);
269
270 let result = parse_cbor(cbor.as_slice()).unwrap();
271
272 assert_eq!(
273 result,
274 CborValue::Array(vec![
275 CborValue::Unsigned(CborUnsignedInt::UInt8(7)),
276 CborValue::Unsigned(CborUnsignedInt::UInt8(8)),
277 CborValue::Unsigned(CborUnsignedInt::UInt8(9)),
278 ])
279 );
280 }
281
282 #[test]
283 fn decodes_nested_arrays() {
284 let cbor_hex = "8307820809820A0B";
285 let cbor = hex_decode(cbor_hex);
286
287 let result = parse_cbor(cbor.as_slice()).unwrap();
288
289 assert_eq!(
290 result,
291 CborValue::Array(vec![
292 CborValue::Unsigned(CborUnsignedInt::UInt8(7)),
293 CborValue::Array(vec![
294 CborValue::Unsigned(CborUnsignedInt::UInt8(8)),
295 CborValue::Unsigned(CborUnsignedInt::UInt8(9)),
296 ]),
297 CborValue::Array(vec![
298 CborValue::Unsigned(CborUnsignedInt::UInt8(10)),
299 CborValue::Unsigned(CborUnsignedInt::UInt8(11)),
300 ]),
301 ])
302 );
303 }
304
305 #[test]
306 fn decodes_array_with_nested_map() {
307 let cbor_hex = "826161a161626163";
308 let cbor = hex_decode(cbor_hex);
309
310 let result = parse_cbor(cbor.as_slice()).unwrap();
311
312 assert_eq!(
313 result,
314 CborValue::Array(vec![
315 CborValue::ByteString(Vec::from("a")),
316 CborValue::Map(HashMap::from([(
317 String::from("b"),
318 CborValue::ByteString(Vec::from("c"))
319 )])),
320 ])
321 );
322 }
323
324 #[test]
325 fn decodes_map_with_nested_array() {
326 let cbor_hex = "A26161076162820809";
327 let cbor = hex_decode(cbor_hex);
328
329 let result = parse_cbor(cbor.as_slice()).unwrap();
330
331 assert_eq!(
332 result,
333 CborValue::Map(HashMap::from([
334 (
335 String::from("a"),
336 CborValue::Unsigned(CborUnsignedInt::UInt8(7))
337 ),
338 (
339 String::from("b"),
340 CborValue::Array(vec![
341 CborValue::Unsigned(CborUnsignedInt::UInt8(8)),
342 CborValue::Unsigned(CborUnsignedInt::UInt8(9)),
343 ])
344 ),
345 ]))
346 )
347 }
348
349 #[test]
350 fn can_parse_cbor_principals_array() {
351 let expected_cbor = vec![(
352 Principal::from_slice("rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes()),
353 Principal::from_slice("rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes()),
354 )];
355
356 assert_eq!(
357 parse_cbor_principals_array(&cbor_encode(&expected_cbor)).unwrap(),
358 vec![(
359 Principal::from_slice("rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes()),
360 Principal::from_slice("rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes())
361 )],
362 )
363 }
364
365 #[test]
366 fn fails_to_parse_cbor_principals_array() {
367 let expected_cbor = vec![(
368 "rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes(),
369 "rdmx6-jaaaa-aaaaa-aaadq-cai".as_bytes(),
370 )];
371
372 assert!(matches!(
373 parse_cbor_principals_array(&cbor_encode(&expected_cbor)).err(),
374 Some(CborError::MalformedCborCanisterRanges),
375 ));
376 }
377}