1use serde::{Deserialize, Serialize};
38use std::time::Instant;
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct SideChannelAnalysis {
43 pub test_name: String,
45 pub num_samples: usize,
47 pub timing_stats: TimingStatistics,
49 pub is_constant_time: bool,
51 pub input_timing_correlation: f64,
53 pub vulnerabilities: Vec<Vulnerability>,
55 pub leakage_score: f64,
57}
58
59impl SideChannelAnalysis {
60 pub fn is_timing_safe(&self) -> bool {
62 self.is_constant_time
63 && self.input_timing_correlation.abs() < 0.1
64 && self.leakage_score < 0.2
65 }
66
67 pub fn get_vulnerabilities(&self) -> &[Vulnerability] {
69 &self.vulnerabilities
70 }
71
72 pub fn max_severity(&self) -> VulnerabilitySeverity {
74 self.vulnerabilities
75 .iter()
76 .map(|v| match v {
77 Vulnerability::DataDependentTiming(s)
78 | Vulnerability::HighTimingVariance(s)
79 | Vulnerability::InputTimingCorrelation(s)
80 | Vulnerability::CacheTimingLeak(s) => *s,
81 })
82 .max()
83 .unwrap_or(VulnerabilitySeverity::Low)
84 }
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct TimingStatistics {
90 pub mean: f64,
92 pub median: f64,
94 pub std_dev: f64,
96 pub coefficient_of_variation: f64,
98 pub min: u64,
100 pub max: u64,
102 pub range: u64,
104}
105
106impl TimingStatistics {
107 pub fn from_measurements(mut timings: Vec<u64>) -> Self {
109 if timings.is_empty() {
110 return Self {
111 mean: 0.0,
112 median: 0.0,
113 std_dev: 0.0,
114 coefficient_of_variation: 0.0,
115 min: 0,
116 max: 0,
117 range: 0,
118 };
119 }
120
121 timings.sort_unstable();
122 let min = timings[0];
123 let max = timings[timings.len() - 1];
124 let range = max - min;
125
126 let mean = timings.iter().sum::<u64>() as f64 / timings.len() as f64;
127 let median = if timings.len() % 2 == 0 {
128 (timings[timings.len() / 2 - 1] + timings[timings.len() / 2]) as f64 / 2.0
129 } else {
130 timings[timings.len() / 2] as f64
131 };
132
133 let variance = timings
134 .iter()
135 .map(|&t| (t as f64 - mean).powi(2))
136 .sum::<f64>()
137 / timings.len() as f64;
138 let std_dev = variance.sqrt();
139 let coefficient_of_variation = if mean > 0.0 { std_dev / mean } else { 0.0 };
140
141 Self {
142 mean,
143 median,
144 std_dev,
145 coefficient_of_variation,
146 min,
147 max,
148 range,
149 }
150 }
151}
152
153#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
155pub enum Vulnerability {
156 DataDependentTiming(VulnerabilitySeverity),
158 HighTimingVariance(VulnerabilitySeverity),
160 InputTimingCorrelation(VulnerabilitySeverity),
162 CacheTimingLeak(VulnerabilitySeverity),
164}
165
166#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
168pub enum VulnerabilitySeverity {
169 Low,
171 Medium,
173 High,
175 Critical,
177}
178
179pub struct TimingTest {
181 name: String,
183 num_samples: usize,
185 input_generator: Box<dyn Fn() -> Vec<u8>>,
187}
188
189impl TimingTest {
190 pub fn new(name: &str, num_samples: usize) -> Self {
192 Self {
193 name: name.to_string(),
194 num_samples,
195 input_generator: Box::new(|| {
196 use rand::RngCore;
197 let mut rng = rand::thread_rng();
198 let mut data = vec![0u8; 32];
199 rng.fill_bytes(&mut data);
200 data
201 }),
202 }
203 }
204
205 pub fn with_input_generator<F>(mut self, generator: F) -> Self
207 where
208 F: Fn() -> Vec<u8> + 'static,
209 {
210 self.input_generator = Box::new(generator);
211 self
212 }
213
214 pub fn name(&self) -> &str {
216 &self.name
217 }
218
219 pub fn num_samples(&self) -> usize {
221 self.num_samples
222 }
223
224 pub fn generate_input(&self) -> Vec<u8> {
226 (self.input_generator)()
227 }
228}
229
230pub struct SideChannelAnalyzer {
232 constant_time_threshold: f64,
234 correlation_threshold: f64,
236}
237
238impl Default for SideChannelAnalyzer {
239 fn default() -> Self {
240 Self::new()
241 }
242}
243
244impl SideChannelAnalyzer {
245 pub fn new() -> Self {
247 Self {
248 constant_time_threshold: 0.05, correlation_threshold: 0.15, }
251 }
252
253 pub fn with_constant_time_threshold(mut self, threshold: f64) -> Self {
255 self.constant_time_threshold = threshold;
256 self
257 }
258
259 pub fn with_correlation_threshold(mut self, threshold: f64) -> Self {
261 self.correlation_threshold = threshold;
262 self
263 }
264
265 pub fn analyze_timing<F>(&self, test: TimingTest, mut operation: F) -> SideChannelAnalysis
267 where
268 F: FnMut(&[u8]),
269 {
270 let mut timings = Vec::with_capacity(test.num_samples());
271 let mut inputs = Vec::with_capacity(test.num_samples());
272
273 for _ in 0..test.num_samples() {
275 let input = test.generate_input();
276 let start = Instant::now();
277 operation(&input);
278 let elapsed = start.elapsed();
279 timings.push(elapsed.as_nanos() as u64);
280 inputs.push(input);
281 }
282
283 let timing_stats = TimingStatistics::from_measurements(timings.clone());
284
285 let is_constant_time = timing_stats.coefficient_of_variation < self.constant_time_threshold;
287
288 let input_timing_correlation = self.calculate_correlation(&inputs, &timings);
290
291 let mut vulnerabilities = Vec::new();
293
294 if !is_constant_time {
295 let severity = if timing_stats.coefficient_of_variation > 0.2 {
296 VulnerabilitySeverity::Critical
297 } else if timing_stats.coefficient_of_variation > 0.1 {
298 VulnerabilitySeverity::High
299 } else {
300 VulnerabilitySeverity::Medium
301 };
302 vulnerabilities.push(Vulnerability::DataDependentTiming(severity));
303 }
304
305 if timing_stats.coefficient_of_variation > 0.1 {
306 let severity = if timing_stats.coefficient_of_variation > 0.3 {
307 VulnerabilitySeverity::High
308 } else {
309 VulnerabilitySeverity::Medium
310 };
311 vulnerabilities.push(Vulnerability::HighTimingVariance(severity));
312 }
313
314 if input_timing_correlation.abs() > self.correlation_threshold {
315 let severity = if input_timing_correlation.abs() > 0.5 {
316 VulnerabilitySeverity::Critical
317 } else if input_timing_correlation.abs() > 0.3 {
318 VulnerabilitySeverity::High
319 } else {
320 VulnerabilitySeverity::Medium
321 };
322 vulnerabilities.push(Vulnerability::InputTimingCorrelation(severity));
323 }
324
325 let leakage_score = self.calculate_leakage_score(&timing_stats, input_timing_correlation);
327
328 SideChannelAnalysis {
329 test_name: test.name().to_string(),
330 num_samples: test.num_samples(),
331 timing_stats,
332 is_constant_time,
333 input_timing_correlation,
334 vulnerabilities,
335 leakage_score,
336 }
337 }
338
339 fn calculate_correlation(&self, inputs: &[Vec<u8>], timings: &[u64]) -> f64 {
341 if inputs.is_empty() || inputs.len() != timings.len() {
342 return 0.0;
343 }
344
345 let input_values: Vec<f64> = inputs.iter().map(|inp| inp[0] as f64).collect();
347 let timing_values: Vec<f64> = timings.iter().map(|&t| t as f64).collect();
348
349 pearson_correlation(&input_values, &timing_values)
350 }
351
352 fn calculate_leakage_score(&self, stats: &TimingStatistics, correlation: f64) -> f64 {
354 let cv_score = (stats.coefficient_of_variation / 0.5).min(1.0);
356 let corr_score = (correlation.abs() / 0.5).min(1.0);
357
358 (cv_score + corr_score) / 2.0
359 }
360}
361
362fn pearson_correlation(x: &[f64], y: &[f64]) -> f64 {
364 if x.len() != y.len() || x.is_empty() {
365 return 0.0;
366 }
367
368 let n = x.len() as f64;
369 let mean_x = x.iter().sum::<f64>() / n;
370 let mean_y = y.iter().sum::<f64>() / n;
371
372 let mut numerator = 0.0;
373 let mut sum_sq_x = 0.0;
374 let mut sum_sq_y = 0.0;
375
376 for i in 0..x.len() {
377 let diff_x = x[i] - mean_x;
378 let diff_y = y[i] - mean_y;
379 numerator += diff_x * diff_y;
380 sum_sq_x += diff_x * diff_x;
381 sum_sq_y += diff_y * diff_y;
382 }
383
384 if sum_sq_x == 0.0 || sum_sq_y == 0.0 {
385 return 0.0;
386 }
387
388 numerator / (sum_sq_x * sum_sq_y).sqrt()
389}
390
391#[cfg(test)]
392mod tests {
393 use super::*;
394
395 #[test]
396 fn test_timing_statistics() {
397 let timings = vec![100, 105, 102, 98, 101, 99, 103, 100];
398 let stats = TimingStatistics::from_measurements(timings);
399
400 assert_eq!(stats.min, 98);
401 assert_eq!(stats.max, 105);
402 assert_eq!(stats.range, 7);
403 assert!((stats.mean - 101.0).abs() < 0.5);
404 assert!(stats.std_dev > 0.0);
405 }
406
407 #[test]
408 fn test_timing_statistics_empty() {
409 let stats = TimingStatistics::from_measurements(vec![]);
410 assert_eq!(stats.mean, 0.0);
411 assert_eq!(stats.median, 0.0);
412 assert_eq!(stats.std_dev, 0.0);
413 }
414
415 #[test]
416 fn test_pearson_correlation_perfect_positive() {
417 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
418 let y = vec![2.0, 4.0, 6.0, 8.0, 10.0];
419 let corr = pearson_correlation(&x, &y);
420 assert!((corr - 1.0).abs() < 0.01);
421 }
422
423 #[test]
424 fn test_pearson_correlation_perfect_negative() {
425 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
426 let y = vec![10.0, 8.0, 6.0, 4.0, 2.0];
427 let corr = pearson_correlation(&x, &y);
428 assert!((corr + 1.0).abs() < 0.01);
429 }
430
431 #[test]
432 fn test_pearson_correlation_no_correlation() {
433 let x = vec![1.0, 2.0, 3.0, 4.0, 5.0];
434 let y = vec![3.0, 3.0, 3.0, 3.0, 3.0];
435 let corr = pearson_correlation(&x, &y);
436 assert_eq!(corr, 0.0);
437 }
438
439 #[test]
440 fn test_timing_test_creation() {
441 let test = TimingTest::new("test", 100);
442 assert_eq!(test.name(), "test");
443 assert_eq!(test.num_samples(), 100);
444 }
445
446 #[test]
447 fn test_timing_test_input_generation() {
448 let test = TimingTest::new("test", 10);
449 let input1 = test.generate_input();
450 let input2 = test.generate_input();
451
452 assert_eq!(input1.len(), 32);
453 assert_eq!(input2.len(), 32);
454 assert_ne!(input1, input2);
456 }
457
458 #[test]
459 fn test_timing_test_custom_generator() {
460 let test = TimingTest::new("test", 10).with_input_generator(|| vec![0u8; 16]);
461 let input = test.generate_input();
462 assert_eq!(input.len(), 16);
463 assert_eq!(input, vec![0u8; 16]);
464 }
465
466 #[test]
467 fn test_analyzer_constant_time_operation() {
468 let analyzer = SideChannelAnalyzer::new();
469 let test = TimingTest::new("constant_op", 50);
470
471 let results = analyzer.analyze_timing(test, |_data| {
472 std::hint::black_box(42);
474 });
475
476 assert_eq!(results.test_name, "constant_op");
478 assert_eq!(results.num_samples, 50);
479 assert!(results.timing_stats.mean > 0.0);
480 assert!(results.input_timing_correlation.abs() < 0.5);
482 }
483
484 #[test]
485 fn test_analyzer_data_dependent_timing() {
486 let analyzer = SideChannelAnalyzer::new();
487 let test = TimingTest::new("data_dependent_op", 100);
488
489 let results = analyzer.analyze_timing(test, |data| {
490 let iterations = data[0] as usize * 10;
492 for _ in 0..iterations {
493 std::hint::black_box(42);
494 }
495 });
496
497 assert!(!results.is_constant_time);
499 assert!(!results.vulnerabilities.is_empty());
500 }
501
502 #[test]
503 fn test_side_channel_analysis_timing_safe() {
504 let analysis = SideChannelAnalysis {
505 test_name: "test".to_string(),
506 num_samples: 100,
507 timing_stats: TimingStatistics {
508 mean: 1000.0,
509 median: 1000.0,
510 std_dev: 10.0,
511 coefficient_of_variation: 0.01,
512 min: 990,
513 max: 1010,
514 range: 20,
515 },
516 is_constant_time: true,
517 input_timing_correlation: 0.05,
518 vulnerabilities: vec![],
519 leakage_score: 0.05,
520 };
521
522 assert!(analysis.is_timing_safe());
523 }
524
525 #[test]
526 fn test_side_channel_analysis_timing_unsafe() {
527 let analysis = SideChannelAnalysis {
528 test_name: "test".to_string(),
529 num_samples: 100,
530 timing_stats: TimingStatistics {
531 mean: 1000.0,
532 median: 1000.0,
533 std_dev: 200.0,
534 coefficient_of_variation: 0.2,
535 min: 500,
536 max: 1500,
537 range: 1000,
538 },
539 is_constant_time: false,
540 input_timing_correlation: 0.5,
541 vulnerabilities: vec![Vulnerability::DataDependentTiming(
542 VulnerabilitySeverity::High,
543 )],
544 leakage_score: 0.6,
545 };
546
547 assert!(!analysis.is_timing_safe());
548 }
549
550 #[test]
551 fn test_vulnerability_severity_ordering() {
552 assert!(VulnerabilitySeverity::Low < VulnerabilitySeverity::Medium);
553 assert!(VulnerabilitySeverity::Medium < VulnerabilitySeverity::High);
554 assert!(VulnerabilitySeverity::High < VulnerabilitySeverity::Critical);
555 }
556
557 #[test]
558 fn test_max_severity() {
559 let analysis = SideChannelAnalysis {
560 test_name: "test".to_string(),
561 num_samples: 100,
562 timing_stats: TimingStatistics::from_measurements(vec![100]),
563 is_constant_time: false,
564 input_timing_correlation: 0.0,
565 vulnerabilities: vec![
566 Vulnerability::DataDependentTiming(VulnerabilitySeverity::Medium),
567 Vulnerability::HighTimingVariance(VulnerabilitySeverity::Critical),
568 ],
569 leakage_score: 0.5,
570 };
571
572 assert_eq!(analysis.max_severity(), VulnerabilitySeverity::Critical);
573 }
574
575 #[test]
576 fn test_analyzer_custom_thresholds() {
577 let analyzer = SideChannelAnalyzer::new()
578 .with_constant_time_threshold(0.1)
579 .with_correlation_threshold(0.2);
580
581 assert_eq!(analyzer.constant_time_threshold, 0.1);
582 assert_eq!(analyzer.correlation_threshold, 0.2);
583 }
584}