Skip to main content

wayk_proto/
container.rs

1macro_rules! impl_container {
2    ($ty:ident as Vec with $size_ty:ident) => {
3        #[derive(PartialEq, Debug, Clone)]
4        pub struct $ty<Item>(pub Vec<Item>);
5
6        impl<Item> core::ops::Deref for $ty<Item> {
7            type Target = Vec<Item>;
8
9            fn deref(&self) -> &Self::Target {
10                &self.0
11            }
12        }
13
14        impl<Item> core::ops::DerefMut for $ty<Item> {
15            fn deref_mut(&mut self) -> &mut Self::Target {
16                &mut self.0
17            }
18        }
19
20        impl<Item> core::iter::IntoIterator for $ty<Item> {
21            type Item = Item;
22            type IntoIter = alloc::vec::IntoIter<Self::Item>;
23
24            fn into_iter(self) -> Self::IntoIter {
25                self.0.into_iter()
26            }
27        }
28
29        impl<'a, Item> core::iter::IntoIterator for &'a $ty<Item> {
30            type Item = &'a Item;
31            type IntoIter = alloc::slice::Iter<'a, Item>;
32
33            fn into_iter(self) -> Self::IntoIter {
34                self.0.iter()
35            }
36        }
37
38        impl<'a, Item> core::iter::IntoIterator for &'a mut $ty<Item> {
39            type Item = &'a mut Item;
40            type IntoIter = alloc::slice::IterMut<'a, Item>;
41
42            fn into_iter(self) -> Self::IntoIter {
43                self.0.iter_mut()
44            }
45        }
46
47        impl<Item> From<Vec<Item>> for $ty<Item> {
48            fn from(v: Vec<Item>) -> Self {
49                Self(v)
50            }
51        }
52
53        impl<Item> Into<Vec<Item>> for $ty<Item> {
54            fn into(self) -> Vec<Item> {
55                self.0
56            }
57        }
58
59        impl<Item> PartialEq<Vec<Item>> for $ty<Item>
60        where
61            Item: PartialEq,
62        {
63            fn eq(&self, other: &Vec<Item>) -> bool {
64                self.0.eq(other)
65            }
66        }
67
68        impl<Item> crate::serialization::Encode for $ty<Item>
69        where
70            Item: crate::serialization::Encode + core::fmt::Debug,
71        {
72            fn encoded_len(&self) -> usize {
73                self.iter().fold(core::mem::size_of::<$size_ty>(), |acc, item| {
74                    acc + item.encoded_len()
75                })
76            }
77
78            fn encode_into<W: std::io::Write>(
79                &self,
80                writer: &mut W,
81            ) -> core::result::Result<(), $crate::error::ProtoError> {
82                use crate::error::*;
83                use core::convert::TryFrom;
84
85                let count = <$size_ty>::try_from(self.len())
86                    .map_err(crate::error::ProtoError::from)
87                    .chain($crate::error::ProtoErrorKind::Encoding(stringify!($ty)))
88                    .or_desc("couldn't convert losslessly vec size into u8 (count)")?;
89                count.encode_into(writer)?;
90                for item in self {
91                    item.encode_into(writer)
92                        .chain($crate::error::ProtoErrorKind::Encoding(stringify!($ty)))
93                        .or_else_desc(|| format!("couldn't encode item {:?}", item))?;
94                }
95                Ok(())
96            }
97        }
98
99        impl<'dec, Item> crate::serialization::Decode<'dec> for $ty<Item>
100        where
101            Item: crate::serialization::Decode<'dec>,
102        {
103            fn decode_from(cursor: &mut std::io::Cursor<&'dec [u8]>) -> Result<Self, $crate::error::ProtoError> {
104                use crate::error::*;
105
106                let count = <$size_ty>::decode_from(cursor)
107                    .chain($crate::error::ProtoErrorKind::Decoding(stringify!($ty)))
108                    .or_desc("couldn't decode list count")?;
109                let mut vec = Vec::new();
110                for i in 0..count {
111                    vec.push(
112                        Item::decode_from(cursor)
113                            .chain($crate::error::ProtoErrorKind::Decoding(stringify!($ty)))
114                            .or_else_desc(|| format!("couldn't decode item n°{}", i))?,
115                    );
116                }
117                Ok(Self(vec))
118            }
119        }
120    };
121    ($ty:ident as &[u8] with $size_ty:ident) => {
122        #[derive(PartialEq, Debug, Clone)]
123        pub struct $ty<'a>(pub &'a [u8]);
124
125        impl<'a> core::ops::Deref for $ty<'a> {
126            type Target = &'a [u8];
127
128            fn deref(&self) -> &Self::Target {
129                &self.0
130            }
131        }
132
133        impl<'a> core::iter::IntoIterator for &'a $ty<'a> {
134            type Item = &'a u8;
135            type IntoIter = alloc::slice::Iter<'a, u8>;
136
137            fn into_iter(self) -> Self::IntoIter {
138                self.0.iter()
139            }
140        }
141
142        impl<'a> From<&'a [u8]> for $ty<'a> {
143            fn from(v: &'a [u8]) -> Self {
144                Self(v)
145            }
146        }
147
148        impl<'a> Into<&'a [u8]> for $ty<'a> {
149            fn into(self) -> &'a [u8] {
150                self.0
151            }
152        }
153
154        impl<'a> PartialEq<&'a [u8]> for $ty<'a> {
155            fn eq(&self, other: &&'a [u8]) -> bool {
156                self.0.eq(*other)
157            }
158        }
159
160        impl crate::serialization::Encode for $ty<'_> {
161            fn encoded_len(&self) -> usize {
162                core::mem::size_of::<$size_ty>() + core::mem::size_of::<u8>() * self.len()
163            }
164
165            fn encode_into<W: std::io::Write>(&self, writer: &mut W) -> Result<(), $crate::error::ProtoError> {
166                use crate::error::*;
167                use core::convert::TryFrom;
168
169                let count = <$size_ty>::try_from(self.len())
170                    .map_err(ProtoError::from)
171                    .chain($crate::error::ProtoErrorKind::Encoding(stringify!($ty)))
172                    .or_else_desc(|| {
173                        format!(
174                            "couldn't convert losslessly slice size into {} (count)",
175                            stringify!($size_ty)
176                        )
177                    })?;
178                count.encode_into(writer)?;
179                writer.write_all(self.0)?;
180                Ok(())
181            }
182        }
183
184        impl<'dec: 'a, 'a> crate::serialization::Decode<'dec> for $ty<'a> {
185            fn decode_from(cursor: &mut std::io::Cursor<&'dec [u8]>) -> Result<Self, $crate::error::ProtoError> {
186                use crate::error::*;
187
188                let count = <$size_ty>::decode_from(cursor)
189                    .chain($crate::error::ProtoErrorKind::Decoding(stringify!($ty)))
190                    .or_desc("couldn't decode list count")?;
191                let start_inclusive = cursor.position() as usize;
192                let slices_to_end = &cursor.get_ref()[start_inclusive..];
193                if slices_to_end.len() < count as usize {
194                    return ProtoError::new(ProtoErrorKind::Decoding(stringify!($ty))).or_else_desc(|| {
195                        format!(
196                            "couldn't decode list: count ({}) greater than available bytes ({})",
197                            count,
198                            slices_to_end.len()
199                        )
200                    });
201                }
202                let bytes = &slices_to_end[..count as usize];
203                Ok($ty(bytes))
204            }
205        }
206    };
207}
208
209impl_container! { Vec8  as Vec with u8  }
210impl_container! { Vec16 as Vec with u16 }
211impl_container! { Vec32 as Vec with u32 }
212impl_container! { Vec64 as Vec with u64 }
213
214impl_container! { Bytes8  as &[u8] with u8  }
215impl_container! { Bytes16 as &[u8] with u16 }
216impl_container! { Bytes32 as &[u8] with u32 }
217impl_container! { Bytes64 as &[u8] with u64 }
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::serialization::{Decode, Encode};
223
224    const U16_VEC8: [u8; 7] = [0x03, 0x50, 0x10, 0x0a, 0x09, 0x57, 0x0b];
225
226    #[test]
227    fn encode_vec8() {
228        let vec = Vec8(vec![0x1050u16, 0x090au16, 0x0b57u16]);
229        assert_eq!(vec.encode().unwrap(), &U16_VEC8);
230    }
231
232    #[test]
233    fn decode_vec8() {
234        assert_eq!(
235            Vec8::<u16>::decode(&U16_VEC8).unwrap(),
236            vec![0x1050u16, 0x090au16, 0x0b57u16]
237        );
238    }
239
240    const U16_VEC32: [u8; 10] = [0x03, 0x00, 0x00, 0x00, 0x50, 0x10, 0x0a, 0x09, 0x57, 0x0b];
241
242    #[test]
243    fn encode_vec32() {
244        let vec = Vec32(vec![0x1050u16, 0x090au16, 0x0b57u16]);
245        assert_eq!(vec.encode().unwrap(), &U16_VEC32);
246    }
247
248    #[test]
249    fn decode_vec32() {
250        assert_eq!(
251            Vec32::<u16>::decode(&U16_VEC32).unwrap(),
252            vec![0x1050u16, 0x090au16, 0x0b57u16]
253        );
254    }
255
256    const ENCODED_MSG_WITH_BYTES8: [u8; 13] = [
257        0x38, 0xae, 0xf3, // things
258        0x06, // count
259        0x50, 0x10, 0x0a, 0x09, 0x57, 0x0b, // elements
260        0xc3, 0xaf, 0x13, // other things
261    ];
262
263    #[test]
264    fn encode_bytes8() {
265        let slice = Bytes8(&ENCODED_MSG_WITH_BYTES8[4..=9]);
266        assert_eq!(slice.encode().unwrap(), &ENCODED_MSG_WITH_BYTES8[3..=9]);
267    }
268
269    #[test]
270    fn decode_bytes8() {
271        assert_eq!(
272            Bytes8::decode(&ENCODED_MSG_WITH_BYTES8[3..]).unwrap(),
273            &ENCODED_MSG_WITH_BYTES8[4..=9]
274        );
275    }
276
277    const ENCODED_MSG_WITH_BYTES32: [u8; 16] = [
278        0x38, 0xae, 0xf3, // things
279        0x06, 0x00, 0x00, 0x00, // count
280        0x50, 0x10, 0x0a, 0x09, 0x57, 0x0b, // elements
281        0xc3, 0xaf, 0x13, // other things
282    ];
283
284    #[test]
285    fn encode_bytes32() {
286        let slice = Bytes32(&ENCODED_MSG_WITH_BYTES32[7..=12]);
287        assert_eq!(slice.encode().unwrap(), &ENCODED_MSG_WITH_BYTES32[3..=12]);
288    }
289
290    #[test]
291    fn decode_bytes32() {
292        assert_eq!(
293            Bytes32::decode(&ENCODED_MSG_WITH_BYTES32[3..]).unwrap(),
294            &ENCODED_MSG_WITH_BYTES32[7..=12]
295        );
296    }
297}