bytecheck/
lib.rs

1//! # bytecheck
2//!
3//! bytecheck is a memory validation framework for Rust.
4//!
5//! For some types, creating an invalid value immediately results in undefined
6//! behavior. This can cause some issues when trying to validate potentially
7//! invalid bytes, as just casting the bytes to your type can technically cause
8//! errors. This makes it difficult to write validation routines, because until
9//! you're certain that the bytes represent valid values you cannot cast them.
10//!
11//! bytecheck provides a framework for performing these byte-level validations
12//! and implements checks for basic types along with a derive macro to implement
13//! validation for custom structs and enums.
14//!
15//! ## Design
16//!
17//! [`CheckBytes`] is at the heart of bytecheck, and does the heavy lifting of
18//! verifying that some bytes represent a valid type. Implementing it can be
19//! done manually or automatically with the [derive macro](macro@CheckBytes).
20//!
21//! ## Layout stability
22//!
23//! The layouts of types may change between compiler versions, or even different
24//! compilations. To guarantee stable type layout between compilations, structs,
25//! enums, and unions can be annotated with `#[repr(C)]`, and enums specifically
26//! can be annotated with `#[repr(int)]` or `#[repr(C, int)]` as well. See
27//! [the reference's page on type layout][reference] for more details.
28//!
29//! [reference]: https://doc.rust-lang.org/reference/type-layout.html
30//!
31//! ## Features
32//!
33//! - `derive`: Re-exports the macros from `bytecheck_derive`. Enabled by
34//!   default.
35//! - `simdutf8`: Uses the `simdutf8` crate to validate UTF-8 strings. Enabled
36//!   by default.
37//!
38//! ### Crates
39//!
40//! Bytecheck provides integrations for some common crates by default. In the
41//! future, crates should depend on bytecheck and provide their own integration.
42//!
43//! - [`uuid-1`](https://docs.rs/uuid/1)
44//!
45//! ## Example
46#![doc = include_str!("../example.md")]
47#![deny(
48    future_incompatible,
49    missing_docs,
50    nonstandard_style,
51    unsafe_op_in_unsafe_fn,
52    unused,
53    warnings,
54    clippy::all,
55    clippy::missing_safety_doc,
56    clippy::undocumented_unsafe_blocks,
57    rustdoc::broken_intra_doc_links,
58    rustdoc::missing_crate_level_docs
59)]
60#![no_std]
61#![cfg_attr(all(docsrs, not(doctest)), feature(doc_cfg, doc_auto_cfg))]
62
63// Support for various common crates. These are primarily to get users off the
64// ground and build some momentum.
65
66// These are NOT PLANNED to remain in bytecheck for the final release. Much like
67// serde, these implementations should be moved into their respective crates
68// over time. Before adding support for another crate, please consider getting
69// bytecheck support in the crate instead.
70
71#[cfg(feature = "uuid-1")]
72mod uuid;
73
74#[cfg(not(feature = "simdutf8"))]
75use core::str::from_utf8;
76#[cfg(target_has_atomic = "8")]
77use core::sync::atomic::{AtomicBool, AtomicI8, AtomicU8};
78#[cfg(target_has_atomic = "16")]
79use core::sync::atomic::{AtomicI16, AtomicU16};
80#[cfg(target_has_atomic = "32")]
81use core::sync::atomic::{AtomicI32, AtomicU32};
82#[cfg(target_has_atomic = "64")]
83use core::sync::atomic::{AtomicI64, AtomicU64};
84use core::{
85    cell::{Cell, UnsafeCell},
86    error::Error,
87    ffi::CStr,
88    fmt,
89    marker::{PhantomData, PhantomPinned},
90    mem::ManuallyDrop,
91    num::{
92        NonZeroI128, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8,
93        NonZeroU128, NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8,
94    },
95    ops, ptr,
96};
97
98pub use bytecheck_derive::CheckBytes;
99pub use rancor;
100use rancor::{fail, Fallible, ResultExt as _, Source, Strategy, Trace};
101#[cfg(feature = "simdutf8")]
102use simdutf8::basic::from_utf8;
103
104/// A type that can check whether a pointer points to a valid value.
105///
106/// `CheckBytes` can be derived with [`CheckBytes`](macro@CheckBytes) or
107/// implemented manually for custom behavior.
108///
109/// # Safety
110///
111/// `check_bytes` must only return `Ok` if `value` points to a valid instance of
112/// `Self`. Because `value` must always be properly aligned for `Self` and point
113/// to enough bytes to represent the type, this implies that `value` may be
114/// dereferenced safely.
115///
116/// # Example
117///
118/// ```
119/// use core::{error::Error, fmt};
120///
121/// use bytecheck::CheckBytes;
122/// use rancor::{fail, Fallible, Source};
123///
124/// #[repr(C, align(4))]
125/// pub struct NonMaxU32(u32);
126///
127/// unsafe impl<C: Fallible + ?Sized> CheckBytes<C> for NonMaxU32
128/// where
129///     C::Error: Source,
130/// {
131///     unsafe fn check_bytes(
132///         value: *const Self,
133///         context: &mut C,
134///     ) -> Result<(), C::Error> {
135///         #[derive(Debug)]
136///         struct NonMaxCheckError;
137///
138///         impl fmt::Display for NonMaxCheckError {
139///             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140///                 write!(f, "non-max u32 was set to u32::MAX")
141///             }
142///         }
143///
144///         impl Error for NonMaxCheckError {}
145///
146///         let value = unsafe { value.read() };
147///         if value.0 == u32::MAX {
148///             fail!(NonMaxCheckError);
149///         }
150///
151///         Ok(())
152///     }
153/// }
154/// ```
155///
156/// See [`Verify`] for an example which uses less unsafe.
157pub unsafe trait CheckBytes<C: Fallible + ?Sized> {
158    /// Checks whether the given pointer points to a valid value within the
159    /// given context.
160    ///
161    /// # Safety
162    ///
163    /// The passed pointer must be aligned and point to enough initialized bytes
164    /// to represent the type.
165    unsafe fn check_bytes(
166        value: *const Self,
167        context: &mut C,
168    ) -> Result<(), C::Error>;
169}
170
171/// A type that can check whether its invariants are upheld.
172///
173/// When using [the derive](macro@CheckBytes), adding `#[bytecheck(verify)]`
174/// allows implementing `Verify` for the derived type. [`Verify::verify`] will
175/// be called after the type is checked and all fields are known to be valid.
176///
177/// # Safety
178///
179/// - `verify` must only return `Ok` if all of the invariants of this type are
180///   upheld by `self`.
181/// - `verify` may not assume that its type invariants are upheld by the given
182///   `self` (the invariants of each field are guaranteed to be upheld).
183///
184/// # Example
185///
186/// ```
187/// use core::{error::Error, fmt};
188///
189/// use bytecheck::{CheckBytes, Verify};
190/// use rancor::{fail, Fallible, Source};
191///
192/// #[derive(CheckBytes)]
193/// #[bytecheck(verify)]
194/// #[repr(C, align(4))]
195/// pub struct NonMaxU32(u32);
196///
197/// unsafe impl<C: Fallible + ?Sized> Verify<C> for NonMaxU32
198/// where
199///     C::Error: Source,
200/// {
201///     fn verify(&self, context: &mut C) -> Result<(), C::Error> {
202///         #[derive(Debug)]
203///         struct NonMaxCheckError;
204///
205///         impl fmt::Display for NonMaxCheckError {
206///             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
207///                 write!(f, "non-max u32 was set to u32::MAX")
208///             }
209///         }
210///
211///         impl Error for NonMaxCheckError {}
212///
213///         if self.0 == u32::MAX {
214///             fail!(NonMaxCheckError);
215///         }
216///
217///         Ok(())
218///     }
219/// }
220/// ```
221pub unsafe trait Verify<C: Fallible + ?Sized> {
222    /// Checks whether the invariants of this type are upheld by `self`.
223    fn verify(&self, context: &mut C) -> Result<(), C::Error>;
224}
225
226/// Checks whether the given pointer points to a valid value.
227///
228/// # Safety
229///
230/// The passed pointer must be aligned and point to enough initialized bytes to
231/// represent the type.
232///
233/// # Example
234///
235/// ```
236/// use bytecheck::check_bytes;
237/// use rancor::Failure;
238///
239/// unsafe {
240///     // 0 and 1 are valid values for bools
241///     check_bytes::<bool, Failure>((&0u8 as *const u8).cast()).unwrap();
242///     check_bytes::<bool, Failure>((&1u8 as *const u8).cast()).unwrap();
243///
244///     // 2 is not a valid value
245///     check_bytes::<bool, Failure>((&2u8 as *const u8).cast()).unwrap_err();
246/// }
247/// ```
248#[inline]
249pub unsafe fn check_bytes<T, E>(value: *const T) -> Result<(), E>
250where
251    T: CheckBytes<Strategy<(), E>> + ?Sized,
252{
253    // SAFETY: The safety conditions of `check_bytes_with_context` are the same
254    // as the safety conditions of this function.
255    unsafe { check_bytes_with_context(value, &mut ()) }
256}
257
258/// Checks whether the given pointer points to a valid value within the given
259/// context.
260///
261/// # Safety
262///
263/// The passed pointer must be aligned and point to enough initialized bytes to
264/// represent the type.
265///
266/// # Example
267///
268/// ```
269/// use core::{error::Error, fmt};
270///
271/// use bytecheck::{check_bytes_with_context, CheckBytes, Verify};
272/// use rancor::{fail, Failure, Fallible, Source, Strategy};
273///
274/// trait Context {
275///     fn is_allowed(&self, value: u8) -> bool;
276/// }
277///
278/// impl<T: Context + ?Sized, E> Context for Strategy<T, E> {
279///     fn is_allowed(&self, value: u8) -> bool {
280///         T::is_allowed(self, value)
281///     }
282/// }
283///
284/// struct Allowed(u8);
285///
286/// impl Context for Allowed {
287///     fn is_allowed(&self, value: u8) -> bool {
288///         value == self.0
289///     }
290/// }
291///
292/// #[derive(CheckBytes)]
293/// #[bytecheck(verify)]
294/// #[repr(C)]
295/// pub struct ContextualByte(u8);
296///
297/// unsafe impl<C: Context + Fallible + ?Sized> Verify<C> for ContextualByte
298/// where
299///     C::Error: Source,
300/// {
301///     fn verify(&self, context: &mut C) -> Result<(), C::Error> {
302///         #[derive(Debug)]
303///         struct InvalidByte(u8);
304///
305///         impl fmt::Display for InvalidByte {
306///             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
307///                 write!(f, "invalid contextual byte: {}", self.0)
308///             }
309///         }
310///
311///         impl Error for InvalidByte {}
312///
313///         if !context.is_allowed(self.0) {
314///             fail!(InvalidByte(self.0));
315///         }
316///
317///         Ok(())
318///     }
319/// }
320///
321/// let value = 45u8;
322/// unsafe {
323///     // Checking passes when the context allows byte 45
324///     check_bytes_with_context::<ContextualByte, _, Failure>(
325///         (&value as *const u8).cast(),
326///         &mut Allowed(45),
327///     )
328///     .unwrap();
329///
330///     // Checking fails when the conteext does not allow byte 45
331///     check_bytes_with_context::<ContextualByte, _, Failure>(
332///         (&value as *const u8).cast(),
333///         &mut Allowed(0),
334///     )
335///     .unwrap_err();
336/// }
337/// ```
338pub unsafe fn check_bytes_with_context<T, C, E>(
339    value: *const T,
340    context: &mut C,
341) -> Result<(), E>
342where
343    T: CheckBytes<Strategy<C, E>> + ?Sized,
344{
345    // SAFETY: The safety conditions of `check_bytes` are the same as the safety
346    // conditions of this function.
347    unsafe { CheckBytes::check_bytes(value, Strategy::wrap(context)) }
348}
349
350macro_rules! impl_primitive {
351    ($type:ty) => {
352        // SAFETY: All bit patterns are valid for these primitive types.
353        unsafe impl<C: Fallible + ?Sized> CheckBytes<C> for $type {
354            #[inline]
355            unsafe fn check_bytes(
356                _: *const Self,
357                _: &mut C,
358            ) -> Result<(), C::Error> {
359                Ok(())
360            }
361        }
362    };
363}
364
365macro_rules! impl_primitives {
366    ($($type:ty),* $(,)?) => {
367        $(
368            impl_primitive!($type);
369        )*
370    }
371}
372
373impl_primitives! {
374    (),
375    i8, i16, i32, i64, i128,
376    u8, u16, u32, u64, u128,
377    f32, f64,
378}
379#[cfg(target_has_atomic = "8")]
380impl_primitives!(AtomicI8, AtomicU8);
381#[cfg(target_has_atomic = "16")]
382impl_primitives!(AtomicI16, AtomicU16);
383#[cfg(target_has_atomic = "32")]
384impl_primitives!(AtomicI32, AtomicU32);
385#[cfg(target_has_atomic = "64")]
386impl_primitives!(AtomicI64, AtomicU64);
387
388// SAFETY: `PhantomData` is a zero-sized type and so all bit patterns are valid.
389unsafe impl<T: ?Sized, C: Fallible + ?Sized> CheckBytes<C> for PhantomData<T> {
390    #[inline]
391    unsafe fn check_bytes(_: *const Self, _: &mut C) -> Result<(), C::Error> {
392        Ok(())
393    }
394}
395
396// SAFETY: `PhantomPinned` is a zero-sized type and so all bit patterns are
397// valid.
398unsafe impl<C: Fallible + ?Sized> CheckBytes<C> for PhantomPinned {
399    #[inline]
400    unsafe fn check_bytes(_: *const Self, _: &mut C) -> Result<(), C::Error> {
401        Ok(())
402    }
403}
404
405// SAFETY: `ManuallyDrop<T>` is a `#[repr(transparent)]` wrapper around a `T`,
406// and so `value` points to a valid `ManuallyDrop<T>` if it also points to a
407// valid `T`.
408unsafe impl<T, C> CheckBytes<C> for ManuallyDrop<T>
409where
410    T: CheckBytes<C> + ?Sized,
411    C: Fallible + ?Sized,
412    C::Error: Trace,
413{
414    #[inline]
415    unsafe fn check_bytes(
416        value: *const Self,
417        c: &mut C,
418    ) -> Result<(), C::Error> {
419        // SAFETY: Because `ManuallyDrop<T>` is `#[repr(transparent)]`, a
420        // pointer to a `ManuallyDrop<T>` is guaranteed to be the same as a
421        // pointer to `T`. We can't call `.cast()` here because `T` may be
422        // an unsized type.
423        let inner_ptr =
424            unsafe { core::mem::transmute::<*const Self, *const T>(value) };
425        // SAFETY: The caller has guaranteed that `value` is aligned for
426        // `ManuallyDrop<T>` and points to enough bytes to represent
427        // `ManuallyDrop<T>`. Since `ManuallyDrop<T>` is `#[repr(transparent)]`,
428        // `inner_ptr` is also aligned for `T` and points to enough bytes to
429        // represent it.
430        unsafe {
431            T::check_bytes(inner_ptr, c)
432                .trace("while checking inner value of `ManuallyDrop`")
433        }
434    }
435}
436
437// SAFETY: `UnsafeCell<T>` has the same memory layout as `T`, and so `value`
438// points to a valid `UnsafeCell<T>` if it also points to a valid `T`.
439unsafe impl<T, C> CheckBytes<C> for UnsafeCell<T>
440where
441    T: CheckBytes<C> + ?Sized,
442    C: Fallible + ?Sized,
443    C::Error: Trace,
444{
445    #[inline]
446    unsafe fn check_bytes(
447        value: *const Self,
448        c: &mut C,
449    ) -> Result<(), C::Error> {
450        // SAFETY: Because `UnsafeCell<T>` has the same memory layout as
451        // `T`, a pointer to an `UnsafeCell<T>` is guaranteed to be the same
452        // as a pointer to `T`. We can't call `.cast()` here because `T` may
453        // be an unsized type.
454        let inner_ptr =
455            unsafe { core::mem::transmute::<*const Self, *const T>(value) };
456        // SAFETY: The caller has guaranteed that `value` is aligned for
457        // `UnsafeCell<T>` and points to enough bytes to represent
458        // `UnsafeCell<T>`. Since `UnsafeCell<T>` has the same layout `T`,
459        // `inner_ptr` is also aligned for `T` and points to enough bytes to
460        // represent it.
461        unsafe {
462            T::check_bytes(inner_ptr, c)
463                .trace("while checking inner value of `UnsafeCell`")
464        }
465    }
466}
467
468// SAFETY: `Cell<T>` has the same memory layout as `UnsafeCell<T>` (and
469// therefore `T` itself), and so `value` points to a valid `UnsafeCell<T>` if it
470// also points to a valid `T`.
471unsafe impl<T, C> CheckBytes<C> for Cell<T>
472where
473    T: CheckBytes<C> + ?Sized,
474    C: Fallible + ?Sized,
475    C::Error: Trace,
476{
477    #[inline]
478    unsafe fn check_bytes(
479        value: *const Self,
480        c: &mut C,
481    ) -> Result<(), C::Error> {
482        // SAFETY: Because `Cell<T>` has the same memory layout as
483        // `UnsafeCell<T>` (and therefore `T` itself), a pointer to a
484        // `Cell<T>` is guaranteed to be the same as a pointer to `T`. We
485        // can't call `.cast()` here because `T` may be an unsized type.
486        let inner_ptr =
487            unsafe { core::mem::transmute::<*const Self, *const T>(value) };
488        // SAFETY: The caller has guaranteed that `value` is aligned for
489        // `Cell<T>` and points to enough bytes to represent `Cell<T>`. Since
490        // `Cell<T>` has the same layout as `UnsafeCell<T>` ( and therefore `T`
491        // itself), `inner_ptr` is also aligned for `T` and points to enough
492        // bytes to represent it.
493        unsafe {
494            T::check_bytes(inner_ptr, c)
495                .trace("while checking inner value of `Cell`")
496        }
497    }
498}
499
500#[derive(Debug)]
501struct BoolCheckError {
502    byte: u8,
503}
504
505impl fmt::Display for BoolCheckError {
506    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
507        write!(
508            f,
509            "bool set to invalid byte {}, expected either 0 or 1",
510            self.byte,
511        )
512    }
513}
514
515impl Error for BoolCheckError {}
516
517// SAFETY: A bool is a one byte value that must either be 0 or 1. `check_bytes`
518// only returns `Ok` if `value` is 0 or 1.
519unsafe impl<C> CheckBytes<C> for bool
520where
521    C: Fallible + ?Sized,
522    C::Error: Source,
523{
524    #[inline]
525    unsafe fn check_bytes(
526        value: *const Self,
527        _: &mut C,
528    ) -> Result<(), C::Error> {
529        // SAFETY: `value` is a pointer to a `bool`, which has a size and
530        // alignment of one. `u8` also has a size and alignment of one, and all
531        // bit patterns are valid for `u8`. So we can cast `value` to a `u8`
532        // pointer and read from it.
533        let byte = unsafe { *value.cast::<u8>() };
534        match byte {
535            0 | 1 => Ok(()),
536            _ => fail!(BoolCheckError { byte }),
537        }
538    }
539}
540
541#[cfg(target_has_atomic = "8")]
542// SAFETY: `AtomicBool` has the same ABI as `bool`, so if `value` points to a
543// valid `bool` then it also points to a valid `AtomicBool`.
544unsafe impl<C> CheckBytes<C> for AtomicBool
545where
546    C: Fallible + ?Sized,
547    C::Error: Source,
548{
549    #[inline]
550    unsafe fn check_bytes(
551        value: *const Self,
552        context: &mut C,
553    ) -> Result<(), C::Error> {
554        // SAFETY: `AtomicBool` has the same ABI as `bool`, so a pointer that is
555        // aligned for `AtomicBool` and points to enough bytes for `AtomicBool`
556        // is also aligned for `bool` and points to enough bytes for `bool`.
557        unsafe { bool::check_bytes(value.cast(), context) }
558    }
559}
560
561// SAFETY: If `char::try_from` succeeds with the pointed-to-value, then it must
562// be a valid value for `char`.
563unsafe impl<C> CheckBytes<C> for char
564where
565    C: Fallible + ?Sized,
566    C::Error: Source,
567{
568    #[inline]
569    unsafe fn check_bytes(ptr: *const Self, _: &mut C) -> Result<(), C::Error> {
570        // SAFETY: `char` and `u32` are both four bytes, but we're not
571        // guaranteed that they have the same alignment. Using `read_unaligned`
572        // ensures that we can read a `u32` regardless and try to convert it to
573        // a `char`.
574        let value = unsafe { ptr.cast::<u32>().read_unaligned() };
575        char::try_from(value).into_error()?;
576        Ok(())
577    }
578}
579
580#[derive(Debug)]
581struct TupleIndexContext {
582    index: usize,
583}
584
585impl fmt::Display for TupleIndexContext {
586    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
587        write!(f, "while checking index {} of tuple", self.index)
588    }
589}
590
591macro_rules! impl_tuple {
592    ($($type:ident $index:tt),*) => {
593        // SAFETY: A tuple is valid if all of its elements are valid, and
594        // `check_bytes` only returns `Ok` when all of the elements validated
595        // successfully.
596        unsafe impl<$($type,)* C> CheckBytes<C> for ($($type,)*)
597        where
598            $($type: CheckBytes<C>,)*
599            C: Fallible + ?Sized,
600            C::Error: Trace,
601        {
602            #[inline]
603            #[allow(clippy::unneeded_wildcard_pattern)]
604            unsafe fn check_bytes(
605                value: *const Self,
606                context: &mut C,
607            ) -> Result<(), C::Error> {
608                $(
609                    // SAFETY: The caller has guaranteed that `value` points to
610                    // enough bytes for this tuple and is properly aligned, so
611                    // we can create pointers to each element and check them.
612                    unsafe {
613                        <$type>::check_bytes(
614                            ptr::addr_of!((*value).$index),
615                            context,
616                        ).with_trace(|| TupleIndexContext { index: $index })?;
617                    }
618                )*
619                Ok(())
620            }
621        }
622    }
623}
624
625impl_tuple!(T0 0);
626impl_tuple!(T0 0, T1 1);
627impl_tuple!(T0 0, T1 1, T2 2);
628impl_tuple!(T0 0, T1 1, T2 2, T3 3);
629impl_tuple!(T0 0, T1 1, T2 2, T3 3, T4 4);
630impl_tuple!(T0 0, T1 1, T2 2, T3 3, T4 4, T5 5);
631impl_tuple!(T0 0, T1 1, T2 2, T3 3, T4 4, T5 5, T6 6);
632impl_tuple!(T0 0, T1 1, T2 2, T3 3, T4 4, T5 5, T6 6, T7 7);
633impl_tuple!(T0 0, T1 1, T2 2, T3 3, T4 4, T5 5, T6 6, T7 7, T8 8);
634impl_tuple!(T0 0, T1 1, T2 2, T3 3, T4 4, T5 5, T6 6, T7 7, T8 8, T9 9);
635impl_tuple!(T0 0, T1 1, T2 2, T3 3, T4 4, T5 5, T6 6, T7 7, T8 8, T9 9, T10 10);
636impl_tuple!(
637    T0 0, T1 1, T2 2, T3 3, T4 4, T5 5, T6 6, T7 7, T8 8, T9 9, T10 10, T11 11
638);
639impl_tuple!(
640    T0 0, T1 1, T2 2, T3 3, T4 4, T5 5, T6 6, T7 7, T8 8, T9 9, T10 10, T11 11,
641    T12 12
642);
643
644#[derive(Debug)]
645struct ArrayCheckContext {
646    index: usize,
647}
648
649impl fmt::Display for ArrayCheckContext {
650    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
651        write!(f, "while checking index '{}' of array", self.index)
652    }
653}
654
655// SAFETY: `check_bytes` only returns `Ok` if each element of the array is
656// valid. If each element of the array is valid then the whole array is also
657// valid.
658unsafe impl<T, const N: usize, C> CheckBytes<C> for [T; N]
659where
660    T: CheckBytes<C>,
661    C: Fallible + ?Sized,
662    C::Error: Trace,
663{
664    #[inline]
665    unsafe fn check_bytes(
666        value: *const Self,
667        context: &mut C,
668    ) -> Result<(), C::Error> {
669        let base = value.cast::<T>();
670        for index in 0..N {
671            // SAFETY: The caller has guaranteed that `value` points to enough
672            // bytes for this array and is properly aligned, so we can create
673            // pointers to each element and check them.
674            unsafe {
675                T::check_bytes(base.add(index), context)
676                    .with_trace(|| ArrayCheckContext { index })?;
677            }
678        }
679        Ok(())
680    }
681}
682
683#[derive(Debug)]
684struct SliceCheckContext {
685    index: usize,
686}
687
688impl fmt::Display for SliceCheckContext {
689    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
690        write!(f, "while checking index '{}' of slice", self.index)
691    }
692}
693
694// SAFETY: `check_bytes` only returns `Ok` if each element of the slice is
695// valid. If each element of the slice is valid then the whole slice is also
696// valid.
697unsafe impl<T, C> CheckBytes<C> for [T]
698where
699    T: CheckBytes<C>,
700    C: Fallible + ?Sized,
701    C::Error: Trace,
702{
703    #[inline]
704    unsafe fn check_bytes(
705        value: *const Self,
706        context: &mut C,
707    ) -> Result<(), C::Error> {
708        let (data_address, len) = ptr_meta::to_raw_parts(value);
709        let base = data_address.cast::<T>();
710        for index in 0..len {
711            // SAFETY: The caller has guaranteed that `value` points to enough
712            // bytes for this slice and is properly aligned, so we can create
713            // pointers to each element and check them.
714            unsafe {
715                T::check_bytes(base.add(index), context)
716                    .with_trace(|| SliceCheckContext { index })?;
717            }
718        }
719        Ok(())
720    }
721}
722
723// SAFETY: `check_bytes` only returns `Ok` if the bytes pointed to by `str` are
724// valid UTF-8. If they are valid UTF-8 then the overall `str` is also valid.
725unsafe impl<C> CheckBytes<C> for str
726where
727    C: Fallible + ?Sized,
728    C::Error: Source,
729{
730    #[inline]
731    unsafe fn check_bytes(
732        value: *const Self,
733        _: &mut C,
734    ) -> Result<(), C::Error> {
735        #[derive(Debug)]
736        struct Utf8Error;
737
738        impl fmt::Display for Utf8Error {
739            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
740                write!(f, "invalid UTF-8")
741            }
742        }
743
744        impl Error for Utf8Error {}
745
746        let slice_ptr = value as *const [u8];
747        // SAFETY: The caller has guaranteed that `value` is properly-aligned
748        // and points to enough bytes for its `str`. Because a `u8` slice has
749        // the same layout as a `str`, we can dereference it for UTF-8
750        // validation.
751        let slice = unsafe { &*slice_ptr };
752
753        // Checking whether a byte slice is ASCII is much faster than checking
754        // whether it is valid UTF-8. Falling back to a full UTF-8 check
755        // afterward nets a performance improvement for the average case.
756        if !slice.is_ascii() {
757            from_utf8(slice).map_err(|_| Utf8Error).into_error()?;
758        }
759
760        Ok(())
761    }
762}
763
764// SAFETY: `check_bytes` only returns `Ok` when the bytes constitute a valid
765// `CStr` per `CStr::from_bytes_with_nul`.
766unsafe impl<C> CheckBytes<C> for CStr
767where
768    C: Fallible + ?Sized,
769    C::Error: Source,
770{
771    #[inline]
772    unsafe fn check_bytes(
773        value: *const Self,
774        _: &mut C,
775    ) -> Result<(), C::Error> {
776        let slice_ptr = value as *const [u8];
777        // SAFETY: The caller has guaranteed that `value` is properly-aligned
778        // and points to enough bytes for its `CStr`. Because a `u8` slice has
779        // the same layout as a `CStr`, we can dereference it for validation.
780        let slice = unsafe { &*slice_ptr };
781        CStr::from_bytes_with_nul(slice).into_error()?;
782        Ok(())
783    }
784}
785
786// Generic contexts used by the derive.
787
788/// Context for errors resulting from invalid structs.
789///
790/// This context is used by the derive macro to trace which field of a struct
791/// failed validation.
792#[derive(Debug)]
793pub struct StructCheckContext {
794    /// The name of the struct with an invalid field.
795    pub struct_name: &'static str,
796    /// The name of the struct field that was invalid.
797    pub field_name: &'static str,
798}
799
800impl fmt::Display for StructCheckContext {
801    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
802        write!(
803            f,
804            "while checking field '{}' of struct '{}'",
805            self.field_name, self.struct_name
806        )
807    }
808}
809
810/// Context for errors resulting from invalid tuple structs.
811///
812/// This context is used by the derive macro to trace which field of a tuple
813/// struct failed validation.
814#[derive(Debug)]
815pub struct TupleStructCheckContext {
816    /// The name of the tuple struct with an invalid field.
817    pub tuple_struct_name: &'static str,
818    /// The index of the tuple struct field that was invalid.
819    pub field_index: usize,
820}
821
822impl fmt::Display for TupleStructCheckContext {
823    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
824        write!(
825            f,
826            "while checking field index {} of tuple struct '{}'",
827            self.field_index, self.tuple_struct_name,
828        )
829    }
830}
831
832/// An error resulting from an invalid enum tag.
833///
834/// This context is used by the derive macro to trace what the invalid
835/// discriminant for an enum is.
836#[derive(Debug)]
837pub struct InvalidEnumDiscriminantError<T> {
838    /// The name of the enum with an invalid discriminant.
839    pub enum_name: &'static str,
840    /// The invalid value of the enum discriminant.
841    pub invalid_discriminant: T,
842}
843
844impl<T: fmt::Display> fmt::Display for InvalidEnumDiscriminantError<T> {
845    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
846        write!(
847            f,
848            "invalid discriminant '{}' for enum '{}'",
849            self.invalid_discriminant, self.enum_name
850        )
851    }
852}
853
854impl<T> Error for InvalidEnumDiscriminantError<T> where
855    T: fmt::Debug + fmt::Display
856{
857}
858
859/// Context for errors resulting from checking enum variants with named fields.
860///
861/// This context is used by the derive macro to trace which field of an enum
862/// variant failed validation.
863#[derive(Debug)]
864pub struct NamedEnumVariantCheckContext {
865    /// The name of the enum with an invalid variant.
866    pub enum_name: &'static str,
867    /// The name of the variant that was invalid.
868    pub variant_name: &'static str,
869    /// The name of the field that was invalid.
870    pub field_name: &'static str,
871}
872
873impl fmt::Display for NamedEnumVariantCheckContext {
874    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
875        write!(
876            f,
877            "while checking field '{}' of variant '{}' of enum '{}'",
878            self.field_name, self.variant_name, self.enum_name,
879        )
880    }
881}
882
883/// Context for errors resulting from checking enum variants with unnamed
884/// fields.
885///
886/// This context is used by the derive macro to trace which field of an enum
887/// variant failed validation.
888#[derive(Debug)]
889pub struct UnnamedEnumVariantCheckContext {
890    /// The name of the enum with an invalid variant.
891    pub enum_name: &'static str,
892    /// The name of the variant that was invalid.
893    pub variant_name: &'static str,
894    /// The name of the field that was invalid.
895    pub field_index: usize,
896}
897
898impl fmt::Display for UnnamedEnumVariantCheckContext {
899    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
900        write!(
901            f,
902            "while checking field index {} of variant '{}' of enum '{}'",
903            self.field_index, self.variant_name, self.enum_name,
904        )
905    }
906}
907
908// Range types
909
910// SAFETY: A `Range<T>` is valid if its `start` and `end` are both valid, and
911// `check_bytes` only returns `Ok` when both `start` and `end` are valid. Note
912// that `Range` does not require `start` be less than `end`.
913unsafe impl<T, C> CheckBytes<C> for ops::Range<T>
914where
915    T: CheckBytes<C>,
916    C: Fallible + ?Sized,
917    C::Error: Trace,
918{
919    #[inline]
920    unsafe fn check_bytes(
921        value: *const Self,
922        context: &mut C,
923    ) -> Result<(), C::Error> {
924        // SAFETY: The caller has guaranteed that `value` is aligned for a
925        // `Range<T>` and points to enough initialized bytes for one, so a
926        // pointer projected to the `start` field will be properly aligned for
927        // a `T` and point to enough initialized bytes for one too.
928        unsafe {
929            T::check_bytes(ptr::addr_of!((*value).start), context).with_trace(
930                || StructCheckContext {
931                    struct_name: "Range",
932                    field_name: "start",
933                },
934            )?;
935        }
936        // SAFETY: Same reasoning as above, but for `end`.
937        unsafe {
938            T::check_bytes(ptr::addr_of!((*value).end), context).with_trace(
939                || StructCheckContext {
940                    struct_name: "Range",
941                    field_name: "end",
942                },
943            )?;
944        }
945        Ok(())
946    }
947}
948
949// SAFETY: A `RangeFrom<T>` is valid if its `start` is valid, and `check_bytes`
950// only returns `Ok` when its `start` is valid.
951unsafe impl<T, C> CheckBytes<C> for ops::RangeFrom<T>
952where
953    T: CheckBytes<C>,
954    C: Fallible + ?Sized,
955    C::Error: Trace,
956{
957    #[inline]
958    unsafe fn check_bytes(
959        value: *const Self,
960        context: &mut C,
961    ) -> Result<(), C::Error> {
962        // SAFETY: The caller has guaranteed that `value` is aligned for a
963        // `RangeFrom<T>` and points to enough initialized bytes for one, so a
964        // pointer projected to the `start` field will be properly aligned for
965        // a `T` and point to enough initialized bytes for one too.
966        unsafe {
967            T::check_bytes(ptr::addr_of!((*value).start), context).with_trace(
968                || StructCheckContext {
969                    struct_name: "RangeFrom",
970                    field_name: "start",
971                },
972            )?;
973        }
974        Ok(())
975    }
976}
977
978// SAFETY: `RangeFull` is a ZST and so every pointer to one is valid.
979unsafe impl<C: Fallible + ?Sized> CheckBytes<C> for ops::RangeFull {
980    #[inline]
981    unsafe fn check_bytes(_: *const Self, _: &mut C) -> Result<(), C::Error> {
982        Ok(())
983    }
984}
985
986// SAFETY: A `RangeTo<T>` is valid if its `end` is valid, and `check_bytes` only
987// returns `Ok` when its `end` is valid.
988unsafe impl<T, C> CheckBytes<C> for ops::RangeTo<T>
989where
990    T: CheckBytes<C>,
991    C: Fallible + ?Sized,
992    C::Error: Trace,
993{
994    #[inline]
995    unsafe fn check_bytes(
996        value: *const Self,
997        context: &mut C,
998    ) -> Result<(), C::Error> {
999        // SAFETY: The caller has guaranteed that `value` is aligned for a
1000        // `RangeTo<T>` and points to enough initialized bytes for one, so a
1001        // pointer projected to the `end` field will be properly aligned for
1002        // a `T` and point to enough initialized bytes for one too.
1003        unsafe {
1004            T::check_bytes(ptr::addr_of!((*value).end), context).with_trace(
1005                || StructCheckContext {
1006                    struct_name: "RangeTo",
1007                    field_name: "end",
1008                },
1009            )?;
1010        }
1011        Ok(())
1012    }
1013}
1014
1015// SAFETY: A `RangeToInclusive<T>` is valid if its `end` is valid, and
1016// `check_bytes` only returns `Ok` when its `end` is valid.
1017unsafe impl<T, C> CheckBytes<C> for ops::RangeToInclusive<T>
1018where
1019    T: CheckBytes<C>,
1020    C: Fallible + ?Sized,
1021    C::Error: Trace,
1022{
1023    #[inline]
1024    unsafe fn check_bytes(
1025        value: *const Self,
1026        context: &mut C,
1027    ) -> Result<(), C::Error> {
1028        // SAFETY: The caller has guaranteed that `value` is aligned for a
1029        // `RangeToInclusive<T>` and points to enough initialized bytes for one,
1030        // so a pointer projected to the `end` field will be properly aligned
1031        // for a `T` and point to enough initialized bytes for one too.
1032        unsafe {
1033            T::check_bytes(ptr::addr_of!((*value).end), context).with_trace(
1034                || StructCheckContext {
1035                    struct_name: "RangeToInclusive",
1036                    field_name: "end",
1037                },
1038            )?;
1039        }
1040        Ok(())
1041    }
1042}
1043
1044#[derive(Debug)]
1045struct NonZeroCheckError;
1046
1047impl fmt::Display for NonZeroCheckError {
1048    #[inline]
1049    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1050        write!(f, "nonzero integer is zero")
1051    }
1052}
1053
1054impl Error for NonZeroCheckError {}
1055
1056macro_rules! impl_nonzero {
1057    ($nonzero:ident, $underlying:ident) => {
1058        // SAFETY: `check_bytes` only returns `Ok` when `value` is not zero, the
1059        // only validity condition for non-zero integer types.
1060        unsafe impl<C> CheckBytes<C> for $nonzero
1061        where
1062            C: Fallible + ?Sized,
1063            C::Error: Source,
1064        {
1065            #[inline]
1066            unsafe fn check_bytes(
1067                value: *const Self,
1068                _: &mut C,
1069            ) -> Result<(), C::Error> {
1070                // SAFETY: Non-zero integer types are guaranteed to have the
1071                // same ABI as their corresponding integer types. Those integers
1072                // have no validity requirements, so we can cast and dereference
1073                // value to check if it is equal to zero.
1074                if unsafe { *value.cast::<$underlying>() } == 0 {
1075                    fail!(NonZeroCheckError);
1076                } else {
1077                    Ok(())
1078                }
1079            }
1080        }
1081    };
1082}
1083
1084impl_nonzero!(NonZeroI8, i8);
1085impl_nonzero!(NonZeroI16, i16);
1086impl_nonzero!(NonZeroI32, i32);
1087impl_nonzero!(NonZeroI64, i64);
1088impl_nonzero!(NonZeroI128, i128);
1089impl_nonzero!(NonZeroU8, u8);
1090impl_nonzero!(NonZeroU16, u16);
1091impl_nonzero!(NonZeroU32, u32);
1092impl_nonzero!(NonZeroU64, u64);
1093impl_nonzero!(NonZeroU128, u128);
1094
1095#[cfg(test)]
1096#[macro_use]
1097mod tests {
1098    use core::ffi::CStr;
1099
1100    use rancor::{Failure, Fallible, Infallible, Source, Strategy};
1101
1102    use crate::{check_bytes, check_bytes_with_context, CheckBytes, Verify};
1103
1104    #[derive(Debug)]
1105    #[repr(transparent)]
1106    struct CharLE(u32);
1107
1108    impl From<char> for CharLE {
1109        fn from(c: char) -> Self {
1110            #[cfg(target_endian = "little")]
1111            {
1112                Self(c as u32)
1113            }
1114            #[cfg(target_endian = "big")]
1115            {
1116                Self((c as u32).swap_bytes())
1117            }
1118        }
1119    }
1120
1121    unsafe impl<C> CheckBytes<C> for CharLE
1122    where
1123        C: Fallible + ?Sized,
1124        C::Error: Source,
1125    {
1126        unsafe fn check_bytes(
1127            value: *const Self,
1128            context: &mut C,
1129        ) -> Result<(), C::Error> {
1130            #[cfg(target_endian = "little")]
1131            unsafe {
1132                char::check_bytes(value.cast(), context)
1133            }
1134            #[cfg(target_endian = "big")]
1135            unsafe {
1136                let mut bytes = *value.cast::<[u8; 4]>();
1137                bytes.reverse();
1138                char::check_bytes(bytes.as_ref().as_ptr().cast(), context)
1139            }
1140        }
1141    }
1142
1143    #[repr(C, align(16))]
1144    struct Aligned<const N: usize>(pub [u8; N]);
1145
1146    macro_rules! bytes {
1147        ($($byte:literal),* $(,)?) => {
1148            (&$crate::tests::Aligned([$($byte,)*]).0 as &[u8]).as_ptr()
1149        }
1150    }
1151
1152    #[test]
1153    fn test_tuples() {
1154        unsafe {
1155            check_bytes::<_, Failure>(&(42u32, true, 'x')).unwrap();
1156        }
1157        unsafe {
1158            check_bytes::<_, Failure>(&(true,)).unwrap();
1159        }
1160
1161        unsafe {
1162            // These tests assume the tuple is packed (u32, bool, char)
1163            check_bytes::<(u32, bool, CharLE), Failure>(
1164                bytes![
1165                    0u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 0x78u8, 0u8,
1166                    0u8, 0u8, 255u8, 255u8, 255u8, 255u8,
1167                ]
1168                .cast(),
1169            )
1170            .unwrap();
1171            check_bytes::<(u32, bool, CharLE), Failure>(
1172                bytes![
1173                    42u8, 16u8, 20u8, 3u8, 1u8, 255u8, 255u8, 255u8, 0x78u8,
1174                    0u8, 0u8, 0u8, 255u8, 255u8, 255u8, 255u8,
1175                ]
1176                .cast(),
1177            )
1178            .unwrap();
1179            check_bytes::<(u32, bool, CharLE), Failure>(
1180                bytes![
1181                    0u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 0x00u8,
1182                    0xd8u8, 0u8, 0u8, 255u8, 255u8, 255u8, 255u8,
1183                ]
1184                .cast(),
1185            )
1186            .unwrap_err();
1187            check_bytes::<(u32, bool, CharLE), Failure>(
1188                bytes![
1189                    0u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 0x00u8,
1190                    0x00u8, 0x11u8, 0u8, 255u8, 255u8, 255u8, 255u8,
1191                ]
1192                .cast(),
1193            )
1194            .unwrap_err();
1195            check_bytes::<(u32, bool, CharLE), Failure>(
1196                bytes![
1197                    0u8, 0u8, 0u8, 0u8, 0u8, 255u8, 255u8, 255u8, 0x78u8, 0u8,
1198                    0u8, 0u8, 255u8, 255u8, 255u8, 255u8,
1199                ]
1200                .cast(),
1201            )
1202            .unwrap();
1203            check_bytes::<(u32, bool, CharLE), Failure>(
1204                bytes![
1205                    0u8, 0u8, 0u8, 0u8, 2u8, 255u8, 255u8, 255u8, 0x78u8, 0u8,
1206                    0u8, 0u8, 255u8, 255u8, 255u8, 255u8,
1207                ]
1208                .cast(),
1209            )
1210            .unwrap_err();
1211        }
1212    }
1213
1214    #[test]
1215    fn test_arrays() {
1216        unsafe {
1217            check_bytes::<_, Failure>(&[true, false, true, false]).unwrap();
1218        }
1219        unsafe {
1220            check_bytes::<_, Failure>(&[false, true]).unwrap();
1221        }
1222
1223        unsafe {
1224            check_bytes::<[bool; 4], Failure>(
1225                bytes![1u8, 0u8, 1u8, 0u8].cast(),
1226            )
1227            .unwrap();
1228            check_bytes::<[bool; 4], Failure>(
1229                bytes![1u8, 2u8, 1u8, 0u8].cast(),
1230            )
1231            .unwrap_err();
1232            check_bytes::<[bool; 4], Failure>(
1233                bytes![2u8, 0u8, 1u8, 0u8].cast(),
1234            )
1235            .unwrap_err();
1236            check_bytes::<[bool; 4], Failure>(
1237                bytes![1u8, 0u8, 1u8, 2u8].cast(),
1238            )
1239            .unwrap_err();
1240            check_bytes::<[bool; 4], Failure>(
1241                bytes![1u8, 0u8, 1u8, 0u8, 2u8].cast(),
1242            )
1243            .unwrap();
1244        }
1245    }
1246
1247    #[test]
1248    fn test_unsized() {
1249        unsafe {
1250            check_bytes::<[i32], Infallible>(
1251                &[1, 2, 3, 4] as &[i32] as *const [i32]
1252            )
1253            .unwrap();
1254            check_bytes::<str, Failure>("hello world" as *const str).unwrap();
1255        }
1256    }
1257
1258    #[test]
1259    fn test_c_str() {
1260        macro_rules! test_cases {
1261            ($($bytes:expr, $pat:pat,)*) => {
1262                $(
1263                    let bytes = $bytes;
1264                    let c_str = ::ptr_meta::from_raw_parts(
1265                        bytes.as_ptr().cast(),
1266                        bytes.len(),
1267                    );
1268                    assert!(matches!(
1269                        check_bytes::<CStr, Failure>(c_str),
1270                        $pat,
1271                    ));
1272                )*
1273            }
1274        }
1275
1276        unsafe {
1277            test_cases! {
1278                b"hello world\0", Ok(_),
1279                b"hello world", Err(_),
1280                b"", Err(_),
1281                [0xc3u8, 0x28u8, 0x00u8], Ok(_),
1282                [0xc3u8, 0x28u8, 0x00u8, 0xc3u8, 0x28u8, 0x00u8], Err(_),
1283            }
1284        }
1285    }
1286
1287    #[test]
1288    fn test_unit_struct() {
1289        #[derive(CheckBytes)]
1290        #[bytecheck(crate)]
1291        struct Test;
1292
1293        unsafe {
1294            check_bytes::<_, Infallible>(&Test).unwrap();
1295        }
1296    }
1297
1298    #[test]
1299    fn test_tuple_struct() {
1300        #[derive(CheckBytes, Debug)]
1301        #[bytecheck(crate)]
1302        struct Test(u32, bool, CharLE);
1303
1304        let value = Test(42, true, 'x'.into());
1305
1306        unsafe {
1307            check_bytes::<_, Failure>(&value).unwrap();
1308        }
1309
1310        unsafe {
1311            // These tests assume the struct is packed (u32, char, bool)
1312            check_bytes::<Test, Failure>(
1313                bytes![
1314                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 1u8, 255u8,
1315                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1316                ]
1317                .cast(),
1318            )
1319            .unwrap();
1320            check_bytes::<Test, Failure>(
1321                bytes![
1322                    42u8, 16u8, 20u8, 3u8, 0x78u8, 0u8, 0u8, 0u8, 1u8, 255u8,
1323                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1324                ]
1325                .cast(),
1326            )
1327            .unwrap();
1328            check_bytes::<Test, Failure>(
1329                bytes![
1330                    0u8, 0u8, 0u8, 0u8, 0x00u8, 0xd8u8, 0u8, 0u8, 1u8, 255u8,
1331                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1332                ]
1333                .cast(),
1334            )
1335            .unwrap_err();
1336            check_bytes::<Test, Failure>(
1337                bytes![
1338                    0u8, 0u8, 0u8, 0u8, 0x00u8, 0x00u8, 0x11u8, 0u8, 1u8,
1339                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1340                ]
1341                .cast(),
1342            )
1343            .unwrap_err();
1344            check_bytes::<Test, Failure>(
1345                bytes![
1346                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 0u8, 255u8,
1347                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1348                ]
1349                .cast(),
1350            )
1351            .unwrap();
1352            check_bytes::<Test, Failure>(
1353                bytes![
1354                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 2u8, 255u8,
1355                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1356                ]
1357                .cast(),
1358            )
1359            .unwrap_err();
1360        }
1361    }
1362
1363    #[test]
1364    fn test_struct() {
1365        #[derive(CheckBytes, Debug)]
1366        #[bytecheck(crate)]
1367        struct Test {
1368            a: u32,
1369            b: bool,
1370            c: CharLE,
1371        }
1372
1373        let value = Test {
1374            a: 42,
1375            b: true,
1376            c: 'x'.into(),
1377        };
1378
1379        unsafe {
1380            check_bytes::<_, Failure>(&value).unwrap();
1381        }
1382
1383        unsafe {
1384            // These tests assume the struct is packed (u32, char, bool)
1385            check_bytes::<Test, Failure>(
1386                bytes![
1387                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 1u8, 255u8,
1388                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1389                ]
1390                .cast(),
1391            )
1392            .unwrap();
1393            check_bytes::<Test, Failure>(
1394                bytes![
1395                    42u8, 16u8, 20u8, 3u8, 0x78u8, 0u8, 0u8, 0u8, 1u8, 255u8,
1396                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1397                ]
1398                .cast(),
1399            )
1400            .unwrap();
1401            check_bytes::<Test, Failure>(
1402                bytes![
1403                    0u8, 0u8, 0u8, 0u8, 0x00u8, 0xd8u8, 0u8, 0u8, 1u8, 255u8,
1404                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1405                ]
1406                .cast(),
1407            )
1408            .unwrap_err();
1409            check_bytes::<Test, Failure>(
1410                bytes![
1411                    0u8, 0u8, 0u8, 0u8, 0x00u8, 0x00u8, 0x11u8, 0u8, 1u8,
1412                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1413                ]
1414                .cast(),
1415            )
1416            .unwrap_err();
1417            check_bytes::<Test, Failure>(
1418                bytes![
1419                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 0u8, 255u8,
1420                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1421                ]
1422                .cast(),
1423            )
1424            .unwrap();
1425            check_bytes::<Test, Failure>(
1426                bytes![
1427                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 2u8, 255u8,
1428                    255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1429                ]
1430                .cast(),
1431            )
1432            .unwrap_err();
1433        }
1434    }
1435
1436    #[test]
1437    fn test_generic_struct() {
1438        #[derive(CheckBytes, Debug)]
1439        #[bytecheck(crate)]
1440        struct Test<T> {
1441            a: u32,
1442            b: T,
1443        }
1444
1445        let value = Test { a: 42, b: true };
1446
1447        unsafe {
1448            check_bytes::<_, Failure>(&value).unwrap();
1449        }
1450
1451        unsafe {
1452            check_bytes::<Test<bool>, Failure>(
1453                bytes![0u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8].cast(),
1454            )
1455            .unwrap();
1456            check_bytes::<Test<bool>, Failure>(
1457                bytes![12u8, 34u8, 56u8, 78u8, 1u8, 255u8, 255u8, 255u8].cast(),
1458            )
1459            .unwrap();
1460            check_bytes::<Test<bool>, Failure>(
1461                bytes![0u8, 0u8, 0u8, 0u8, 0u8, 255u8, 255u8, 255u8].cast(),
1462            )
1463            .unwrap();
1464            check_bytes::<Test<bool>, Failure>(
1465                bytes![0u8, 0u8, 0u8, 0u8, 2u8, 255u8, 255u8, 255u8].cast(),
1466            )
1467            .unwrap_err();
1468        }
1469    }
1470
1471    #[test]
1472    fn test_enum() {
1473        #[allow(dead_code)]
1474        #[derive(CheckBytes, Debug)]
1475        #[bytecheck(crate)]
1476        #[repr(u8)]
1477        enum Test {
1478            A(u32, bool, CharLE),
1479            B { a: u32, b: bool, c: CharLE },
1480            C,
1481        }
1482
1483        let value = Test::A(42, true, 'x'.into());
1484
1485        unsafe {
1486            check_bytes::<_, Failure>(&value).unwrap();
1487        }
1488
1489        let value = Test::B {
1490            a: 42,
1491            b: true,
1492            c: 'x'.into(),
1493        };
1494
1495        unsafe {
1496            check_bytes::<_, Failure>(&value).unwrap();
1497        }
1498
1499        let value = Test::C;
1500
1501        unsafe {
1502            check_bytes::<_, Failure>(&value).unwrap();
1503        }
1504
1505        unsafe {
1506            check_bytes::<Test, Failure>(
1507                bytes![
1508                    0u8, 0u8, 0u8, 0u8, 12u8, 34u8, 56u8, 78u8, 1u8, 255u8,
1509                    255u8, 255u8, 120u8, 0u8, 0u8, 0u8,
1510                ]
1511                .cast(),
1512            )
1513            .unwrap();
1514            check_bytes::<Test, Failure>(
1515                bytes![
1516                    1u8, 0u8, 0u8, 0u8, 12u8, 34u8, 56u8, 78u8, 1u8, 255u8,
1517                    255u8, 255u8, 120u8, 0u8, 0u8, 0u8,
1518                ]
1519                .cast(),
1520            )
1521            .unwrap();
1522            check_bytes::<Test, Failure>(
1523                bytes![
1524                    2u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1525                    255u8, 255u8, 255u8, 255u8, 25u8, 255u8, 255u8, 255u8,
1526                ]
1527                .cast(),
1528            )
1529            .unwrap();
1530            check_bytes::<Test, Failure>(
1531                bytes![
1532                    3u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
1533                    255u8, 255u8, 255u8, 255u8, 25u8, 255u8, 255u8, 255u8,
1534                ]
1535                .cast(),
1536            )
1537            .unwrap_err();
1538        }
1539    }
1540
1541    #[test]
1542    fn test_explicit_enum_values() {
1543        #[derive(CheckBytes, Debug)]
1544        #[bytecheck(crate)]
1545        #[repr(u8)]
1546        enum Test {
1547            A,
1548            B = 100,
1549            C,
1550            D = 200,
1551            E,
1552        }
1553
1554        unsafe {
1555            check_bytes::<_, Failure>(&Test::A).unwrap();
1556        }
1557        unsafe {
1558            check_bytes::<_, Failure>(&Test::B).unwrap();
1559        }
1560        unsafe {
1561            check_bytes::<_, Failure>(&Test::C).unwrap();
1562        }
1563        unsafe {
1564            check_bytes::<_, Failure>(&Test::D).unwrap();
1565        }
1566        unsafe {
1567            check_bytes::<_, Failure>(&Test::E).unwrap();
1568        }
1569
1570        unsafe {
1571            check_bytes::<Test, Failure>(bytes![1u8].cast()).unwrap_err();
1572            check_bytes::<Test, Failure>(bytes![99u8].cast()).unwrap_err();
1573            check_bytes::<Test, Failure>(bytes![102u8].cast()).unwrap_err();
1574            check_bytes::<Test, Failure>(bytes![199u8].cast()).unwrap_err();
1575            check_bytes::<Test, Failure>(bytes![202u8].cast()).unwrap_err();
1576            check_bytes::<Test, Failure>(bytes![255u8].cast()).unwrap_err();
1577        }
1578    }
1579
1580    #[test]
1581    fn test_recursive() {
1582        struct MyBox<T: ?Sized> {
1583            inner: *const T,
1584        }
1585
1586        unsafe impl<T, C> CheckBytes<C> for MyBox<T>
1587        where
1588            T: CheckBytes<C>,
1589            C: Fallible + ?Sized,
1590        {
1591            unsafe fn check_bytes(
1592                value: *const Self,
1593                context: &mut C,
1594            ) -> Result<(), C::Error> {
1595                unsafe { T::check_bytes((*value).inner, context) }
1596            }
1597        }
1598
1599        #[allow(dead_code)]
1600        #[derive(CheckBytes)]
1601        #[bytecheck(crate)]
1602        #[repr(u8)]
1603        enum Node {
1604            Nil,
1605            Cons(#[bytecheck(omit_bounds)] MyBox<Node>),
1606        }
1607
1608        unsafe {
1609            let nil = Node::Nil;
1610            let cons = Node::Cons(MyBox {
1611                inner: &nil as *const Node,
1612            });
1613            check_bytes::<Node, Failure>(&cons).unwrap();
1614        }
1615    }
1616
1617    #[test]
1618    fn test_explicit_crate_root() {
1619        mod bytecheck {}
1620        mod m {
1621            pub use crate as bc;
1622        }
1623
1624        #[derive(CheckBytes)]
1625        #[bytecheck(crate = m::bc)]
1626        struct Test;
1627
1628        unsafe {
1629            check_bytes::<_, Infallible>(&Test).unwrap();
1630        }
1631
1632        #[derive(CheckBytes)]
1633        #[bytecheck(crate = crate)]
1634        struct Test2;
1635
1636        unsafe {
1637            check_bytes::<_, Infallible>(&Test2).unwrap();
1638        }
1639    }
1640
1641    trait MyContext {
1642        fn set_value(&mut self, value: i32);
1643    }
1644
1645    impl<T: MyContext, E> MyContext for Strategy<T, E> {
1646        fn set_value(&mut self, value: i32) {
1647            T::set_value(self, value)
1648        }
1649    }
1650
1651    struct FooContext {
1652        value: i32,
1653    }
1654
1655    impl MyContext for FooContext {
1656        fn set_value(&mut self, value: i32) {
1657            self.value = value;
1658        }
1659    }
1660
1661    #[test]
1662    fn test_derive_verify_unit_struct() {
1663        unsafe impl<C: Fallible + MyContext + ?Sized> Verify<C> for UnitStruct {
1664            fn verify(&self, context: &mut C) -> Result<(), C::Error> {
1665                context.set_value(1);
1666                Ok(())
1667            }
1668        }
1669
1670        #[derive(CheckBytes)]
1671        #[bytecheck(crate, verify)]
1672        struct UnitStruct;
1673
1674        let mut context = FooContext { value: 0 };
1675        unsafe {
1676            check_bytes_with_context::<_, _, Infallible>(
1677                &UnitStruct,
1678                &mut context,
1679            )
1680            .unwrap();
1681        }
1682
1683        assert_eq!(context.value, 1);
1684    }
1685
1686    #[test]
1687    fn test_derive_verify_struct() {
1688        unsafe impl<C: Fallible + MyContext + ?Sized> Verify<C> for Struct {
1689            fn verify(&self, context: &mut C) -> Result<(), C::Error> {
1690                context.set_value(self.value);
1691                Ok(())
1692            }
1693        }
1694
1695        #[derive(CheckBytes)]
1696        #[bytecheck(crate, verify)]
1697        struct Struct {
1698            value: i32,
1699        }
1700
1701        let mut context = FooContext { value: 0 };
1702        unsafe {
1703            check_bytes_with_context::<_, _, Infallible>(
1704                &Struct { value: 4 },
1705                &mut context,
1706            )
1707            .unwrap();
1708        }
1709
1710        assert_eq!(context.value, 4);
1711    }
1712
1713    #[test]
1714    fn test_derive_verify_tuple_struct() {
1715        unsafe impl<C> Verify<C> for TupleStruct
1716        where
1717            C: Fallible + MyContext + ?Sized,
1718        {
1719            fn verify(&self, context: &mut C) -> Result<(), C::Error> {
1720                context.set_value(self.0);
1721                Ok(())
1722            }
1723        }
1724
1725        #[derive(CheckBytes)]
1726        #[bytecheck(crate, verify)]
1727        struct TupleStruct(i32);
1728
1729        let mut context = FooContext { value: 0 };
1730        unsafe {
1731            check_bytes_with_context::<_, _, Infallible>(
1732                &TupleStruct(10),
1733                &mut context,
1734            )
1735            .unwrap();
1736        }
1737
1738        assert_eq!(context.value, 10);
1739    }
1740
1741    #[test]
1742    fn test_derive_verify_enum() {
1743        unsafe impl<C: Fallible + MyContext + ?Sized> Verify<C> for Enum {
1744            fn verify(&self, context: &mut C) -> Result<(), C::Error> {
1745                match self {
1746                    Enum::A => context.set_value(2),
1747                    Enum::B(value) => context.set_value(*value),
1748                    Enum::C { value } => context.set_value(*value),
1749                }
1750                Ok(())
1751            }
1752        }
1753
1754        #[derive(CheckBytes)]
1755        #[bytecheck(crate, verify)]
1756        #[repr(u8)]
1757        enum Enum {
1758            A,
1759            B(i32),
1760            C { value: i32 },
1761        }
1762
1763        // Unit variant
1764        let mut context = FooContext { value: 0 };
1765        unsafe {
1766            check_bytes_with_context::<_, _, Failure>(&Enum::A, &mut context)
1767                .unwrap();
1768        }
1769
1770        assert_eq!(context.value, 2);
1771
1772        // Tuple variant
1773        let mut context = FooContext { value: 0 };
1774        unsafe {
1775            check_bytes_with_context::<_, _, Failure>(
1776                &Enum::B(5),
1777                &mut context,
1778            )
1779            .unwrap();
1780        }
1781
1782        assert_eq!(context.value, 5);
1783
1784        // Struct variant
1785        let mut context = FooContext { value: 0 };
1786        unsafe {
1787            check_bytes_with_context::<_, _, Failure>(
1788                &Enum::C { value: 7 },
1789                &mut context,
1790            )
1791            .unwrap();
1792        }
1793
1794        assert_eq!(context.value, 7);
1795    }
1796}