s2n-codec 0.77.0

Internal crate used by s2n-quic
Documentation
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

use core::{cmp::Ordering, fmt, hash::Hash};
pub use zerocopy::*;

#[cfg(feature = "generator")]
use bolero_generator::prelude::*;

/// Define a codec implementation for a zerocopy value that implements
/// `FromBytes`, `IntoBytes`, and `Unaligned`.
#[macro_export]
macro_rules! zerocopy_value_codec {
    ($name:ident) => {
        impl<'a> $crate::DecoderValue<'a> for $name
        where
            $name: $crate::zerocopy::FromBytes,
        {
            #[inline]
            fn decode(buffer: $crate::DecoderBuffer<'a>) -> $crate::DecoderBufferResult<'a, Self> {
                let (value, buffer) = <&'a $name as $crate::DecoderValue>::decode(buffer)?;
                Ok((*value, buffer))
            }
        }

        impl<'a> $crate::DecoderValue<'a> for &'a $name
        where
            $name: $crate::zerocopy::FromBytes,
        {
            #[inline]
            fn decode(buffer: $crate::DecoderBuffer<'a>) -> $crate::DecoderBufferResult<'a, Self> {
                let (value, buffer) = buffer.decode_slice(core::mem::size_of::<$name>())?;
                let value = value.into_less_safe_slice();
                let value = unsafe {
                    // Safety: the type implements FromBytes
                    &*(value as *const _ as *const $name)
                };
                Ok((value, buffer.into()))
            }
        }

        impl<'a> $crate::DecoderValueMut<'a> for $name
        where
            $name: $crate::zerocopy::FromBytes,
        {
            #[inline]
            fn decode_mut(
                buffer: $crate::DecoderBufferMut<'a>,
            ) -> $crate::DecoderBufferMutResult<'a, Self> {
                let (value, buffer) = <&'a $name as $crate::DecoderValueMut>::decode_mut(buffer)?;
                Ok((*value, buffer))
            }
        }

        impl<'a> $crate::DecoderValueMut<'a> for &'a $name
        where
            $name: $crate::zerocopy::FromBytes,
        {
            #[inline]
            fn decode_mut(
                buffer: $crate::DecoderBufferMut<'a>,
            ) -> $crate::DecoderBufferMutResult<'a, Self> {
                let (value, buffer) =
                    <&'a mut $name as $crate::DecoderValueMut>::decode_mut(buffer)?;
                Ok((value, buffer))
            }
        }

        impl<'a> $crate::DecoderValueMut<'a> for &'a mut $name
        where
            $name: $crate::zerocopy::FromBytes,
        {
            #[inline]
            fn decode_mut(
                buffer: $crate::DecoderBufferMut<'a>,
            ) -> $crate::DecoderBufferMutResult<'a, Self> {
                let (value, buffer) = buffer.decode_slice(core::mem::size_of::<$name>())?;
                let value = value.into_less_safe_slice();
                let value = unsafe {
                    // Safety: the type implements FromBytes
                    &mut *(value as *mut _ as *mut $name)
                };

                Ok((value, buffer.into()))
            }
        }

        impl $crate::EncoderValue for $name
        where
            $name: $crate::zerocopy::IntoBytes,
        {
            #[inline]
            fn encoding_size(&self) -> usize {
                core::mem::size_of::<$name>()
            }

            #[inline]
            fn encoding_size_for_encoder<E: $crate::Encoder>(&self, _encoder: &E) -> usize {
                core::mem::size_of::<$name>()
            }

            #[inline]
            fn encode<E: $crate::Encoder>(&self, encoder: &mut E) {
                let bytes = unsafe {
                    // Safety: the type implements IntoBytes
                    core::slice::from_raw_parts(
                        self as *const $name as *const u8,
                        core::mem::size_of::<$name>(),
                    )
                };
                encoder.write_slice(bytes);
            }
        }

        impl<'a> $crate::EncoderValue for &'a $name
        where
            $name: $crate::zerocopy::IntoBytes,
        {
            #[inline]
            fn encoding_size(&self) -> usize {
                core::mem::size_of::<$name>()
            }

            #[inline]
            fn encoding_size_for_encoder<E: $crate::Encoder>(&self, _encoder: &E) -> usize {
                ::core::mem::size_of::<$name>()
            }

            #[inline]
            fn encode<E: $crate::Encoder>(&self, encoder: &mut E) {
                let bytes = unsafe {
                    // Safety: the type implements IntoBytes
                    core::slice::from_raw_parts(
                        *self as *const $name as *const u8,
                        core::mem::size_of::<$name>(),
                    )
                };
                encoder.write_slice(bytes);
            }
        }

        impl<'a> $crate::EncoderValue for &'a mut $name
        where
            $name: $crate::zerocopy::IntoBytes,
        {
            #[inline]
            fn encoding_size(&self) -> usize {
                core::mem::size_of::<$name>()
            }

            #[inline]
            fn encoding_size_for_encoder<E: $crate::Encoder>(&self, _encoder: &E) -> usize {
                ::core::mem::size_of::<$name>()
            }

            #[inline]
            fn encode<E: $crate::Encoder>(&self, encoder: &mut E) {
                let bytes = unsafe {
                    // Safety: the type implements IntoBytes
                    core::slice::from_raw_parts(
                        *self as *const $name as *const u8,
                        core::mem::size_of::<$name>(),
                    )
                };
                encoder.write_slice(bytes);
            }
        }
    };
}

// The `zerocopy` crate provides integer types that are able to be referenced
// in an endian-independent method. This macro wraps those types and implements
// a few convenience traits.
macro_rules! zerocopy_network_integer {
    ($native:ident, $name:ident) => {
        #[derive(
            Clone,
            Copy,
            Default,
            PartialEq,
            PartialOrd,
            Ord,
            Hash,
            Eq,
            Immutable,
            $crate::zerocopy::FromBytes,
            $crate::zerocopy::IntoBytes,
            $crate::zerocopy::Unaligned,
        )]
        #[repr(C)]
        pub struct $name(::zerocopy::byteorder::$name<NetworkEndian>);

        impl $name {
            pub const ZERO: Self = Self(::zerocopy::byteorder::$name::ZERO);

            #[inline(always)]
            pub const fn new(value: $native) -> Self {
                let zerocopy_int = ::zerocopy::byteorder::$name::<NetworkEndian>::new(value);
                $name(zerocopy_int)
            }

            #[inline(always)]
            pub fn get(&self) -> $native {
                self.get_be().to_be()
            }

            #[inline(always)]
            pub fn get_be(&self) -> $native {
                unsafe {
                    $native::from_ne_bytes(
                        *(self.0.as_bytes().as_ptr()
                            as *const [u8; ::core::mem::size_of::<$native>()]),
                    )
                }
            }

            #[inline(always)]
            pub fn set(&mut self, value: $native) {
                self.0.as_mut_bytes().copy_from_slice(&value.to_be_bytes());
            }

            #[inline(always)]
            pub fn set_be(&mut self, value: $native) {
                self.0.as_mut_bytes().copy_from_slice(&value.to_ne_bytes());
            }
        }

        impl PartialEq<$native> for $name {
            #[inline]
            fn eq(&self, other: &$native) -> bool {
                self.partial_cmp(other) == Some(Ordering::Equal)
            }
        }

        impl PartialOrd<$native> for $name {
            #[inline]
            fn partial_cmp(&self, other: &$native) -> Option<Ordering> {
                Some(self.get().cmp(other))
            }
        }

        impl fmt::Debug for $name {
            fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                write!(formatter, "{}", self.get())
            }
        }

        impl fmt::Display for $name {
            fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
                write!(formatter, "{}", self.get())
            }
        }

        impl From<$native> for $name {
            #[inline]
            fn from(value: $native) -> Self {
                Self(::zerocopy::byteorder::$name::new(value))
            }
        }

        impl From<$name> for $native {
            #[inline]
            fn from(v: $name) -> $native {
                v.get()
            }
        }

        #[cfg(feature = "generator")]
        impl TypeGenerator for $name {
            fn generate<D: bolero_generator::Driver>(driver: &mut D) -> Option<Self> {
                Some(Self::new(driver.produce()?))
            }
        }

        #[cfg(kani)]
        impl kani::Arbitrary for $name {
            fn any() -> Self {
                Self::new(kani::any())
            }
        }

        zerocopy_value_codec!($name);
    };
}

zerocopy_network_integer!(i16, I16);
zerocopy_network_integer!(u16, U16);
zerocopy_network_integer!(i32, I32);
zerocopy_network_integer!(u32, U32);
zerocopy_network_integer!(i64, I64);
zerocopy_network_integer!(u64, U64);
zerocopy_network_integer!(i128, I128);
zerocopy_network_integer!(u128, U128);

#[test]
fn zerocopy_struct_test() {
    use crate::DecoderBuffer;

    #[derive(Copy, Clone, Debug, PartialEq, PartialOrd, FromBytes, IntoBytes, Unaligned)]
    #[repr(C)]
    struct UdpHeader {
        source_port: U16,
        destination_port: U16,
        payload_len: U16,
        checksum: U16,
    }

    zerocopy_value_codec!(UdpHeader);

    let buffer = vec![0, 1, 0, 2, 0, 3, 0, 4];
    let decoder = DecoderBuffer::new(&buffer);
    let (mut header, _) = decoder.decode().unwrap();

    ensure_codec_round_trip_value!(UdpHeader, header).unwrap();
    ensure_codec_round_trip_value!(&UdpHeader, &header).unwrap();
    ensure_codec_round_trip_value_mut!(&mut UdpHeader, &mut header).unwrap();

    assert_eq!(header.source_port, 1u16);
    assert_eq!(header.destination_port, 2u16);
    assert_eq!(header.payload_len, 3u16);
    assert_eq!(header.checksum, 4u16);
}