Skip to main content

s2n_codec/
zerocopy.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use core::{cmp::Ordering, fmt, hash::Hash};
5pub use zerocopy::*;
6
7#[cfg(feature = "generator")]
8use bolero_generator::prelude::*;
9
10/// Define a codec implementation for a zerocopy value that implements
11/// `FromBytes`, `IntoBytes`, and `Unaligned`.
12#[macro_export]
13macro_rules! zerocopy_value_codec {
14    ($name:ident) => {
15        impl<'a> $crate::DecoderValue<'a> for $name
16        where
17            $name: $crate::zerocopy::FromBytes,
18        {
19            #[inline]
20            fn decode(buffer: $crate::DecoderBuffer<'a>) -> $crate::DecoderBufferResult<'a, Self> {
21                let (value, buffer) = <&'a $name as $crate::DecoderValue>::decode(buffer)?;
22                Ok((*value, buffer))
23            }
24        }
25
26        impl<'a> $crate::DecoderValue<'a> for &'a $name
27        where
28            $name: $crate::zerocopy::FromBytes,
29        {
30            #[inline]
31            fn decode(buffer: $crate::DecoderBuffer<'a>) -> $crate::DecoderBufferResult<'a, Self> {
32                let (value, buffer) = buffer.decode_slice(core::mem::size_of::<$name>())?;
33                let value = value.into_less_safe_slice();
34                let value = unsafe {
35                    // Safety: the type implements FromBytes
36                    &*(value as *const _ as *const $name)
37                };
38                Ok((value, buffer.into()))
39            }
40        }
41
42        impl<'a> $crate::DecoderValueMut<'a> for $name
43        where
44            $name: $crate::zerocopy::FromBytes,
45        {
46            #[inline]
47            fn decode_mut(
48                buffer: $crate::DecoderBufferMut<'a>,
49            ) -> $crate::DecoderBufferMutResult<'a, Self> {
50                let (value, buffer) = <&'a $name as $crate::DecoderValueMut>::decode_mut(buffer)?;
51                Ok((*value, buffer))
52            }
53        }
54
55        impl<'a> $crate::DecoderValueMut<'a> for &'a $name
56        where
57            $name: $crate::zerocopy::FromBytes,
58        {
59            #[inline]
60            fn decode_mut(
61                buffer: $crate::DecoderBufferMut<'a>,
62            ) -> $crate::DecoderBufferMutResult<'a, Self> {
63                let (value, buffer) =
64                    <&'a mut $name as $crate::DecoderValueMut>::decode_mut(buffer)?;
65                Ok((value, buffer))
66            }
67        }
68
69        impl<'a> $crate::DecoderValueMut<'a> for &'a mut $name
70        where
71            $name: $crate::zerocopy::FromBytes,
72        {
73            #[inline]
74            fn decode_mut(
75                buffer: $crate::DecoderBufferMut<'a>,
76            ) -> $crate::DecoderBufferMutResult<'a, Self> {
77                let (value, buffer) = buffer.decode_slice(core::mem::size_of::<$name>())?;
78                let value = value.into_less_safe_slice();
79                let value = unsafe {
80                    // Safety: the type implements FromBytes
81                    &mut *(value as *mut _ as *mut $name)
82                };
83
84                Ok((value, buffer.into()))
85            }
86        }
87
88        impl $crate::EncoderValue for $name
89        where
90            $name: $crate::zerocopy::IntoBytes,
91        {
92            #[inline]
93            fn encoding_size(&self) -> usize {
94                core::mem::size_of::<$name>()
95            }
96
97            #[inline]
98            fn encoding_size_for_encoder<E: $crate::Encoder>(&self, _encoder: &E) -> usize {
99                core::mem::size_of::<$name>()
100            }
101
102            #[inline]
103            fn encode<E: $crate::Encoder>(&self, encoder: &mut E) {
104                let bytes = unsafe {
105                    // Safety: the type implements IntoBytes
106                    core::slice::from_raw_parts(
107                        self as *const $name as *const u8,
108                        core::mem::size_of::<$name>(),
109                    )
110                };
111                encoder.write_slice(bytes);
112            }
113        }
114
115        impl<'a> $crate::EncoderValue for &'a $name
116        where
117            $name: $crate::zerocopy::IntoBytes,
118        {
119            #[inline]
120            fn encoding_size(&self) -> usize {
121                core::mem::size_of::<$name>()
122            }
123
124            #[inline]
125            fn encoding_size_for_encoder<E: $crate::Encoder>(&self, _encoder: &E) -> usize {
126                ::core::mem::size_of::<$name>()
127            }
128
129            #[inline]
130            fn encode<E: $crate::Encoder>(&self, encoder: &mut E) {
131                let bytes = unsafe {
132                    // Safety: the type implements IntoBytes
133                    core::slice::from_raw_parts(
134                        *self as *const $name as *const u8,
135                        core::mem::size_of::<$name>(),
136                    )
137                };
138                encoder.write_slice(bytes);
139            }
140        }
141
142        impl<'a> $crate::EncoderValue for &'a mut $name
143        where
144            $name: $crate::zerocopy::IntoBytes,
145        {
146            #[inline]
147            fn encoding_size(&self) -> usize {
148                core::mem::size_of::<$name>()
149            }
150
151            #[inline]
152            fn encoding_size_for_encoder<E: $crate::Encoder>(&self, _encoder: &E) -> usize {
153                ::core::mem::size_of::<$name>()
154            }
155
156            #[inline]
157            fn encode<E: $crate::Encoder>(&self, encoder: &mut E) {
158                let bytes = unsafe {
159                    // Safety: the type implements IntoBytes
160                    core::slice::from_raw_parts(
161                        *self as *const $name as *const u8,
162                        core::mem::size_of::<$name>(),
163                    )
164                };
165                encoder.write_slice(bytes);
166            }
167        }
168    };
169}
170
171// The `zerocopy` crate provides integer types that are able to be referenced
172// in an endian-independent method. This macro wraps those types and implements
173// a few convenience traits.
174macro_rules! zerocopy_network_integer {
175    ($native:ident, $name:ident) => {
176        #[derive(
177            Clone,
178            Copy,
179            Default,
180            PartialEq,
181            PartialOrd,
182            Ord,
183            Hash,
184            Eq,
185            Immutable,
186            $crate::zerocopy::FromBytes,
187            $crate::zerocopy::IntoBytes,
188            $crate::zerocopy::Unaligned,
189        )]
190        #[repr(C)]
191        pub struct $name(::zerocopy::byteorder::$name<NetworkEndian>);
192
193        impl $name {
194            pub const ZERO: Self = Self(::zerocopy::byteorder::$name::ZERO);
195
196            #[inline(always)]
197            pub const fn new(value: $native) -> Self {
198                let zerocopy_int = ::zerocopy::byteorder::$name::<NetworkEndian>::new(value);
199                $name(zerocopy_int)
200            }
201
202            #[inline(always)]
203            pub fn get(&self) -> $native {
204                self.get_be().to_be()
205            }
206
207            #[inline(always)]
208            pub fn get_be(&self) -> $native {
209                unsafe {
210                    $native::from_ne_bytes(
211                        *(self.0.as_bytes().as_ptr()
212                            as *const [u8; ::core::mem::size_of::<$native>()]),
213                    )
214                }
215            }
216
217            #[inline(always)]
218            pub fn set(&mut self, value: $native) {
219                self.0.as_mut_bytes().copy_from_slice(&value.to_be_bytes());
220            }
221
222            #[inline(always)]
223            pub fn set_be(&mut self, value: $native) {
224                self.0.as_mut_bytes().copy_from_slice(&value.to_ne_bytes());
225            }
226        }
227
228        impl PartialEq<$native> for $name {
229            #[inline]
230            fn eq(&self, other: &$native) -> bool {
231                self.partial_cmp(other) == Some(Ordering::Equal)
232            }
233        }
234
235        impl PartialOrd<$native> for $name {
236            #[inline]
237            fn partial_cmp(&self, other: &$native) -> Option<Ordering> {
238                Some(self.get().cmp(other))
239            }
240        }
241
242        impl fmt::Debug for $name {
243            fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
244                write!(formatter, "{}", self.get())
245            }
246        }
247
248        impl fmt::Display for $name {
249            fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
250                write!(formatter, "{}", self.get())
251            }
252        }
253
254        impl From<$native> for $name {
255            #[inline]
256            fn from(value: $native) -> Self {
257                Self(::zerocopy::byteorder::$name::new(value))
258            }
259        }
260
261        impl From<$name> for $native {
262            #[inline]
263            fn from(v: $name) -> $native {
264                v.get()
265            }
266        }
267
268        #[cfg(feature = "generator")]
269        impl TypeGenerator for $name {
270            fn generate<D: bolero_generator::Driver>(driver: &mut D) -> Option<Self> {
271                Some(Self::new(driver.produce()?))
272            }
273        }
274
275        #[cfg(kani)]
276        impl kani::Arbitrary for $name {
277            fn any() -> Self {
278                Self::new(kani::any())
279            }
280        }
281
282        zerocopy_value_codec!($name);
283    };
284}
285
286zerocopy_network_integer!(i16, I16);
287zerocopy_network_integer!(u16, U16);
288zerocopy_network_integer!(i32, I32);
289zerocopy_network_integer!(u32, U32);
290zerocopy_network_integer!(i64, I64);
291zerocopy_network_integer!(u64, U64);
292zerocopy_network_integer!(i128, I128);
293zerocopy_network_integer!(u128, U128);
294
295#[test]
296fn zerocopy_struct_test() {
297    use crate::DecoderBuffer;
298
299    #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, FromBytes, IntoBytes, Unaligned)]
300    #[repr(C)]
301    struct UdpHeader {
302        source_port: U16,
303        destination_port: U16,
304        payload_len: U16,
305        checksum: U16,
306    }
307
308    zerocopy_value_codec!(UdpHeader);
309
310    let buffer = vec![0, 1, 0, 2, 0, 3, 0, 4];
311    let decoder = DecoderBuffer::new(&buffer);
312    let (mut header, _) = decoder.decode().unwrap();
313
314    ensure_codec_round_trip_value!(UdpHeader, header).unwrap();
315    ensure_codec_round_trip_value!(&UdpHeader, &header).unwrap();
316    ensure_codec_round_trip_value_mut!(&mut UdpHeader, &mut header).unwrap();
317
318    assert_eq!(header.source_port, 1u16);
319    assert_eq!(header.destination_port, 2u16);
320    assert_eq!(header.payload_len, 3u16);
321    assert_eq!(header.checksum, 4u16);
322}