1use std::fmt::Display;
16
17use minicbor::decode;
18
19use crate::cbor;
20
21pub mod lazy;
22
23pub fn decode_break<'d>(d: &mut cbor::Decoder<'d>, len: Option<u64>) -> Result<bool, cbor::decode::Error> {
27 if d.datatype()? == cbor::data::Type::Break {
28 if len.is_some() {
30 return Err(cbor::decode::Error::type_mismatch(cbor::data::Type::Break));
31 }
32
33 d.skip()?;
34
35 return Ok(true);
36 }
37
38 Ok(false)
39}
40
41pub fn tee<'d, A>(
43 d: &mut cbor::Decoder<'d>,
44 decoder: impl FnOnce(&mut cbor::Decoder<'d>) -> Result<A, cbor::decode::Error>,
45) -> Result<(A, &'d [u8]), cbor::decode::Error> {
46 let original_bytes = d.input();
47 let start = d.position();
48 let a = decoder(d)?;
49 let end = d.position();
50 Ok((a, &original_bytes[start..end]))
51}
52
53pub fn heterogeneous_array<'d, A>(
61 d: &mut cbor::Decoder<'d>,
62 elems: impl FnOnce(
63 &mut cbor::Decoder<'d>,
64 &dyn Fn(u64) -> Result<(), cbor::decode::Error>,
65 ) -> Result<A, cbor::decode::Error>,
66) -> Result<A, cbor::decode::Error> {
67 let len = d.array()?;
68
69 match len {
70 None => {
71 let result = elems(d, &|_| Ok(()))?;
72 decode_break(d, len)?;
73 Ok(result)
74 }
75 Some(len) => elems(
76 d,
77 &(move |expected_len| {
78 if len != expected_len {
79 return Err(cbor::decode::Error::message(format!(
80 "CBOR array length mismatch: expected {} got {}",
81 expected_len, len
82 )));
83 }
84
85 Ok(())
86 }),
87 ),
88 }
89}
90
91pub fn check_tagged_array_length(label: usize, actual: Option<u64>, expected: u64) -> Result<(), decode::Error> {
101 if actual != Some(expected) {
102 Err(decode::Error::message(format!("expected array length {expected} for label {label}, got: {actual:?}")))
103 } else {
104 Ok(())
105 }
106}
107
108pub fn heterogeneous_map<K, S>(
134 d: &mut cbor::Decoder<'_>,
135 mut state: S,
136 decode_key: impl Fn(&mut cbor::Decoder<'_>) -> Result<K, cbor::decode::Error>,
137 mut decode_value: impl FnMut(&mut cbor::Decoder<'_>, &mut S, K) -> Result<(), cbor::decode::Error>,
138) -> Result<S, cbor::decode::Error> {
139 let len = d.map()?;
140
141 let mut n = 0;
142 while len.is_none() || Some(n) < len {
143 if decode_break(d, len)? {
144 break;
145 }
146
147 let k = decode_key(d)?;
148 decode_value(d, &mut state, k)?;
149
150 n += 1;
151 }
152
153 Ok(state)
154}
155
156pub fn missing_field<C: ?Sized, A>(field_tag: u8) -> cbor::decode::Error {
159 let msg = format!(
160 "missing <{}> at field .{field_tag} in <{}> CBOR map",
161 std::any::type_name::<A>(),
162 std::any::type_name::<C>(),
163 );
164 cbor::decode::Error::message(msg)
165}
166
167pub fn unexpected_field<C: ?Sized, A>(field_tag: impl Display) -> Result<A, cbor::decode::Error> {
170 Err(cbor::decode::Error::message(format!(
171 "unexpected field .{field_tag} in <{}> CBOR map",
172 std::any::type_name::<C>(),
173 )))
174}
175
176#[cfg(test)]
180mod tests {
181 use std::fmt::Debug;
182
183 use crate::{
184 cbor, from_cbor, from_cbor_no_leftovers, heterogeneous_array, heterogeneous_map, missing_field,
185 tests::{AsDefinite, AsIndefinite, AsMap, foo::Foo},
186 to_cbor, unexpected_field,
187 };
188
189 fn assert_ok<T: Eq + Debug + for<'d> cbor::decode::Decode<'d, ()>>(left: T, bytes: &[u8]) {
190 assert_eq!(Ok(left), from_cbor_no_leftovers::<T>(bytes).map_err(|e| e.to_string()));
191 }
192
193 fn assert_err<T: Debug + for<'d> cbor::decode::Decode<'d, ()>>(msg: &str, bytes: &[u8]) {
194 match from_cbor_no_leftovers::<T>(bytes).map_err(|e| e.to_string()) {
195 Err(e) => assert!(e.contains(msg), "{e}"),
196 Ok(ok) => panic!("expected error but got {:#?}", ok),
197 }
198 }
199
200 const FIXTURE: Foo = Foo { field0: 14, field1: 42 };
201
202 mod heterogeneous_array_tests {
203 use super::*;
204
205 #[test]
206 fn happy_case() {
207 #[derive(Debug, PartialEq, Eq)]
208 struct TestCase<A>(A);
209
210 impl<'d, C> cbor::decode::Decode<'d, C> for TestCase<Foo> {
212 fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
213 heterogeneous_array(d, |d, assert_len| {
214 assert_len(2)?;
215 Ok(TestCase(Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? }))
216 })
217 }
218 }
219
220 assert_ok(TestCase(FIXTURE), &to_cbor(&AsDefinite(&FIXTURE)));
221 assert_ok(TestCase(FIXTURE), &to_cbor(&AsIndefinite(&FIXTURE)));
222 }
223
224 #[test]
225 fn smaller_definite_length() {
226 #[derive(Debug, PartialEq, Eq)]
227 struct TestCase<A>(A);
228
229 impl<'d, C> cbor::decode::Decode<'d, C> for TestCase<Foo> {
231 fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
232 heterogeneous_array(d, |d, assert_len| {
233 assert_len(1)?;
234 Ok(TestCase(Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? }))
235 })
236 }
237 }
238
239 assert_err::<TestCase<Foo>>("array length mismatch", &to_cbor(&AsDefinite(&FIXTURE)));
240 }
241
242 #[test]
243 fn larger_definite_length() {
244 #[derive(Debug, PartialEq, Eq)]
245 struct TestCase<A>(A);
246
247 impl<'d, C> cbor::decode::Decode<'d, C> for TestCase<Foo> {
249 fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
250 heterogeneous_array(d, |d, assert_len| {
251 assert_len(3)?;
252 Ok(TestCase(Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? }))
253 })
254 }
255 }
256
257 assert_err::<TestCase<Foo>>("array length mismatch", &to_cbor(&AsDefinite(&FIXTURE)))
258 }
259
260 #[test]
261 fn incomplete_indefinite() {
262 #[derive(Debug, PartialEq, Eq)]
263 struct TestCase<A>(A);
264
265 impl<C> cbor::encode::Encode<C> for TestCase<&Foo> {
267 fn encode<W: cbor::encode::Write>(
268 &self,
269 e: &mut cbor::Encoder<W>,
270 ctx: &mut C,
271 ) -> Result<(), cbor::encode::Error<W::Error>> {
272 e.begin_array()?;
273 e.encode_with(self.0.field0, ctx)?;
274 e.encode_with(self.0.field1, ctx)?;
275 Ok(())
276 }
277 }
278
279 let bytes = to_cbor(&TestCase(&FIXTURE));
280
281 assert!(from_cbor::<AsDefinite<Foo>>(&bytes).is_none());
282 assert!(from_cbor::<AsIndefinite<Foo>>(&bytes).is_none());
283 }
284 }
285
286 mod heterogeneous_map_tests {
287 use super::*;
288
289 #[derive(Debug, PartialEq, Eq)]
291 struct NoMissingFields<A>(A);
292 impl<'d, C> cbor::decode::Decode<'d, C> for NoMissingFields<Foo> {
293 fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
294 let (field0, field1) = heterogeneous_map(
295 d,
296 (None::<u8>, None::<u8>),
297 |d| d.u8(),
298 |d, state, field| {
299 match field {
300 0 => state.0 = d.decode_with(ctx)?,
301 1 => state.1 = d.decode_with(ctx)?,
302 _ => return unexpected_field::<Foo, _>(field),
303 }
304 Ok(())
305 },
306 )?;
307
308 Ok(NoMissingFields(Foo {
309 field0: field0.ok_or_else(|| missing_field::<Foo, u8>(0))?,
310 field1: field1.ok_or_else(|| missing_field::<Foo, u8>(1))?,
311 }))
312 }
313 }
314
315 #[derive(Debug, PartialEq, Eq)]
317 struct WithDefaultValues<A>(A);
318 impl<'d, C> cbor::decode::Decode<'d, C> for WithDefaultValues<Foo> {
319 fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
320 let (field0, field1) = heterogeneous_map(
321 d,
322 (14_u8, 42_u8),
323 |d| d.u8(),
324 |d, state, field| {
325 match field {
326 0 => state.0 = d.decode_with(ctx)?,
327 1 => state.1 = d.decode_with(ctx)?,
328 _ => return unexpected_field::<Foo, _>(field),
329 }
330 Ok(())
331 },
332 )?;
333
334 Ok(WithDefaultValues(Foo { field0, field1 }))
335 }
336 }
337
338 #[test]
339 fn no_optional_fields_no_missing_fields() {
340 assert_ok(NoMissingFields(FIXTURE), &to_cbor(&AsIndefinite(AsMap(&FIXTURE))));
341
342 assert_ok(NoMissingFields(FIXTURE), &to_cbor(&AsDefinite(AsMap(&FIXTURE))));
343 }
344
345 #[test]
346 fn out_of_order_fields() {
347 #[derive(Debug, PartialEq, Eq)]
348 struct TestCase<A>(A);
349
350 impl<C> cbor::encode::Encode<C> for TestCase<&Foo> {
352 fn encode<W: cbor::encode::Write>(
353 &self,
354 e: &mut cbor::Encoder<W>,
355 ctx: &mut C,
356 ) -> Result<(), cbor::encode::Error<W::Error>> {
357 e.map(2)?;
358 e.encode_with(1_u8, ctx)?;
359 e.encode_with(self.0.field1, ctx)?;
360 e.encode_with(0_u8, ctx)?;
361 e.encode_with(self.0.field0, ctx)?;
362 Ok(())
363 }
364 }
365
366 assert_ok(NoMissingFields(FIXTURE), &to_cbor(&TestCase(&FIXTURE)));
367 }
368
369 #[test]
370 fn optional_fields_no_missing_fields() {
371 assert_ok(WithDefaultValues(FIXTURE), &to_cbor(&AsIndefinite(AsMap(&FIXTURE))));
372
373 assert_ok(WithDefaultValues(FIXTURE), &to_cbor(&AsDefinite(AsMap(&FIXTURE))));
374 }
375
376 #[test]
377 fn one_field_missing() {
378 #[derive(Debug, PartialEq, Eq)]
379 struct TestCase<A>(A);
380
381 impl<C> cbor::encode::Encode<C> for TestCase<AsIndefinite<&Foo>> {
382 fn encode<W: cbor::encode::Write>(
383 &self,
384 e: &mut cbor::Encoder<W>,
385 ctx: &mut C,
386 ) -> Result<(), cbor::encode::Error<W::Error>> {
387 e.map(1)?;
388 e.encode_with(0_u8, ctx)?;
389 e.encode_with(self.0.0.field0, ctx)?;
390 Ok(())
391 }
392 }
393
394 impl<C> cbor::encode::Encode<C> for TestCase<AsDefinite<&Foo>> {
395 fn encode<W: cbor::encode::Write>(
396 &self,
397 e: &mut cbor::Encoder<W>,
398 ctx: &mut C,
399 ) -> Result<(), cbor::encode::Error<W::Error>> {
400 e.begin_map()?;
401 e.encode_with(1_u8, ctx)?;
402 e.encode_with(self.0.0.field1, ctx)?;
403 e.end()?;
404 Ok(())
405 }
406 }
407
408 assert_err::<NoMissingFields<Foo>>("missing <u8> at field .1", &to_cbor(&TestCase(AsIndefinite(&FIXTURE))));
409
410 assert_ok(WithDefaultValues(FIXTURE), &to_cbor(&TestCase(AsIndefinite(&FIXTURE))));
411
412 assert_err::<NoMissingFields<Foo>>("missing <u8> at field .0", &to_cbor(&TestCase(AsDefinite(&FIXTURE))));
413
414 assert_ok(WithDefaultValues(FIXTURE), &to_cbor(&TestCase(AsDefinite(&FIXTURE))));
415 }
416
417 #[test]
418 fn rogue_break() {
419 #[derive(Debug, PartialEq, Eq)]
420 struct TestCase<A>(A);
421
422 impl<C> cbor::encode::Encode<C> for TestCase<&Foo> {
424 fn encode<W: cbor::encode::Write>(
425 &self,
426 e: &mut cbor::Encoder<W>,
427 ctx: &mut C,
428 ) -> Result<(), cbor::encode::Error<W::Error>> {
429 e.map(2)?;
430 e.encode_with(0_u8, ctx)?;
431 e.encode_with(self.0.field0, ctx)?;
432 e.end()?;
433 Ok(())
434 }
435 }
436
437 assert_err::<WithDefaultValues<Foo>>("unexpected type break", &to_cbor(&TestCase(&FIXTURE)));
438 }
439
440 #[test]
441 fn unexpected_field_tag() {
442 #[derive(Debug, PartialEq, Eq)]
443 struct TestCase<A>(A);
444
445 impl<C> cbor::encode::Encode<C> for TestCase<&Foo> {
447 fn encode<W: cbor::encode::Write>(
448 &self,
449 e: &mut cbor::Encoder<W>,
450 ctx: &mut C,
451 ) -> Result<(), cbor::encode::Error<W::Error>> {
452 e.map(2)?;
453 e.encode_with(0_u8, ctx)?;
454 e.encode_with(self.0.field0, ctx)?;
455 e.encode_with(14_u8, ctx)?;
456 e.encode_with(self.0.field0, ctx)?;
457 Ok(())
458 }
459 }
460
461 assert_err::<WithDefaultValues<Foo>>("unexpected field .14", &to_cbor(&TestCase(&FIXTURE)));
462 }
463 }
464}