Skip to main content

amaru_minicbor_extra/
lib.rs

1// Copyright 2025 PRAGMA
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::cell::RefCell;
16
17pub use decode::*;
18use minicbor::{self as cbor, data::Tag};
19mod decode;
20
21/// The IANA Tag 258: <https://github.com/input-output-hk/cbor-sets-spec/blob/master/CBOR_SETS.md>
22pub static TAG_SET_258: Tag = Tag::new(258);
23
24/// The IANA Tag 259: <https://github.com/shanewholloway/js-cbor-codec/blob/master/docs/CBOR-259-spec--explicit-maps.md>
25pub static TAG_MAP_259: Tag = Tag::new(259);
26
27/// Encode any serialisable value `T` into bytes.
28pub fn to_cbor<T: cbor::Encode<()>>(value: &T) -> Vec<u8> {
29    thread_local! {
30        static BUFFER: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
31    }
32    BUFFER.with_borrow_mut(|buffer| {
33        #[expect(clippy::expect_used)]
34        cbor::encode(value, &mut *buffer).expect("serialization should not fail");
35        let ret = buffer.as_slice().to_vec();
36        buffer.clear();
37        ret
38    })
39}
40
41/// Decode raw bytes into a structured type `T`, assuming no context.
42pub fn from_cbor<T: for<'d> cbor::Decode<'d, ()>>(bytes: &[u8]) -> Option<T> {
43    cbor::decode(bytes).ok()
44}
45
46/// Decode a CBOR input, ensuring that there are no bytes leftovers once decoded. This is handy to
47/// test standalone decoders and ensures that they entirely consume their inputs.
48pub fn from_cbor_no_leftovers<T: for<'d> cbor::Decode<'d, ()>>(bytes: &[u8]) -> Result<T, cbor::decode::Error> {
49    cbor::decode(bytes).map(|NoLeftovers(inner)| inner)
50}
51
52/// Decode a CBOR input, ensuring that there are no bytes leftovers once decoded. This is handy to
53/// test standalone decoders and ensures that they entirely consume their inputs.
54pub fn from_cbor_no_leftovers_with<C, T: for<'d> cbor::Decode<'d, C>>(
55    bytes: &[u8],
56    ctx: &mut C,
57) -> Result<T, cbor::decode::Error> {
58    cbor::decode_with(bytes, ctx).map(|NoLeftovers(inner)| inner)
59}
60
61/// Decode a tagged value, expecting the given tag. For a lenient version, see allow_tag.
62pub fn expect_tag(d: &mut cbor::Decoder<'_>, expected: impl Into<Tag>) -> Result<(), cbor::decode::Error> {
63    let tag: Tag = d.tag()?;
64    let expected: Tag = expected.into();
65
66    if tag != expected {
67        return Err(cbor::decode::Error::message(format!("invalid CBOR tag: got {tag}, expected {expected}")));
68    };
69
70    Ok(())
71}
72
73pub fn allow_tag(d: &mut cbor::Decoder<'_>, expected: Tag) -> Result<(), cbor::decode::Error> {
74    if d.datatype()? == cbor::data::Type::Tag {
75        let tag = d.tag()?;
76        if tag != expected {
77            return Err(cbor::decode::Error::message(format!("invalid CBOR tag: expected {expected} got {tag}")));
78        }
79    }
80
81    Ok(())
82}
83
84#[repr(transparent)]
85struct NoLeftovers<A>(A);
86
87impl<'a, C, A: cbor::Decode<'a, C>> cbor::decode::Decode<'a, C> for NoLeftovers<A> {
88    fn decode(d: &mut cbor::Decoder<'a>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
89        let inner = d.decode_with(ctx)?;
90
91        if !d.datatype().is_err_and(|e| e.is_end_of_input()) {
92            return Err(cbor::decode::Error::message(format!(
93                "leftovers bytes after decoding after position {}",
94                d.position()
95            )));
96        }
97
98        Ok(NoLeftovers(inner))
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use foo::Foo;
105
106    use crate::{cbor, from_cbor, from_cbor_no_leftovers, to_cbor};
107
108    #[test]
109    fn from_cbor_no_leftovers_catches_trailing_breaks() {
110        #[derive(Debug, PartialEq, Eq)]
111        struct TestCase<A>(A);
112
113        // Incomplete decoder that ignores the trailing break caracter.
114        impl<'d, C> cbor::decode::Decode<'d, C> for TestCase<Foo> {
115            fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
116                d.array()?;
117                Ok(TestCase(Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? }))
118            }
119        }
120
121        let original_foo = Foo { field0: 14, field1: 42 };
122
123        let bytes = to_cbor(&AsIndefinite(&original_foo));
124
125        assert_eq!(Some(TestCase(original_foo)), from_cbor(&bytes));
126        assert!(from_cbor_no_leftovers::<TestCase<Foo>>(&bytes).is_err())
127    }
128
129    pub(crate) struct AsIndefinite<A>(pub(crate) A);
130
131    pub(crate) struct AsDefinite<A>(pub(crate) A);
132
133    pub(crate) struct AsMap<A>(pub(crate) A);
134
135    pub(crate) mod foo {
136        use minicbor as cbor;
137
138        use super::{AsDefinite, AsIndefinite, AsMap};
139
140        #[derive(Debug, Clone, Copy, PartialEq, Eq)]
141        pub(crate) struct Foo {
142            pub(crate) field0: u8,
143            pub(crate) field1: u8,
144        }
145
146        impl<C> cbor::encode::Encode<C> for AsIndefinite<&Foo> {
147            fn encode<W: cbor::encode::Write>(
148                &self,
149                e: &mut cbor::Encoder<W>,
150                ctx: &mut C,
151            ) -> Result<(), cbor::encode::Error<W::Error>> {
152                e.begin_array()?;
153                e.encode_with(self.0.field0, ctx)?;
154                e.encode_with(self.0.field1, ctx)?;
155                e.end()?;
156                Ok(())
157            }
158        }
159
160        impl<C> cbor::encode::Encode<C> for AsIndefinite<AsMap<&Foo>> {
161            fn encode<W: cbor::encode::Write>(
162                &self,
163                e: &mut cbor::Encoder<W>,
164                ctx: &mut C,
165            ) -> Result<(), cbor::encode::Error<W::Error>> {
166                e.begin_map()?;
167                e.encode_with(0_u8, ctx)?;
168                e.encode_with(self.0.0.field0, ctx)?;
169                e.encode_with(1_u8, ctx)?;
170                e.encode_with(self.0.0.field1, ctx)?;
171                e.end()?;
172                Ok(())
173            }
174        }
175
176        impl<C> cbor::encode::Encode<C> for AsDefinite<&Foo> {
177            fn encode<W: cbor::encode::Write>(
178                &self,
179                e: &mut cbor::Encoder<W>,
180                ctx: &mut C,
181            ) -> Result<(), cbor::encode::Error<W::Error>> {
182                e.array(2)?;
183                e.encode_with(self.0.field0, ctx)?;
184                e.encode_with(self.0.field1, ctx)?;
185                Ok(())
186            }
187        }
188
189        impl<C> cbor::encode::Encode<C> for AsDefinite<AsMap<&Foo>> {
190            fn encode<W: cbor::encode::Write>(
191                &self,
192                e: &mut cbor::Encoder<W>,
193                ctx: &mut C,
194            ) -> Result<(), cbor::encode::Error<W::Error>> {
195                e.map(2)?;
196                e.encode_with(0_u8, ctx)?;
197                e.encode_with(self.0.0.field0, ctx)?;
198                e.encode_with(1_u8, ctx)?;
199                e.encode_with(self.0.0.field1, ctx)?;
200                Ok(())
201            }
202        }
203
204        impl<'d, C> cbor::decode::Decode<'d, C> for AsDefinite<Foo> {
205            fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
206                let len = d.array()?;
207                let foo = Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? };
208                match len {
209                    Some(2) => Ok(AsDefinite(foo)),
210                    _ => Err(cbor::decode::Error::message("invalid or missing definite array length")),
211                }
212            }
213        }
214
215        impl<'d, C> cbor::decode::Decode<'d, C> for AsIndefinite<Foo> {
216            fn decode(d: &mut cbor::Decoder<'d>, ctx: &mut C) -> Result<Self, cbor::decode::Error> {
217                let len = d.array()?;
218                let foo = Foo { field0: d.decode_with(ctx)?, field1: d.decode_with(ctx)? };
219                match len {
220                    None if d.datatype()? == cbor::data::Type::Break => {
221                        d.skip()?;
222                        Ok(AsIndefinite(foo))
223                    }
224                    _ => Err(cbor::decode::Error::message("missing indefinite array break")),
225                }
226            }
227        }
228    }
229}