use sgx_types::marker::{BytewiseEquality};
use core::mem;
use alloc::slice;
pub trait ConsttimeMemEq<T: BytewiseEquality + ?Sized = Self> {
fn consttime_memeq(&self, other: &T) -> bool;
fn consttime_memne(&self, other: &T) -> bool { !self.consttime_memeq(other) }
}
impl<T> ConsttimeMemEq<[T]> for [T]
where T: Eq + BytewiseEquality {
fn consttime_memeq(&self, other: &[T]) -> bool {
if self.len() != other.len() {
return false;
}
if self.as_ptr() == other.as_ptr() {
return true;
}
let size = mem::size_of_val(self);
unsafe {
consttime_memequal(self.as_ptr() as *const u8,
other.as_ptr() as *const u8,
size) != 0
}
}
}
impl<T> ConsttimeMemEq<T> for T
where T: Eq + BytewiseEquality {
fn consttime_memeq(&self, other: &T) -> bool {
let size = mem::size_of_val(self);
if size == 0 {
return true;
}
unsafe {
consttime_memequal(self as *const T as *const u8,
other as *const T as *const u8,
size) != 0
}
}
}
unsafe fn consttime_memequal(b1: *const u8,
b2: *const u8,
l: usize) -> i32 {
let mut res: u32 = 0;
let mut len = l;
let p1 = slice::from_raw_parts(b1, l);
let p2 = slice::from_raw_parts(b2, l);
while len > 0 {
len -= 1;
res |= (p1[len] ^ p2[len]) as u32;
}
(1 & ((res - 1) >> 8)) as i32
}