1use std::borrow::Cow;
7
8use aws_smithy_types::{BigInteger, Blob, DateTime};
9use minicbor::decode::Error;
10
11use crate::data::Type;
12
13#[derive(Debug, Clone)]
21pub struct Decoder<'b> {
22 decoder: minicbor::Decoder<'b>,
23}
24
25#[derive(Debug)]
28pub struct DeserializeError {
29 #[allow(dead_code)]
30 _inner: Error,
31}
32
33impl std::fmt::Display for DeserializeError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 self._inner.fmt(f)
36 }
37}
38
39impl std::error::Error for DeserializeError {}
40
41impl DeserializeError {
42 pub(crate) fn new(inner: Error) -> Self {
43 Self { _inner: inner }
44 }
45
46 pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self {
48 Self {
49 _inner: Error::type_mismatch(unexpected_type.into_minicbor_type())
50 .with_message("encountered unexpected union variant; expected end of union")
51 .at(at),
52 }
53 }
54
55 pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self {
57 Self {
58 _inner: Error::message(format!("encountered unknown union variant {variant_name}"))
59 .at(at),
60 }
61 }
62
63 pub fn mixed_union_variants(at: usize) -> Self {
67 Self {
68 _inner: Error::message(
69 "encountered mixed variants in union; expected a single union variant to be set",
70 )
71 .at(at),
72 }
73 }
74
75 pub fn expected_end_of_stream(at: usize) -> Self {
77 Self {
78 _inner: Error::message("encountered additional data; expected end of stream").at(at),
79 }
80 }
81
82 pub fn custom(message: impl Into<Cow<'static, str>>, at: usize) -> Self {
84 Self {
85 _inner: Error::message(message.into()).at(at),
86 }
87 }
88
89 pub fn is_type_mismatch(&self) -> bool {
93 self._inner.is_type_mismatch()
94 }
95}
96
97macro_rules! delegate_method {
113 ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
114 $(
115 pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
116 self.decoder.$encoder_name().map_err(DeserializeError::new)
117 }
118 )+
119 };
120}
121
122impl<'b> Decoder<'b> {
123 pub fn new(bytes: &'b [u8]) -> Self {
124 Self {
125 decoder: minicbor::Decoder::new(bytes),
126 }
127 }
128
129 pub fn datatype(&self) -> Result<Type, DeserializeError> {
130 self.decoder
131 .datatype()
132 .map(Type::new)
133 .map_err(DeserializeError::new)
134 }
135
136 delegate_method! {
137 skip => skip(());
139 boolean => bool(bool);
141 byte => i8(i8);
143 short => i16(i16);
145 integer => i32(i32);
147 long => i64(i64);
149 float => f32(f32);
151 double => f64(f64);
153 null => null(());
155 list => array(Option<u64>);
157 map => map(Option<u64>);
159 }
160
161 pub fn position(&self) -> usize {
163 self.decoder.position()
164 }
165
166 pub fn set_position(&mut self, pos: usize) {
168 self.decoder.set_position(pos)
169 }
170
171 pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
176 let bookmark = self.decoder.position();
177 match self.decoder.str() {
178 Ok(str_value) => Ok(Cow::Borrowed(str_value)),
179 Err(e) if e.is_type_mismatch() => {
180 self.decoder.set_position(bookmark);
183 Ok(Cow::Owned(self.string()?))
184 }
185 Err(e) => Err(DeserializeError::new(e)),
186 }
187 }
188
189 pub fn string(&mut self) -> Result<String, DeserializeError> {
192 let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
193 let head = iter.next();
194
195 let decoded_string = match head {
196 None => String::new(),
197 Some(head) => {
198 let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
199 for chunk in iter {
200 combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
201 }
202 combined_chunks
203 }
204 };
205
206 Ok(decoded_string)
207 }
208
209 pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
212 let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
213 let parts: Vec<&[u8]> = iter
214 .collect::<Result<_, _>>()
215 .map_err(DeserializeError::new)?;
216
217 Ok(if parts.len() == 1 {
218 Blob::new(parts[0]) } else {
220 Blob::new(parts.concat()) })
222 }
223
224 pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
227 let tag = self.decoder.tag().map_err(DeserializeError::new)?;
228 let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp);
229
230 if tag != timestamp_tag {
231 Err(DeserializeError::new(Error::message(
232 "expected timestamp tag",
233 )))
234 } else {
235 let epoch_seconds = match self.decoder.datatype().map_err(DeserializeError::new)? {
245 minicbor::data::Type::F16
246 | minicbor::data::Type::F32
247 | minicbor::data::Type::F64 => self.decoder.f64().map_err(DeserializeError::new)?,
248 _ => self.decoder.i64().map_err(DeserializeError::new)? as f64,
249 };
250 let mut result = DateTime::from_secs_f64(epoch_seconds);
251 let subsec_nanos = result.subsec_nanos();
252 result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000);
253 Ok(result)
254 }
255 }
256
257 pub fn big_integer(&mut self) -> Result<BigInteger, DeserializeError> {
264 use num_bigint::BigInt;
265
266 match self.decoder.datatype().map_err(DeserializeError::new)? {
267 minicbor::data::Type::Tag => {
268 let tag = self.decoder.tag().map_err(DeserializeError::new)?;
269 let bytes = self.decoder.bytes().map_err(DeserializeError::new)?;
270 let n = BigInt::from_bytes_be(num_bigint::Sign::Plus, bytes);
271
272 let value = match tag.as_u64() {
273 2 => n,
274 3 => -n - 1, _ => {
276 return Err(DeserializeError::new(Error::message(
277 "expected CBOR tag 2 (positive bignum) or tag 3 (negative bignum)",
278 )));
279 }
280 };
281 value
282 .to_string()
283 .parse()
284 .map_err(|_| DeserializeError::new(Error::message("invalid bignum value")))
285 }
286 minicbor::data::Type::U8
287 | minicbor::data::Type::U16
288 | minicbor::data::Type::U32
289 | minicbor::data::Type::U64 => {
290 let value = self.decoder.u64().map_err(DeserializeError::new)?;
291 value
292 .to_string()
293 .parse()
294 .map_err(|_| DeserializeError::new(Error::message("invalid integer value")))
295 }
296 minicbor::data::Type::I8
297 | minicbor::data::Type::I16
298 | minicbor::data::Type::I32
299 | minicbor::data::Type::I64 => {
300 let value = self.decoder.i64().map_err(DeserializeError::new)?;
301 value
302 .to_string()
303 .parse()
304 .map_err(|_| DeserializeError::new(Error::message("invalid integer value")))
305 }
306 minicbor::data::Type::Int => {
309 let int_val = self.decoder.int().map_err(DeserializeError::new)?;
310 let value: i128 = int_val.into();
311 BigInt::from(value)
312 .to_string()
313 .parse()
314 .map_err(|_| DeserializeError::new(Error::message("invalid integer value")))
315 }
316 _ => Err(DeserializeError::new(Error::message(
317 "expected CBOR integer or bignum tag",
318 ))),
319 }
320 }
321}
322
323#[allow(dead_code)] #[derive(Debug)]
325pub struct ArrayIter<'a, 'b, T> {
326 inner: minicbor::decode::ArrayIter<'a, 'b, T>,
327}
328
329impl<'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'_, 'b, T> {
330 type Item = Result<T, DeserializeError>;
331
332 fn next(&mut self) -> Option<Self::Item> {
333 self.inner
334 .next()
335 .map(|opt| opt.map_err(DeserializeError::new))
336 }
337}
338
339#[allow(dead_code)] #[derive(Debug)]
341pub struct MapIter<'a, 'b, K, V> {
342 inner: minicbor::decode::MapIter<'a, 'b, K, V>,
343}
344
345impl<'b, K, V> Iterator for MapIter<'_, 'b, K, V>
346where
347 K: minicbor::Decode<'b, ()>,
348 V: minicbor::Decode<'b, ()>,
349{
350 type Item = Result<(K, V), DeserializeError>;
351
352 fn next(&mut self) -> Option<Self::Item> {
353 self.inner
354 .next()
355 .map(|opt| opt.map_err(DeserializeError::new))
356 }
357}
358
359pub fn set_optional<B, F>(builder: B, decoder: &mut Decoder, f: F) -> Result<B, DeserializeError>
360where
361 F: Fn(B, &mut Decoder) -> Result<B, DeserializeError>,
362{
363 match decoder.datatype()? {
364 crate::data::Type::Null => {
365 decoder.null()?;
366 Ok(builder)
367 }
368 _ => f(builder, decoder),
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use crate::Decoder;
375 use aws_smithy_types::date_time::Format;
376
377 #[test]
378 fn test_definite_str_is_cow_borrowed() {
379 let definite_bytes = [
381 0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79,
382 ];
383 let mut decoder = Decoder::new(&definite_bytes);
384 let member = decoder.str().expect("could not decode str");
385 assert_eq!(member, "thisIsAKey");
386 assert!(matches!(member, std::borrow::Cow::Borrowed(_)));
387 }
388
389 #[test]
390 fn test_indefinite_str_is_cow_owned() {
391 let indefinite_bytes = [
393 0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65,
394 0x79, 0xff,
395 ];
396 let mut decoder = Decoder::new(&indefinite_bytes);
397 let member = decoder.str().expect("could not decode str");
398 assert_eq!(member, "thisIsAKey");
399 assert!(matches!(member, std::borrow::Cow::Owned(_)));
400 }
401
402 #[test]
403 fn test_empty_str_works() {
404 let bytes = [0x60];
405 let mut decoder = Decoder::new(&bytes);
406 let member = decoder.str().expect("could not decode empty str");
407 assert_eq!(member, "");
408 }
409
410 #[test]
411 fn test_empty_blob_works() {
412 let bytes = [0x40];
413 let mut decoder = Decoder::new(&bytes);
414 let member = decoder.blob().expect("could not decode an empty blob");
415 assert_eq!(member, aws_smithy_types::Blob::new([]));
416 }
417
418 #[test]
419 fn test_indefinite_length_blob() {
420 let indefinite_bytes = [
423 0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62,
424 0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c,
425 0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d,
426 0x61, 0xff,
427 ];
428 let mut decoder = Decoder::new(&indefinite_bytes);
429 let member = decoder.blob().expect("could not decode blob");
430 assert_eq!(
431 member,
432 aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes())
433 );
434 }
435
436 #[test]
437 fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() {
438 let bytes = [
441 0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff,
442 ];
443 let mut decoder = Decoder::new(&bytes);
444 let timestamp = decoder.timestamp().expect("should decode timestamp");
445 assert_eq!(
446 timestamp,
447 aws_smithy_types::date_time::DateTime::from_str(
448 "2000-01-02T20:34:56.123Z",
449 Format::DateTime
450 )
451 .unwrap()
452 );
453 }
454
455 #[test]
456 fn big_integer_round_trip_positive() {
457 for value in ["0", "1", "23", "256", "65535", "18446744073709551615"] {
458 let mut encoder = crate::Encoder::new(Vec::new());
459 encoder.big_integer(&value.parse().unwrap());
460 let bytes = encoder.into_writer();
461 let mut decoder = Decoder::new(&bytes);
462 let result = decoder.big_integer().expect("should decode");
463 assert_eq!(result.as_ref(), value, "round-trip failed for {value}");
464 }
465 }
466
467 #[test]
468 fn big_integer_round_trip_negative() {
469 for value in ["-1", "-42", "-256", "-18446744073709551616"] {
470 let mut encoder = crate::Encoder::new(Vec::new());
471 encoder.big_integer(&value.parse().unwrap());
472 let bytes = encoder.into_writer();
473 let mut decoder = Decoder::new(&bytes);
474 let result = decoder.big_integer().expect("should decode");
475 assert_eq!(result.as_ref(), value, "round-trip failed for {value}");
476 }
477 }
478
479 #[test]
480 fn big_integer_round_trip_large() {
481 let large_pos = "123456789012345678901234567890";
482 let large_neg = "-123456789012345678901234567890";
483 for value in [large_pos, large_neg] {
484 let mut encoder = crate::Encoder::new(Vec::new());
485 encoder.big_integer(&value.parse().unwrap());
486 let bytes = encoder.into_writer();
487 let mut decoder = Decoder::new(&bytes);
488 let result = decoder.big_integer().expect("should decode");
489 assert_eq!(result.as_ref(), value, "round-trip failed for {value}");
490 }
491 }
492
493 #[test]
494 fn big_integer_rfc8949_appendix_a_positive() {
495 let bytes = [
497 0xc2, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
498 ];
499 let mut decoder = Decoder::new(&bytes);
500 let result = decoder.big_integer().expect("should decode");
501 assert_eq!(result.as_ref(), "18446744073709551616");
502 }
503
504 #[test]
505 fn big_integer_rfc8949_appendix_a_negative() {
506 let bytes = [
508 0xc3, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
509 ];
510 let mut decoder = Decoder::new(&bytes);
511 let result = decoder.big_integer().expect("should decode");
512 assert_eq!(result.as_ref(), "-18446744073709551617");
513 }
514
515 #[test]
516 fn big_integer_from_plain_cbor_unsigned() {
517 let mut enc = minicbor::Encoder::new(Vec::new());
518 enc.u64(9999).unwrap();
519 let bytes = enc.into_writer();
520 let mut decoder = Decoder::new(&bytes);
521 let result = decoder.big_integer().expect("should decode plain integer");
522 assert_eq!(result.as_ref(), "9999");
523 }
524
525 #[test]
526 fn big_integer_from_plain_cbor_negative() {
527 let mut enc = minicbor::Encoder::new(Vec::new());
528 enc.i64(-500).unwrap();
529 let bytes = enc.into_writer();
530 let mut decoder = Decoder::new(&bytes);
531 let result = decoder
532 .big_integer()
533 .expect("should decode negative plain integer");
534 assert_eq!(result.as_ref(), "-500");
535 }
536
537 #[test]
538 fn big_integer_from_plain_cbor_positive_signed() {
539 let mut enc = minicbor::Encoder::new(Vec::new());
542 enc.i64(123).unwrap();
543 let bytes = enc.into_writer();
544 let mut decoder = Decoder::new(&bytes);
545 let result = decoder
546 .big_integer()
547 .expect("should decode positive plain integer");
548 assert_eq!(result.as_ref(), "123");
549 }
550
551 #[test]
552 fn big_integer_tag3_empty_byte_string() {
553 let bytes = [0xc3, 0x40]; let mut decoder = Decoder::new(&bytes);
556 let result = decoder.big_integer().expect("should decode");
557 assert_eq!(result.as_ref(), "-1");
558 }
559
560 #[test]
561 fn big_integer_tag2_empty_byte_string() {
562 let bytes = [0xc2, 0x40]; let mut decoder = Decoder::new(&bytes);
565 let result = decoder.big_integer().expect("should decode");
566 assert_eq!(result.as_ref(), "0");
567 }
568
569 #[test]
570 fn big_integer_rejects_invalid_tag() {
571 let bytes = [0xc4, 0x82, 0x21, 0x19, 0x6a, 0xb3];
573 let mut decoder = Decoder::new(&bytes);
574 assert!(decoder.big_integer().is_err());
575 }
576
577 #[test]
578 fn big_integer_decode_major_type_1_exceeding_i64() {
579 let bytes = [0x3b, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff];
583 let mut decoder = Decoder::new(&bytes);
584 let result = decoder
585 .big_integer()
586 .expect("should decode major type 1 > i64::MAX");
587 assert_eq!(result.as_ref(), "-18446744073709551616");
588 }
589
590 #[test]
591 fn test_timestamp_integer_epoch_seconds() {
592 let bytes = [0xc1u8, 0x1a, 0x65, 0x53, 0xf1, 0x00];
595 let mut decoder = Decoder::new(&bytes);
596 let timestamp = decoder
597 .timestamp()
598 .expect("should decode integer timestamp");
599 assert_eq!(timestamp, aws_smithy_types::DateTime::from_secs(1700000000));
600 }
601}