Skip to main content

arcanum_verify/
timing.rs

1//! Timing analysis tools for detecting side-channel leaks.
2//!
3//! Uses statistical methods (inspired by dudect) to detect timing
4//! variations that could leak secret information.
5//!
6//! ## Methodology
7//!
8//! The timing test works by:
9//! 1. Running the same operation on two input classes (e.g., all-zero vs random)
10//! 2. Collecting timing measurements for each class
11//! 3. Using Welch's t-test to detect statistically significant differences
12//!
13//! A large |t-value| indicates the operation takes different time for
14//! different inputs, which could leak secret information.
15//!
16//! ## Usage
17//!
18//! ```ignore
19//! use arcanum_verify::prelude::*;
20//!
21//! let result = TimingTest::new("my_crypto_operation")
22//!     .iterations(100_000)
23//!     .with_percentile_cropping(5.0) // Remove top/bottom 5%
24//!     .run(|class| {
25//!         let input = match class {
26//!             Class::Left => [0u8; 32],
27//!             Class::Right => random_bytes(),
28//!         };
29//!         crypto_operation(&input)
30//!     });
31//!
32//! assert!(result.is_constant_time());
33//! ```
34
35use crate::errors::{VerifyError, VerifyResult};
36use crate::stats;
37
38/// Classification for timing test inputs.
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum Class {
41    /// First class of inputs (e.g., all-zero keys)
42    Left,
43    /// Second class of inputs (e.g., all-one keys)
44    Right,
45}
46
47/// Configuration for percentile-based outlier removal.
48#[derive(Debug, Clone, Copy)]
49pub struct PercentileCrop {
50    /// Percentage to remove from the low end (0-50)
51    pub low: f64,
52    /// Percentage to remove from the high end (0-50)
53    pub high: f64,
54}
55
56impl Default for PercentileCrop {
57    fn default() -> Self {
58        Self {
59            low: 0.0,
60            high: 0.0,
61        }
62    }
63}
64
65impl PercentileCrop {
66    /// Create symmetric cropping (same percentage from both ends).
67    pub fn symmetric(percent: f64) -> Self {
68        Self {
69            low: percent,
70            high: percent,
71        }
72    }
73
74    /// Create asymmetric cropping.
75    pub fn asymmetric(low: f64, high: f64) -> Self {
76        Self { low, high }
77    }
78}
79
80/// Result of a timing analysis test.
81#[derive(Debug, Clone)]
82pub struct TimingResult {
83    /// Name of the test
84    pub name: String,
85    /// Number of raw samples collected
86    pub samples: usize,
87    /// Number of samples after cropping
88    pub samples_after_crop: usize,
89    /// Welch's t-statistic
90    pub t_value: f64,
91    /// Whether the test passed (no leak detected)
92    pub passed: bool,
93    /// Threshold used for detection
94    pub threshold: f64,
95    /// Mean timing for left class (nanoseconds)
96    pub mean_left: f64,
97    /// Mean timing for right class (nanoseconds)
98    pub mean_right: f64,
99    /// Standard deviation for left class
100    pub std_left: f64,
101    /// Standard deviation for right class
102    pub std_right: f64,
103}
104
105impl TimingResult {
106    /// Check if the operation is constant-time.
107    pub fn is_constant_time(&self) -> bool {
108        self.passed
109    }
110
111    /// Get the absolute t-value.
112    pub fn abs_t_value(&self) -> f64 {
113        self.t_value.abs()
114    }
115
116    /// Get timing difference as percentage of mean.
117    pub fn timing_difference_percent(&self) -> f64 {
118        let mean = (self.mean_left + self.mean_right) / 2.0;
119        if mean == 0.0 {
120            0.0
121        } else {
122            ((self.mean_left - self.mean_right).abs() / mean) * 100.0
123        }
124    }
125
126    /// Get a human-readable summary.
127    pub fn summary(&self) -> String {
128        format!(
129            "{}: t={:.2} (threshold={:.1}) - {}",
130            self.name,
131            self.t_value,
132            self.threshold,
133            if self.passed {
134                "PASS"
135            } else {
136                "FAIL - TIMING LEAK DETECTED"
137            }
138        )
139    }
140
141    /// Get detailed report.
142    pub fn detailed_report(&self) -> String {
143        format!(
144            "{}\n\
145             Samples: {} (after crop: {})\n\
146             Left class:  mean={:.2}ns, std={:.2}ns\n\
147             Right class: mean={:.2}ns, std={:.2}ns\n\
148             Difference: {:.4}%\n\
149             t-statistic: {:.4} (threshold: ±{:.1})\n\
150             Result: {}",
151            self.name,
152            self.samples,
153            self.samples_after_crop,
154            self.mean_left,
155            self.std_left,
156            self.mean_right,
157            self.std_right,
158            self.timing_difference_percent(),
159            self.t_value,
160            self.threshold,
161            if self.passed {
162                "PASS (constant-time)"
163            } else {
164                "FAIL (timing leak detected)"
165            }
166        )
167    }
168}
169
170impl std::fmt::Display for TimingResult {
171    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172        write!(f, "{}", self.summary())
173    }
174}
175
176/// Online statistics accumulator using Welford's algorithm.
177///
178/// Computes mean and variance in a single pass with numerical stability.
179#[derive(Debug, Clone, Default)]
180pub struct OnlineStats {
181    count: usize,
182    mean: f64,
183    m2: f64, // Sum of squares of differences from mean
184}
185
186impl OnlineStats {
187    /// Create a new accumulator.
188    pub fn new() -> Self {
189        Self::default()
190    }
191
192    /// Add a new sample.
193    pub fn update(&mut self, x: f64) {
194        self.count += 1;
195        let delta = x - self.mean;
196        self.mean += delta / self.count as f64;
197        let delta2 = x - self.mean;
198        self.m2 += delta * delta2;
199    }
200
201    /// Get the number of samples.
202    pub fn count(&self) -> usize {
203        self.count
204    }
205
206    /// Get the mean.
207    pub fn mean(&self) -> f64 {
208        self.mean
209    }
210
211    /// Get the sample variance.
212    pub fn variance(&self) -> f64 {
213        if self.count < 2 {
214            0.0
215        } else {
216            self.m2 / (self.count - 1) as f64
217        }
218    }
219
220    /// Get the standard deviation.
221    pub fn std_dev(&self) -> f64 {
222        self.variance().sqrt()
223    }
224}
225
226/// Builder for timing tests.
227pub struct TimingTest {
228    name: String,
229    iterations: usize,
230    warmup: usize,
231    threshold: f64,
232    percentile_crop: PercentileCrop,
233}
234
235impl TimingTest {
236    /// Create a new timing test.
237    pub fn new(name: impl Into<String>) -> Self {
238        Self {
239            name: name.into(),
240            iterations: 10_000,
241            warmup: 100,
242            threshold: stats::TIMING_LEAK_THRESHOLD,
243            percentile_crop: PercentileCrop::default(),
244        }
245    }
246
247    /// Set the number of iterations.
248    pub fn iterations(mut self, n: usize) -> Self {
249        self.iterations = n;
250        self
251    }
252
253    /// Set the number of warmup iterations.
254    pub fn warmup(mut self, n: usize) -> Self {
255        self.warmup = n;
256        self
257    }
258
259    /// Set the t-value threshold.
260    pub fn threshold(mut self, t: f64) -> Self {
261        self.threshold = t;
262        self
263    }
264
265    /// Enable symmetric percentile cropping.
266    ///
267    /// Removes the specified percentage of samples from both ends
268    /// of the timing distribution to reduce noise from outliers.
269    pub fn with_percentile_cropping(mut self, percent: f64) -> Self {
270        self.percentile_crop = PercentileCrop::symmetric(percent);
271        self
272    }
273
274    /// Enable asymmetric percentile cropping.
275    pub fn with_asymmetric_cropping(mut self, low: f64, high: f64) -> Self {
276        self.percentile_crop = PercentileCrop::asymmetric(low, high);
277        self
278    }
279
280    /// Crop timing samples based on percentiles.
281    fn crop_samples(&self, samples: &mut Vec<f64>) {
282        if self.percentile_crop.low == 0.0 && self.percentile_crop.high == 0.0 {
283            return;
284        }
285
286        samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
287
288        let n = samples.len();
289        let low_idx = ((n as f64 * self.percentile_crop.low / 100.0) as usize).min(n / 2);
290        let high_idx = n - ((n as f64 * self.percentile_crop.high / 100.0) as usize).min(n / 2);
291
292        *samples = samples[low_idx..high_idx].to_vec();
293    }
294
295    /// Run the timing test.
296    ///
297    /// The function `f` should take a `Class` and return some result.
298    /// Timing measurements are taken for both classes and compared.
299    pub fn run<F, R>(self, mut f: F) -> TimingResult
300    where
301        F: FnMut(Class) -> R,
302    {
303        use std::time::Instant;
304
305        // Warmup phase
306        for _ in 0..self.warmup {
307            let _ = f(Class::Left);
308            let _ = f(Class::Right);
309        }
310
311        let mut left_times = Vec::with_capacity(self.iterations);
312        let mut right_times = Vec::with_capacity(self.iterations);
313
314        // Interleave measurements to reduce systematic bias
315        for _ in 0..self.iterations {
316            // Measure left class
317            let start = Instant::now();
318            let _result = std::hint::black_box(f(Class::Left));
319            let elapsed = start.elapsed().as_nanos() as f64;
320            left_times.push(elapsed);
321
322            // Measure right class
323            let start = Instant::now();
324            let _result = std::hint::black_box(f(Class::Right));
325            let elapsed = start.elapsed().as_nanos() as f64;
326            right_times.push(elapsed);
327        }
328
329        let raw_samples = self.iterations * 2;
330
331        // Apply percentile cropping
332        self.crop_samples(&mut left_times);
333        self.crop_samples(&mut right_times);
334
335        let samples_after_crop = left_times.len() + right_times.len();
336
337        // Compute statistics
338        let mut left_stats = OnlineStats::new();
339        for &t in &left_times {
340            left_stats.update(t);
341        }
342
343        let mut right_stats = OnlineStats::new();
344        for &t in &right_times {
345            right_stats.update(t);
346        }
347
348        // Compute t-statistic
349        let t_value = stats::welch_t_test(&left_times, &right_times);
350        let passed = t_value.abs() < self.threshold;
351
352        TimingResult {
353            name: self.name,
354            samples: raw_samples,
355            samples_after_crop,
356            t_value,
357            passed,
358            threshold: self.threshold,
359            mean_left: left_stats.mean(),
360            mean_right: right_stats.mean(),
361            std_left: left_stats.std_dev(),
362            std_right: right_stats.std_dev(),
363        }
364    }
365
366    /// Run with online statistics (memory-efficient for large iterations).
367    pub fn run_online<F, R>(self, mut f: F) -> TimingResult
368    where
369        F: FnMut(Class) -> R,
370    {
371        use std::time::Instant;
372
373        // Warmup
374        for _ in 0..self.warmup {
375            let _ = f(Class::Left);
376            let _ = f(Class::Right);
377        }
378
379        let mut left_stats = OnlineStats::new();
380        let mut right_stats = OnlineStats::new();
381
382        // Collect measurements with online statistics
383        for _ in 0..self.iterations {
384            // Measure left class
385            let start = Instant::now();
386            let _result = std::hint::black_box(f(Class::Left));
387            let elapsed = start.elapsed().as_nanos() as f64;
388            left_stats.update(elapsed);
389
390            // Measure right class
391            let start = Instant::now();
392            let _result = std::hint::black_box(f(Class::Right));
393            let elapsed = start.elapsed().as_nanos() as f64;
394            right_stats.update(elapsed);
395        }
396
397        // Compute t-statistic using online stats
398        let t_value = stats::welch_t_online(&left_stats, &right_stats);
399        let passed = t_value.abs() < self.threshold;
400
401        TimingResult {
402            name: self.name,
403            samples: self.iterations * 2,
404            samples_after_crop: self.iterations * 2, // No cropping in online mode
405            t_value,
406            passed,
407            threshold: self.threshold,
408            mean_left: left_stats.mean(),
409            mean_right: right_stats.mean(),
410            std_left: left_stats.std_dev(),
411            std_right: right_stats.std_dev(),
412        }
413    }
414}
415
416/// Run a timing test and return an error if a leak is detected.
417pub fn assert_constant_time<F, R>(name: &str, iterations: usize, f: F) -> VerifyResult<()>
418where
419    F: FnMut(Class) -> R,
420{
421    let result = TimingTest::new(name).iterations(iterations).run(f);
422
423    if result.passed {
424        Ok(())
425    } else {
426        Err(VerifyError::TimingLeakDetected {
427            t_value: result.t_value,
428            threshold: result.threshold,
429        })
430    }
431}
432
433/// Common test patterns for cryptographic operations.
434pub mod patterns {
435    use super::*;
436
437    /// Test key comparison for constant-time behavior.
438    ///
439    /// Compares timing of operations on all-zero vs all-one keys.
440    pub fn test_key_comparison<F, R>(name: &str, iterations: usize, mut op: F) -> TimingResult
441    where
442        F: FnMut(&[u8; 32]) -> R,
443    {
444        let zero_key = [0u8; 32];
445        let one_key = [0xFFu8; 32];
446
447        TimingTest::new(name)
448            .iterations(iterations)
449            .run(move |class| {
450                let key = match class {
451                    Class::Left => &zero_key,
452                    Class::Right => &one_key,
453                };
454                op(key)
455            })
456    }
457
458    /// Test early exit behavior (e.g., MAC verification).
459    ///
460    /// Compares timing when the first byte differs vs last byte differs.
461    pub fn test_early_exit<F>(name: &str, iterations: usize, mut compare: F) -> TimingResult
462    where
463        F: FnMut(&[u8; 32], &[u8; 32]) -> bool,
464    {
465        let correct = [0u8; 32];
466        let mut wrong_first = [0u8; 32];
467        wrong_first[0] = 0xFF;
468        let mut wrong_last = [0u8; 32];
469        wrong_last[31] = 0xFF;
470
471        TimingTest::new(name)
472            .iterations(iterations)
473            .run(move |class| {
474                let wrong = match class {
475                    Class::Left => &wrong_first,
476                    Class::Right => &wrong_last,
477                };
478                compare(&correct, wrong)
479            })
480    }
481
482    /// Test padding oracle behavior.
483    ///
484    /// Compares timing for valid vs invalid padding.
485    pub fn test_padding_oracle<F, R, E>(
486        name: &str,
487        iterations: usize,
488        mut decrypt: F,
489    ) -> TimingResult
490    where
491        F: FnMut(&[u8]) -> Result<R, E>,
492    {
493        // Valid PKCS#7 padding (last byte = 1)
494        let mut valid_padding = vec![0u8; 48];
495        valid_padding[47] = 0x01;
496
497        // Invalid padding (last byte = 17, impossible)
498        let mut invalid_padding = vec![0u8; 48];
499        invalid_padding[47] = 0x11;
500
501        TimingTest::new(name)
502            .iterations(iterations)
503            .run(move |class| {
504                let data = match class {
505                    Class::Left => &valid_padding,
506                    Class::Right => &invalid_padding,
507                };
508                let _ = decrypt(data);
509            })
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516
517    #[test]
518    fn test_constant_time_operation() {
519        // A truly constant-time operation should pass
520        let result = TimingTest::new("constant_add")
521            .iterations(1000)
522            .run(|class| {
523                let a = match class {
524                    Class::Left => 0u64,
525                    Class::Right => u64::MAX,
526                };
527                // Simple addition is constant-time
528                std::hint::black_box(a.wrapping_add(42))
529            });
530
531        // Should pass (t-value close to 0)
532        assert!(
533            result.t_value.abs() < 10.0,
534            "t-value too high: {}",
535            result.t_value
536        );
537    }
538
539    #[test]
540    fn test_timing_result_display() {
541        let result = TimingResult {
542            name: "test".into(),
543            samples: 1000,
544            samples_after_crop: 900,
545            t_value: 1.5,
546            passed: true,
547            threshold: 4.5,
548            mean_left: 100.0,
549            mean_right: 100.5,
550            std_left: 10.0,
551            std_right: 10.0,
552        };
553
554        assert!(result.to_string().contains("PASS"));
555        assert!(result.detailed_report().contains("100.00ns"));
556    }
557
558    #[test]
559    fn test_online_stats() {
560        let mut stats = OnlineStats::new();
561        stats.update(1.0);
562        stats.update(2.0);
563        stats.update(3.0);
564
565        assert_eq!(stats.count(), 3);
566        assert!((stats.mean() - 2.0).abs() < 0.001);
567        assert!((stats.variance() - 1.0).abs() < 0.001);
568    }
569
570    #[test]
571    fn test_percentile_cropping() {
572        let test = TimingTest::new("test")
573            .iterations(100)
574            .with_percentile_cropping(10.0);
575
576        let result = test.run(|_| 42);
577        // 10% from each end = 20% removed
578        assert!(result.samples_after_crop < result.samples);
579    }
580
581    #[test]
582    fn test_online_mode() {
583        let result =
584            TimingTest::new("online_test")
585                .iterations(1000)
586                .run_online(|class| match class {
587                    Class::Left => 1u64,
588                    Class::Right => 2u64,
589                });
590
591        assert!(result.samples_after_crop == result.samples);
592    }
593
594    #[test]
595    fn test_key_comparison_pattern() {
596        let result = patterns::test_key_comparison("test_key", 500, |key| {
597            // Simple sum - should be constant time
598            key.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64))
599        });
600
601        // Key sum is constant-time
602        assert!(
603            result.t_value.abs() < 20.0,
604            "Unexpected timing variation: {}",
605            result.t_value
606        );
607    }
608}