use crate::errors::Result;
use ring::rand::SecureRandom;
use std::time::Duration;
use subtle::ConstantTimeEq;
pub fn constant_time_compare(a: &[u8], b: &[u8]) -> bool {
if a.len() != b.len() {
return false;
}
a.ct_eq(b).into()
}
pub async fn random_delay(base_delay_ms: u64, max_random_ms: u64) {
let base_delay = Duration::from_millis(base_delay_ms);
let rng = ring::rand::SystemRandom::new();
let mut buf = [0u8; 8];
rng.fill(&mut buf).expect("system RNG failure");
let random_ms = u64::from_le_bytes(buf) % max_random_ms;
let random_delay = Duration::from_millis(random_ms);
let total_delay = base_delay + random_delay;
tokio::time::sleep(total_delay).await;
}
pub fn constant_time_string_compare(a: &str, b: &str) -> bool {
constant_time_compare(a.as_bytes(), b.as_bytes())
}
pub async fn timing_safe_operation<T, F, Fut>(operation: F, min_duration_ms: u64) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
let start = std::time::Instant::now();
let result = operation().await;
let elapsed = start.elapsed();
let min_duration = Duration::from_millis(min_duration_ms);
if elapsed < min_duration {
let remaining = min_duration - elapsed;
tokio::time::sleep(remaining).await;
}
result
}
pub async fn rsa_operation_protected<T, F, Fut>(operation: F) -> Result<T>
where
F: FnOnce() -> Fut,
Fut: std::future::Future<Output = Result<T>>,
{
random_delay(1, 5).await;
timing_safe_operation(operation, 10).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_constant_time_compare() {
let a = b"hello";
let b = b"hello";
let c = b"world";
assert!(constant_time_compare(a, b));
assert!(!constant_time_compare(a, c));
assert!(!constant_time_compare(a, b"hi"));
}
#[test]
fn test_constant_time_string_compare() {
assert!(constant_time_string_compare("hello", "hello"));
assert!(!constant_time_string_compare("hello", "world"));
assert!(!constant_time_string_compare("hello", "hi"));
}
#[tokio::test]
async fn test_random_delay() {
let start = std::time::Instant::now();
random_delay(0, 5).await;
let elapsed = start.elapsed();
assert!(elapsed >= Duration::from_millis(0));
assert!(elapsed < Duration::from_millis(50));
}
#[tokio::test]
async fn test_timing_safe_operation() {
let start = std::time::Instant::now();
let result = timing_safe_operation(
|| async { Ok::<_, crate::errors::AuthError>("success") },
50,
)
.await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert!(elapsed >= Duration::from_millis(50));
}
}