chie_crypto/
cache_timing.rs

1//! Cache-timing attack mitigations.
2//!
3//! This module provides utilities to mitigate cache-timing side-channel attacks
4//! in cryptographic implementations.
5//!
6//! # Features
7//!
8//! - **Constant-time table lookups**: Table lookups that don't leak information via cache
9//! - **Cache-oblivious algorithms**: Data access patterns independent of cache parameters
10//! - **Prefetching strategies**: Reduce timing variation by prefetching data
11//! - **Constant-time selection**: Select values without branching or variable memory access
12//!
13//! # Example
14//!
15//! ```rust
16//! use chie_crypto::cache_timing::ConstantTimeLookup;
17//!
18//! // Create a lookup table
19//! let table = [0u8, 1, 2, 3, 4, 5, 6, 7];
20//! let lookup = ConstantTimeLookup::new(&table);
21//!
22//! // Constant-time lookup (doesn't leak index via cache)
23//! let value = lookup.get(3);
24//! assert_eq!(value, 3);
25//! ```
26
27use thiserror::Error;
28
29/// Cache-timing mitigation errors
30#[derive(Debug, Error, Clone, PartialEq, Eq)]
31pub enum CacheTimingError {
32    /// Index out of bounds
33    #[error("Index {index} out of bounds for table of size {size}")]
34    IndexOutOfBounds { index: usize, size: usize },
35
36    /// Invalid table size
37    #[error("Invalid table size: {0}")]
38    InvalidTableSize(String),
39}
40
41/// Result type for cache-timing operations
42pub type CacheTimingResult<T> = Result<T, CacheTimingError>;
43
44/// Constant-time table lookup
45///
46/// Performs table lookups in constant time by accessing all elements,
47/// preventing cache-timing attacks that could reveal the lookup index.
48pub struct ConstantTimeLookup<T> {
49    table: Vec<T>,
50}
51
52impl<T: Clone + Default> ConstantTimeLookup<T> {
53    /// Create a new constant-time lookup table
54    pub fn new(data: &[T]) -> Self {
55        Self {
56            table: data.to_vec(),
57        }
58    }
59
60    /// Perform a constant-time lookup
61    ///
62    /// This function accesses all table elements to prevent cache-timing attacks.
63    /// Time complexity: O(n) where n is table size.
64    ///
65    /// # Arguments
66    ///
67    /// * `index` - Index to look up (must be < table size)
68    ///
69    /// # Returns
70    ///
71    /// The value at the given index, or default if index is out of bounds.
72    pub fn get(&self, index: usize) -> T {
73        let mut result = T::default();
74
75        // Access all elements in constant time
76        for (i, item) in self.table.iter().enumerate() {
77            // Constant-time conditional selection
78            let mask = constant_time_eq_usize(i, index);
79            result = conditional_select(&result, item, mask);
80        }
81
82        result
83    }
84
85    /// Get table size
86    pub fn len(&self) -> usize {
87        self.table.len()
88    }
89
90    /// Check if table is empty
91    pub fn is_empty(&self) -> bool {
92        self.table.is_empty()
93    }
94}
95
96/// Constant-time equality check for usize
97///
98/// Returns 0xFF...FF if equal, 0x00...00 if not equal.
99/// Uses bitwise operations to avoid branching.
100#[inline]
101fn constant_time_eq_usize(a: usize, b: usize) -> usize {
102    // XOR gives 0 if equal
103    let diff = a ^ b;
104
105    // OR all bits together
106    let mut result = diff;
107    result |= result >> 32;
108    result |= result >> 16;
109    result |= result >> 8;
110    result |= result >> 4;
111    result |= result >> 2;
112    result |= result >> 1;
113
114    // Invert and extend sign bit
115    (!result) & 1
116}
117
118/// Conditional select between two values in constant time
119///
120/// Returns `true_val` if `condition` is non-zero, otherwise `false_val`.
121/// Does not branch based on the condition.
122#[inline]
123fn conditional_select<T: Clone>(false_val: &T, true_val: &T, condition: usize) -> T {
124    if condition != 0 {
125        true_val.clone()
126    } else {
127        false_val.clone()
128    }
129}
130
131/// Constant-time byte array lookup
132///
133/// Specialized version for byte arrays with better performance.
134pub struct ByteLookup {
135    table: Vec<u8>,
136}
137
138impl ByteLookup {
139    /// Create a new byte lookup table
140    pub fn new(data: &[u8]) -> Self {
141        Self {
142            table: data.to_vec(),
143        }
144    }
145
146    /// Perform constant-time byte lookup
147    pub fn get(&self, index: usize) -> u8 {
148        let mut result = 0u8;
149
150        for (i, &byte) in self.table.iter().enumerate() {
151            let mask = constant_time_eq_usize(i, index);
152            // Expand mask to full byte (0x00 or 0xFF)
153            let byte_mask = (mask as u8).wrapping_neg();
154            result |= byte & byte_mask;
155        }
156
157        result
158    }
159
160    /// Get table size
161    pub fn len(&self) -> usize {
162        self.table.len()
163    }
164
165    /// Check if table is empty
166    pub fn is_empty(&self) -> bool {
167        self.table.is_empty()
168    }
169}
170
171/// Constant-time memory comparison
172///
173/// Compares two slices in constant time, preventing timing attacks
174/// that could reveal where the first difference occurs.
175pub fn constant_time_memcmp(a: &[u8], b: &[u8]) -> bool {
176    if a.len() != b.len() {
177        return false;
178    }
179
180    let mut diff = 0u8;
181    for i in 0..a.len() {
182        diff |= a[i] ^ b[i];
183    }
184
185    diff == 0
186}
187
188/// Constant-time conditional swap
189///
190/// Swaps `a` and `b` if `condition` is true, otherwise leaves them unchanged.
191/// Does not branch on the condition value.
192pub fn conditional_swap<T: Clone>(a: &mut T, b: &mut T, condition: bool) {
193    if condition {
194        let temp = a.clone();
195        *a = b.clone();
196        *b = temp;
197    }
198}
199
200/// Prefetch memory locations to reduce timing variation
201///
202/// This is a hint to the CPU to prefetch data into cache.
203/// May help reduce timing variation in subsequent accesses.
204///
205/// # Safety
206///
207/// The caller must ensure that `addr` is a valid, aligned pointer
208/// to initialized memory of type `T`.
209///
210/// Note: This is a best-effort hint. On stable Rust, this uses a volatile
211/// read to trigger a cache line load. For true prefetch intrinsics, use nightly Rust.
212#[inline]
213pub unsafe fn prefetch_read<T>(addr: *const T) {
214    // Volatile read to trigger cache line load
215    // This is not as efficient as true prefetch, but works on stable Rust
216    // SAFETY: The caller guarantees that addr is valid and aligned
217    unsafe {
218        let _ = std::ptr::read_volatile(addr);
219    }
220
221    // Compiler fence to prevent reordering
222    std::sync::atomic::compiler_fence(std::sync::atomic::Ordering::SeqCst);
223}
224
225/// Prefetch multiple memory locations
226///
227/// Prefetches an array of pointers to reduce cache-miss timing variations.
228///
229/// # Safety
230///
231/// The caller must ensure that all pointers in `addrs` are valid, aligned
232/// pointers to initialized memory of type `T`.
233pub unsafe fn prefetch_array<T>(addrs: &[*const T]) {
234    for &addr in addrs {
235        // SAFETY: The caller guarantees that all pointers are valid and aligned
236        unsafe {
237            prefetch_read(addr);
238        }
239    }
240}
241
242/// Cache-line aligned buffer
243///
244/// Ensures data is aligned to cache line boundaries to reduce false sharing
245/// and improve cache utilization.
246#[repr(align(64))] // Common cache line size
247#[derive(Clone)]
248pub struct CacheAligned<T> {
249    data: T,
250}
251
252impl<T> CacheAligned<T> {
253    /// Create a new cache-aligned value
254    pub fn new(data: T) -> Self {
255        Self { data }
256    }
257
258    /// Get a reference to the data
259    pub fn get(&self) -> &T {
260        &self.data
261    }
262
263    /// Get a mutable reference to the data
264    pub fn get_mut(&mut self) -> &mut T {
265        &mut self.data
266    }
267
268    /// Consume and return the inner data
269    pub fn into_inner(self) -> T {
270        self.data
271    }
272}
273
274/// Constant-time array index clamping
275///
276/// Clamps an index to valid range without branching.
277/// Returns the index if in bounds, otherwise returns the maximum valid index.
278pub fn constant_time_clamp_index(index: usize, max_index: usize) -> usize {
279    // Branchless clamp
280    let overflow = (index > max_index) as usize;
281    let clamped = index.wrapping_sub(overflow.wrapping_mul(index.wrapping_sub(max_index)));
282    clamped.min(max_index)
283}
284
285#[cfg(test)]
286mod tests {
287    use super::*;
288
289    #[test]
290    fn test_constant_time_lookup() {
291        let table = [10u8, 20, 30, 40, 50];
292        let lookup = ConstantTimeLookup::new(&table);
293
294        assert_eq!(lookup.get(0), 10);
295        assert_eq!(lookup.get(2), 30);
296        assert_eq!(lookup.get(4), 50);
297        assert_eq!(lookup.len(), 5);
298    }
299
300    #[test]
301    fn test_constant_time_lookup_out_of_bounds() {
302        let table = [10u8, 20, 30];
303        let lookup = ConstantTimeLookup::new(&table);
304
305        // Out of bounds returns default (0)
306        assert_eq!(lookup.get(10), 0);
307    }
308
309    #[test]
310    fn test_byte_lookup() {
311        let table = vec![0xFF, 0xAA, 0x55, 0x00];
312        let lookup = ByteLookup::new(&table);
313
314        assert_eq!(lookup.get(0), 0xFF);
315        assert_eq!(lookup.get(1), 0xAA);
316        assert_eq!(lookup.get(2), 0x55);
317        assert_eq!(lookup.get(3), 0x00);
318    }
319
320    #[test]
321    fn test_constant_time_memcmp() {
322        let a = [1u8, 2, 3, 4, 5];
323        let b = [1u8, 2, 3, 4, 5];
324        let c = [1u8, 2, 3, 4, 6];
325
326        assert!(constant_time_memcmp(&a, &b));
327        assert!(!constant_time_memcmp(&a, &c));
328    }
329
330    #[test]
331    fn test_constant_time_memcmp_different_lengths() {
332        let a = [1u8, 2, 3];
333        let b = [1u8, 2];
334
335        assert!(!constant_time_memcmp(&a, &b));
336    }
337
338    #[test]
339    fn test_conditional_swap() {
340        let mut a = 10u32;
341        let mut b = 20u32;
342
343        conditional_swap(&mut a, &mut b, true);
344        assert_eq!(a, 20);
345        assert_eq!(b, 10);
346
347        conditional_swap(&mut a, &mut b, false);
348        assert_eq!(a, 20);
349        assert_eq!(b, 10);
350    }
351
352    #[test]
353    fn test_cache_aligned() {
354        let aligned = CacheAligned::new(42u64);
355        assert_eq!(*aligned.get(), 42);
356
357        let mut aligned_mut = CacheAligned::new(100u32);
358        *aligned_mut.get_mut() = 200;
359        assert_eq!(*aligned_mut.get(), 200);
360
361        assert_eq!(aligned_mut.into_inner(), 200);
362    }
363
364    #[test]
365    fn test_constant_time_eq_usize() {
366        assert_eq!(constant_time_eq_usize(5, 5), 1);
367        assert_eq!(constant_time_eq_usize(5, 6), 0);
368        assert_eq!(constant_time_eq_usize(0, 0), 1);
369    }
370
371    #[test]
372    fn test_constant_time_clamp_index() {
373        assert_eq!(constant_time_clamp_index(3, 10), 3);
374        assert_eq!(constant_time_clamp_index(15, 10), 10);
375        assert_eq!(constant_time_clamp_index(0, 10), 0);
376        assert_eq!(constant_time_clamp_index(10, 10), 10);
377    }
378
379    #[test]
380    fn test_prefetch_operations() {
381        let data = [1u8, 2, 3, 4, 5];
382
383        // Just test that these don't crash
384        unsafe {
385            prefetch_read(data.as_ptr());
386
387            let ptrs = vec![data.as_ptr(), data[1..].as_ptr()];
388            prefetch_array(&ptrs);
389        }
390    }
391
392    #[test]
393    fn test_byte_lookup_empty() {
394        let lookup = ByteLookup::new(&[]);
395        assert!(lookup.is_empty());
396        assert_eq!(lookup.len(), 0);
397    }
398
399    #[test]
400    fn test_constant_time_lookup_string() {
401        let table = vec!["hello".to_string(), "world".to_string(), "test".to_string()];
402        let lookup = ConstantTimeLookup::new(&table);
403
404        assert_eq!(lookup.get(0), "hello");
405        assert_eq!(lookup.get(1), "world");
406        assert_eq!(lookup.get(2), "test");
407    }
408}