1use std::time::{Duration, Instant};
48use thiserror::Error;
49
50#[derive(Debug, Error, Clone, PartialEq)]
52pub enum CtAuditError {
53 #[error("Insufficient samples: need at least {needed}, got {actual}")]
55 InsufficientSamples { needed: usize, actual: usize },
56
57 #[error(
59 "Timing leak detected: coefficient of variation {cv:.4} exceeds threshold {threshold:.4}"
60 )]
61 TimingLeakDetected { cv: f64, threshold: f64 },
62}
63
64pub type CtAuditResult<T> = Result<T, CtAuditError>;
66
67#[derive(Debug, Clone)]
69pub struct TimingStatistics {
70 pub count: usize,
72 pub min_ns: u64,
74 pub max_ns: u64,
76 pub mean_ns: f64,
78 pub median_ns: u64,
80 pub std_dev_ns: f64,
82 pub coefficient_of_variation: f64,
84}
85
86impl TimingStatistics {
87 pub fn is_constant_time(&self, threshold: f64) -> bool {
92 self.coefficient_of_variation < threshold
93 }
94
95 pub fn z_score(&self, value_ns: u64) -> f64 {
97 if self.std_dev_ns == 0.0 {
98 return 0.0;
99 }
100 (value_ns as f64 - self.mean_ns) / self.std_dev_ns
101 }
102}
103
104#[derive(Debug)]
106pub struct OperationBenchmark {
107 name: String,
108 measurements: Vec<u64>, #[allow(dead_code)]
110 capacity: usize,
111}
112
113impl OperationBenchmark {
114 pub fn new(name: impl Into<String>, capacity: usize) -> Self {
116 Self {
117 name: name.into(),
118 measurements: Vec::with_capacity(capacity),
119 capacity,
120 }
121 }
122
123 pub fn measure<F, R>(&mut self, op: F) -> R
125 where
126 F: FnOnce() -> R,
127 {
128 let start = Instant::now();
129 let result = op();
130 let elapsed = start.elapsed();
131
132 self.measurements.push(elapsed.as_nanos() as u64);
133 result
134 }
135
136 pub fn measure_n<F>(&mut self, n: usize, mut op: F)
138 where
139 F: FnMut(),
140 {
141 for _ in 0..n {
142 self.measure(&mut op);
143 }
144 }
145
146 pub fn measurements(&self) -> &[u64] {
148 &self.measurements
149 }
150
151 pub fn statistics(&self) -> CtAuditResult<TimingStatistics> {
153 if self.measurements.is_empty() {
154 return Err(CtAuditError::InsufficientSamples {
155 needed: 1,
156 actual: 0,
157 });
158 }
159
160 let mut sorted = self.measurements.clone();
161 sorted.sort_unstable();
162
163 let count = sorted.len();
164 let min_ns = sorted[0];
165 let max_ns = sorted[count - 1];
166
167 let sum: u64 = sorted.iter().sum();
169 let mean_ns = sum as f64 / count as f64;
170
171 let median_ns = if count % 2 == 0 {
173 (sorted[count / 2 - 1] + sorted[count / 2]) / 2
174 } else {
175 sorted[count / 2]
176 };
177
178 let variance: f64 = sorted
180 .iter()
181 .map(|&x| {
182 let diff = x as f64 - mean_ns;
183 diff * diff
184 })
185 .sum::<f64>()
186 / count as f64;
187 let std_dev_ns = variance.sqrt();
188
189 let coefficient_of_variation = if mean_ns > 0.0 {
191 std_dev_ns / mean_ns
192 } else {
193 0.0
194 };
195
196 Ok(TimingStatistics {
197 count,
198 min_ns,
199 max_ns,
200 mean_ns,
201 median_ns,
202 std_dev_ns,
203 coefficient_of_variation,
204 })
205 }
206
207 pub fn is_constant_time(&self, threshold: f64) -> CtAuditResult<bool> {
209 let stats = self.statistics()?;
210 if stats.coefficient_of_variation > threshold {
211 return Err(CtAuditError::TimingLeakDetected {
212 cv: stats.coefficient_of_variation,
213 threshold,
214 });
215 }
216 Ok(true)
217 }
218
219 pub fn reset(&mut self) {
221 self.measurements.clear();
222 }
223
224 pub fn name(&self) -> &str {
226 &self.name
227 }
228}
229
230pub struct CtAuditor {
232 name: String,
233 warmup_iterations: usize,
234}
235
236impl CtAuditor {
237 pub fn new(name: impl Into<String>, warmup_iterations: usize) -> Self {
244 Self {
245 name: name.into(),
246 warmup_iterations,
247 }
248 }
249
250 fn warmup<F>(&self, mut op: F)
252 where
253 F: FnMut(),
254 {
255 for _ in 0..self.warmup_iterations {
256 op();
257 }
258 }
259
260 pub fn audit<F>(&self, iterations: usize, mut op: F) -> CtAuditResult<TimingStatistics>
273 where
274 F: FnMut(),
275 {
276 self.warmup(&mut op);
278
279 let mut bench = OperationBenchmark::new(&self.name, iterations);
281 bench.measure_n(iterations, op);
282
283 bench.statistics()
284 }
285
286 pub fn compare<F, G>(
300 &self,
301 iterations: usize,
302 mut op_a: F,
303 mut op_b: G,
304 ) -> CtAuditResult<(TimingStatistics, TimingStatistics)>
305 where
306 F: FnMut(),
307 G: FnMut(),
308 {
309 self.warmup(&mut op_a);
311 self.warmup(&mut op_b);
312
313 let mut bench_a = OperationBenchmark::new(format!("{}_input_a", self.name), iterations);
315 bench_a.measure_n(iterations, &mut op_a);
316
317 let mut bench_b = OperationBenchmark::new(format!("{}_input_b", self.name), iterations);
319 bench_b.measure_n(iterations, &mut op_b);
320
321 Ok((bench_a.statistics()?, bench_b.statistics()?))
322 }
323
324 pub fn detect_leak<F, G>(
329 &self,
330 iterations: usize,
331 op_a: F,
332 op_b: G,
333 threshold: f64,
334 ) -> CtAuditResult<bool>
335 where
336 F: FnMut(),
337 G: FnMut(),
338 {
339 let (stats_a, stats_b) = self.compare(iterations, op_a, op_b)?;
340
341 let mean_diff = (stats_a.mean_ns - stats_b.mean_ns).abs();
343 let mean_avg = (stats_a.mean_ns + stats_b.mean_ns) / 2.0;
344 let relative_diff = if mean_avg > 0.0 {
345 mean_diff / mean_avg
346 } else {
347 0.0
348 };
349
350 Ok(relative_diff > threshold)
351 }
352}
353
354pub fn measure_once<F, R>(op: F) -> (R, Duration)
356where
357 F: FnOnce() -> R,
358{
359 let start = Instant::now();
360 let result = op();
361 let elapsed = start.elapsed();
362 (result, elapsed)
363}
364
365pub fn measure_average<F>(n: usize, mut op: F) -> Duration
367where
368 F: FnMut(),
369{
370 let mut total = Duration::ZERO;
371 for _ in 0..n {
372 let start = Instant::now();
373 op();
374 total += start.elapsed();
375 }
376 total / n as u32
377}
378
379#[cfg(test)]
380mod tests {
381 use super::*;
382
383 #[test]
384 fn test_benchmark_basic() {
385 let mut bench = OperationBenchmark::new("test", 100);
386
387 for i in 0..100 {
388 bench.measure(|| {
389 let _ = i * 2;
391 });
392 }
393
394 assert_eq!(bench.measurements().len(), 100);
395 }
396
397 #[test]
398 fn test_statistics_calculation() {
399 let mut bench = OperationBenchmark::new("test", 10);
400
401 for _ in 0..10 {
403 bench.measure(|| {
404 let mut sum = 0u64;
406 for i in 0..100 {
407 sum = sum.wrapping_add(i);
408 }
409 std::hint::black_box(sum);
410 });
411 }
412
413 let stats = bench.statistics().unwrap();
414 assert_eq!(stats.count, 10);
415 assert!(stats.max_ns >= stats.min_ns);
417 assert!(stats.mean_ns >= 0.0);
418 assert!(stats.std_dev_ns >= 0.0);
419 }
420
421 #[test]
422 #[ignore] fn test_constant_time_check() {
424 let mut bench = OperationBenchmark::new("constant_op", 10000);
425
426 let data = [0u8; 256];
429 for _ in 0..10000 {
430 bench.measure(|| {
431 let mut sum = 0u64;
433 for &byte in &data {
434 sum = sum.wrapping_add(byte as u64).wrapping_mul(3);
435 }
436 std::hint::black_box(sum);
437 });
438 }
439
440 assert!(bench.is_constant_time(5.0).is_ok());
445 }
446
447 #[test]
448 fn test_auditor_basic() {
449 let auditor = CtAuditor::new("test_operation", 10);
450
451 let stats = auditor
452 .audit(100, || {
453 std::hint::black_box(42);
454 })
455 .unwrap();
456
457 assert_eq!(stats.count, 100);
458 assert!(stats.mean_ns > 0.0);
459 }
460
461 #[test]
462 fn test_compare_operations() {
463 let auditor = CtAuditor::new("compare_test", 10);
464
465 let (stats_a, stats_b) = auditor
466 .compare(
467 50,
468 || {
469 std::hint::black_box(42);
470 },
471 || {
472 std::hint::black_box(43);
473 },
474 )
475 .unwrap();
476
477 assert_eq!(stats_a.count, 50);
478 assert_eq!(stats_b.count, 50);
479 }
480
481 #[test]
482 fn test_measure_once() {
483 let (result, duration) = measure_once(|| {
484 let mut sum = 0u64;
486 for i in 0..1000 {
487 sum = std::hint::black_box(sum.wrapping_add(i));
488 }
489 std::hint::black_box(sum);
490 4
491 });
492 assert_eq!(result, 4);
493 let _ = duration;
495 }
496
497 #[test]
498 fn test_measure_average() {
499 let avg = measure_average(10, || {
500 std::hint::black_box(42);
501 });
502 assert!(avg.as_nanos() > 0);
503 }
504
505 #[test]
506 fn test_z_score() {
507 let mut bench = OperationBenchmark::new("test", 5);
508 bench.measurements = vec![100, 110, 120, 130, 140];
509
510 let stats = bench.statistics().unwrap();
511 let z = stats.z_score(120);
512 assert!((z - 0.0).abs() < 0.01); }
514
515 #[test]
516 fn test_benchmark_reset() {
517 let mut bench = OperationBenchmark::new("test", 10);
518 bench.measure(|| {});
519 assert_eq!(bench.measurements().len(), 1);
520
521 bench.reset();
522 assert_eq!(bench.measurements().len(), 0);
523 }
524
525 #[test]
526 fn test_insufficient_samples() {
527 let bench = OperationBenchmark::new("test", 10);
528 let result = bench.statistics();
529 assert!(result.is_err());
530 }
531
532 #[test]
533 fn test_is_constant_time_pass() {
534 let stats = TimingStatistics {
535 count: 100,
536 min_ns: 90,
537 max_ns: 110,
538 mean_ns: 100.0,
539 median_ns: 100,
540 std_dev_ns: 3.0,
541 coefficient_of_variation: 0.03, };
543
544 assert!(stats.is_constant_time(0.05)); }
546
547 #[test]
548 fn test_is_constant_time_fail() {
549 let stats = TimingStatistics {
550 count: 100,
551 min_ns: 50,
552 max_ns: 150,
553 mean_ns: 100.0,
554 median_ns: 100,
555 std_dev_ns: 25.0,
556 coefficient_of_variation: 0.25, };
558
559 assert!(!stats.is_constant_time(0.05)); }
561
562 #[test]
563 #[ignore] fn test_detect_leak_none() {
565 let auditor = CtAuditor::new("leak_test", 5);
566
567 let has_leak = auditor
569 .detect_leak(
570 50,
571 || {
572 std::hint::black_box(42);
573 },
574 || {
575 std::hint::black_box(43);
576 },
577 0.5, )
579 .unwrap();
580
581 assert!(!has_leak);
583 }
584}