1use crate::errors::{VerifyError, VerifyResult};
36use crate::stats;
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub enum Class {
41 Left,
43 Right,
45}
46
47#[derive(Debug, Clone, Copy)]
49pub struct PercentileCrop {
50 pub low: f64,
52 pub high: f64,
54}
55
56impl Default for PercentileCrop {
57 fn default() -> Self {
58 Self {
59 low: 0.0,
60 high: 0.0,
61 }
62 }
63}
64
65impl PercentileCrop {
66 pub fn symmetric(percent: f64) -> Self {
68 Self {
69 low: percent,
70 high: percent,
71 }
72 }
73
74 pub fn asymmetric(low: f64, high: f64) -> Self {
76 Self { low, high }
77 }
78}
79
80#[derive(Debug, Clone)]
82pub struct TimingResult {
83 pub name: String,
85 pub samples: usize,
87 pub samples_after_crop: usize,
89 pub t_value: f64,
91 pub passed: bool,
93 pub threshold: f64,
95 pub mean_left: f64,
97 pub mean_right: f64,
99 pub std_left: f64,
101 pub std_right: f64,
103}
104
105impl TimingResult {
106 pub fn is_constant_time(&self) -> bool {
108 self.passed
109 }
110
111 pub fn abs_t_value(&self) -> f64 {
113 self.t_value.abs()
114 }
115
116 pub fn timing_difference_percent(&self) -> f64 {
118 let mean = (self.mean_left + self.mean_right) / 2.0;
119 if mean == 0.0 {
120 0.0
121 } else {
122 ((self.mean_left - self.mean_right).abs() / mean) * 100.0
123 }
124 }
125
126 pub fn summary(&self) -> String {
128 format!(
129 "{}: t={:.2} (threshold={:.1}) - {}",
130 self.name,
131 self.t_value,
132 self.threshold,
133 if self.passed {
134 "PASS"
135 } else {
136 "FAIL - TIMING LEAK DETECTED"
137 }
138 )
139 }
140
141 pub fn detailed_report(&self) -> String {
143 format!(
144 "{}\n\
145 Samples: {} (after crop: {})\n\
146 Left class: mean={:.2}ns, std={:.2}ns\n\
147 Right class: mean={:.2}ns, std={:.2}ns\n\
148 Difference: {:.4}%\n\
149 t-statistic: {:.4} (threshold: ±{:.1})\n\
150 Result: {}",
151 self.name,
152 self.samples,
153 self.samples_after_crop,
154 self.mean_left,
155 self.std_left,
156 self.mean_right,
157 self.std_right,
158 self.timing_difference_percent(),
159 self.t_value,
160 self.threshold,
161 if self.passed {
162 "PASS (constant-time)"
163 } else {
164 "FAIL (timing leak detected)"
165 }
166 )
167 }
168}
169
170impl std::fmt::Display for TimingResult {
171 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
172 write!(f, "{}", self.summary())
173 }
174}
175
176#[derive(Debug, Clone, Default)]
180pub struct OnlineStats {
181 count: usize,
182 mean: f64,
183 m2: f64, }
185
186impl OnlineStats {
187 pub fn new() -> Self {
189 Self::default()
190 }
191
192 pub fn update(&mut self, x: f64) {
194 self.count += 1;
195 let delta = x - self.mean;
196 self.mean += delta / self.count as f64;
197 let delta2 = x - self.mean;
198 self.m2 += delta * delta2;
199 }
200
201 pub fn count(&self) -> usize {
203 self.count
204 }
205
206 pub fn mean(&self) -> f64 {
208 self.mean
209 }
210
211 pub fn variance(&self) -> f64 {
213 if self.count < 2 {
214 0.0
215 } else {
216 self.m2 / (self.count - 1) as f64
217 }
218 }
219
220 pub fn std_dev(&self) -> f64 {
222 self.variance().sqrt()
223 }
224}
225
226pub struct TimingTest {
228 name: String,
229 iterations: usize,
230 warmup: usize,
231 threshold: f64,
232 percentile_crop: PercentileCrop,
233}
234
235impl TimingTest {
236 pub fn new(name: impl Into<String>) -> Self {
238 Self {
239 name: name.into(),
240 iterations: 10_000,
241 warmup: 100,
242 threshold: stats::TIMING_LEAK_THRESHOLD,
243 percentile_crop: PercentileCrop::default(),
244 }
245 }
246
247 pub fn iterations(mut self, n: usize) -> Self {
249 self.iterations = n;
250 self
251 }
252
253 pub fn warmup(mut self, n: usize) -> Self {
255 self.warmup = n;
256 self
257 }
258
259 pub fn threshold(mut self, t: f64) -> Self {
261 self.threshold = t;
262 self
263 }
264
265 pub fn with_percentile_cropping(mut self, percent: f64) -> Self {
270 self.percentile_crop = PercentileCrop::symmetric(percent);
271 self
272 }
273
274 pub fn with_asymmetric_cropping(mut self, low: f64, high: f64) -> Self {
276 self.percentile_crop = PercentileCrop::asymmetric(low, high);
277 self
278 }
279
280 fn crop_samples(&self, samples: &mut Vec<f64>) {
282 if self.percentile_crop.low == 0.0 && self.percentile_crop.high == 0.0 {
283 return;
284 }
285
286 samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
287
288 let n = samples.len();
289 let low_idx = ((n as f64 * self.percentile_crop.low / 100.0) as usize).min(n / 2);
290 let high_idx = n - ((n as f64 * self.percentile_crop.high / 100.0) as usize).min(n / 2);
291
292 *samples = samples[low_idx..high_idx].to_vec();
293 }
294
295 pub fn run<F, R>(self, mut f: F) -> TimingResult
300 where
301 F: FnMut(Class) -> R,
302 {
303 use std::time::Instant;
304
305 for _ in 0..self.warmup {
307 let _ = f(Class::Left);
308 let _ = f(Class::Right);
309 }
310
311 let mut left_times = Vec::with_capacity(self.iterations);
312 let mut right_times = Vec::with_capacity(self.iterations);
313
314 for _ in 0..self.iterations {
316 let start = Instant::now();
318 let _result = std::hint::black_box(f(Class::Left));
319 let elapsed = start.elapsed().as_nanos() as f64;
320 left_times.push(elapsed);
321
322 let start = Instant::now();
324 let _result = std::hint::black_box(f(Class::Right));
325 let elapsed = start.elapsed().as_nanos() as f64;
326 right_times.push(elapsed);
327 }
328
329 let raw_samples = self.iterations * 2;
330
331 self.crop_samples(&mut left_times);
333 self.crop_samples(&mut right_times);
334
335 let samples_after_crop = left_times.len() + right_times.len();
336
337 let mut left_stats = OnlineStats::new();
339 for &t in &left_times {
340 left_stats.update(t);
341 }
342
343 let mut right_stats = OnlineStats::new();
344 for &t in &right_times {
345 right_stats.update(t);
346 }
347
348 let t_value = stats::welch_t_test(&left_times, &right_times);
350 let passed = t_value.abs() < self.threshold;
351
352 TimingResult {
353 name: self.name,
354 samples: raw_samples,
355 samples_after_crop,
356 t_value,
357 passed,
358 threshold: self.threshold,
359 mean_left: left_stats.mean(),
360 mean_right: right_stats.mean(),
361 std_left: left_stats.std_dev(),
362 std_right: right_stats.std_dev(),
363 }
364 }
365
366 pub fn run_online<F, R>(self, mut f: F) -> TimingResult
368 where
369 F: FnMut(Class) -> R,
370 {
371 use std::time::Instant;
372
373 for _ in 0..self.warmup {
375 let _ = f(Class::Left);
376 let _ = f(Class::Right);
377 }
378
379 let mut left_stats = OnlineStats::new();
380 let mut right_stats = OnlineStats::new();
381
382 for _ in 0..self.iterations {
384 let start = Instant::now();
386 let _result = std::hint::black_box(f(Class::Left));
387 let elapsed = start.elapsed().as_nanos() as f64;
388 left_stats.update(elapsed);
389
390 let start = Instant::now();
392 let _result = std::hint::black_box(f(Class::Right));
393 let elapsed = start.elapsed().as_nanos() as f64;
394 right_stats.update(elapsed);
395 }
396
397 let t_value = stats::welch_t_online(&left_stats, &right_stats);
399 let passed = t_value.abs() < self.threshold;
400
401 TimingResult {
402 name: self.name,
403 samples: self.iterations * 2,
404 samples_after_crop: self.iterations * 2, t_value,
406 passed,
407 threshold: self.threshold,
408 mean_left: left_stats.mean(),
409 mean_right: right_stats.mean(),
410 std_left: left_stats.std_dev(),
411 std_right: right_stats.std_dev(),
412 }
413 }
414}
415
416pub fn assert_constant_time<F, R>(name: &str, iterations: usize, f: F) -> VerifyResult<()>
418where
419 F: FnMut(Class) -> R,
420{
421 let result = TimingTest::new(name).iterations(iterations).run(f);
422
423 if result.passed {
424 Ok(())
425 } else {
426 Err(VerifyError::TimingLeakDetected {
427 t_value: result.t_value,
428 threshold: result.threshold,
429 })
430 }
431}
432
433pub mod patterns {
435 use super::*;
436
437 pub fn test_key_comparison<F, R>(name: &str, iterations: usize, mut op: F) -> TimingResult
441 where
442 F: FnMut(&[u8; 32]) -> R,
443 {
444 let zero_key = [0u8; 32];
445 let one_key = [0xFFu8; 32];
446
447 TimingTest::new(name)
448 .iterations(iterations)
449 .run(move |class| {
450 let key = match class {
451 Class::Left => &zero_key,
452 Class::Right => &one_key,
453 };
454 op(key)
455 })
456 }
457
458 pub fn test_early_exit<F>(name: &str, iterations: usize, mut compare: F) -> TimingResult
462 where
463 F: FnMut(&[u8; 32], &[u8; 32]) -> bool,
464 {
465 let correct = [0u8; 32];
466 let mut wrong_first = [0u8; 32];
467 wrong_first[0] = 0xFF;
468 let mut wrong_last = [0u8; 32];
469 wrong_last[31] = 0xFF;
470
471 TimingTest::new(name)
472 .iterations(iterations)
473 .run(move |class| {
474 let wrong = match class {
475 Class::Left => &wrong_first,
476 Class::Right => &wrong_last,
477 };
478 compare(&correct, wrong)
479 })
480 }
481
482 pub fn test_padding_oracle<F, R, E>(
486 name: &str,
487 iterations: usize,
488 mut decrypt: F,
489 ) -> TimingResult
490 where
491 F: FnMut(&[u8]) -> Result<R, E>,
492 {
493 let mut valid_padding = vec![0u8; 48];
495 valid_padding[47] = 0x01;
496
497 let mut invalid_padding = vec![0u8; 48];
499 invalid_padding[47] = 0x11;
500
501 TimingTest::new(name)
502 .iterations(iterations)
503 .run(move |class| {
504 let data = match class {
505 Class::Left => &valid_padding,
506 Class::Right => &invalid_padding,
507 };
508 let _ = decrypt(data);
509 })
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516
517 #[test]
518 fn test_constant_time_operation() {
519 let result = TimingTest::new("constant_add")
521 .iterations(1000)
522 .run(|class| {
523 let a = match class {
524 Class::Left => 0u64,
525 Class::Right => u64::MAX,
526 };
527 std::hint::black_box(a.wrapping_add(42))
529 });
530
531 assert!(
533 result.t_value.abs() < 10.0,
534 "t-value too high: {}",
535 result.t_value
536 );
537 }
538
539 #[test]
540 fn test_timing_result_display() {
541 let result = TimingResult {
542 name: "test".into(),
543 samples: 1000,
544 samples_after_crop: 900,
545 t_value: 1.5,
546 passed: true,
547 threshold: 4.5,
548 mean_left: 100.0,
549 mean_right: 100.5,
550 std_left: 10.0,
551 std_right: 10.0,
552 };
553
554 assert!(result.to_string().contains("PASS"));
555 assert!(result.detailed_report().contains("100.00ns"));
556 }
557
558 #[test]
559 fn test_online_stats() {
560 let mut stats = OnlineStats::new();
561 stats.update(1.0);
562 stats.update(2.0);
563 stats.update(3.0);
564
565 assert_eq!(stats.count(), 3);
566 assert!((stats.mean() - 2.0).abs() < 0.001);
567 assert!((stats.variance() - 1.0).abs() < 0.001);
568 }
569
570 #[test]
571 fn test_percentile_cropping() {
572 let test = TimingTest::new("test")
573 .iterations(100)
574 .with_percentile_cropping(10.0);
575
576 let result = test.run(|_| 42);
577 assert!(result.samples_after_crop < result.samples);
579 }
580
581 #[test]
582 fn test_online_mode() {
583 let result =
584 TimingTest::new("online_test")
585 .iterations(1000)
586 .run_online(|class| match class {
587 Class::Left => 1u64,
588 Class::Right => 2u64,
589 });
590
591 assert!(result.samples_after_crop == result.samples);
592 }
593
594 #[test]
595 fn test_key_comparison_pattern() {
596 let result = patterns::test_key_comparison("test_key", 500, |key| {
597 key.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64))
599 });
600
601 assert!(
603 result.t_value.abs() < 20.0,
604 "Unexpected timing variation: {}",
605 result.t_value
606 );
607 }
608}