minicbor_embedded_io/
reader.rs

1use core::marker::PhantomData;
2
3use embedded_io_async::Read;
4use minicbor::{data::Type, decode, Decode, Decoder};
5
6const BREAK: u8 = 0xFF;
7
8#[derive(Debug)]
9pub enum Error {
10    UnexpectedEof,
11    BufferTooSmall,
12    Io(embedded_io_async::ErrorKind),
13    Decode(decode::Error),
14    #[cfg(feature = "alloc")]
15    TryReserveError,
16}
17
18#[cfg(feature = "defmt")]
19impl defmt::Format for Error {
20    fn format(&self, f: defmt::Formatter) {
21        match self {
22            Error::Decode(_) => defmt::write!(f, "Decode"),
23            error => defmt::Format::format(error, f),
24        }
25    }
26}
27
28impl<T: embedded_io_async::Error> From<T> for Error {
29    fn from(value: T) -> Self {
30        Error::Io(value.kind())
31    }
32}
33
34pub trait CborArrayReader<C> {
35    fn read_begin_array(&mut self, len: Option<u64>, ctx: &mut C) -> Result<(), Error>;
36
37    async fn read_array_item<'b, R: Read>(
38        &mut self,
39        reader: &mut CborReader<'b, R>,
40        ctx: &mut C,
41    ) -> Result<(), Error>;
42}
43
44pub trait CborMapReader<C> {
45    fn read_begin_map(&mut self, len: Option<u64>, ctx: &mut C) -> Result<(), Error>;
46
47    async fn read_map_item<'b, R: Read>(
48        &mut self,
49        reader: &mut CborReader<'b, R>,
50        ctx: &mut C,
51    ) -> Result<(), Error>;
52}
53
54#[derive(Debug)]
55pub struct CborReader<'b, R>
56where
57    R: Read,
58{
59    source: R,
60    buf: &'b mut [u8],
61    read: usize,
62    decoded: usize,
63}
64
65#[derive(Debug)]
66struct ArrayHeader(pub Option<u64>);
67
68impl<'b, C> Decode<'b, C> for ArrayHeader {
69    fn decode(d: &mut Decoder<'b>, _ctx: &mut C) -> Result<Self, decode::Error> {
70        let pos = d.position();
71        let ty = d.datatype()?;
72        match ty {
73            Type::Array => {}
74            Type::ArrayIndef => {}
75            ty => {
76                return Err(decode::Error::type_mismatch(ty)
77                    .with_message("expected array")
78                    .at(pos))
79            }
80        }
81
82        let buf = d.input();
83        let available = &buf[pos..];
84        if available.is_empty() {
85            return Err(decode::Error::end_of_input());
86        }
87
88        let head = decode::info::Size::head(available[0])?;
89        if available.len() < head {
90            return Err(decode::Error::end_of_input());
91        }
92
93        // Advance the decoder to the beginning of the array items
94        d.set_position(pos + head);
95
96        match decode::info::Size::tail(&available[..head])? {
97            decode::info::Size::Head => Ok(Self(Some(0))),
98            decode::info::Size::Bytes(_) => Err(decode::Error::type_mismatch(ty)),
99            decode::info::Size::Items(len) => Ok(Self(Some(len))),
100            decode::info::Size::Indef => Ok(Self(None)),
101        }
102    }
103}
104
105#[derive(Debug)]
106struct MapHeader(pub Option<u64>);
107
108impl<'b, C> Decode<'b, C> for MapHeader {
109    fn decode(d: &mut Decoder<'b>, _ctx: &mut C) -> Result<Self, decode::Error> {
110        let pos = d.position();
111        let ty = d.datatype()?;
112        match ty {
113            Type::Map => {}
114            Type::MapIndef => {}
115            ty => {
116                return Err(decode::Error::type_mismatch(ty)
117                    .with_message("expected map")
118                    .at(pos))
119            }
120        }
121
122        let buf = d.input();
123        let available = &buf[pos..];
124        if available.is_empty() {
125            return Err(decode::Error::end_of_input());
126        }
127
128        let head = decode::info::Size::head(available[0])?;
129        if available.len() < head {
130            return Err(decode::Error::end_of_input());
131        }
132
133        // Advance the decoder to the beginning of the array items
134        d.set_position(pos + head);
135
136        match decode::info::Size::tail(&available[..head])? {
137            decode::info::Size::Head => Ok(Self(Some(0))),
138            decode::info::Size::Bytes(_) => Err(decode::Error::type_mismatch(ty)),
139            decode::info::Size::Items(len) => Ok(Self(Some(len))),
140            decode::info::Size::Indef => Ok(Self(None)),
141        }
142    }
143}
144
145impl<'b, R: Read> CborReader<'b, R> {
146    /// Create a new reader
147    ///
148    /// The provided `buf` must be sufficiently large to contain what corresponds
149    /// to one decode item.
150    pub fn new(source: R, buf: &'b mut [u8]) -> Self {
151        Self {
152            source,
153            buf,
154            read: 0,
155            decoded: 0,
156        }
157    }
158
159    /// Read an array using a [`CborArrayReader`].
160    pub async fn array<AR: CborArrayReader<()>>(
161        &mut self,
162        array_reader: &mut AR,
163    ) -> Result<usize, Error>
164    where
165        Self: Sized,
166    {
167        self.array_with(array_reader, &mut ()).await
168    }
169
170    /// Read an array using a [`CborArrayReader`] accepting a user provided decoding context.
171    pub async fn array_with<C, AR: CborArrayReader<C>>(
172        &mut self,
173        array_reader: &mut AR,
174        ctx: &mut C,
175    ) -> Result<usize, Error>
176    where
177        Self: Sized,
178    {
179        let mut count = 0;
180        if let Some(header) = self.read::<ArrayHeader>().await? {
181            let len = header.0;
182            array_reader.read_begin_array(len, ctx)?;
183            if let Some(len) = len {
184                for _ in 0..len {
185                    array_reader.read_array_item(self, ctx).await?;
186                }
187                count = len as usize;
188            } else {
189                while self.peek().await?.ok_or(Error::UnexpectedEof)? != BREAK {
190                    array_reader.read_array_item(self, ctx).await?;
191                    count += 1;
192                }
193            }
194        }
195
196        Ok(count)
197    }
198
199    /// Read a map using a [`CborMapReader`].
200    pub async fn map<MR: CborMapReader<()>>(&mut self, map_reader: &mut MR) -> Result<usize, Error>
201    where
202        Self: Sized,
203    {
204        self.map_with(map_reader, &mut ()).await
205    }
206
207    /// Read a map using a [`CborMapReader`] accepting a user provided decoding context.
208    pub async fn map_with<C, MR: CborMapReader<C>>(
209        &mut self,
210        map_reader: &mut MR,
211        ctx: &mut C,
212    ) -> Result<usize, Error>
213    where
214        Self: Sized,
215    {
216        let mut count = 0;
217        if let Some(header) = self.read::<MapHeader>().await? {
218            let len = header.0;
219            map_reader.read_begin_map(len, ctx)?;
220            if let Some(len) = len {
221                for _ in 0..len {
222                    map_reader.read_map_item(self, ctx).await?;
223                }
224                count = len as usize;
225            } else {
226                while self.peek().await?.ok_or(Error::UnexpectedEof)? != BREAK {
227                    map_reader.read_map_item(self, ctx).await?;
228                    count += 1;
229                }
230            }
231        }
232
233        Ok(count)
234    }
235
236    /// Read the next CBOR value and decode it
237    pub async fn read<T>(&mut self) -> Result<Option<T>, Error>
238    where
239        for<'a> T: Decode<'a, ()>,
240    {
241        self.read_with(&mut ()).await
242    }
243
244    /// Like [`CborReader::read`] but accepting a user provided decoding context.
245    pub async fn read_with<C, T>(&mut self, ctx: &mut C) -> Result<Option<T>, Error>
246    where
247        for<'a> T: Decode<'a, C>,
248    {
249        loop {
250            if self.decoded == 0 && self.read_to_buf().await? == 0 {
251                return Ok(None);
252            }
253
254            // Read an item from the buffer
255            let bytes = &self.buf[self.decoded..self.read];
256
257            #[cfg(feature = "defmt")]
258            defmt::trace!("Decoder item bytes: {:02x}", bytes);
259
260            let mut decoder = Decoder::new(bytes);
261            let decoded: Option<T> = Self::try_decode_with(&mut decoder, ctx)?;
262            if decoded.is_some() {
263                self.decoded += decoder.position();
264                return Ok(decoded);
265            } else if self.decoded == 0 && self.read == self.buf.len() {
266                return Err(Error::BufferTooSmall);
267            }
268
269            // Remove the decoded values from the buffer by moving the
270            // remaining, unused bytes in the buffer to the beginning
271            self.buf.copy_within(self.decoded..self.read, 0);
272            self.read -= self.decoded;
273            self.decoded = 0;
274        }
275    }
276
277    /// Peek the next byte in the buffer
278    async fn peek(&mut self) -> Result<Option<u8>, Error> {
279        if self.decoded == 0 && self.read_to_buf().await? == 0 {
280            return Ok(None);
281        }
282
283        Ok(Some(self.buf[self.decoded]))
284    }
285
286    async fn read_to_buf(&mut self) -> Result<usize, Error> {
287        let len = self.source.read(&mut self.buf[self.read..]).await?;
288        if len == 0 {
289            return if self.read == 0 {
290                Ok(0)
291            } else {
292                Err(Error::UnexpectedEof)
293            };
294        }
295
296        self.read += len;
297        Ok(len)
298    }
299
300    /// Try and decode an item from the decoder.
301    /// Ignore end-of-input error as that, for now, signifies that we need to read more bytes
302    /// from the underlying reader.
303    fn try_decode_with<'a, C, T: Decode<'a, C>>(
304        decoder: &mut Decoder<'a>,
305        ctx: &mut C,
306    ) -> Result<Option<T>, Error> {
307        match decoder.decode_with(ctx) {
308            Ok(decoded) => Ok(Some(decoded)),
309            Err(e) if e.is_end_of_input() => Ok(None),
310            Err(e) => Err(Error::Decode(e)),
311        }
312    }
313}
314
315#[cfg(all(feature = "alloc", not(feature = "allocator_api")))]
316impl<T> CborArrayReader<()> for alloc::vec::Vec<T>
317where
318    for<'b> T: Decode<'b, ()>,
319{
320    fn read_begin_array(&mut self, len: Option<u64>, _ctx: &mut ()) -> Result<(), Error> {
321        if let Some(len) = len {
322            self.try_reserve_exact(len as usize)
323                .map_err(|_| Error::TryReserveError)?;
324        }
325
326        Ok(())
327    }
328
329    async fn read_array_item<'b, R: Read>(
330        &mut self,
331        reader: &mut CborReader<'b, R>,
332        _ctx: &mut (),
333    ) -> Result<(), Error> {
334        if let Some(item) = reader.read::<T>().await? {
335            self.try_reserve(1).map_err(|_| Error::TryReserveError)?;
336            self.push(item);
337        }
338
339        Ok(())
340    }
341}
342
343#[cfg(feature = "allocator_api")]
344impl<T, A: core::alloc::Allocator> CborArrayReader<()> for alloc::vec::Vec<T, A>
345where
346    for<'b> T: Decode<'b, ()>,
347{
348    fn read_begin_array(&mut self, len: Option<u64>, _ctx: &mut ()) -> Result<(), Error> {
349        if let Some(len) = len {
350            self.try_reserve_exact(len as usize)
351                .map_err(|_| Error::TryReserveError)?;
352        }
353
354        Ok(())
355    }
356
357    async fn read_array_item<'b, R: Read>(
358        &mut self,
359        reader: &mut CborReader<'b, R>,
360        _ctx: &mut (),
361    ) -> Result<(), Error> {
362        if let Some(item) = reader.read::<T>().await? {
363            self.try_reserve(1).map_err(|_| Error::TryReserveError)?;
364            self.push(item);
365        }
366
367        Ok(())
368    }
369}
370
371pub struct MapEntryReader<T: for<'b> MapEntryDecode<'b>> {
372    entry_decode: PhantomData<T>,
373}
374
375impl<T: for<'b> MapEntryDecode<'b>> MapEntryReader<T> {
376    pub fn new() -> Self {
377        Self {
378            entry_decode: PhantomData,
379        }
380    }
381}
382
383impl<T: for<'b> MapEntryDecode<'b>> CborMapReader<T> for MapEntryReader<T> {
384    fn read_begin_map(&mut self, _len: Option<u64>, _ctx: &mut T) -> Result<(), Error> {
385        Ok(())
386    }
387
388    async fn read_map_item<'b, R: Read>(
389        &mut self,
390        reader: &mut CborReader<'b, R>,
391        ctx: &mut T,
392    ) -> Result<(), Error> {
393        reader.read_with::<T, Self>(ctx).await?;
394        Ok(())
395    }
396}
397
398impl<'d, T: for<'b> MapEntryDecode<'b>> Decode<'d, T> for MapEntryReader<T> {
399    fn decode(d: &mut Decoder<'d>, ctx: &mut T) -> Result<Self, decode::Error> {
400        T::decode_entry(ctx, d)?;
401        Ok(Self::new())
402    }
403}
404
405pub trait MapEntryDecode<'b> {
406    fn decode_entry(&mut self, d: &mut Decoder<'b>) -> Result<(), decode::Error>;
407}
408
409#[cfg(test)]
410mod tests {
411    #[cfg(feature = "alloc")]
412    use core::iter::repeat;
413
414    #[cfg(feature = "alloc")]
415    use embassy_sync::{
416        blocking_mutex::raw::CriticalSectionRawMutex,
417        pipe::{Pipe, Writer},
418    };
419
420    use crate::reader::CborArrayReader;
421
422    use super::*;
423
424    #[test]
425    fn can_decode_small_array_header() {
426        // Given
427        let cbor: [u8; 5] = [0xf4, 0x83, 0x01, 0x02, 0x03];
428        let mut d = Decoder::new(&cbor);
429        d.set_position(1); // Skip first byte in buffer
430
431        // When
432        let header = ArrayHeader::decode(&mut d, &mut ()).unwrap();
433
434        // Then
435        assert_eq!(Some(3), header.0);
436        assert_eq!(2, d.position());
437    }
438
439    #[test]
440    fn can_decode_large_array_header() {
441        // Given
442        let cbor: [u8; 28] = [
443            0xf4, 0x98, 0x18, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
444            0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x18,
445        ];
446        let mut d = Decoder::new(&cbor);
447        d.set_position(1); // Skip first byte in buffer
448
449        // When
450        let header = ArrayHeader::decode(&mut d, &mut ()).unwrap();
451
452        // Then
453        assert_eq!(Some(24), header.0);
454        assert_eq!(3, d.position());
455    }
456
457    #[tokio::test]
458    async fn can_read_small_array_manually() {
459        let mut buf = [0; 16];
460        let cbor: [u8; 5] = [0xf4, 0x83, 0x01, 0x02, 0x03];
461        let mut reader = CborReader::new(cbor.as_slice(), &mut buf);
462        assert_eq!(false, reader.read::<bool>().await.unwrap().unwrap()); // Something before the array
463        assert_eq!(
464            3,
465            reader
466                .read::<ArrayHeader>()
467                .await
468                .unwrap()
469                .unwrap()
470                .0
471                .unwrap()
472        );
473
474        assert_eq!(1, reader.read::<u8>().await.unwrap().unwrap());
475        assert_eq!(2, reader.read::<u8>().await.unwrap().unwrap());
476        assert_eq!(3, reader.read::<u8>().await.unwrap().unwrap());
477        assert!(reader.read::<u8>().await.unwrap().is_none());
478    }
479
480    #[tokio::test]
481    async fn can_read_large_array_manually() {
482        let mut buf = [0; 16];
483        let cbor: [u8; 28] = [
484            0xf4, 0x98, 0x18, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b,
485            0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x18,
486        ];
487        let mut reader = CborReader::new(cbor.as_slice(), &mut buf);
488        assert_eq!(false, reader.read::<bool>().await.unwrap().unwrap()); // Something before the array
489        assert_eq!(
490            24,
491            reader
492                .read::<ArrayHeader>()
493                .await
494                .unwrap()
495                .unwrap()
496                .0
497                .unwrap()
498        );
499
500        for i in 1..=24 {
501            assert_eq!(i, reader.read::<u8>().await.unwrap().unwrap());
502        }
503        assert!(reader.read::<u8>().await.unwrap().is_none());
504    }
505
506    #[cfg(feature = "alloc")]
507    #[tokio::test]
508    async fn can_read_with_vec() {
509        let mut buf = [0; 16];
510        let cbor: [u8; 4] = [0x83, 0x01, 0x02, 0x03];
511        let mut reader = CborReader::new(cbor.as_slice(), &mut buf);
512
513        let mut vec = Vec::new();
514        reader.array(&mut vec).await.unwrap();
515
516        assert_eq!(&[1, 2, 3], vec.as_slice());
517    }
518
519    struct TestArrayReader;
520
521    impl CborArrayReader<Vec<u8>> for TestArrayReader {
522        fn read_begin_array(&mut self, len: Option<u64>, ctx: &mut Vec<u8>) -> Result<(), Error> {
523            if let Some(len) = len {
524                ctx.reserve_exact(len as usize);
525            }
526
527            Ok(())
528        }
529
530        async fn read_array_item<'b, R: Read>(
531            &mut self,
532            reader: &mut CborReader<'b, R>,
533            ctx: &mut Vec<u8>,
534        ) -> Result<(), Error> {
535            if let Some(item) = reader.read::<u8>().await? {
536                ctx.push(item);
537            }
538
539            Ok(())
540        }
541    }
542
543    #[tokio::test]
544    async fn can_read_fixed_array() {
545        let mut buf = [0; 16];
546        let cbor: [u8; 4] = [0x83, 0x01, 0x02, 0x03];
547        let mut reader = CborReader::new(cbor.as_slice(), &mut buf);
548
549        let mut array_reader = TestArrayReader;
550        let mut ctx = Vec::new();
551        reader
552            .array_with(&mut array_reader, &mut ctx)
553            .await
554            .unwrap();
555
556        assert_eq!(&[1, 2, 3], ctx.as_slice());
557    }
558
559    #[tokio::test]
560    async fn can_read_inf_array() {
561        let mut buf = [0; 16];
562        let cbor: [u8; 5] = [0x9F, 0x01, 0x02, 0x03, 0xFF];
563        let mut reader = CborReader::new(cbor.as_slice(), &mut buf);
564
565        let mut array_reader = TestArrayReader;
566        let mut ctx = Vec::new();
567        reader
568            .array_with(&mut array_reader, &mut ctx)
569            .await
570            .unwrap();
571
572        assert_eq!(&[1, 2, 3], ctx.as_slice());
573    }
574
575    #[cfg(feature = "alloc")]
576    #[tokio::test]
577    async fn can_read_fixed_array_fuzz() {
578        can_read_fixed_array_fuzz_case(1).await;
579        can_read_fixed_array_fuzz_case(2).await;
580        can_read_fixed_array_fuzz_case(3).await;
581        can_read_fixed_array_fuzz_case(4).await;
582        can_read_fixed_array_fuzz_case(5).await;
583        can_read_fixed_array_fuzz_case(6).await;
584        can_read_fixed_array_fuzz_case(7).await;
585        can_read_fixed_array_fuzz_case(8).await;
586        can_read_fixed_array_fuzz_case(9).await;
587        can_read_fixed_array_fuzz_case(10).await;
588    }
589
590    #[cfg(feature = "alloc")]
591    async fn can_read_fixed_array_fuzz_case(chunk_size: usize) {
592        use embedded_io_async::Write;
593
594        // Given
595        const ITEM: &str = "wmbus-XXXXXXXXXXXXXXXX";
596        const LEN: usize = 950;
597        let strings: Vec<&str> = repeat(ITEM).take(LEN).collect();
598        let cbor = minicbor::to_vec(strings.as_slice()).unwrap();
599
600        static mut PIPE: Pipe<CriticalSectionRawMutex, 20> = Pipe::new();
601        let pipe = unsafe { &mut PIPE };
602        let (reader, writer) = pipe.split();
603
604        // When
605        let deserialize = body_reader(reader);
606        let write = tokio::task::spawn(ingest(writer, cbor, chunk_size));
607
608        let (deserialized, _) = tokio::join!(deserialize, write);
609
610        // Then
611        assert_eq!(LEN, deserialized.len());
612        for item in deserialized {
613            assert_eq!(ITEM, &item.0);
614        }
615
616        async fn body_reader(reader: impl Read) -> Vec<ArrayItem> {
617            let mut cbor_item_buf = [0; 1 + 22]; // text(22) "wmbus-XXXXXXXXXXXXXXXX"
618            let mut reader = CborReader::new(reader, &mut cbor_item_buf);
619            let mut entries = Vec::new();
620            reader.array(&mut entries).await.unwrap();
621            entries
622        }
623
624        async fn ingest(
625            mut writer: Writer<'_, CriticalSectionRawMutex, 20>,
626            cbor: Vec<u8>,
627            chunk_size: usize,
628        ) {
629            for chunk in cbor.chunks(chunk_size) {
630                writer.write_all(chunk).await.unwrap();
631            }
632        }
633
634        struct ArrayItem(String);
635
636        impl<'b> Decode<'b, ()> for ArrayItem {
637            fn decode(d: &mut Decoder<'b>, _ctx: &mut ()) -> Result<Self, decode::Error> {
638                let text = d.str()?;
639                assert_eq!(ITEM, text);
640                Ok(ArrayItem(text.to_string()))
641            }
642        }
643    }
644}