use crate::errors::{VerifyError, VerifyResult};
use crate::stats;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Class {
Left,
Right,
}
#[derive(Debug, Clone, Copy)]
pub struct PercentileCrop {
pub low: f64,
pub high: f64,
}
impl Default for PercentileCrop {
fn default() -> Self {
Self {
low: 0.0,
high: 0.0,
}
}
}
impl PercentileCrop {
pub fn symmetric(percent: f64) -> Self {
Self {
low: percent,
high: percent,
}
}
pub fn asymmetric(low: f64, high: f64) -> Self {
Self { low, high }
}
}
#[derive(Debug, Clone)]
pub struct TimingResult {
pub name: String,
pub samples: usize,
pub samples_after_crop: usize,
pub t_value: f64,
pub passed: bool,
pub threshold: f64,
pub mean_left: f64,
pub mean_right: f64,
pub std_left: f64,
pub std_right: f64,
}
impl TimingResult {
pub fn is_constant_time(&self) -> bool {
self.passed
}
pub fn abs_t_value(&self) -> f64 {
self.t_value.abs()
}
pub fn timing_difference_percent(&self) -> f64 {
let mean = (self.mean_left + self.mean_right) / 2.0;
if mean == 0.0 {
0.0
} else {
((self.mean_left - self.mean_right).abs() / mean) * 100.0
}
}
pub fn summary(&self) -> String {
format!(
"{}: t={:.2} (threshold={:.1}) - {}",
self.name,
self.t_value,
self.threshold,
if self.passed {
"PASS"
} else {
"FAIL - TIMING LEAK DETECTED"
}
)
}
pub fn detailed_report(&self) -> String {
format!(
"{}\n\
Samples: {} (after crop: {})\n\
Left class: mean={:.2}ns, std={:.2}ns\n\
Right class: mean={:.2}ns, std={:.2}ns\n\
Difference: {:.4}%\n\
t-statistic: {:.4} (threshold: ±{:.1})\n\
Result: {}",
self.name,
self.samples,
self.samples_after_crop,
self.mean_left,
self.std_left,
self.mean_right,
self.std_right,
self.timing_difference_percent(),
self.t_value,
self.threshold,
if self.passed {
"PASS (constant-time)"
} else {
"FAIL (timing leak detected)"
}
)
}
}
impl std::fmt::Display for TimingResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.summary())
}
}
#[derive(Debug, Clone, Default)]
pub struct OnlineStats {
count: usize,
mean: f64,
m2: f64, }
impl OnlineStats {
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, x: f64) {
self.count += 1;
let delta = x - self.mean;
self.mean += delta / self.count as f64;
let delta2 = x - self.mean;
self.m2 += delta * delta2;
}
pub fn count(&self) -> usize {
self.count
}
pub fn mean(&self) -> f64 {
self.mean
}
pub fn variance(&self) -> f64 {
if self.count < 2 {
0.0
} else {
self.m2 / (self.count - 1) as f64
}
}
pub fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
}
pub struct TimingTest {
name: String,
iterations: usize,
warmup: usize,
threshold: f64,
percentile_crop: PercentileCrop,
}
impl TimingTest {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
iterations: 10_000,
warmup: 100,
threshold: stats::TIMING_LEAK_THRESHOLD,
percentile_crop: PercentileCrop::default(),
}
}
pub fn iterations(mut self, n: usize) -> Self {
self.iterations = n;
self
}
pub fn warmup(mut self, n: usize) -> Self {
self.warmup = n;
self
}
pub fn threshold(mut self, t: f64) -> Self {
self.threshold = t;
self
}
pub fn with_percentile_cropping(mut self, percent: f64) -> Self {
self.percentile_crop = PercentileCrop::symmetric(percent);
self
}
pub fn with_asymmetric_cropping(mut self, low: f64, high: f64) -> Self {
self.percentile_crop = PercentileCrop::asymmetric(low, high);
self
}
fn crop_samples(&self, samples: &mut Vec<f64>) {
if self.percentile_crop.low == 0.0 && self.percentile_crop.high == 0.0 {
return;
}
samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = samples.len();
let low_idx = ((n as f64 * self.percentile_crop.low / 100.0) as usize).min(n / 2);
let high_idx = n - ((n as f64 * self.percentile_crop.high / 100.0) as usize).min(n / 2);
*samples = samples[low_idx..high_idx].to_vec();
}
pub fn run<F, R>(self, mut f: F) -> TimingResult
where
F: FnMut(Class) -> R,
{
use std::time::Instant;
for _ in 0..self.warmup {
let _ = f(Class::Left);
let _ = f(Class::Right);
}
let mut left_times = Vec::with_capacity(self.iterations);
let mut right_times = Vec::with_capacity(self.iterations);
for _ in 0..self.iterations {
let start = Instant::now();
let _result = std::hint::black_box(f(Class::Left));
let elapsed = start.elapsed().as_nanos() as f64;
left_times.push(elapsed);
let start = Instant::now();
let _result = std::hint::black_box(f(Class::Right));
let elapsed = start.elapsed().as_nanos() as f64;
right_times.push(elapsed);
}
let raw_samples = self.iterations * 2;
self.crop_samples(&mut left_times);
self.crop_samples(&mut right_times);
let samples_after_crop = left_times.len() + right_times.len();
let mut left_stats = OnlineStats::new();
for &t in &left_times {
left_stats.update(t);
}
let mut right_stats = OnlineStats::new();
for &t in &right_times {
right_stats.update(t);
}
let t_value = stats::welch_t_test(&left_times, &right_times);
let passed = t_value.abs() < self.threshold;
TimingResult {
name: self.name,
samples: raw_samples,
samples_after_crop,
t_value,
passed,
threshold: self.threshold,
mean_left: left_stats.mean(),
mean_right: right_stats.mean(),
std_left: left_stats.std_dev(),
std_right: right_stats.std_dev(),
}
}
pub fn run_online<F, R>(self, mut f: F) -> TimingResult
where
F: FnMut(Class) -> R,
{
use std::time::Instant;
for _ in 0..self.warmup {
let _ = f(Class::Left);
let _ = f(Class::Right);
}
let mut left_stats = OnlineStats::new();
let mut right_stats = OnlineStats::new();
for _ in 0..self.iterations {
let start = Instant::now();
let _result = std::hint::black_box(f(Class::Left));
let elapsed = start.elapsed().as_nanos() as f64;
left_stats.update(elapsed);
let start = Instant::now();
let _result = std::hint::black_box(f(Class::Right));
let elapsed = start.elapsed().as_nanos() as f64;
right_stats.update(elapsed);
}
let t_value = stats::welch_t_online(&left_stats, &right_stats);
let passed = t_value.abs() < self.threshold;
TimingResult {
name: self.name,
samples: self.iterations * 2,
samples_after_crop: self.iterations * 2, t_value,
passed,
threshold: self.threshold,
mean_left: left_stats.mean(),
mean_right: right_stats.mean(),
std_left: left_stats.std_dev(),
std_right: right_stats.std_dev(),
}
}
}
pub fn assert_constant_time<F, R>(name: &str, iterations: usize, f: F) -> VerifyResult<()>
where
F: FnMut(Class) -> R,
{
let result = TimingTest::new(name).iterations(iterations).run(f);
if result.passed {
Ok(())
} else {
Err(VerifyError::TimingLeakDetected {
t_value: result.t_value,
threshold: result.threshold,
})
}
}
pub mod patterns {
use super::*;
pub fn test_key_comparison<F, R>(name: &str, iterations: usize, mut op: F) -> TimingResult
where
F: FnMut(&[u8; 32]) -> R,
{
let zero_key = [0u8; 32];
let one_key = [0xFFu8; 32];
TimingTest::new(name)
.iterations(iterations)
.run(move |class| {
let key = match class {
Class::Left => &zero_key,
Class::Right => &one_key,
};
op(key)
})
}
pub fn test_early_exit<F>(name: &str, iterations: usize, mut compare: F) -> TimingResult
where
F: FnMut(&[u8; 32], &[u8; 32]) -> bool,
{
let correct = [0u8; 32];
let mut wrong_first = [0u8; 32];
wrong_first[0] = 0xFF;
let mut wrong_last = [0u8; 32];
wrong_last[31] = 0xFF;
TimingTest::new(name)
.iterations(iterations)
.run(move |class| {
let wrong = match class {
Class::Left => &wrong_first,
Class::Right => &wrong_last,
};
compare(&correct, wrong)
})
}
pub fn test_padding_oracle<F, R, E>(
name: &str,
iterations: usize,
mut decrypt: F,
) -> TimingResult
where
F: FnMut(&[u8]) -> Result<R, E>,
{
let mut valid_padding = vec![0u8; 48];
valid_padding[47] = 0x01;
let mut invalid_padding = vec![0u8; 48];
invalid_padding[47] = 0x11;
TimingTest::new(name)
.iterations(iterations)
.run(move |class| {
let data = match class {
Class::Left => &valid_padding,
Class::Right => &invalid_padding,
};
let _ = decrypt(data);
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_operation() {
let result = TimingTest::new("constant_add")
.iterations(1000)
.run(|class| {
let a = match class {
Class::Left => 0u64,
Class::Right => u64::MAX,
};
std::hint::black_box(a.wrapping_add(42))
});
assert!(
result.t_value.abs() < 10.0,
"t-value too high: {}",
result.t_value
);
}
#[test]
fn test_timing_result_display() {
let result = TimingResult {
name: "test".into(),
samples: 1000,
samples_after_crop: 900,
t_value: 1.5,
passed: true,
threshold: 4.5,
mean_left: 100.0,
mean_right: 100.5,
std_left: 10.0,
std_right: 10.0,
};
assert!(result.to_string().contains("PASS"));
assert!(result.detailed_report().contains("100.00ns"));
}
#[test]
fn test_online_stats() {
let mut stats = OnlineStats::new();
stats.update(1.0);
stats.update(2.0);
stats.update(3.0);
assert_eq!(stats.count(), 3);
assert!((stats.mean() - 2.0).abs() < 0.001);
assert!((stats.variance() - 1.0).abs() < 0.001);
}
#[test]
fn test_percentile_cropping() {
let test = TimingTest::new("test")
.iterations(100)
.with_percentile_cropping(10.0);
let result = test.run(|_| 42);
assert!(result.samples_after_crop < result.samples);
}
#[test]
fn test_online_mode() {
let result =
TimingTest::new("online_test")
.iterations(1000)
.run_online(|class| match class {
Class::Left => 1u64,
Class::Right => 2u64,
});
assert!(result.samples_after_crop == result.samples);
}
#[test]
fn test_key_comparison_pattern() {
let result = patterns::test_key_comparison("test_key", 500, |key| {
key.iter().fold(0u64, |acc, &b| acc.wrapping_add(b as u64))
});
assert!(
result.t_value.abs() < 20.0,
"Unexpected timing variation: {}",
result.t_value
);
}
}