Skip to main content

aws_smithy_cbor/
decode.rs

1/*
2 * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6use std::borrow::Cow;
7
8use aws_smithy_types::{BigInteger, Blob, DateTime};
9use minicbor::decode::Error;
10
11use crate::data::Type;
12
13/// Provides functions for decoding a CBOR object with a known schema.
14///
15/// Although CBOR is a self-describing format, this decoder is tailored for cases where the schema
16/// is known in advance. Therefore, the caller can determine which object key exists at the current
17/// position by calling `str` method, and call the relevant function based on the predetermined schema
18/// for that key. If an unexpected key is encountered, the caller can use the `skip` method to skip
19/// over the element.
20#[derive(Debug, Clone)]
21pub struct Decoder<'b> {
22    decoder: minicbor::Decoder<'b>,
23}
24
25/// When any of the decode methods are called they look for that particular data type at the current
26/// position. If the CBOR data tag does not match the type, a `DeserializeError` is returned.
27#[derive(Debug)]
28pub struct DeserializeError {
29    #[allow(dead_code)]
30    _inner: Error,
31}
32
33impl std::fmt::Display for DeserializeError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        self._inner.fmt(f)
36    }
37}
38
39impl std::error::Error for DeserializeError {}
40
41impl DeserializeError {
42    pub(crate) fn new(inner: Error) -> Self {
43        Self { _inner: inner }
44    }
45
46    /// More than one union variant was detected: `unexpected_type` was unexpected.
47    pub fn unexpected_union_variant(unexpected_type: Type, at: usize) -> Self {
48        Self {
49            _inner: Error::type_mismatch(unexpected_type.into_minicbor_type())
50                .with_message("encountered unexpected union variant; expected end of union")
51                .at(at),
52        }
53    }
54
55    /// Unknown union variant was detected. Servers reject unknown union varaints.
56    pub fn unknown_union_variant(variant_name: &str, at: usize) -> Self {
57        Self {
58            _inner: Error::message(format!("encountered unknown union variant {variant_name}"))
59                .at(at),
60        }
61    }
62
63    /// More than one union variant was detected, but we never even got to parse the first one.
64    /// We immediately raise this error when detecting a union serialized as a fixed-length CBOR
65    /// map whose length (specified upfront) is a value different than 1.
66    pub fn mixed_union_variants(at: usize) -> Self {
67        Self {
68            _inner: Error::message(
69                "encountered mixed variants in union; expected a single union variant to be set",
70            )
71            .at(at),
72        }
73    }
74
75    /// Expected end of stream but more data is available.
76    pub fn expected_end_of_stream(at: usize) -> Self {
77        Self {
78            _inner: Error::message("encountered additional data; expected end of stream").at(at),
79        }
80    }
81
82    /// Returns a custom error with an offset.
83    pub fn custom(message: impl Into<Cow<'static, str>>, at: usize) -> Self {
84        Self {
85            _inner: Error::message(message.into()).at(at),
86        }
87    }
88
89    /// An unexpected type was encountered.
90    // We handle this one when decoding sparse collections: we have to expect either a `null` or an
91    // item, so we try decoding both.
92    pub fn is_type_mismatch(&self) -> bool {
93        self._inner.is_type_mismatch()
94    }
95}
96
97/// Macro for delegating method calls to the decoder.
98///
99/// This macro generates wrapper methods for calling specific methods on the decoder and returning
100/// the result with error handling.
101///
102/// # Example
103///
104/// ```ignore
105/// delegate_method! {
106///     /// Wrapper method for encoding method `encode_str` on the decoder.
107///     encode_str_wrapper => encode_str(String);
108///     /// Wrapper method for encoding method `encode_int` on the decoder.
109///     encode_int_wrapper => encode_int(i32);
110/// }
111/// ```
112macro_rules! delegate_method {
113    ($($(#[$meta:meta])* $wrapper_name:ident => $encoder_name:ident($result_type:ty);)+) => {
114        $(
115            pub fn $wrapper_name(&mut self) -> Result<$result_type, DeserializeError> {
116                self.decoder.$encoder_name().map_err(DeserializeError::new)
117            }
118        )+
119    };
120}
121
122impl<'b> Decoder<'b> {
123    pub fn new(bytes: &'b [u8]) -> Self {
124        Self {
125            decoder: minicbor::Decoder::new(bytes),
126        }
127    }
128
129    pub fn datatype(&self) -> Result<Type, DeserializeError> {
130        self.decoder
131            .datatype()
132            .map(Type::new)
133            .map_err(DeserializeError::new)
134    }
135
136    delegate_method! {
137        /// Skips the current CBOR element.
138        skip => skip(());
139        /// Reads a boolean at the current position.
140        boolean => bool(bool);
141        /// Reads a byte at the current position.
142        byte => i8(i8);
143        /// Reads a short at the current position.
144        short => i16(i16);
145        /// Reads a integer at the current position.
146        integer => i32(i32);
147        /// Reads a long at the current position.
148        long => i64(i64);
149        /// Reads a float at the current position.
150        float => f32(f32);
151        /// Reads a double at the current position.
152        double => f64(f64);
153        /// Reads a null CBOR element at the current position.
154        null => null(());
155        /// Returns the number of elements in a definite list. For indefinite lists it returns a `None`.
156        list => array(Option<u64>);
157        /// Returns the number of elements in a definite map. For indefinite map it returns a `None`.
158        map => map(Option<u64>);
159    }
160
161    /// Returns the current position of the buffer, which will be decoded when any of the methods is called.
162    pub fn position(&self) -> usize {
163        self.decoder.position()
164    }
165
166    /// Set the current decode position.
167    pub fn set_position(&mut self, pos: usize) {
168        self.decoder.set_position(pos)
169    }
170
171    /// Returns a `Cow::Borrowed(&str)` if the element at the current position in the buffer is a definite
172    /// length string. Otherwise, it returns a `Cow::Owned(String)` if the element at the current position is an
173    /// indefinite-length string. An error is returned if the element is neither a definite length nor an
174    /// indefinite-length string.
175    pub fn str(&mut self) -> Result<Cow<'b, str>, DeserializeError> {
176        let bookmark = self.decoder.position();
177        match self.decoder.str() {
178            Ok(str_value) => Ok(Cow::Borrowed(str_value)),
179            Err(e) if e.is_type_mismatch() => {
180                // Move the position back to the start of the CBOR element and then try
181                // decoding it as an indefinite length string.
182                self.decoder.set_position(bookmark);
183                Ok(Cow::Owned(self.string()?))
184            }
185            Err(e) => Err(DeserializeError::new(e)),
186        }
187    }
188
189    /// Allocates and returns a `String` if the element at the current position in the buffer is either a
190    /// definite-length or an indefinite-length string. Otherwise, an error is returned if the element is not a string type.
191    pub fn string(&mut self) -> Result<String, DeserializeError> {
192        let mut iter = self.decoder.str_iter().map_err(DeserializeError::new)?;
193        let head = iter.next();
194
195        let decoded_string = match head {
196            None => String::new(),
197            Some(head) => {
198                let mut combined_chunks = String::from(head.map_err(DeserializeError::new)?);
199                for chunk in iter {
200                    combined_chunks.push_str(chunk.map_err(DeserializeError::new)?);
201                }
202                combined_chunks
203            }
204        };
205
206        Ok(decoded_string)
207    }
208
209    /// Returns a `blob` if the element at the current position in the buffer is a byte string. Otherwise,
210    /// a `DeserializeError` error is returned.
211    pub fn blob(&mut self) -> Result<Blob, DeserializeError> {
212        let iter = self.decoder.bytes_iter().map_err(DeserializeError::new)?;
213        let parts: Vec<&[u8]> = iter
214            .collect::<Result<_, _>>()
215            .map_err(DeserializeError::new)?;
216
217        Ok(if parts.len() == 1 {
218            Blob::new(parts[0]) // Directly convert &[u8] to Blob if there's only one part.
219        } else {
220            Blob::new(parts.concat()) // Concatenate all parts into a single Blob.
221        })
222    }
223
224    /// Returns a `DateTime` if the element at the current position in the buffer is a `timestamp`. Otherwise,
225    /// a `DeserializeError` error is returned.
226    pub fn timestamp(&mut self) -> Result<DateTime, DeserializeError> {
227        let tag = self.decoder.tag().map_err(DeserializeError::new)?;
228        let timestamp_tag = minicbor::data::Tag::from(minicbor::data::IanaTag::Timestamp);
229
230        if tag != timestamp_tag {
231            Err(DeserializeError::new(Error::message(
232                "expected timestamp tag",
233            )))
234        } else {
235            // RFC 8949 §3.4.2: tag 1 content MUST be int OR float.
236            // Values that are more granular than millisecond precision SHOULD be truncated to fit
237            // millisecond precision for epoch-seconds:
238            // https://smithy.io/2.0/spec/protocol-traits.html#timestamp-formats
239            //
240            // Without truncation, the `RpcV2CborDateTimeWithFractionalSeconds` protocol test would
241            // fail since the upstream test expect `123000000` in subsec but the decoded actual
242            // subsec would be `123000025`.
243            // https://github.com/smithy-lang/smithy/blob/6466fe77c65b8a17b219f0b0a60c767915205f95/smithy-protocol-tests/model/rpcv2Cbor/fractional-seconds.smithy#L17
244            let epoch_seconds = match self.decoder.datatype().map_err(DeserializeError::new)? {
245                minicbor::data::Type::F16
246                | minicbor::data::Type::F32
247                | minicbor::data::Type::F64 => self.decoder.f64().map_err(DeserializeError::new)?,
248                _ => self.decoder.i64().map_err(DeserializeError::new)? as f64,
249            };
250            let mut result = DateTime::from_secs_f64(epoch_seconds);
251            let subsec_nanos = result.subsec_nanos();
252            result.set_subsec_nanos((subsec_nanos / 1_000_000) * 1_000_000);
253            Ok(result)
254        }
255    }
256
257    /// Returns a `BigInteger` from either a CBOR tag 2/3 (bignum) or a plain integer.
258    ///
259    /// Per RFC 8949 §3.4.3, tag 2 encodes unsigned bignum `n` and tag 3 encodes
260    /// negative bignum `-1 - n`, where `n` is the unsigned integer from the byte
261    /// string in network byte order. Plain CBOR integers (major types 0 and 1)
262    /// are also accepted per preferred serialization rules.
263    pub fn big_integer(&mut self) -> Result<BigInteger, DeserializeError> {
264        use num_bigint::BigInt;
265
266        match self.decoder.datatype().map_err(DeserializeError::new)? {
267            minicbor::data::Type::Tag => {
268                let tag = self.decoder.tag().map_err(DeserializeError::new)?;
269                let bytes = self.decoder.bytes().map_err(DeserializeError::new)?;
270                let n = BigInt::from_bytes_be(num_bigint::Sign::Plus, bytes);
271
272                let value = match tag.as_u64() {
273                    2 => n,
274                    3 => -n - 1, // tag 3 value = -1 - n
275                    _ => {
276                        return Err(DeserializeError::new(Error::message(
277                            "expected CBOR tag 2 (positive bignum) or tag 3 (negative bignum)",
278                        )));
279                    }
280                };
281                value
282                    .to_string()
283                    .parse()
284                    .map_err(|_| DeserializeError::new(Error::message("invalid bignum value")))
285            }
286            minicbor::data::Type::U8
287            | minicbor::data::Type::U16
288            | minicbor::data::Type::U32
289            | minicbor::data::Type::U64 => {
290                let value = self.decoder.u64().map_err(DeserializeError::new)?;
291                value
292                    .to_string()
293                    .parse()
294                    .map_err(|_| DeserializeError::new(Error::message("invalid integer value")))
295            }
296            minicbor::data::Type::I8
297            | minicbor::data::Type::I16
298            | minicbor::data::Type::I32
299            | minicbor::data::Type::I64 => {
300                let value = self.decoder.i64().map_err(DeserializeError::new)?;
301                value
302                    .to_string()
303                    .parse()
304                    .map_err(|_| DeserializeError::new(Error::message("invalid integer value")))
305            }
306            // Int covers CBOR major type 1 values that exceed i64 range
307            // (argument > i64::MAX, i.e. values from -2^64 to -(i64::MAX+2)).
308            minicbor::data::Type::Int => {
309                let int_val = self.decoder.int().map_err(DeserializeError::new)?;
310                let value: i128 = int_val.into();
311                BigInt::from(value)
312                    .to_string()
313                    .parse()
314                    .map_err(|_| DeserializeError::new(Error::message("invalid integer value")))
315            }
316            _ => Err(DeserializeError::new(Error::message(
317                "expected CBOR integer or bignum tag",
318            ))),
319        }
320    }
321}
322
323#[allow(dead_code)] // to avoid `never constructed` warning
324#[derive(Debug)]
325pub struct ArrayIter<'a, 'b, T> {
326    inner: minicbor::decode::ArrayIter<'a, 'b, T>,
327}
328
329impl<'b, T: minicbor::Decode<'b, ()>> Iterator for ArrayIter<'_, 'b, T> {
330    type Item = Result<T, DeserializeError>;
331
332    fn next(&mut self) -> Option<Self::Item> {
333        self.inner
334            .next()
335            .map(|opt| opt.map_err(DeserializeError::new))
336    }
337}
338
339#[allow(dead_code)] // to avoid `never constructed` warning
340#[derive(Debug)]
341pub struct MapIter<'a, 'b, K, V> {
342    inner: minicbor::decode::MapIter<'a, 'b, K, V>,
343}
344
345impl<'b, K, V> Iterator for MapIter<'_, 'b, K, V>
346where
347    K: minicbor::Decode<'b, ()>,
348    V: minicbor::Decode<'b, ()>,
349{
350    type Item = Result<(K, V), DeserializeError>;
351
352    fn next(&mut self) -> Option<Self::Item> {
353        self.inner
354            .next()
355            .map(|opt| opt.map_err(DeserializeError::new))
356    }
357}
358
359pub fn set_optional<B, F>(builder: B, decoder: &mut Decoder, f: F) -> Result<B, DeserializeError>
360where
361    F: Fn(B, &mut Decoder) -> Result<B, DeserializeError>,
362{
363    match decoder.datatype()? {
364        crate::data::Type::Null => {
365            decoder.null()?;
366            Ok(builder)
367        }
368        _ => f(builder, decoder),
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use crate::Decoder;
375    use aws_smithy_types::date_time::Format;
376
377    #[test]
378    fn test_definite_str_is_cow_borrowed() {
379        // Definite length key `thisIsAKey`.
380        let definite_bytes = [
381            0x6a, 0x74, 0x68, 0x69, 0x73, 0x49, 0x73, 0x41, 0x4b, 0x65, 0x79,
382        ];
383        let mut decoder = Decoder::new(&definite_bytes);
384        let member = decoder.str().expect("could not decode str");
385        assert_eq!(member, "thisIsAKey");
386        assert!(matches!(member, std::borrow::Cow::Borrowed(_)));
387    }
388
389    #[test]
390    fn test_indefinite_str_is_cow_owned() {
391        // Indefinite length key `this`, `Is`, `A` and `Key`.
392        let indefinite_bytes = [
393            0x7f, 0x64, 0x74, 0x68, 0x69, 0x73, 0x62, 0x49, 0x73, 0x61, 0x41, 0x63, 0x4b, 0x65,
394            0x79, 0xff,
395        ];
396        let mut decoder = Decoder::new(&indefinite_bytes);
397        let member = decoder.str().expect("could not decode str");
398        assert_eq!(member, "thisIsAKey");
399        assert!(matches!(member, std::borrow::Cow::Owned(_)));
400    }
401
402    #[test]
403    fn test_empty_str_works() {
404        let bytes = [0x60];
405        let mut decoder = Decoder::new(&bytes);
406        let member = decoder.str().expect("could not decode empty str");
407        assert_eq!(member, "");
408    }
409
410    #[test]
411    fn test_empty_blob_works() {
412        let bytes = [0x40];
413        let mut decoder = Decoder::new(&bytes);
414        let member = decoder.blob().expect("could not decode an empty blob");
415        assert_eq!(member, aws_smithy_types::Blob::new([]));
416    }
417
418    #[test]
419    fn test_indefinite_length_blob() {
420        // Indefinite length blob containing bytes corresponding to `indefinite-byte, chunked, on each comma`.
421        // https://cbor.nemo157.com/#type=hex&value=bf69626c6f6256616c75655f50696e646566696e6974652d627974652c49206368756e6b65642c4e206f6e206561636820636f6d6d61ffff
422        let indefinite_bytes = [
423            0x5f, 0x50, 0x69, 0x6e, 0x64, 0x65, 0x66, 0x69, 0x6e, 0x69, 0x74, 0x65, 0x2d, 0x62,
424            0x79, 0x74, 0x65, 0x2c, 0x49, 0x20, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x65, 0x64, 0x2c,
425            0x4e, 0x20, 0x6f, 0x6e, 0x20, 0x65, 0x61, 0x63, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d,
426            0x61, 0xff,
427        ];
428        let mut decoder = Decoder::new(&indefinite_bytes);
429        let member = decoder.blob().expect("could not decode blob");
430        assert_eq!(
431            member,
432            aws_smithy_types::Blob::new("indefinite-byte, chunked, on each comma".as_bytes())
433        );
434    }
435
436    #[test]
437    fn test_timestamp_should_be_truncated_to_fit_millisecond_precision() {
438        // Input bytes are derived from the `RpcV2CborDateTimeWithFractionalSeconds` protocol test,
439        // extracting portion representing a timestamp value.
440        let bytes = [
441            0xc1, 0xfb, 0x41, 0xcc, 0x37, 0xdb, 0x38, 0x0f, 0xbe, 0x77, 0xff,
442        ];
443        let mut decoder = Decoder::new(&bytes);
444        let timestamp = decoder.timestamp().expect("should decode timestamp");
445        assert_eq!(
446            timestamp,
447            aws_smithy_types::date_time::DateTime::from_str(
448                "2000-01-02T20:34:56.123Z",
449                Format::DateTime
450            )
451            .unwrap()
452        );
453    }
454
455    #[test]
456    fn big_integer_round_trip_positive() {
457        for value in ["0", "1", "23", "256", "65535", "18446744073709551615"] {
458            let mut encoder = crate::Encoder::new(Vec::new());
459            encoder.big_integer(&value.parse().unwrap());
460            let bytes = encoder.into_writer();
461            let mut decoder = Decoder::new(&bytes);
462            let result = decoder.big_integer().expect("should decode");
463            assert_eq!(result.as_ref(), value, "round-trip failed for {value}");
464        }
465    }
466
467    #[test]
468    fn big_integer_round_trip_negative() {
469        for value in ["-1", "-42", "-256", "-18446744073709551616"] {
470            let mut encoder = crate::Encoder::new(Vec::new());
471            encoder.big_integer(&value.parse().unwrap());
472            let bytes = encoder.into_writer();
473            let mut decoder = Decoder::new(&bytes);
474            let result = decoder.big_integer().expect("should decode");
475            assert_eq!(result.as_ref(), value, "round-trip failed for {value}");
476        }
477    }
478
479    #[test]
480    fn big_integer_round_trip_large() {
481        let large_pos = "123456789012345678901234567890";
482        let large_neg = "-123456789012345678901234567890";
483        for value in [large_pos, large_neg] {
484            let mut encoder = crate::Encoder::new(Vec::new());
485            encoder.big_integer(&value.parse().unwrap());
486            let bytes = encoder.into_writer();
487            let mut decoder = Decoder::new(&bytes);
488            let result = decoder.big_integer().expect("should decode");
489            assert_eq!(result.as_ref(), value, "round-trip failed for {value}");
490        }
491    }
492
493    #[test]
494    fn big_integer_rfc8949_appendix_a_positive() {
495        // RFC 8949 Appendix A: 18446744073709551616 (2^64) = 0xc249010000000000000000
496        let bytes = [
497            0xc2, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
498        ];
499        let mut decoder = Decoder::new(&bytes);
500        let result = decoder.big_integer().expect("should decode");
501        assert_eq!(result.as_ref(), "18446744073709551616");
502    }
503
504    #[test]
505    fn big_integer_rfc8949_appendix_a_negative() {
506        // RFC 8949 Appendix A: -18446744073709551617 = 0xc349010000000000000000
507        let bytes = [
508            0xc3, 0x49, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
509        ];
510        let mut decoder = Decoder::new(&bytes);
511        let result = decoder.big_integer().expect("should decode");
512        assert_eq!(result.as_ref(), "-18446744073709551617");
513    }
514
515    #[test]
516    fn big_integer_from_plain_cbor_unsigned() {
517        let mut enc = minicbor::Encoder::new(Vec::new());
518        enc.u64(9999).unwrap();
519        let bytes = enc.into_writer();
520        let mut decoder = Decoder::new(&bytes);
521        let result = decoder.big_integer().expect("should decode plain integer");
522        assert_eq!(result.as_ref(), "9999");
523    }
524
525    #[test]
526    fn big_integer_from_plain_cbor_negative() {
527        let mut enc = minicbor::Encoder::new(Vec::new());
528        enc.i64(-500).unwrap();
529        let bytes = enc.into_writer();
530        let mut decoder = Decoder::new(&bytes);
531        let result = decoder
532            .big_integer()
533            .expect("should decode negative plain integer");
534        assert_eq!(result.as_ref(), "-500");
535    }
536
537    #[test]
538    fn big_integer_from_plain_cbor_positive_signed() {
539        // A positive value such as +123 is encoded as CBOR major type 0 (unsigned)
540        // per preferred serialization and must decode back to the same value.
541        let mut enc = minicbor::Encoder::new(Vec::new());
542        enc.i64(123).unwrap();
543        let bytes = enc.into_writer();
544        let mut decoder = Decoder::new(&bytes);
545        let result = decoder
546            .big_integer()
547            .expect("should decode positive plain integer");
548        assert_eq!(result.as_ref(), "123");
549    }
550
551    #[test]
552    fn big_integer_tag3_empty_byte_string() {
553        // Tag 3 with empty byte string = -1 - 0 = -1
554        let bytes = [0xc3, 0x40]; // tag 3, empty byte string
555        let mut decoder = Decoder::new(&bytes);
556        let result = decoder.big_integer().expect("should decode");
557        assert_eq!(result.as_ref(), "-1");
558    }
559
560    #[test]
561    fn big_integer_tag2_empty_byte_string() {
562        // Tag 2 with empty byte string = 0
563        let bytes = [0xc2, 0x40]; // tag 2, empty byte string
564        let mut decoder = Decoder::new(&bytes);
565        let result = decoder.big_integer().expect("should decode");
566        assert_eq!(result.as_ref(), "0");
567    }
568
569    #[test]
570    fn big_integer_rejects_invalid_tag() {
571        // Tag 4 (decimal fraction) should be rejected.
572        let bytes = [0xc4, 0x82, 0x21, 0x19, 0x6a, 0xb3];
573        let mut decoder = Decoder::new(&bytes);
574        assert!(decoder.big_integer().is_err());
575    }
576
577    #[test]
578    fn big_integer_decode_major_type_1_exceeding_i64() {
579        // CBOR major type 1 with argument u64::MAX (0x3b + 8 bytes of 0xff).
580        // Value = -1 - u64::MAX = -18446744073709551616.
581        // This exercises the minicbor Type::Int path in the decoder.
582        let bytes = [0x3b, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff];
583        let mut decoder = Decoder::new(&bytes);
584        let result = decoder
585            .big_integer()
586            .expect("should decode major type 1 > i64::MAX");
587        assert_eq!(result.as_ref(), "-18446744073709551616");
588    }
589
590    #[test]
591    fn test_timestamp_integer_epoch_seconds() {
592        // RFC 8949 §3.4.2: tag 1 content MUST be int OR float.
593        // tag(1) + uint(1700000000) = 0xc1 0x1a 0x65 0x53 0xf1 0x00
594        let bytes = [0xc1u8, 0x1a, 0x65, 0x53, 0xf1, 0x00];
595        let mut decoder = Decoder::new(&bytes);
596        let timestamp = decoder
597            .timestamp()
598            .expect("should decode integer timestamp");
599        assert_eq!(timestamp, aws_smithy_types::DateTime::from_secs(1700000000));
600    }
601}