Skip to main content

amaru_minicbor_extra/
decode.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::fmt::Display;
16
17use minicbor::decode;
18
19use crate::cbor;
20
21pub mod lazy;
22
23// Misc
24// ----------------------------------------------------------------------------
25
26pub fn decode_break<'d>(d: &mut cbor::Decoder<'d>, len: Option<u64>) -> Result<bool, cbor::decode::Error> {
27    if d.datatype()? == cbor::data::Type::Break {
28        // NOTE: If we encounter a rogue Break while decoding a definite map, that's an error.
29        if len.is_some() {
30            return Err(cbor::decode::Error::type_mismatch(cbor::data::Type::Break));
31        }
32
33        d.skip()?;
34
35        return Ok(true);
36    }
37
38    Ok(false)
39}
40
41/// Decode a chunk, but retain a reference to the decoded bytes.
42pub fn tee<'d, A>(
43    d: &mut cbor::Decoder<'d>,
44    decoder: impl FnOnce(&mut cbor::Decoder<'d>) -> Result<A, cbor::decode::Error>,
45) -> Result<(A, &'d [u8]), cbor::decode::Error> {
46    let original_bytes = d.input();
47    let start = d.position();
48    let a = decoder(d)?;
49    let end = d.position();
50    Ok((a, &original_bytes[start..end]))
51}
52
53// Array
54// ----------------------------------------------------------------------------
55
56/// Decode any heterogeneous CBOR array, irrespective of whether they're indefinite or definite.
57///
58/// FIXME: Allow callers to check that the length is not static, but simply matches what is
59/// advertised; e.g. using `Option<u64>` as a callback.
60pub fn heterogeneous_array<'d, A>(
61    d: &mut cbor::Decoder<'d>,
62    elems: impl FnOnce(
63        &mut cbor::Decoder<'d>,
64        &dyn Fn(u64) -> Result<(), cbor::decode::Error>,
65    ) -> Result<A, cbor::decode::Error>,
66) -> Result<A, cbor::decode::Error> {
67    let len = d.array()?;
68
69    match len {
70        None => {
71            let result = elems(d, &|_| Ok(()))?;
72            decode_break(d, len)?;
73            Ok(result)
74        }
75        Some(len) => elems(
76            d,
77            &(move |expected_len| {
78                if len != expected_len {
79                    return Err(cbor::decode::Error::message(format!(
80                        "CBOR array length mismatch: expected {} got {}",
81                        expected_len, len
82                    )));
83                }
84
85                Ok(())
86            }),
87        ),
88    }
89}
90
91/// This function checks the size of an array containing a tagged value.
92/// The `label` parameter is used to identify which variant is being checked.
93///
94/// FIXME: suspicious check_tagged_array_length
95///
96/// This function is a code smell and seems to indicate that we are manually decoding def
97/// array somewhere, instead of using the heterogeneous_array above to also deal indef arrays.
98/// There might be a good reason why this function exists; I haven't checked, but leaving a note
99/// for later to check.
100pub fn check_tagged_array_length(label: usize, actual: Option<u64>, expected: u64) -> Result<(), decode::Error> {
101    if actual != Some(expected) {
102        Err(decode::Error::message(format!("expected array length {expected} for label {label}, got: {actual:?}")))
103    } else {
104        Ok(())
105    }
106}
107
108// Map
109// ----------------------------------------------------------------------------
110
111/// Decode any heterogeneous CBOR map, irrespective of whether they're indefinite or definite.
112///
113/// A good choice for `S` is generally to pick a tuple of `Option` for each field item
114/// that needs decoding. For example:
115///
116/// ```rs
117/// let (address, value, datum, script) = decode_map(
118///     d,
119///     (None, None, MemoizedDatum::None, None),
120///     |d| d.u8(),
121///     |d, state, field| {
122///         match field {
123///             0 => state.0 = Some(decode_address(d.bytes()?),
124///             1 => state.1 = Some(d.decode()?),
125///             2 => state.2 = decode_datum()?,
126///             3 => state.3 = decode_reference_script()?,
127///             _ => return unexpected_field::<Output, _>(field),
128///         }
129///         Ok(())
130///     },
131/// )?;
132/// ```
133pub fn heterogeneous_map<K, S>(
134    d: &mut cbor::Decoder<'_>,
135    mut state: S,
136    decode_key: impl Fn(&mut cbor::Decoder<'_>) -> Result<K, cbor::decode::Error>,
137    mut decode_value: impl FnMut(&mut cbor::Decoder<'_>, &mut S, K) -> Result<(), cbor::decode::Error>,
138) -> Result<S, cbor::decode::Error> {
139    let len = d.map()?;
140
141    let mut n = 0;
142    while len.is_none() || Some(n) < len {
143        if decode_break(d, len)? {
144            break;
145        }
146
147        let k = decode_key(d)?;
148        decode_value(d, &mut state, k)?;
149
150        n += 1;
151    }
152
153    Ok(state)
154}
155
156/// Yield a `PartialDecoder` that fails with a comprehensible error message when an expected field
157/// is missing from the map.
158pub fn missing_field<C: ?Sized, A>(field_tag: u8) -> cbor::decode::Error {
159    let msg = format!(
160        "missing <{}> at field .{field_tag} in <{}> CBOR map",
161        std::any::type_name::<A>(),
162        std::any::type_name::<C>(),
163    );
164    cbor::decode::Error::message(msg)
165}
166
167/// Yield a `Result<_, decode::Error>` that always fails with a comprehensible error message when a
168/// map key is unexpected.
169pub fn unexpected_field<C: ?Sized, A>(field_tag: impl Display) -> Result<A, cbor::decode::Error> {
170    Err(cbor::decode::Error::message(format!(
171        "unexpected field .{field_tag} in <{}> CBOR map",
172        std::any::type_name::<C>(),
173    )))
174}
175
176// Tests
177// ----------------------------------------------------------------------------
178
179#[cfg(test)]
180mod tests {
181    use std::fmt::Debug;
182
183    use crate::{
184        cbor, from_cbor, from_cbor_no_leftovers, heterogeneous_array, heterogeneous_map, missing_field,
185        tests::{AsDefinite, AsIndefinite, AsMap, foo::Foo},
186        to_cbor, unexpected_field,
187    };
188
189    fn assert_ok<T: Eq + Debug + for<'d> cbor::decode::Decode<'d, ()>>(left: T, bytes: &[u8]) {
190        assert_eq!(Ok(left), from_cbor_no_leftovers::<T>(bytes).map_err(|e| e.to_string()));
191    }
192
193    fn assert_err<T: Debug + for<'d> cbor::decode::Decode<'d, ()>>(msg: &str, bytes: &[u8]) {
194        match from_cbor_no_leftovers::<T>(bytes).map_err(|e| e.to_string()) {
195            Err(e) => assert!(e.contains(msg), "{e}"),
196            Ok(ok) => panic!("expected error but got {:#?}", ok),
197        }
198    }
199
200    const FIXTURE: Foo = Foo { field0: 14, field1: 42 };
201
202    mod heterogeneous_array_tests {
203        use super::*;
204
205        #[test]
206        fn happy_case() {
207            #[derive(Debug, PartialEq, Eq)]
208            struct TestCase<A>(A);
209
210            // A flexible decoder that can ingest both definite and indefinite arrays.
211            impl<'d, C> cbor::decode::Decode<'d, C> for TestCase<Foo> {
212                fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
213                    heterogeneous_array(d, |d, assert_len| {
214                        assert_len(2)?;
215                        Ok(TestCase(Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? }))
216                    })
217                }
218            }
219
220            assert_ok(TestCase(FIXTURE), &to_cbor(&AsDefinite(&FIXTURE)));
221            assert_ok(TestCase(FIXTURE), &to_cbor(&AsIndefinite(&FIXTURE)));
222        }
223
224        #[test]
225        fn smaller_definite_length() {
226            #[derive(Debug, PartialEq, Eq)]
227            struct TestCase<A>(A);
228
229            // A decoder which expects less elements than actually supplied.
230            impl<'d, C> cbor::decode::Decode<'d, C> for TestCase<Foo> {
231                fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
232                    heterogeneous_array(d, |d, assert_len| {
233                        assert_len(1)?;
234                        Ok(TestCase(Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? }))
235                    })
236                }
237            }
238
239            assert_err::<TestCase<Foo>>("array length mismatch", &to_cbor(&AsDefinite(&FIXTURE)));
240        }
241
242        #[test]
243        fn larger_definite_length() {
244            #[derive(Debug, PartialEq, Eq)]
245            struct TestCase<A>(A);
246
247            // A decoder which expects more elements than actually supplied.
248            impl<'d, C> cbor::decode::Decode<'d, C> for TestCase<Foo> {
249                fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
250                    heterogeneous_array(d, |d, assert_len| {
251                        assert_len(3)?;
252                        Ok(TestCase(Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? }))
253                    })
254                }
255            }
256
257            assert_err::<TestCase<Foo>>("array length mismatch", &to_cbor(&AsDefinite(&FIXTURE)))
258        }
259
260        #[test]
261        fn incomplete_indefinite() {
262            #[derive(Debug, PartialEq, Eq)]
263            struct TestCase<A>(A);
264
265            // An incomplete encoder, which skips the final break on indefinite arrays.
266            impl<C> cbor::encode::Encode<C> for TestCase<&Foo> {
267                fn encode<W: cbor::encode::Write>(
268                    &self,
269                    e: &mut cbor::Encoder<W>,
270                    ctx: &mut C,
271                ) -> Result<(), cbor::encode::Error<W::Error>> {
272                    e.begin_array()?;
273                    e.encode_with(self.0.field0, ctx)?;
274                    e.encode_with(self.0.field1, ctx)?;
275                    Ok(())
276                }
277            }
278
279            let bytes = to_cbor(&TestCase(&FIXTURE));
280
281            assert!(from_cbor::<AsDefinite<Foo>>(&bytes).is_none());
282            assert!(from_cbor::<AsIndefinite<Foo>>(&bytes).is_none());
283        }
284    }
285
286    mod heterogeneous_map_tests {
287        use super::*;
288
289        /// A decoder for `Foo` that interpret it as a map, and fails in case of a missing field.
290        #[derive(Debug, PartialEq, Eq)]
291        struct NoMissingFields<A>(A);
292        impl<'d, C> cbor::decode::Decode<'d, C> for NoMissingFields<Foo> {
293            fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
294                let (field0, field1) = heterogeneous_map(
295                    d,
296                    (None::<u8>, None::<u8>),
297                    |d| d.u8(),
298                    |d, state, field| {
299                        match field {
300                            0 => state.0 = d.decode_with(ctx)?,
301                            1 => state.1 = d.decode_with(ctx)?,
302                            _ => return unexpected_field::<Foo, _>(field),
303                        }
304                        Ok(())
305                    },
306                )?;
307
308                Ok(NoMissingFields(Foo {
309                    field0: field0.ok_or_else(|| missing_field::<Foo, u8>(0))?,
310                    field1: field1.ok_or_else(|| missing_field::<Foo, u8>(1))?,
311                }))
312            }
313        }
314
315        /// A decoder for `Foo` that interpret it as a map, but allows fields to be missing.
316        #[derive(Debug, PartialEq, Eq)]
317        struct WithDefaultValues<A>(A);
318        impl<'d, C> cbor::decode::Decode<'d, C> for WithDefaultValues<Foo> {
319            fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
320                let (field0, field1) = heterogeneous_map(
321                    d,
322                    (14_u8, 42_u8),
323                    |d| d.u8(),
324                    |d, state, field| {
325                        match field {
326                            0 => state.0 = d.decode_with(ctx)?,
327                            1 => state.1 = d.decode_with(ctx)?,
328                            _ => return unexpected_field::<Foo, _>(field),
329                        }
330                        Ok(())
331                    },
332                )?;
333
334                Ok(WithDefaultValues(Foo { field0, field1 }))
335            }
336        }
337
338        #[test]
339        fn no_optional_fields_no_missing_fields() {
340            assert_ok(NoMissingFields(FIXTURE), &to_cbor(&AsIndefinite(AsMap(&FIXTURE))));
341
342            assert_ok(NoMissingFields(FIXTURE), &to_cbor(&AsDefinite(AsMap(&FIXTURE))));
343        }
344
345        #[test]
346        fn out_of_order_fields() {
347            #[derive(Debug, PartialEq, Eq)]
348            struct TestCase<A>(A);
349
350            // An invalid encoder, which adds an extra break in an definite map.
351            impl<C> cbor::encode::Encode<C> for TestCase<&Foo> {
352                fn encode<W: cbor::encode::Write>(
353                    &self,
354                    e: &mut cbor::Encoder<W>,
355                    ctx: &mut C,
356                ) -> Result<(), cbor::encode::Error<W::Error>> {
357                    e.map(2)?;
358                    e.encode_with(1_u8, ctx)?;
359                    e.encode_with(self.0.field1, ctx)?;
360                    e.encode_with(0_u8, ctx)?;
361                    e.encode_with(self.0.field0, ctx)?;
362                    Ok(())
363                }
364            }
365
366            assert_ok(NoMissingFields(FIXTURE), &to_cbor(&TestCase(&FIXTURE)));
367        }
368
369        #[test]
370        fn optional_fields_no_missing_fields() {
371            assert_ok(WithDefaultValues(FIXTURE), &to_cbor(&AsIndefinite(AsMap(&FIXTURE))));
372
373            assert_ok(WithDefaultValues(FIXTURE), &to_cbor(&AsDefinite(AsMap(&FIXTURE))));
374        }
375
376        #[test]
377        fn one_field_missing() {
378            #[derive(Debug, PartialEq, Eq)]
379            struct TestCase<A>(A);
380
381            impl<C> cbor::encode::Encode<C> for TestCase<AsIndefinite<&Foo>> {
382                fn encode<W: cbor::encode::Write>(
383                    &self,
384                    e: &mut cbor::Encoder<W>,
385                    ctx: &mut C,
386                ) -> Result<(), cbor::encode::Error<W::Error>> {
387                    e.map(1)?;
388                    e.encode_with(0_u8, ctx)?;
389                    e.encode_with(self.0.0.field0, ctx)?;
390                    Ok(())
391                }
392            }
393
394            impl<C> cbor::encode::Encode<C> for TestCase<AsDefinite<&Foo>> {
395                fn encode<W: cbor::encode::Write>(
396                    &self,
397                    e: &mut cbor::Encoder<W>,
398                    ctx: &mut C,
399                ) -> Result<(), cbor::encode::Error<W::Error>> {
400                    e.begin_map()?;
401                    e.encode_with(1_u8, ctx)?;
402                    e.encode_with(self.0.0.field1, ctx)?;
403                    e.end()?;
404                    Ok(())
405                }
406            }
407
408            assert_err::<NoMissingFields<Foo>>("missing <u8> at field .1", &to_cbor(&TestCase(AsIndefinite(&FIXTURE))));
409
410            assert_ok(WithDefaultValues(FIXTURE), &to_cbor(&TestCase(AsIndefinite(&FIXTURE))));
411
412            assert_err::<NoMissingFields<Foo>>("missing <u8> at field .0", &to_cbor(&TestCase(AsDefinite(&FIXTURE))));
413
414            assert_ok(WithDefaultValues(FIXTURE), &to_cbor(&TestCase(AsDefinite(&FIXTURE))));
415        }
416
417        #[test]
418        fn rogue_break() {
419            #[derive(Debug, PartialEq, Eq)]
420            struct TestCase<A>(A);
421
422            // An invalid encoder, which adds an extra break in an definite map.
423            impl<C> cbor::encode::Encode<C> for TestCase<&Foo> {
424                fn encode<W: cbor::encode::Write>(
425                    &self,
426                    e: &mut cbor::Encoder<W>,
427                    ctx: &mut C,
428                ) -> Result<(), cbor::encode::Error<W::Error>> {
429                    e.map(2)?;
430                    e.encode_with(0_u8, ctx)?;
431                    e.encode_with(self.0.field0, ctx)?;
432                    e.end()?;
433                    Ok(())
434                }
435            }
436
437            assert_err::<WithDefaultValues<Foo>>("unexpected type break", &to_cbor(&TestCase(&FIXTURE)));
438        }
439
440        #[test]
441        fn unexpected_field_tag() {
442            #[derive(Debug, PartialEq, Eq)]
443            struct TestCase<A>(A);
444
445            // An invalid encoder, which adds an extra break in an definite map.
446            impl<C> cbor::encode::Encode<C> for TestCase<&Foo> {
447                fn encode<W: cbor::encode::Write>(
448                    &self,
449                    e: &mut cbor::Encoder<W>,
450                    ctx: &mut C,
451                ) -> Result<(), cbor::encode::Error<W::Error>> {
452                    e.map(2)?;
453                    e.encode_with(0_u8, ctx)?;
454                    e.encode_with(self.0.field0, ctx)?;
455                    e.encode_with(14_u8, ctx)?;
456                    e.encode_with(self.0.field0, ctx)?;
457                    Ok(())
458                }
459            }
460
461            assert_err::<WithDefaultValues<Foo>>("unexpected field .14", &to_cbor(&TestCase(&FIXTURE)));
462        }
463    }
464}