use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum CacheTimingError {
#[error("Index {index} out of bounds for table of size {size}")]
IndexOutOfBounds { index: usize, size: usize },
#[error("Invalid table size: {0}")]
InvalidTableSize(String),
}
pub type CacheTimingResult<T> = Result<T, CacheTimingError>;
pub struct ConstantTimeLookup<T> {
table: Vec<T>,
}
impl<T: Clone + Default> ConstantTimeLookup<T> {
pub fn new(data: &[T]) -> Self {
Self {
table: data.to_vec(),
}
}
pub fn get(&self, index: usize) -> T {
let mut result = T::default();
for (i, item) in self.table.iter().enumerate() {
let mask = constant_time_eq_usize(i, index);
result = conditional_select(&result, item, mask);
}
result
}
pub fn len(&self) -> usize {
self.table.len()
}
pub fn is_empty(&self) -> bool {
self.table.is_empty()
}
}
#[inline]
fn constant_time_eq_usize(a: usize, b: usize) -> usize {
let diff = a ^ b;
let mut result = diff;
result |= result >> 32;
result |= result >> 16;
result |= result >> 8;
result |= result >> 4;
result |= result >> 2;
result |= result >> 1;
(!result) & 1
}
#[inline]
fn conditional_select<T: Clone>(false_val: &T, true_val: &T, condition: usize) -> T {
if condition != 0 {
true_val.clone()
} else {
false_val.clone()
}
}
pub struct ByteLookup {
table: Vec<u8>,
}
impl ByteLookup {
pub fn new(data: &[u8]) -> Self {
Self {
table: data.to_vec(),
}
}
pub fn get(&self, index: usize) -> u8 {
let mut result = 0u8;
for (i, &byte) in self.table.iter().enumerate() {
let mask = constant_time_eq_usize(i, index);
let byte_mask = (mask as u8).wrapping_neg();
result |= byte & byte_mask;
}
result
}
pub fn len(&self) -> usize {
self.table.len()
}
pub fn is_empty(&self) -> bool {
self.table.is_empty()
}
}
pub fn constant_time_memcmp(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
let mut diff = 0u8;
for i in 0..a.len() {
diff |= a[i] ^ b[i];
}
diff == 0
}
pub fn conditional_swap<T: Clone>(a: &mut T, b: &mut T, condition: bool) {
if condition {
let temp = a.clone();
*a = b.clone();
*b = temp;
}
}
#[inline]
pub unsafe fn prefetch_read<T>(addr: *const T) {
unsafe {
let _ = std::ptr::read_volatile(addr);
}
std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
}
pub unsafe fn prefetch_array<T>(addrs: &[*const T]) {
for &addr in addrs {
unsafe {
prefetch_read(addr);
}
}
}
#[repr(align(64))] #[derive(Clone)]
pub struct CacheAligned<T> {
data: T,
}
impl<T> CacheAligned<T> {
pub fn new(data: T) -> Self {
Self { data }
}
pub fn get(&self) -> &T {
&self.data
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.data
}
pub fn into_inner(self) -> T {
self.data
}
}
pub fn constant_time_clamp_index(index: usize, max_index: usize) -> usize {
let overflow = (index > max_index) as usize;
let clamped = index.wrapping_sub(overflow.wrapping_mul(index.wrapping_sub(max_index)));
clamped.min(max_index)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_lookup() {
let table = [10u8, 20, 30, 40, 50];
let lookup = ConstantTimeLookup::new(&table);
assert_eq!(lookup.get(0), 10);
assert_eq!(lookup.get(2), 30);
assert_eq!(lookup.get(4), 50);
assert_eq!(lookup.len(), 5);
}
#[test]
fn test_constant_time_lookup_out_of_bounds() {
let table = [10u8, 20, 30];
let lookup = ConstantTimeLookup::new(&table);
assert_eq!(lookup.get(10), 0);
}
#[test]
fn test_byte_lookup() {
let table = vec![0xFF, 0xAA, 0x55, 0x00];
let lookup = ByteLookup::new(&table);
assert_eq!(lookup.get(0), 0xFF);
assert_eq!(lookup.get(1), 0xAA);
assert_eq!(lookup.get(2), 0x55);
assert_eq!(lookup.get(3), 0x00);
}
#[test]
fn test_constant_time_memcmp() {
let a = [1u8, 2, 3, 4, 5];
let b = [1u8, 2, 3, 4, 5];
let c = [1u8, 2, 3, 4, 6];
assert!(constant_time_memcmp(&a, &b));
assert!(!constant_time_memcmp(&a, &c));
}
#[test]
fn test_constant_time_memcmp_different_lengths() {
let a = [1u8, 2, 3];
let b = [1u8, 2];
assert!(!constant_time_memcmp(&a, &b));
}
#[test]
fn test_conditional_swap() {
let mut a = 10u32;
let mut b = 20u32;
conditional_swap(&mut a, &mut b, true);
assert_eq!(a, 20);
assert_eq!(b, 10);
conditional_swap(&mut a, &mut b, false);
assert_eq!(a, 20);
assert_eq!(b, 10);
}
#[test]
fn test_cache_aligned() {
let aligned = CacheAligned::new(42u64);
assert_eq!(*aligned.get(), 42);
let mut aligned_mut = CacheAligned::new(100u32);
*aligned_mut.get_mut() = 200;
assert_eq!(*aligned_mut.get(), 200);
assert_eq!(aligned_mut.into_inner(), 200);
}
#[test]
fn test_constant_time_eq_usize() {
assert_eq!(constant_time_eq_usize(5, 5), 1);
assert_eq!(constant_time_eq_usize(5, 6), 0);
assert_eq!(constant_time_eq_usize(0, 0), 1);
}
#[test]
fn test_constant_time_clamp_index() {
assert_eq!(constant_time_clamp_index(3, 10), 3);
assert_eq!(constant_time_clamp_index(15, 10), 10);
assert_eq!(constant_time_clamp_index(0, 10), 0);
assert_eq!(constant_time_clamp_index(10, 10), 10);
}
#[test]
fn test_prefetch_operations() {
let data = [1u8, 2, 3, 4, 5];
unsafe {
prefetch_read(data.as_ptr());
let ptrs = vec![data.as_ptr(), data[1..].as_ptr()];
prefetch_array(&ptrs);
}
}
#[test]
fn test_byte_lookup_empty() {
let lookup = ByteLookup::new(&[]);
assert!(lookup.is_empty());
assert_eq!(lookup.len(), 0);
}
#[test]
fn test_constant_time_lookup_string() {
let table = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
let lookup = ConstantTimeLookup::new(&table);
assert_eq!(lookup.get(0), "hello");
assert_eq!(lookup.get(1), "world");
assert_eq!(lookup.get(2), "test");
}
}