use core::mem::size_of;
use core::ops::BitXor;
use core::ptr::read_unaligned;
use crate::with_dit;
#[cfg(all(
target_pointer_width = "64",
any(not(target_arch = "riscv64"), target_feature = "unaligned-scalar-mem")
))]
pub(crate) type Word = u64;
#[cfg(all(
target_pointer_width = "32",
any(not(target_arch = "riscv32"), target_feature = "unaligned-scalar-mem")
))]
pub(crate) type Word = u32;
#[cfg(target_pointer_width = "16")]
pub(crate) type Word = u16;
#[cfg(not(any(
target_pointer_width = "64",
target_pointer_width = "32",
target_pointer_width = "16"
)))]
pub(crate) type Word = usize;
#[cfg(all(
any(target_arch = "riscv64", target_arch = "riscv32"),
not(target_feature = "unaligned-scalar-mem")
))]
pub(crate) type Word = u8;
#[cfg(all(
not(miri),
any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "arm",
target_arch = "aarch64",
target_arch = "arm64ec",
target_arch = "riscv32",
target_arch = "riscv64",
target_arch = "loongarch64",
target_arch = "s390x",
target_arch = "powerpc",
target_arch = "powerpc64",
)
))]
#[must_use]
#[inline(always)]
fn optimizer_hide(mut value: Word) -> Word {
unsafe {
core::arch::asm!("/* {0} */", inlateout(reg) value, options(pure, nomem, preserves_flags, nostack));
}
value
}
#[cfg(any(
miri,
not(any(
target_arch = "x86",
target_arch = "x86_64",
target_arch = "arm",
target_arch = "aarch64",
target_arch = "arm64ec",
target_arch = "riscv32",
target_arch = "riscv64",
target_arch = "loongarch64",
target_arch = "s390x",
target_arch = "powerpc",
target_arch = "powerpc64",
))
))]
#[must_use]
#[inline(never)]
fn optimizer_hide(value: Word) -> Word {
core::hint::black_box(value)
}
#[must_use]
#[inline(always)]
unsafe fn read_unaligned_from_slice<T>(src: &[u8]) -> T {
assert_eq!(src.len(), size_of::<T>());
unsafe { read_unaligned(src.as_ptr().cast::<T>()) }
}
#[must_use]
#[inline(always)]
pub(crate) fn constant_time_eq_impl(mut a: &[u8], mut b: &[u8], mut tmp: Word) -> bool {
if a.len() != b.len() {
return false;
}
b = &b[..a.len()];
if a.is_empty() {
return tmp == 0;
}
#[must_use]
#[inline(always)]
unsafe fn cmp_step<T: BitXor<Output = T>>(a: &mut &[u8], b: &mut &[u8]) -> T {
let tmpa = unsafe { read_unaligned_from_slice::<T>(&a[..size_of::<T>()]) };
let tmpb = unsafe { read_unaligned_from_slice::<T>(&b[..size_of::<T>()]) };
*a = &a[size_of::<T>()..];
*b = &b[size_of::<T>()..];
tmpa ^ tmpb
}
while a.len() >= size_of::<Word>() {
let cmp = optimizer_hide(unsafe { cmp_step::<Word>(&mut a, &mut b) });
tmp = optimizer_hide(tmp | cmp);
}
while a.len() >= size_of::<u128>() {
let cmp = optimizer_hide(unsafe { cmp_step::<u128>(&mut a, &mut b) } as Word);
tmp = optimizer_hide(tmp | cmp);
}
if a.len() >= size_of::<u64>() {
let cmp = optimizer_hide(unsafe { cmp_step::<u64>(&mut a, &mut b) } as Word);
tmp = optimizer_hide(tmp | cmp);
}
if a.len() >= size_of::<u32>() {
let cmp = optimizer_hide(unsafe { cmp_step::<u32>(&mut a, &mut b) } as Word);
tmp = optimizer_hide(tmp | cmp);
}
if a.len() >= size_of::<u16>() {
let cmp = optimizer_hide(unsafe { cmp_step::<u16>(&mut a, &mut b) } as Word);
tmp = optimizer_hide(tmp | cmp);
}
if a.len() >= size_of::<u8>() {
let cmp = optimizer_hide(unsafe { cmp_step::<u8>(&mut a, &mut b) } as Word);
tmp = optimizer_hide(tmp | cmp);
}
tmp == 0
}
#[must_use]
pub fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
with_dit(|| constant_time_eq_impl(a, b, 0))
}
#[must_use]
pub fn constant_time_eq_n<const N: usize>(a: &[u8; N], b: &[u8; N]) -> bool {
with_dit(|| constant_time_eq_impl(&a[..], &b[..], 0))
}
#[cfg(test)]
mod tests {
#[cfg(feature = "count_instructions_test")]
extern crate std;
#[cfg(feature = "count_instructions_test")]
#[test]
fn count_optimizer_hide_instructions() -> std::io::Result<()> {
use super::{Word, optimizer_hide};
use count_instructions::count_instructions;
fn count() -> std::io::Result<usize> {
let mut count = 0;
assert_eq!(
10 as Word,
count_instructions(
|| optimizer_hide(1)
+ optimizer_hide(2)
+ optimizer_hide(3)
+ optimizer_hide(4),
|_| count += 1
)?
);
Ok(count)
}
fn count_optimized() -> std::io::Result<usize> {
#[inline(always)]
fn inline_identity(value: Word) -> Word {
value
}
let mut count = 0;
assert_eq!(
10 as Word,
count_instructions(
|| inline_identity(1)
+ inline_identity(2)
+ inline_identity(3)
+ inline_identity(4),
|_| count += 1
)?
);
Ok(count)
}
assert!(count()? > count_optimized()?);
Ok(())
}
}