mem_cmp/
mem_eq.rs

1use core::mem::{size_of, size_of_val, transmute_copy};
2use ext::*;
3
4/// Trait for equality comparisons performed over bytes directly.
5pub trait MemEq<Rhs: ?Sized = Self> {
6    /// Tests whether `self` and `other` are equal in memory.
7    #[must_use]
8    fn mem_eq(&self, other: &Rhs) -> bool;
9
10    /// Tests whether `self` and `other` are not equal in memory.
11    #[inline]
12    #[must_use]
13    fn mem_neq(&self, other: &Rhs) -> bool { !self.mem_eq(other) }
14}
15
16#[derive(Copy, Clone)]
17#[cfg_attr(not(feature = "simd"), derive(PartialEq))]
18#[cfg_attr(feature = "simd", repr(simd))]
19struct U128(u64, u64);
20
21#[cfg(feature = "simd")]
22impl PartialEq for U128 {
23    #[inline(always)]
24    fn eq(&self, other: &Self) -> bool {
25        use core::mem::transmute;
26        use simd::u32x4;
27        unsafe {
28            let x: u32x4 = transmute(*self);
29            let y: u32x4 = transmute(*other);
30            x.eq(y).all()
31        }
32    }
33}
34
35#[cfg(feature = "avx")]
36#[derive(Copy, Clone)]
37#[repr(simd)]
38struct U256(u64, u64, u64, u64);
39
40#[cfg(feature = "avx")]
41impl PartialEq for U256 {
42    #[inline(always)]
43    fn eq(&self, other: &Self) -> bool {
44        use core::mem::transmute;
45        use simd::x86::avx::u8x32;
46        unsafe {
47            let x: u8x32 = transmute(*self);
48            let y: u8x32 = transmute(*other);
49            x.eq(y).all()
50        }
51    }
52}
53
54#[cfg(not(feature = "avx"))]
55type U256 = (U128, U128);
56
57type U512 = (U256, U256);
58
59macro_rules! from_type {
60    ($t:ty, $x:expr, $y:expr) => {
61        unsafe {
62            let x: $t = transmute_copy($x);
63            let y: $t = transmute_copy($y);
64            x == y
65        }
66    }
67}
68
69impl<T, U> MemEq<U> for T {
70    #[inline]
71    fn mem_eq(&self, other: &U) -> bool {
72        let size = size_of::<T>();
73        if size != size_of::<U>() { return false; }
74
75        macro_rules! impl_eq {
76            ($($t:ty),+; simd: $($s:ty),+ $(,)*) => {
77                $(if size == size_of::<$t>() {
78                    from_type!($t, self, other)
79                } else)+ $(if cfg!(feature = "simd") && size == size_of::<$s>() {
80                    from_type!($s, self, other)
81                } else)+ {
82                    unsafe { _memcmp(self, other, 1) == 0 }
83                }
84            }
85        }
86        impl_eq! {
87            u8, u16, u32, u64,
88            U128, (U128, u64),
89            U256, (U256, u64), (U256, U128), (U256, U128, u64),
90            U512, (U512, u64), (U512, U256), (U512, U256, u64);
91
92            // These types are only used when simd is enabled
93            simd: (U512, U512), (U512, U512, U512), (U512, U512, U512, U512)
94        }
95    }
96}
97
98#[inline(always)]
99fn _mem_eq<T: ?Sized, U: ?Sized>(a: &T, b: &U) -> bool {
100    let size = size_of_val(a);
101    size == size_of_val(b) && unsafe {
102        let x = a as *const _ as _;
103        let y = b as *const _ as _;
104        (x as usize) == (y as usize) || memcmp(x, y, size) == 0
105    }
106}
107
108#[cfg(feature = "specialization")]
109impl<T: ?Sized, U: ?Sized> MemEq<U> for T {
110    #[inline]
111    default fn mem_eq(&self, other: &U) -> bool {
112        _mem_eq(self, other)
113    }
114}
115
116#[cfg(not(feature = "specialization"))]
117impl<T, U> MemEq<[U]> for [T] {
118    #[inline]
119    fn mem_eq(&self, other: &[U]) -> bool {
120        _mem_eq(self, other)
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn equal_bytes() {
130        let buf = [0u8; 128];
131        assert!(buf.mem_eq(&buf));
132
133        let x = [0u8; 1];
134        let y = 0u8;
135        assert!(x.mem_eq(&y));
136
137        assert!(buf.mem_neq(&x));
138        assert!(buf.mem_neq(&y));
139    }
140}