commonware_codec/types/
primitives.rs

1//! Implementations of Codec for primitive types.
2
3use crate::{util::at_least, Config, EncodeSize, Error, FixedSize, Read, ReadExt, Write};
4use bytes::{Buf, BufMut};
5
6// Numeric types implementation
7macro_rules! impl_numeric {
8    ($type:ty, $read_method:ident, $write_method:ident) => {
9        impl Write for $type {
10            #[inline]
11            fn write(&self, buf: &mut impl BufMut) {
12                buf.$write_method(*self);
13            }
14        }
15
16        impl Read for $type {
17            #[inline]
18            fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
19                at_least(buf, std::mem::size_of::<$type>())?;
20                Ok(buf.$read_method())
21            }
22        }
23
24        impl FixedSize for $type {
25            const SIZE: usize = std::mem::size_of::<$type>();
26        }
27    };
28}
29
30impl_numeric!(u8, get_u8, put_u8);
31impl_numeric!(u16, get_u16, put_u16);
32impl_numeric!(u32, get_u32, put_u32);
33impl_numeric!(u64, get_u64, put_u64);
34impl_numeric!(u128, get_u128, put_u128);
35impl_numeric!(i8, get_i8, put_i8);
36impl_numeric!(i16, get_i16, put_i16);
37impl_numeric!(i32, get_i32, put_i32);
38impl_numeric!(i64, get_i64, put_i64);
39impl_numeric!(i128, get_i128, put_i128);
40impl_numeric!(f32, get_f32, put_f32);
41impl_numeric!(f64, get_f64, put_f64);
42
43// Bool implementation
44impl Write for bool {
45    #[inline]
46    fn write(&self, buf: &mut impl BufMut) {
47        buf.put_u8(if *self { 1 } else { 0 });
48    }
49}
50
51impl Read for bool {
52    #[inline]
53    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
54        match u8::read(buf)? {
55            0 => Ok(false),
56            1 => Ok(true),
57            _ => Err(Error::InvalidBool),
58        }
59    }
60}
61
62impl FixedSize for bool {
63    const SIZE: usize = 1;
64}
65
66// Constant-size array implementation
67impl<const N: usize> Write for [u8; N] {
68    #[inline]
69    fn write(&self, buf: &mut impl BufMut) {
70        buf.put(&self[..]);
71    }
72}
73
74impl<const N: usize> Read for [u8; N] {
75    #[inline]
76    fn read_cfg(buf: &mut impl Buf, _: &()) -> Result<Self, Error> {
77        at_least(buf, N)?;
78        let mut dst = [0; N];
79        buf.copy_to_slice(&mut dst);
80        Ok(dst)
81    }
82}
83
84impl<const N: usize> FixedSize for [u8; N] {
85    const SIZE: usize = N;
86}
87
88// Option implementation
89impl<T: Write> Write for Option<T> {
90    #[inline]
91    fn write(&self, buf: &mut impl BufMut) {
92        self.is_some().write(buf);
93        if let Some(inner) = self {
94            inner.write(buf);
95        }
96    }
97}
98
99impl<T: EncodeSize> EncodeSize for Option<T> {
100    #[inline]
101    fn encode_size(&self) -> usize {
102        match self {
103            Some(inner) => 1 + inner.encode_size(),
104            None => 1,
105        }
106    }
107}
108
109impl<Cfg: Config, T: Read<Cfg>> Read<Cfg> for Option<T> {
110    #[inline]
111    fn read_cfg(buf: &mut impl Buf, cfg: &Cfg) -> Result<Self, Error> {
112        if bool::read(buf)? {
113            Ok(Some(T::read_cfg(buf, cfg)?))
114        } else {
115            Ok(None)
116        }
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123    use crate::{DecodeExt, Encode, EncodeFixed};
124    use bytes::Bytes;
125    use paste::paste;
126
127    // Float tests
128    macro_rules! impl_num_test {
129        ($type:ty, $size:expr) => {
130            paste! {
131                #[test]
132                fn [<test_ $type>]() {
133                    let expected_len = std::mem::size_of::<$type>();
134                    let values: [$type; 5] =
135                        [0 as $type, 1 as $type, 42 as $type, <$type>::MAX, <$type>::MIN];
136                    for value in values.iter() {
137                        let encoded = value.encode();
138                        assert_eq!(encoded.len(), expected_len);
139                        let decoded = <$type>::decode(encoded).unwrap();
140                        assert_eq!(*value, decoded);
141                        assert_eq!(value.encode_size(), expected_len);
142
143                        let fixed: [u8; $size] = value.encode_fixed();
144                        assert_eq!(fixed.len(), expected_len);
145                        let decoded = <$type>::decode(Bytes::copy_from_slice(&fixed)).unwrap();
146                        assert_eq!(*value, decoded);
147                    }
148                }
149            }
150        };
151    }
152    impl_num_test!(u8, 1);
153    impl_num_test!(u16, 2);
154    impl_num_test!(u32, 4);
155    impl_num_test!(u64, 8);
156    impl_num_test!(u128, 16);
157    impl_num_test!(i8, 1);
158    impl_num_test!(i16, 2);
159    impl_num_test!(i32, 4);
160    impl_num_test!(i64, 8);
161    impl_num_test!(i128, 16);
162    impl_num_test!(f32, 4);
163    impl_num_test!(f64, 8);
164
165    #[test]
166    fn test_endianness() {
167        // u16
168        let encoded = 0x0102u16.encode();
169        assert_eq!(encoded, Bytes::from_static(&[0x01, 0x02]));
170
171        // u32
172        let encoded = 0x01020304u32.encode();
173        assert_eq!(encoded, Bytes::from_static(&[0x01, 0x02, 0x03, 0x04]));
174
175        // f32
176        let encoded = 1.0f32.encode();
177        assert_eq!(encoded, Bytes::from_static(&[0x3F, 0x80, 0x00, 0x00])); // Big-endian IEEE 754
178    }
179
180    #[test]
181    fn test_bool() {
182        let values = [true, false];
183        for value in values.iter() {
184            let encoded = value.encode();
185            assert_eq!(encoded.len(), 1);
186            let decoded = bool::decode(encoded).unwrap();
187            assert_eq!(*value, decoded);
188            assert_eq!(value.encode_size(), 1);
189        }
190    }
191
192    #[test]
193    fn test_array() {
194        let values = [1u8, 2, 3];
195        let encoded = values.encode();
196        let decoded = <[u8; 3]>::decode(encoded).unwrap();
197        assert_eq!(values, decoded);
198    }
199
200    #[test]
201    fn test_option() {
202        let option_values = [Some(42u32), None];
203        for value in option_values {
204            let encoded = value.encode();
205            let decoded = Option::<u32>::decode(encoded).unwrap();
206            assert_eq!(value, decoded);
207        }
208    }
209
210    #[test]
211    fn test_option_length() {
212        let some = Some(42u32);
213        assert_eq!(some.encode_size(), 1 + 4);
214        assert_eq!(some.encode().len(), 1 + 4);
215        let none: Option<u32> = None;
216        assert_eq!(none.encode_size(), 1);
217        assert_eq!(none.encode().len(), 1);
218    }
219}