bytecheck_test/
lib.rs

1#![deny(
2    rust_2018_compatibility,
3    rust_2018_idioms,
4    future_incompatible,
5    nonstandard_style,
6    unused,
7    clippy::all
8)]
9#![cfg_attr(not(feature = "std"), no_std)]
10
11#[cfg(not(feature = "std"))]
12extern crate alloc;
13
14#[cfg(test)]
15mod tests {
16    use bytecheck::CheckBytes;
17
18    #[derive(Debug)]
19    #[repr(transparent)]
20    struct CharLE(u32);
21
22    impl From<char> for CharLE {
23        fn from(c: char) -> Self {
24            #[cfg(target_endian = "little")]
25            {
26                Self(c as u32)
27            }
28            #[cfg(target_endian = "big")]
29            {
30                Self((c as u32).swap_bytes())
31            }
32        }
33    }
34
35    impl<C: ?Sized> CheckBytes<C> for CharLE {
36        type Error = <char as CheckBytes<C>>::Error;
37
38        unsafe fn check_bytes<'a>(
39            value: *const Self,
40            context: &mut C,
41        ) -> Result<&'a Self, Self::Error> {
42            #[cfg(target_endian = "little")]
43            {
44                char::check_bytes(value.cast(), context)?;
45                Ok(&*value.cast())
46            }
47            #[cfg(target_endian = "big")]
48            {
49                let mut bytes = *value.cast::<[u8; 4]>();
50                bytes.reverse();
51                char::check_bytes(bytes.as_ref().as_ptr().cast(), context)?;
52                Ok(&*value.cast())
53            }
54        }
55    }
56
57    fn check_as_bytes<T: CheckBytes<C>, C>(value: &T, mut context: C) {
58        unsafe { T::check_bytes(value as *const T, &mut context).unwrap() };
59    }
60
61    #[repr(C, align(16))]
62    struct Aligned<const N: usize>([u8; N]);
63
64    macro_rules! bytes {
65        ($($byte:literal,)*) => {
66            (&Aligned([$($byte,)*]).0 as &[u8]).as_ptr()
67        };
68        ($($byte:literal),*) => {
69            bytes!($($byte,)*)
70        };
71    }
72
73    #[test]
74    fn test_tuples() {
75        check_as_bytes(&(42u32, true, 'x'), &mut ());
76        check_as_bytes(&(true,), &mut ());
77
78        unsafe {
79            // These tests assume the tuple is packed (u32, bool, char)
80            <(u32, bool, CharLE)>::check_bytes(
81                bytes![
82                    0u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 0x78u8, 0u8, 0u8, 0u8, 255u8,
83                    255u8, 255u8, 255u8,
84                ]
85                .cast(),
86                &mut (),
87            )
88            .unwrap();
89            <(u32, bool, CharLE)>::check_bytes(
90                bytes![
91                    42u8, 16u8, 20u8, 3u8, 1u8, 255u8, 255u8, 255u8, 0x78u8, 0u8, 0u8, 0u8, 255u8,
92                    255u8, 255u8, 255u8,
93                ]
94                .cast(),
95                &mut (),
96            )
97            .unwrap();
98            <(u32, bool, CharLE)>::check_bytes(
99                bytes![
100                    0u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 0x00u8, 0xd8u8, 0u8, 0u8, 255u8,
101                    255u8, 255u8, 255u8,
102                ]
103                .cast(),
104                &mut (),
105            )
106            .unwrap_err();
107            <(u32, bool, CharLE)>::check_bytes(
108                bytes![
109                    0u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 0x00u8, 0x00u8, 0x11u8, 0u8,
110                    255u8, 255u8, 255u8, 255u8,
111                ]
112                .cast(),
113                &mut (),
114            )
115            .unwrap_err();
116            <(u32, bool, CharLE)>::check_bytes(
117                bytes![
118                    0u8, 0u8, 0u8, 0u8, 0u8, 255u8, 255u8, 255u8, 0x78u8, 0u8, 0u8, 0u8, 255u8,
119                    255u8, 255u8, 255u8,
120                ]
121                .cast(),
122                &mut (),
123            )
124            .unwrap();
125            <(u32, bool, CharLE)>::check_bytes(
126                bytes![
127                    0u8, 0u8, 0u8, 0u8, 2u8, 255u8, 255u8, 255u8, 0x78u8, 0u8, 0u8, 0u8, 255u8,
128                    255u8, 255u8, 255u8,
129                ]
130                .cast(),
131                &mut (),
132            )
133            .unwrap_err();
134        }
135    }
136
137    #[test]
138    fn test_arrays() {
139        check_as_bytes(&[true, false, true, false], &mut ());
140        check_as_bytes(&[false, true], &mut ());
141
142        unsafe {
143            <[bool; 4]>::check_bytes(bytes![1u8, 0u8, 1u8, 0u8].cast(), &mut ()).unwrap();
144            <[bool; 4]>::check_bytes(bytes![1u8, 2u8, 1u8, 0u8].cast(), &mut ()).unwrap_err();
145            <[bool; 4]>::check_bytes(bytes![2u8, 0u8, 1u8, 0u8].cast(), &mut ()).unwrap_err();
146            <[bool; 4]>::check_bytes(bytes![1u8, 0u8, 1u8, 2u8].cast(), &mut ()).unwrap_err();
147            <[bool; 4]>::check_bytes(bytes![1u8, 0u8, 1u8, 0u8, 2u8].cast(), &mut ()).unwrap();
148        }
149    }
150
151    #[test]
152    fn test_unit_struct() {
153        #[derive(CheckBytes)]
154        struct Test;
155
156        check_as_bytes(&Test, &mut ());
157    }
158
159    #[test]
160    fn test_tuple_struct() {
161        #[derive(CheckBytes, Debug)]
162        struct Test(u32, bool, CharLE);
163
164        let value = Test(42, true, 'x'.into());
165
166        check_as_bytes(&value, &mut ());
167
168        unsafe {
169            // These tests assume the struct is packed (u32, char, bool)
170            Test::check_bytes(
171                bytes![
172                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 255u8,
173                    255u8, 255u8, 255u8,
174                ]
175                .cast(),
176                &mut (),
177            )
178            .unwrap();
179            Test::check_bytes(
180                bytes![
181                    42u8, 16u8, 20u8, 3u8, 0x78u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 255u8,
182                    255u8, 255u8, 255u8,
183                ]
184                .cast(),
185                &mut (),
186            )
187            .unwrap();
188            Test::check_bytes(
189                bytes![
190                    0u8, 0u8, 0u8, 0u8, 0x00u8, 0xd8u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 255u8,
191                    255u8, 255u8, 255u8,
192                ]
193                .cast(),
194                &mut (),
195            )
196            .unwrap_err();
197            Test::check_bytes(
198                bytes![
199                    0u8, 0u8, 0u8, 0u8, 0x00u8, 0x00u8, 0x11u8, 0u8, 1u8, 255u8, 255u8, 255u8,
200                    255u8, 255u8, 255u8, 255u8,
201                ]
202                .cast(),
203                &mut (),
204            )
205            .unwrap_err();
206            Test::check_bytes(
207                bytes![
208                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 0u8, 255u8, 255u8, 255u8, 255u8,
209                    255u8, 255u8, 255u8,
210                ]
211                .cast(),
212                &mut (),
213            )
214            .unwrap();
215            Test::check_bytes(
216                bytes![
217                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 2u8, 255u8, 255u8, 255u8, 255u8,
218                    255u8, 255u8, 255u8,
219                ]
220                .cast(),
221                &mut (),
222            )
223            .unwrap_err();
224        }
225    }
226
227    #[test]
228    fn test_struct() {
229        #[derive(CheckBytes, Debug)]
230        struct Test {
231            a: u32,
232            b: bool,
233            c: CharLE,
234        }
235
236        let value = Test {
237            a: 42,
238            b: true,
239            c: 'x'.into(),
240        };
241
242        check_as_bytes(&value, &mut ());
243
244        unsafe {
245            // These tests assume the struct is packed (u32, char, bool)
246            Test::check_bytes(
247                bytes![
248                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 255u8,
249                    255u8, 255u8, 255u8,
250                ]
251                .cast(),
252                &mut (),
253            )
254            .unwrap();
255            Test::check_bytes(
256                bytes![
257                    42u8, 16u8, 20u8, 3u8, 0x78u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 255u8,
258                    255u8, 255u8, 255u8,
259                ]
260                .cast(),
261                &mut (),
262            )
263            .unwrap();
264            Test::check_bytes(
265                bytes![
266                    0u8, 0u8, 0u8, 0u8, 0x00u8, 0xd8u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8, 255u8,
267                    255u8, 255u8, 255u8,
268                ]
269                .cast(),
270                &mut (),
271            )
272            .unwrap_err();
273            Test::check_bytes(
274                bytes![
275                    0u8, 0u8, 0u8, 0u8, 0x00u8, 0x00u8, 0x11u8, 0u8, 1u8, 255u8, 255u8, 255u8,
276                    255u8, 255u8, 255u8, 255u8,
277                ]
278                .cast(),
279                &mut (),
280            )
281            .unwrap_err();
282            Test::check_bytes(
283                bytes![
284                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 0u8, 255u8, 255u8, 255u8, 255u8,
285                    255u8, 255u8, 255u8,
286                ]
287                .cast(),
288                &mut (),
289            )
290            .unwrap();
291            Test::check_bytes(
292                bytes![
293                    0u8, 0u8, 0u8, 0u8, 0x78u8, 0u8, 0u8, 0u8, 2u8, 255u8, 255u8, 255u8, 255u8,
294                    255u8, 255u8, 255u8,
295                ]
296                .cast(),
297                &mut (),
298            )
299            .unwrap_err();
300        }
301    }
302
303    #[test]
304    fn test_generic_struct() {
305        #[derive(CheckBytes, Debug)]
306        struct Test<T> {
307            a: u32,
308            b: T,
309        }
310
311        let value = Test { a: 42, b: true };
312
313        check_as_bytes(&value, &mut ());
314
315        unsafe {
316            Test::<bool>::check_bytes(
317                bytes![0u8, 0u8, 0u8, 0u8, 1u8, 255u8, 255u8, 255u8,].cast(),
318                &mut (),
319            )
320            .unwrap();
321            Test::<bool>::check_bytes(
322                bytes![12u8, 34u8, 56u8, 78u8, 1u8, 255u8, 255u8, 255u8,].cast(),
323                &mut (),
324            )
325            .unwrap();
326            Test::<bool>::check_bytes(
327                bytes![0u8, 0u8, 0u8, 0u8, 0u8, 255u8, 255u8, 255u8,].cast(),
328                &mut (),
329            )
330            .unwrap();
331            Test::<bool>::check_bytes(
332                bytes![0u8, 0u8, 0u8, 0u8, 2u8, 255u8, 255u8, 255u8,].cast(),
333                &mut (),
334            )
335            .unwrap_err();
336        }
337    }
338
339    #[test]
340    fn test_enum() {
341        #[derive(CheckBytes, Debug)]
342        #[repr(u8)]
343        enum Test {
344            A(u32, bool, CharLE),
345            #[allow(dead_code)]
346            B {
347                a: u32,
348                b: bool,
349                c: CharLE,
350            },
351            C,
352        }
353
354        let value = Test::A(42, true, 'x'.into());
355
356        check_as_bytes(&value, &mut ());
357
358        let value = Test::B {
359            a: 42,
360            b: true,
361            c: 'x'.into(),
362        };
363
364        check_as_bytes(&value, &mut ());
365
366        let value = Test::C;
367
368        check_as_bytes(&value, &mut ());
369
370        unsafe {
371            Test::check_bytes(
372                bytes![
373                    0u8, 0u8, 0u8, 0u8, 12u8, 34u8, 56u8, 78u8, 1u8, 255u8, 255u8, 255u8, 120u8,
374                    0u8, 0u8, 0u8,
375                ]
376                .cast(),
377                &mut (),
378            )
379            .unwrap();
380            Test::check_bytes(
381                bytes![
382                    1u8, 0u8, 0u8, 0u8, 12u8, 34u8, 56u8, 78u8, 1u8, 255u8, 255u8, 255u8, 120u8,
383                    0u8, 0u8, 0u8,
384                ]
385                .cast(),
386                &mut (),
387            )
388            .unwrap();
389            Test::check_bytes(
390                bytes![
391                    2u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
392                    255u8, 25u8, 255u8, 255u8, 255u8,
393                ]
394                .cast(),
395                &mut (),
396            )
397            .unwrap();
398            Test::check_bytes(
399                bytes![
400                    3u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8, 255u8,
401                    255u8, 25u8, 255u8, 255u8, 255u8,
402                ]
403                .cast(),
404                &mut (),
405            )
406            .unwrap_err();
407        }
408    }
409
410    #[test]
411    fn test_explicit_enum_values() {
412        #[derive(CheckBytes, Debug)]
413        #[repr(u8)]
414        enum Test {
415            A,
416            B = 100,
417            C,
418            D = 200,
419            E,
420        }
421
422        check_as_bytes(&Test::A, &mut ());
423        check_as_bytes(&Test::B, &mut ());
424        check_as_bytes(&Test::C, &mut ());
425        check_as_bytes(&Test::D, &mut ());
426        check_as_bytes(&Test::E, &mut ());
427
428        unsafe {
429            Test::check_bytes(bytes![1u8].cast(), &mut ()).unwrap_err();
430            Test::check_bytes(bytes![99u8].cast(), &mut ()).unwrap_err();
431            Test::check_bytes(bytes![102u8].cast(), &mut ()).unwrap_err();
432            Test::check_bytes(bytes![199u8].cast(), &mut ()).unwrap_err();
433            Test::check_bytes(bytes![202u8].cast(), &mut ()).unwrap_err();
434            Test::check_bytes(bytes![255u8].cast(), &mut ()).unwrap_err();
435        }
436    }
437
438    #[test]
439    fn test_context() {
440        use core::{convert::Infallible, fmt};
441
442        #[derive(Debug)]
443        #[repr(transparent)]
444        struct Test(i32);
445
446        struct TestContext(pub i32);
447
448        #[derive(Debug)]
449        struct TestError {
450            expected: i32,
451            found: i32,
452        }
453
454        impl fmt::Display for TestError {
455            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
456                write!(
457                    f,
458                    "mismatched value: expected {}, found {}",
459                    self.expected, self.found
460                )
461            }
462        }
463
464        #[cfg(feature = "std")]
465        impl std::error::Error for TestError {}
466
467        impl From<Infallible> for TestError {
468            fn from(_: Infallible) -> Self {
469                unsafe { core::hint::unreachable_unchecked() }
470            }
471        }
472
473        impl CheckBytes<TestContext> for Test {
474            type Error = TestError;
475
476            unsafe fn check_bytes<'a>(
477                value: *const Test,
478                context: &mut TestContext,
479            ) -> Result<&'a Self, Self::Error> {
480                let found = *i32::check_bytes(value.cast(), context)?;
481                let expected = context.0;
482                if expected == found {
483                    Ok(&*value)
484                } else {
485                    Err(TestError { expected, found })
486                }
487            }
488        }
489
490        unsafe {
491            Test::check_bytes(bytes![42u8, 0u8, 0u8, 0u8].cast(), &mut TestContext(42)).unwrap();
492            Test::check_bytes(bytes![41u8, 0u8, 0u8, 0u8].cast(), &mut TestContext(42))
493                .unwrap_err();
494        }
495
496        #[repr(transparent)]
497        #[derive(CheckBytes, Debug)]
498        struct TestContainer(Test);
499
500        unsafe {
501            TestContainer::check_bytes(bytes![42u8, 0u8, 0u8, 0u8].cast(), &mut TestContext(42))
502                .unwrap();
503            TestContainer::check_bytes(bytes![41u8, 0u8, 0u8, 0u8].cast(), &mut TestContext(42))
504                .unwrap_err();
505        }
506    }
507
508    #[test]
509    fn test_unsized() {
510        unsafe {
511            <[i32] as CheckBytes<()>>::check_bytes(
512                &[1, 2, 3, 4] as &[i32] as *const [i32],
513                &mut (),
514            )
515            .unwrap();
516            <str as CheckBytes<()>>::check_bytes("hello world" as *const str, &mut ()).unwrap();
517        }
518    }
519
520    #[test]
521    fn test_recursive() {
522        struct MyBox<T: ?Sized> {
523            inner: *const T,
524        }
525
526        impl<T: CheckBytes<C>, C: Default + ?Sized> CheckBytes<C> for MyBox<T> {
527            type Error = T::Error;
528
529            unsafe fn check_bytes<'a>(
530                value: *const Self,
531                context: &mut C,
532            ) -> Result<&'a Self, Self::Error> {
533                T::check_bytes((*value).inner, context)?;
534                Ok(&*value)
535            }
536        }
537
538        #[derive(CheckBytes)]
539        #[check_bytes(bound = "__C: Default")]
540        #[repr(u8)]
541        enum Node {
542            Nil,
543            Cons(#[omit_bounds] MyBox<Node>),
544        }
545
546        unsafe {
547            let nil = Node::Nil;
548            let cons = Node::Cons(MyBox {
549                inner: &nil as *const Node,
550            });
551            Node::check_bytes(&cons, &mut ()).unwrap();
552        }
553    }
554}