use serde::{Deserialize, Serialize};
use std::time::Instant;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SideChannelAnalysis {
pub test_name: String,
pub num_samples: usize,
pub timing_stats: TimingStatistics,
pub is_constant_time: bool,
pub input_timing_correlation: f64,
pub vulnerabilities: Vec<Vulnerability>,
pub leakage_score: f64,
}
impl SideChannelAnalysis {
pub fn is_timing_safe(&self) -> bool {
self.is_constant_time
&& self.input_timing_correlation.abs() < 0.1
&& self.leakage_score < 0.2
}
pub fn get_vulnerabilities(&self) -> &[Vulnerability] {
&self.vulnerabilities
}
pub fn max_severity(&self) -> VulnerabilitySeverity {
self.vulnerabilities
.iter()
.map(|v| match v {
Vulnerability::DataDependentTiming(s)
| Vulnerability::HighTimingVariance(s)
| Vulnerability::InputTimingCorrelation(s)
| Vulnerability::CacheTimingLeak(s) => *s,
})
.max()
.unwrap_or(VulnerabilitySeverity::Low)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TimingStatistics {
pub mean: f64,
pub median: f64,
pub std_dev: f64,
pub coefficient_of_variation: f64,
pub min: u64,
pub max: u64,
pub range: u64,
}
impl TimingStatistics {
pub fn from_measurements(mut timings: Vec<u64>) -> Self {
if timings.is_empty() {
return Self {
mean: 0.0,
median: 0.0,
std_dev: 0.0,
coefficient_of_variation: 0.0,
min: 0,
max: 0,
range: 0,
};
}
timings.sort_unstable();
let min = timings[0];
let max = timings[timings.len() - 1];
let range = max - min;
let mean = timings.iter().sum::<u64>() as f64 / timings.len() as f64;
let median = if timings.len() % 2 == 0 {
(timings[timings.len() / 2 - 1] + timings[timings.len() / 2]) as f64 / 2.0
} else {
timings[timings.len() / 2] as f64
};
let variance = timings
.iter()
.map(|&t| (t as f64 - mean).powi(2))
.sum::<f64>()
/ timings.len() as f64;
let std_dev = variance.sqrt();
let coefficient_of_variation = if mean > 0.0 { std_dev / mean } else { 0.0 };
Self {
mean,
median,
std_dev,
coefficient_of_variation,
min,
max,
range,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum Vulnerability {
DataDependentTiming(VulnerabilitySeverity),
HighTimingVariance(VulnerabilitySeverity),
InputTimingCorrelation(VulnerabilitySeverity),
CacheTimingLeak(VulnerabilitySeverity),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub enum VulnerabilitySeverity {
Low,
Medium,
High,
Critical,
}
pub struct TimingTest {
name: String,
num_samples: usize,
input_generator: Box<dyn Fn() -> Vec<u8>>,
}
impl TimingTest {
pub fn new(name: &str, num_samples: usize) -> Self {
Self {
name: name.to_string(),
num_samples,
input_generator: Box::new(|| {
use rand::Rng as _;
let mut rng = rand::rng();
let mut data = vec![0u8; 32];
rng.fill_bytes(&mut data);
data
}),
}
}
pub fn with_input_generator<F>(mut self, generator: F) -> Self
where
F: Fn() -> Vec<u8> + 'static,
{
self.input_generator = Box::new(generator);
self
}
pub fn name(&self) -> &str {
&self.name
}
pub fn num_samples(&self) -> usize {
self.num_samples
}
pub fn generate_input(&self) -> Vec<u8> {
(self.input_generator)()
}
}
pub struct SideChannelAnalyzer {
constant_time_threshold: f64,
correlation_threshold: f64,
}
impl Default for SideChannelAnalyzer {
fn default() -> Self {
Self::new()
}
}
impl SideChannelAnalyzer {
pub fn new() -> Self {
Self {
constant_time_threshold: 0.05, correlation_threshold: 0.15, }
}
pub fn with_constant_time_threshold(mut self, threshold: f64) -> Self {
self.constant_time_threshold = threshold;
self
}
pub fn with_correlation_threshold(mut self, threshold: f64) -> Self {
self.correlation_threshold = threshold;
self
}
pub fn analyze_timing<F>(&self, test: TimingTest, mut operation: F) -> SideChannelAnalysis
where
F: FnMut(&[u8]),
{
let mut timings = Vec::with_capacity(test.num_samples());
let mut inputs = Vec::with_capacity(test.num_samples());
for _ in 0..test.num_samples() {
let input = test.generate_input();
let start = Instant::now();
operation(&input);
let elapsed = start.elapsed();
timings.push(elapsed.as_nanos() as u64);
inputs.push(input);
}
let timing_stats = TimingStatistics::from_measurements(timings.clone());
let is_constant_time = timing_stats.coefficient_of_variation < self.constant_time_threshold;
let input_timing_correlation = self.calculate_correlation(&inputs, &timings);
let mut vulnerabilities = Vec::new();
if !is_constant_time {
let severity = if timing_stats.coefficient_of_variation > 0.2 {
VulnerabilitySeverity::Critical
} else if timing_stats.coefficient_of_variation > 0.1 {
VulnerabilitySeverity::High
} else {
VulnerabilitySeverity::Medium
};
vulnerabilities.push(Vulnerability::DataDependentTiming(severity));
}
if timing_stats.coefficient_of_variation > 0.1 {
let severity = if timing_stats.coefficient_of_variation > 0.3 {
VulnerabilitySeverity::High
} else {
VulnerabilitySeverity::Medium
};
vulnerabilities.push(Vulnerability::HighTimingVariance(severity));
}
if input_timing_correlation.abs() > self.correlation_threshold {
let severity = if input_timing_correlation.abs() > 0.5 {
VulnerabilitySeverity::Critical
} else if input_timing_correlation.abs() > 0.3 {
VulnerabilitySeverity::High
} else {
VulnerabilitySeverity::Medium
};
vulnerabilities.push(Vulnerability::InputTimingCorrelation(severity));
}
let leakage_score = self.calculate_leakage_score(&timing_stats, input_timing_correlation);
SideChannelAnalysis {
test_name: test.name().to_string(),
num_samples: test.num_samples(),
timing_stats,
is_constant_time,
input_timing_correlation,
vulnerabilities,
leakage_score,
}
}
fn calculate_correlation(&self, inputs: &[Vec<u8>], timings: &[u64]) -> f64 {
if inputs.is_empty() || inputs.len() != timings.len() {
return 0.0;
}
let input_values: Vec<f64> = inputs.iter().map(|inp| inp[0] as f64).collect();
let timing_values: Vec<f64> = timings.iter().map(|&t| t as f64).collect();
pearson_correlation(&input_values, &timing_values)
}
fn calculate_leakage_score(&self, stats: &TimingStatistics, correlation: f64) -> f64 {
let cv_score = (stats.coefficient_of_variation / 0.5).min(1.0);
let corr_score = (correlation.abs() / 0.5).min(1.0);
(cv_score + corr_score) / 2.0
}
}
fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
if x.len() != y.len() || x.is_empty() {
return 0.0;
}
let n = x.len() as f64;
let mean_x = x.iter().sum::<f64>() / n;
let mean_y = y.iter().sum::<f64>() / n;
let mut numerator = 0.0;
let mut sum_sq_x = 0.0;
let mut sum_sq_y = 0.0;
for i in 0..x.len() {
let diff_x = x[i] - mean_x;
let diff_y = y[i] - mean_y;
numerator += diff_x * diff_y;
sum_sq_x += diff_x * diff_x;
sum_sq_y += diff_y * diff_y;
}
if sum_sq_x == 0.0 || sum_sq_y == 0.0 {
return 0.0;
}
numerator / (sum_sq_x * sum_sq_y).sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timing_statistics() {
let timings = vec![100, 105, 102, 98, 101, 99, 103, 100];
let stats = TimingStatistics::from_measurements(timings);
assert_eq!(stats.min, 98);
assert_eq!(stats.max, 105);
assert_eq!(stats.range, 7);
assert!((stats.mean - 101.0).abs() < 0.5);
assert!(stats.std_dev > 0.0);
}
#[test]
fn test_timing_statistics_empty() {
let stats = TimingStatistics::from_measurements(vec![]);
assert_eq!(stats.mean, 0.0);
assert_eq!(stats.median, 0.0);
assert_eq!(stats.std_dev, 0.0);
}
#[test]
fn test_pearson_correlation_perfect_positive() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
let corr = pearson_correlation(&x, &y);
assert!((corr - 1.0).abs() < 0.01);
}
#[test]
fn test_pearson_correlation_perfect_negative() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
let corr = pearson_correlation(&x, &y);
assert!((corr + 1.0).abs() < 0.01);
}
#[test]
fn test_pearson_correlation_no_correlation() {
let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let y = vec![3.0, 3.0, 3.0, 3.0, 3.0];
let corr = pearson_correlation(&x, &y);
assert_eq!(corr, 0.0);
}
#[test]
fn test_timing_test_creation() {
let test = TimingTest::new("test", 100);
assert_eq!(test.name(), "test");
assert_eq!(test.num_samples(), 100);
}
#[test]
fn test_timing_test_input_generation() {
let test = TimingTest::new("test", 10);
let input1 = test.generate_input();
let input2 = test.generate_input();
assert_eq!(input1.len(), 32);
assert_eq!(input2.len(), 32);
assert_ne!(input1, input2);
}
#[test]
fn test_timing_test_custom_generator() {
let test = TimingTest::new("test", 10).with_input_generator(|| vec![0u8; 16]);
let input = test.generate_input();
assert_eq!(input.len(), 16);
assert_eq!(input, vec![0u8; 16]);
}
#[test]
fn test_analyzer_constant_time_operation() {
let analyzer = SideChannelAnalyzer::new();
let test = TimingTest::new("constant_op", 50);
let results = analyzer.analyze_timing(test, |_data| {
std::hint::black_box(42);
});
assert_eq!(results.test_name, "constant_op");
assert_eq!(results.num_samples, 50);
assert!(results.timing_stats.mean > 0.0);
assert!(results.input_timing_correlation.abs() < 0.5);
}
#[test]
fn test_analyzer_data_dependent_timing() {
let analyzer = SideChannelAnalyzer::new();
let test = TimingTest::new("data_dependent_op", 100);
let results = analyzer.analyze_timing(test, |data| {
let iterations = data[0] as usize * 10;
for _ in 0..iterations {
std::hint::black_box(42);
}
});
assert!(!results.is_constant_time);
assert!(!results.vulnerabilities.is_empty());
}
#[test]
fn test_side_channel_analysis_timing_safe() {
let analysis = SideChannelAnalysis {
test_name: "test".to_string(),
num_samples: 100,
timing_stats: TimingStatistics {
mean: 1000.0,
median: 1000.0,
std_dev: 10.0,
coefficient_of_variation: 0.01,
min: 990,
max: 1010,
range: 20,
},
is_constant_time: true,
input_timing_correlation: 0.05,
vulnerabilities: vec![],
leakage_score: 0.05,
};
assert!(analysis.is_timing_safe());
}
#[test]
fn test_side_channel_analysis_timing_unsafe() {
let analysis = SideChannelAnalysis {
test_name: "test".to_string(),
num_samples: 100,
timing_stats: TimingStatistics {
mean: 1000.0,
median: 1000.0,
std_dev: 200.0,
coefficient_of_variation: 0.2,
min: 500,
max: 1500,
range: 1000,
},
is_constant_time: false,
input_timing_correlation: 0.5,
vulnerabilities: vec![Vulnerability::DataDependentTiming(
VulnerabilitySeverity::High,
)],
leakage_score: 0.6,
};
assert!(!analysis.is_timing_safe());
}
#[test]
fn test_vulnerability_severity_ordering() {
assert!(VulnerabilitySeverity::Low < VulnerabilitySeverity::Medium);
assert!(VulnerabilitySeverity::Medium < VulnerabilitySeverity::High);
assert!(VulnerabilitySeverity::High < VulnerabilitySeverity::Critical);
}
#[test]
fn test_max_severity() {
let analysis = SideChannelAnalysis {
test_name: "test".to_string(),
num_samples: 100,
timing_stats: TimingStatistics::from_measurements(vec![100]),
is_constant_time: false,
input_timing_correlation: 0.0,
vulnerabilities: vec![
Vulnerability::DataDependentTiming(VulnerabilitySeverity::Medium),
Vulnerability::HighTimingVariance(VulnerabilitySeverity::Critical),
],
leakage_score: 0.5,
};
assert_eq!(analysis.max_severity(), VulnerabilitySeverity::Critical);
}
#[test]
fn test_analyzer_custom_thresholds() {
let analyzer = SideChannelAnalyzer::new()
.with_constant_time_threshold(0.1)
.with_correlation_threshold(0.2);
assert_eq!(analyzer.constant_time_threshold, 0.1);
assert_eq!(analyzer.correlation_threshold, 0.2);
}
}