commonware_codec/types/
primitives.rs

1//! Implementations of Codec for common types
2
3use crate::{util::at_least, varint, Codec, Error, SizedCodec};
4use bytes::{Buf, BufMut, Bytes};
5use paste::paste;
6
7// Numeric types implementation
8macro_rules! impl_numeric {
9    ($type:ty, $read_method:ident, $write_method:ident) => {
10        impl Codec for $type {
11            #[inline]
12            fn write(&self, buf: &mut impl BufMut) {
13                buf.$write_method(*self);
14            }
15
16            #[inline]
17            fn len_encoded(&self) -> usize {
18                Self::LEN_ENCODED
19            }
20
21            #[inline]
22            fn read(buf: &mut impl Buf) -> Result<Self, Error> {
23                at_least(buf, std::mem::size_of::<$type>())?;
24                Ok(buf.$read_method())
25            }
26        }
27
28        impl SizedCodec for $type {
29            const LEN_ENCODED: usize = std::mem::size_of::<$type>();
30        }
31    };
32}
33
34impl_numeric!(u8, get_u8, put_u8);
35impl_numeric!(u16, get_u16, put_u16);
36impl_numeric!(u32, get_u32, put_u32);
37impl_numeric!(u64, get_u64, put_u64);
38impl_numeric!(u128, get_u128, put_u128);
39impl_numeric!(i8, get_i8, put_i8);
40impl_numeric!(i16, get_i16, put_i16);
41impl_numeric!(i32, get_i32, put_i32);
42impl_numeric!(i64, get_i64, put_i64);
43impl_numeric!(i128, get_i128, put_i128);
44impl_numeric!(f32, get_f32, put_f32);
45impl_numeric!(f64, get_f64, put_f64);
46
47// Bool implementation
48impl Codec for bool {
49    #[inline]
50    fn write(&self, buf: &mut impl BufMut) {
51        buf.put_u8(if *self { 1 } else { 0 });
52    }
53
54    #[inline]
55    fn len_encoded(&self) -> usize {
56        Self::LEN_ENCODED
57    }
58
59    #[inline]
60    fn read(buf: &mut impl Buf) -> Result<Self, Error> {
61        at_least(buf, 1)?;
62        match buf.get_u8() {
63            0 => Ok(false),
64            1 => Ok(true),
65            _ => Err(Error::InvalidBool),
66        }
67    }
68}
69
70impl SizedCodec for bool {
71    const LEN_ENCODED: usize = 1;
72}
73
74// Bytes implementation
75impl Codec for Bytes {
76    #[inline]
77    fn write(&self, buf: &mut impl BufMut) {
78        let len = u32::try_from(self.len()).expect("Bytes length exceeds u32");
79        varint::write(len, buf);
80        buf.put_slice(self);
81    }
82
83    #[inline]
84    fn len_encoded(&self) -> usize {
85        let len = u32::try_from(self.len()).expect("Bytes length exceeds u32");
86        varint::size(len) + self.len()
87    }
88
89    #[inline]
90    fn read(buf: &mut impl Buf) -> Result<Self, Error> {
91        let len32 = varint::read::<u32>(buf)?;
92        let len = usize::try_from(len32).map_err(|_| Error::InvalidVarint)?;
93        at_least(buf, len)?;
94        Ok(buf.copy_to_bytes(len))
95    }
96}
97
98// Constant-size array implementation
99impl<const N: usize> Codec for [u8; N] {
100    #[inline]
101    fn write(&self, buf: &mut impl BufMut) {
102        buf.put(&self[..]);
103    }
104
105    #[inline]
106    fn len_encoded(&self) -> usize {
107        N
108    }
109
110    #[inline]
111    fn read(buf: &mut impl Buf) -> Result<Self, Error> {
112        at_least(buf, N)?;
113        let mut dst = [0; N];
114        buf.copy_to_slice(&mut dst);
115        Ok(dst)
116    }
117}
118
119impl<const N: usize> SizedCodec for [u8; N] {
120    const LEN_ENCODED: usize = N;
121}
122
123// Option implementation
124impl<T: Codec> Codec for Option<T> {
125    #[inline]
126    fn write(&self, buf: &mut impl BufMut) {
127        self.is_some().write(buf);
128        if let Some(inner) = self {
129            inner.write(buf);
130        }
131    }
132
133    #[inline]
134    fn len_encoded(&self) -> usize {
135        match self {
136            Some(inner) => 1 + inner.len_encoded(),
137            None => 1,
138        }
139    }
140
141    #[inline]
142    fn read(buf: &mut impl Buf) -> Result<Self, Error> {
143        if bool::read(buf)? {
144            Ok(Some(T::read(buf)?))
145        } else {
146            Ok(None)
147        }
148    }
149}
150
151// Tuple implementation
152macro_rules! impl_codec_for_tuple {
153    ($($index:literal),*) => {
154        paste! {
155            impl<$( [<T $index>]: Codec ),*> Codec for ( $( [<T $index>], )* ) {
156                fn write(&self, buf: &mut impl BufMut) {
157                    $( self.$index.write(buf); )*
158                }
159
160                fn len_encoded(&self) -> usize {
161                    0 $( + self.$index.len_encoded() )*
162                }
163
164                fn read(buf: &mut impl Buf) -> Result<Self, Error> {
165                    Ok(( $( [<T $index>]::read(buf)?, )* ))
166                }
167            }
168        }
169    };
170}
171
172// Generate implementations for tuple sizes 1 through 12
173impl_codec_for_tuple!(0);
174impl_codec_for_tuple!(0, 1);
175impl_codec_for_tuple!(0, 1, 2);
176impl_codec_for_tuple!(0, 1, 2, 3);
177impl_codec_for_tuple!(0, 1, 2, 3, 4);
178impl_codec_for_tuple!(0, 1, 2, 3, 4, 5);
179impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6);
180impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7);
181impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7, 8);
182impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
183impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
184impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
185
186// Vec implementation
187impl<T: Codec> Codec for Vec<T> {
188    #[inline]
189    fn write(&self, buf: &mut impl BufMut) {
190        let len = u32::try_from(self.len()).expect("Vec length exceeds u32");
191        varint::write(len, buf);
192        for item in self {
193            item.write(buf);
194        }
195    }
196
197    #[inline]
198    fn len_encoded(&self) -> usize {
199        let len = u32::try_from(self.len()).expect("Vec length exceeds u32");
200        varint::size(len) + self.iter().map(Codec::len_encoded).sum::<usize>()
201    }
202
203    #[inline]
204    fn read(buf: &mut impl Buf) -> Result<Self, Error> {
205        let len32 = varint::read::<u32>(buf)?;
206        let len = usize::try_from(len32).map_err(|_| Error::InvalidVarint)?;
207        let mut vec = Vec::with_capacity(len);
208        for _ in 0..len {
209            vec.push(T::read(buf)?);
210        }
211        Ok(vec)
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::codec::{Codec, SizedCodec};
219    use bytes::Bytes;
220
221    // Float tests
222    macro_rules! impl_num_test {
223        ($type:ty, $size:expr) => {
224            paste! {
225                #[test]
226                fn [<test_ $type>]() {
227                    let expected_len = std::mem::size_of::<$type>();
228                    let values: [$type; 5] =
229                        [0 as $type, 1 as $type, 42 as $type, <$type>::MAX, <$type>::MIN];
230                    for value in values.iter() {
231                        let encoded = value.encode();
232                        assert_eq!(encoded.len(), expected_len);
233                        let decoded = <$type>::decode(encoded).unwrap();
234                        assert_eq!(*value, decoded);
235                        assert_eq!(Codec::len_encoded(value), expected_len);
236                        assert_eq!(SizedCodec::len_encoded(value), expected_len);
237
238                        let fixed: [u8; $size] = value.encode_fixed();
239                        assert_eq!(fixed.len(), expected_len);
240                        let decoded = <$type>::decode(Bytes::copy_from_slice(&fixed)).unwrap();
241                        assert_eq!(*value, decoded);
242                    }
243                }
244            }
245        };
246    }
247    impl_num_test!(u8, 1);
248    impl_num_test!(u16, 2);
249    impl_num_test!(u32, 4);
250    impl_num_test!(u64, 8);
251    impl_num_test!(u128, 16);
252    impl_num_test!(i8, 1);
253    impl_num_test!(i16, 2);
254    impl_num_test!(i32, 4);
255    impl_num_test!(i64, 8);
256    impl_num_test!(i128, 16);
257    impl_num_test!(f32, 4);
258    impl_num_test!(f64, 8);
259
260    #[test]
261    fn test_endianness() {
262        // u16
263        let encoded = 0x0102u16.encode();
264        assert_eq!(encoded, Bytes::from_static(&[0x01, 0x02]));
265
266        // u32
267        let encoded = 0x01020304u32.encode();
268        assert_eq!(encoded, Bytes::from_static(&[0x01, 0x02, 0x03, 0x04]));
269
270        // f32
271        let encoded = 1.0f32.encode();
272        assert_eq!(encoded, Bytes::from_static(&[0x3F, 0x80, 0x00, 0x00])); // Big-endian IEEE 754
273    }
274
275    #[test]
276    fn test_bool() {
277        let values = [true, false];
278        for value in values.iter() {
279            let encoded = value.encode();
280            assert_eq!(encoded.len(), 1);
281            let decoded = bool::decode(encoded).unwrap();
282            assert_eq!(*value, decoded);
283            assert_eq!(Codec::len_encoded(value), 1);
284            assert_eq!(SizedCodec::len_encoded(value), 1);
285        }
286    }
287
288    #[test]
289    fn test_bytes() {
290        let values = [
291            Bytes::new(),
292            Bytes::from_static(&[1, 2, 3]),
293            Bytes::from(vec![0; 300]),
294        ];
295        for value in values {
296            let encoded = value.encode();
297            assert_eq!(
298                encoded.len(),
299                varint::size(value.len() as u64) + value.len()
300            );
301            let decoded = Bytes::decode(encoded).unwrap();
302            assert_eq!(value, decoded);
303        }
304    }
305
306    #[test]
307    fn test_array() {
308        let values = [1u8, 2, 3];
309        let encoded = values.encode();
310        let decoded = <[u8; 3]>::decode(encoded).unwrap();
311        assert_eq!(values, decoded);
312    }
313
314    #[test]
315    fn test_option() {
316        let option_values = [Some(42u32), None];
317        for value in option_values {
318            let encoded = value.encode();
319            let decoded = Option::<u32>::decode(encoded).unwrap();
320            assert_eq!(value, decoded);
321        }
322    }
323
324    #[test]
325    fn test_option_length() {
326        let some = Some(42u32);
327        assert_eq!(Codec::len_encoded(&some), 1 + 4);
328        assert_eq!(some.encode().len(), 1 + 4);
329        let none: Option<u32> = None;
330        assert_eq!(Codec::len_encoded(&none), 1);
331        assert_eq!(none.encode().len(), 1);
332    }
333
334    #[test]
335    fn test_tuple() {
336        let tuple_values = [(1u16, None), (1u16, Some(2u32))];
337        for value in tuple_values {
338            let encoded = value.encode();
339            let decoded = <(u16, Option<u32>)>::decode(encoded).unwrap();
340            assert_eq!(value, decoded);
341        }
342    }
343
344    #[test]
345    fn test_vec() {
346        let vec_values = [vec![], vec![1u8], vec![1u8, 2u8, 3u8]];
347        for value in vec_values {
348            let encoded = value.encode();
349            assert_eq!(encoded.len(), value.len() * std::mem::size_of::<u8>() + 1);
350            let decoded = Vec::<u8>::decode(encoded).unwrap();
351            assert_eq!(value, decoded);
352        }
353    }
354}