1use crate::{util::at_least, varint, Codec, Error, SizedCodec};
4use bytes::{Buf, BufMut, Bytes};
5use paste::paste;
6
7macro_rules! impl_numeric {
9 ($type:ty, $read_method:ident, $write_method:ident) => {
10 impl Codec for $type {
11 #[inline]
12 fn write(&self, buf: &mut impl BufMut) {
13 buf.$write_method(*self);
14 }
15
16 #[inline]
17 fn len_encoded(&self) -> usize {
18 Self::LEN_ENCODED
19 }
20
21 #[inline]
22 fn read(buf: &mut impl Buf) -> Result<Self, Error> {
23 at_least(buf, std::mem::size_of::<$type>())?;
24 Ok(buf.$read_method())
25 }
26 }
27
28 impl SizedCodec for $type {
29 const LEN_ENCODED: usize = std::mem::size_of::<$type>();
30 }
31 };
32}
33
34impl_numeric!(u8, get_u8, put_u8);
35impl_numeric!(u16, get_u16, put_u16);
36impl_numeric!(u32, get_u32, put_u32);
37impl_numeric!(u64, get_u64, put_u64);
38impl_numeric!(u128, get_u128, put_u128);
39impl_numeric!(i8, get_i8, put_i8);
40impl_numeric!(i16, get_i16, put_i16);
41impl_numeric!(i32, get_i32, put_i32);
42impl_numeric!(i64, get_i64, put_i64);
43impl_numeric!(i128, get_i128, put_i128);
44impl_numeric!(f32, get_f32, put_f32);
45impl_numeric!(f64, get_f64, put_f64);
46
47impl Codec for bool {
49 #[inline]
50 fn write(&self, buf: &mut impl BufMut) {
51 buf.put_u8(if *self { 1 } else { 0 });
52 }
53
54 #[inline]
55 fn len_encoded(&self) -> usize {
56 Self::LEN_ENCODED
57 }
58
59 #[inline]
60 fn read(buf: &mut impl Buf) -> Result<Self, Error> {
61 at_least(buf, 1)?;
62 match buf.get_u8() {
63 0 => Ok(false),
64 1 => Ok(true),
65 _ => Err(Error::InvalidBool),
66 }
67 }
68}
69
70impl SizedCodec for bool {
71 const LEN_ENCODED: usize = 1;
72}
73
74impl Codec for Bytes {
76 #[inline]
77 fn write(&self, buf: &mut impl BufMut) {
78 let len = u32::try_from(self.len()).expect("Bytes length exceeds u32");
79 varint::write(len, buf);
80 buf.put_slice(self);
81 }
82
83 #[inline]
84 fn len_encoded(&self) -> usize {
85 let len = u32::try_from(self.len()).expect("Bytes length exceeds u32");
86 varint::size(len) + self.len()
87 }
88
89 #[inline]
90 fn read(buf: &mut impl Buf) -> Result<Self, Error> {
91 let len32 = varint::read::<u32>(buf)?;
92 let len = usize::try_from(len32).map_err(|_| Error::InvalidVarint)?;
93 at_least(buf, len)?;
94 Ok(buf.copy_to_bytes(len))
95 }
96}
97
98impl<const N: usize> Codec for [u8; N] {
100 #[inline]
101 fn write(&self, buf: &mut impl BufMut) {
102 buf.put(&self[..]);
103 }
104
105 #[inline]
106 fn len_encoded(&self) -> usize {
107 N
108 }
109
110 #[inline]
111 fn read(buf: &mut impl Buf) -> Result<Self, Error> {
112 at_least(buf, N)?;
113 let mut dst = [0; N];
114 buf.copy_to_slice(&mut dst);
115 Ok(dst)
116 }
117}
118
119impl<const N: usize> SizedCodec for [u8; N] {
120 const LEN_ENCODED: usize = N;
121}
122
123impl<T: Codec> Codec for Option<T> {
125 #[inline]
126 fn write(&self, buf: &mut impl BufMut) {
127 self.is_some().write(buf);
128 if let Some(inner) = self {
129 inner.write(buf);
130 }
131 }
132
133 #[inline]
134 fn len_encoded(&self) -> usize {
135 match self {
136 Some(inner) => 1 + inner.len_encoded(),
137 None => 1,
138 }
139 }
140
141 #[inline]
142 fn read(buf: &mut impl Buf) -> Result<Self, Error> {
143 if bool::read(buf)? {
144 Ok(Some(T::read(buf)?))
145 } else {
146 Ok(None)
147 }
148 }
149}
150
151macro_rules! impl_codec_for_tuple {
153 ($($index:literal),*) => {
154 paste! {
155 impl<$( [<T $index>]: Codec ),*> Codec for ( $( [<T $index>], )* ) {
156 fn write(&self, buf: &mut impl BufMut) {
157 $( self.$index.write(buf); )*
158 }
159
160 fn len_encoded(&self) -> usize {
161 0 $( + self.$index.len_encoded() )*
162 }
163
164 fn read(buf: &mut impl Buf) -> Result<Self, Error> {
165 Ok(( $( [<T $index>]::read(buf)?, )* ))
166 }
167 }
168 }
169 };
170}
171
172impl_codec_for_tuple!(0);
174impl_codec_for_tuple!(0, 1);
175impl_codec_for_tuple!(0, 1, 2);
176impl_codec_for_tuple!(0, 1, 2, 3);
177impl_codec_for_tuple!(0, 1, 2, 3, 4);
178impl_codec_for_tuple!(0, 1, 2, 3, 4, 5);
179impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6);
180impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7);
181impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7, 8);
182impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9);
183impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
184impl_codec_for_tuple!(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11);
185
186impl<T: Codec> Codec for Vec<T> {
188 #[inline]
189 fn write(&self, buf: &mut impl BufMut) {
190 let len = u32::try_from(self.len()).expect("Vec length exceeds u32");
191 varint::write(len, buf);
192 for item in self {
193 item.write(buf);
194 }
195 }
196
197 #[inline]
198 fn len_encoded(&self) -> usize {
199 let len = u32::try_from(self.len()).expect("Vec length exceeds u32");
200 varint::size(len) + self.iter().map(Codec::len_encoded).sum::<usize>()
201 }
202
203 #[inline]
204 fn read(buf: &mut impl Buf) -> Result<Self, Error> {
205 let len32 = varint::read::<u32>(buf)?;
206 let len = usize::try_from(len32).map_err(|_| Error::InvalidVarint)?;
207 let mut vec = Vec::with_capacity(len);
208 for _ in 0..len {
209 vec.push(T::read(buf)?);
210 }
211 Ok(vec)
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::codec::{Codec, SizedCodec};
219 use bytes::Bytes;
220
221 macro_rules! impl_num_test {
223 ($type:ty, $size:expr) => {
224 paste! {
225 #[test]
226 fn [<test_ $type>]() {
227 let expected_len = std::mem::size_of::<$type>();
228 let values: [$type; 5] =
229 [0 as $type, 1 as $type, 42 as $type, <$type>::MAX, <$type>::MIN];
230 for value in values.iter() {
231 let encoded = value.encode();
232 assert_eq!(encoded.len(), expected_len);
233 let decoded = <$type>::decode(encoded).unwrap();
234 assert_eq!(*value, decoded);
235 assert_eq!(Codec::len_encoded(value), expected_len);
236 assert_eq!(SizedCodec::len_encoded(value), expected_len);
237
238 let fixed: [u8; $size] = value.encode_fixed();
239 assert_eq!(fixed.len(), expected_len);
240 let decoded = <$type>::decode(Bytes::copy_from_slice(&fixed)).unwrap();
241 assert_eq!(*value, decoded);
242 }
243 }
244 }
245 };
246 }
247 impl_num_test!(u8, 1);
248 impl_num_test!(u16, 2);
249 impl_num_test!(u32, 4);
250 impl_num_test!(u64, 8);
251 impl_num_test!(u128, 16);
252 impl_num_test!(i8, 1);
253 impl_num_test!(i16, 2);
254 impl_num_test!(i32, 4);
255 impl_num_test!(i64, 8);
256 impl_num_test!(i128, 16);
257 impl_num_test!(f32, 4);
258 impl_num_test!(f64, 8);
259
260 #[test]
261 fn test_endianness() {
262 let encoded = 0x0102u16.encode();
264 assert_eq!(encoded, Bytes::from_static(&[0x01, 0x02]));
265
266 let encoded = 0x01020304u32.encode();
268 assert_eq!(encoded, Bytes::from_static(&[0x01, 0x02, 0x03, 0x04]));
269
270 let encoded = 1.0f32.encode();
272 assert_eq!(encoded, Bytes::from_static(&[0x3F, 0x80, 0x00, 0x00])); }
274
275 #[test]
276 fn test_bool() {
277 let values = [true, false];
278 for value in values.iter() {
279 let encoded = value.encode();
280 assert_eq!(encoded.len(), 1);
281 let decoded = bool::decode(encoded).unwrap();
282 assert_eq!(*value, decoded);
283 assert_eq!(Codec::len_encoded(value), 1);
284 assert_eq!(SizedCodec::len_encoded(value), 1);
285 }
286 }
287
288 #[test]
289 fn test_bytes() {
290 let values = [
291 Bytes::new(),
292 Bytes::from_static(&[1, 2, 3]),
293 Bytes::from(vec![0; 300]),
294 ];
295 for value in values {
296 let encoded = value.encode();
297 assert_eq!(
298 encoded.len(),
299 varint::size(value.len() as u64) + value.len()
300 );
301 let decoded = Bytes::decode(encoded).unwrap();
302 assert_eq!(value, decoded);
303 }
304 }
305
306 #[test]
307 fn test_array() {
308 let values = [1u8, 2, 3];
309 let encoded = values.encode();
310 let decoded = <[u8; 3]>::decode(encoded).unwrap();
311 assert_eq!(values, decoded);
312 }
313
314 #[test]
315 fn test_option() {
316 let option_values = [Some(42u32), None];
317 for value in option_values {
318 let encoded = value.encode();
319 let decoded = Option::<u32>::decode(encoded).unwrap();
320 assert_eq!(value, decoded);
321 }
322 }
323
324 #[test]
325 fn test_option_length() {
326 let some = Some(42u32);
327 assert_eq!(Codec::len_encoded(&some), 1 + 4);
328 assert_eq!(some.encode().len(), 1 + 4);
329 let none: Option<u32> = None;
330 assert_eq!(Codec::len_encoded(&none), 1);
331 assert_eq!(none.encode().len(), 1);
332 }
333
334 #[test]
335 fn test_tuple() {
336 let tuple_values = [(1u16, None), (1u16, Some(2u32))];
337 for value in tuple_values {
338 let encoded = value.encode();
339 let decoded = <(u16, Option<u32>)>::decode(encoded).unwrap();
340 assert_eq!(value, decoded);
341 }
342 }
343
344 #[test]
345 fn test_vec() {
346 let vec_values = [vec![], vec![1u8], vec![1u8, 2u8, 3u8]];
347 for value in vec_values {
348 let encoded = value.encode();
349 assert_eq!(encoded.len(), value.len() * std::mem::size_of::<u8>() + 1);
350 let decoded = Vec::<u8>::decode(encoded).unwrap();
351 assert_eq!(value, decoded);
352 }
353 }
354}