1use thiserror::Error;
28
29#[derive(Debug, Error, Clone, PartialEq, Eq)]
31pub enum CacheTimingError {
32 #[error("Index {index} out of bounds for table of size {size}")]
34 IndexOutOfBounds { index: usize, size: usize },
35
36 #[error("Invalid table size: {0}")]
38 InvalidTableSize(String),
39}
40
41pub type CacheTimingResult<T> = Result<T, CacheTimingError>;
43
44pub struct ConstantTimeLookup<T> {
49 table: Vec<T>,
50}
51
52impl<T: Clone + Default> ConstantTimeLookup<T> {
53 pub fn new(data: &[T]) -> Self {
55 Self {
56 table: data.to_vec(),
57 }
58 }
59
60 pub fn get(&self, index: usize) -> T {
73 let mut result = T::default();
74
75 for (i, item) in self.table.iter().enumerate() {
77 let mask = constant_time_eq_usize(i, index);
79 result = conditional_select(&result, item, mask);
80 }
81
82 result
83 }
84
85 pub fn len(&self) -> usize {
87 self.table.len()
88 }
89
90 pub fn is_empty(&self) -> bool {
92 self.table.is_empty()
93 }
94}
95
96#[inline]
101fn constant_time_eq_usize(a: usize, b: usize) -> usize {
102 let diff = a ^ b;
104
105 let mut result = diff;
107 result |= result >> 32;
108 result |= result >> 16;
109 result |= result >> 8;
110 result |= result >> 4;
111 result |= result >> 2;
112 result |= result >> 1;
113
114 (!result) & 1
116}
117
118#[inline]
123fn conditional_select<T: Clone>(false_val: &T, true_val: &T, condition: usize) -> T {
124 if condition != 0 {
125 true_val.clone()
126 } else {
127 false_val.clone()
128 }
129}
130
131pub struct ByteLookup {
135 table: Vec<u8>,
136}
137
138impl ByteLookup {
139 pub fn new(data: &[u8]) -> Self {
141 Self {
142 table: data.to_vec(),
143 }
144 }
145
146 pub fn get(&self, index: usize) -> u8 {
148 let mut result = 0u8;
149
150 for (i, &byte) in self.table.iter().enumerate() {
151 let mask = constant_time_eq_usize(i, index);
152 let byte_mask = (mask as u8).wrapping_neg();
154 result |= byte & byte_mask;
155 }
156
157 result
158 }
159
160 pub fn len(&self) -> usize {
162 self.table.len()
163 }
164
165 pub fn is_empty(&self) -> bool {
167 self.table.is_empty()
168 }
169}
170
171pub fn constant_time_memcmp(a: &[u8], b: &[u8]) -> bool {
176 if a.len() != b.len() {
177 return false;
178 }
179
180 let mut diff = 0u8;
181 for i in 0..a.len() {
182 diff |= a[i] ^ b[i];
183 }
184
185 diff == 0
186}
187
188pub fn conditional_swap<T: Clone>(a: &mut T, b: &mut T, condition: bool) {
193 if condition {
194 let temp = a.clone();
195 *a = b.clone();
196 *b = temp;
197 }
198}
199
200#[inline]
213pub unsafe fn prefetch_read<T>(addr: *const T) {
214 unsafe {
218 let _ = std::ptr::read_volatile(addr);
219 }
220
221 std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
223}
224
225pub unsafe fn prefetch_array<T>(addrs: &[*const T]) {
234 for &addr in addrs {
235 unsafe {
237 prefetch_read(addr);
238 }
239 }
240}
241
242#[repr(align(64))] #[derive(Clone)]
248pub struct CacheAligned<T> {
249 data: T,
250}
251
252impl<T> CacheAligned<T> {
253 pub fn new(data: T) -> Self {
255 Self { data }
256 }
257
258 pub fn get(&self) -> &T {
260 &self.data
261 }
262
263 pub fn get_mut(&mut self) -> &mut T {
265 &mut self.data
266 }
267
268 pub fn into_inner(self) -> T {
270 self.data
271 }
272}
273
274pub fn constant_time_clamp_index(index: usize, max_index: usize) -> usize {
279 let overflow = (index > max_index) as usize;
281 let clamped = index.wrapping_sub(overflow.wrapping_mul(index.wrapping_sub(max_index)));
282 clamped.min(max_index)
283}
284
285#[cfg(test)]
286mod tests {
287 use super::*;
288
289 #[test]
290 fn test_constant_time_lookup() {
291 let table = [10u8, 20, 30, 40, 50];
292 let lookup = ConstantTimeLookup::new(&table);
293
294 assert_eq!(lookup.get(0), 10);
295 assert_eq!(lookup.get(2), 30);
296 assert_eq!(lookup.get(4), 50);
297 assert_eq!(lookup.len(), 5);
298 }
299
300 #[test]
301 fn test_constant_time_lookup_out_of_bounds() {
302 let table = [10u8, 20, 30];
303 let lookup = ConstantTimeLookup::new(&table);
304
305 assert_eq!(lookup.get(10), 0);
307 }
308
309 #[test]
310 fn test_byte_lookup() {
311 let table = vec![0xFF, 0xAA, 0x55, 0x00];
312 let lookup = ByteLookup::new(&table);
313
314 assert_eq!(lookup.get(0), 0xFF);
315 assert_eq!(lookup.get(1), 0xAA);
316 assert_eq!(lookup.get(2), 0x55);
317 assert_eq!(lookup.get(3), 0x00);
318 }
319
320 #[test]
321 fn test_constant_time_memcmp() {
322 let a = [1u8, 2, 3, 4, 5];
323 let b = [1u8, 2, 3, 4, 5];
324 let c = [1u8, 2, 3, 4, 6];
325
326 assert!(constant_time_memcmp(&a, &b));
327 assert!(!constant_time_memcmp(&a, &c));
328 }
329
330 #[test]
331 fn test_constant_time_memcmp_different_lengths() {
332 let a = [1u8, 2, 3];
333 let b = [1u8, 2];
334
335 assert!(!constant_time_memcmp(&a, &b));
336 }
337
338 #[test]
339 fn test_conditional_swap() {
340 let mut a = 10u32;
341 let mut b = 20u32;
342
343 conditional_swap(&mut a, &mut b, true);
344 assert_eq!(a, 20);
345 assert_eq!(b, 10);
346
347 conditional_swap(&mut a, &mut b, false);
348 assert_eq!(a, 20);
349 assert_eq!(b, 10);
350 }
351
352 #[test]
353 fn test_cache_aligned() {
354 let aligned = CacheAligned::new(42u64);
355 assert_eq!(*aligned.get(), 42);
356
357 let mut aligned_mut = CacheAligned::new(100u32);
358 *aligned_mut.get_mut() = 200;
359 assert_eq!(*aligned_mut.get(), 200);
360
361 assert_eq!(aligned_mut.into_inner(), 200);
362 }
363
364 #[test]
365 fn test_constant_time_eq_usize() {
366 assert_eq!(constant_time_eq_usize(5, 5), 1);
367 assert_eq!(constant_time_eq_usize(5, 6), 0);
368 assert_eq!(constant_time_eq_usize(0, 0), 1);
369 }
370
371 #[test]
372 fn test_constant_time_clamp_index() {
373 assert_eq!(constant_time_clamp_index(3, 10), 3);
374 assert_eq!(constant_time_clamp_index(15, 10), 10);
375 assert_eq!(constant_time_clamp_index(0, 10), 0);
376 assert_eq!(constant_time_clamp_index(10, 10), 10);
377 }
378
379 #[test]
380 fn test_prefetch_operations() {
381 let data = [1u8, 2, 3, 4, 5];
382
383 unsafe {
385 prefetch_read(data.as_ptr());
386
387 let ptrs = vec![data.as_ptr(), data[1..].as_ptr()];
388 prefetch_array(&ptrs);
389 }
390 }
391
392 #[test]
393 fn test_byte_lookup_empty() {
394 let lookup = ByteLookup::new(&[]);
395 assert!(lookup.is_empty());
396 assert_eq!(lookup.len(), 0);
397 }
398
399 #[test]
400 fn test_constant_time_lookup_string() {
401 let table = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
402 let lookup = ConstantTimeLookup::new(&table);
403
404 assert_eq!(lookup.get(0), "hello");
405 assert_eq!(lookup.get(1), "world");
406 assert_eq!(lookup.get(2), "test");
407 }
408}