use core::future::Future;
use core::sync::atomic::{
Ordering,
compiler_fence,
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TimingProtection {
pub enabled: bool,
pub target_duration_ns: u64,
}
impl Default for TimingProtection {
fn default() -> Self {
Self {
enabled: true,
target_duration_ns: 1_000, }
}
}
impl TimingProtection {
pub fn new() -> Self {
Self::default()
}
pub fn strict() -> Self {
Self {
enabled: true,
target_duration_ns: 5_000,
}
}
pub fn permissive() -> Self {
Self {
enabled: false,
target_duration_ns: 0,
}
}
pub fn balanced() -> Self {
Self {
enabled: true,
target_duration_ns: 1_000,
}
}
pub fn protect<F, R>(&self, func: F) -> R
where
F: FnOnce() -> R,
{
if !self.enabled {
return func();
}
let start = Self::timestamp_ns();
let result = func();
let result = core::hint::black_box(result);
compiler_fence(Ordering::SeqCst);
Self::spin_until(start, self.target_duration_ns);
result
}
pub async fn protect_async<F, Fut, R>(&self, func: F) -> R
where
F: FnOnce() -> Fut,
Fut: Future<Output = R>,
{
if !self.enabled {
return func().await;
}
let start = Self::timestamp_ns();
let result = func().await;
let result = core::hint::black_box(result);
compiler_fence(Ordering::SeqCst);
Self::spin_until(start, self.target_duration_ns);
result
}
pub fn protect_with_timing<F, R>(&self, func: F) -> (R, u64)
where
F: FnOnce() -> R,
{
let start = Self::timestamp_ns();
if !self.enabled {
let result = func();
let elapsed = Self::timestamp_ns().wrapping_sub(start);
return (result, elapsed);
}
let result = func();
let result = core::hint::black_box(result);
compiler_fence(Ordering::SeqCst);
Self::spin_until(start, self.target_duration_ns);
let elapsed = Self::timestamp_ns().wrapping_sub(start);
(result, elapsed)
}
pub async fn protect_with_timing_async<F, Fut, R>(&self, func: F) -> (R, u64)
where
F: FnOnce() -> Fut,
Fut: Future<Output = R>,
{
let start = Self::timestamp_ns();
if !self.enabled {
let result = func().await;
let elapsed = Self::timestamp_ns().wrapping_sub(start);
return (result, elapsed);
}
let result = func().await;
let result = core::hint::black_box(result);
compiler_fence(Ordering::SeqCst);
Self::spin_until(start, self.target_duration_ns);
let elapsed = Self::timestamp_ns().wrapping_sub(start);
(result, elapsed)
}
#[inline]
fn timestamp_ns() -> u64 {
#[cfg(all(feature = "std", not(target_arch = "wasm32")))]
{
use std::sync::OnceLock;
use std::time::Instant;
static EPOCH: OnceLock<Instant> = OnceLock::new();
let epoch = EPOCH.get_or_init(Instant::now);
epoch.elapsed().as_nanos() as u64
}
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
{
Self::wasm_performance_now_ns()
}
#[cfg(not(any(
all(feature = "std", not(target_arch = "wasm32")),
all(target_arch = "wasm32", feature = "wasm"),
)))]
{
Self::monotonic_tick_counter()
}
}
#[cfg(all(target_arch = "wasm32", feature = "wasm"))]
#[inline]
fn wasm_performance_now_ns() -> u64 {
use wasm_bindgen::JsCast;
let global = js_sys::global();
let Ok(perf_val) =
js_sys::Reflect::get(&global, &wasm_bindgen::JsValue::from_str("performance"))
else {
return Self::monotonic_tick_counter();
};
if perf_val.is_null() || perf_val.is_undefined() {
return Self::monotonic_tick_counter();
}
let Ok(perf) = perf_val.dyn_into::<web_sys::Performance>() else {
return Self::monotonic_tick_counter();
};
let ms = perf.now();
if !ms.is_finite() || ms < 0.0 {
return Self::monotonic_tick_counter();
}
(ms * 1_000_000.0) as u64
}
#[cfg_attr(all(feature = "std", not(target_arch = "wasm32")), allow(dead_code))]
#[inline]
fn monotonic_tick_counter() -> u64 {
use core::sync::atomic::AtomicU64;
static COUNTER: AtomicU64 = AtomicU64::new(0);
COUNTER.fetch_add(1, Ordering::SeqCst)
}
#[inline(never)]
fn spin_until(start: u64, duration_ns: u64) {
while Self::timestamp_ns().wrapping_sub(start) < duration_ns {
core::hint::spin_loop();
}
compiler_fence(Ordering::SeqCst);
}
}
#[cfg(feature = "std")]
use std::sync::{
Arc,
RwLock,
};
#[cfg(feature = "std")]
static GLOBAL_TIMING_PROTECTION: std::sync::OnceLock<Arc<RwLock<TimingProtection>>> =
std::sync::OnceLock::new();
#[cfg(not(feature = "std"))]
static GLOBAL_TIMING_PROTECTION: once_cell::sync::Lazy<spin::Mutex<TimingProtection>> =
once_cell::sync::Lazy::new(|| spin::Mutex::new(TimingProtection::default()));
pub fn get_timing_protection() -> TimingProtection {
#[cfg(feature = "std")]
{
GLOBAL_TIMING_PROTECTION
.get_or_init(|| Arc::new(RwLock::new(TimingProtection::default())))
.read()
.map(|guard| *guard)
.unwrap_or_else(|_| TimingProtection::default())
}
#[cfg(not(feature = "std"))]
{
*GLOBAL_TIMING_PROTECTION.lock()
}
}
pub fn set_timing_protection(protection: TimingProtection) {
#[cfg(feature = "std")]
{
if let Some(global_protection) = GLOBAL_TIMING_PROTECTION.get() {
if let Ok(mut global) = global_protection.write() {
*global = protection;
}
} else {
let _ = GLOBAL_TIMING_PROTECTION.set(Arc::new(RwLock::new(protection)));
}
}
#[cfg(not(feature = "std"))]
{
*GLOBAL_TIMING_PROTECTION.lock() = protection;
}
}
pub fn protect_timing<F, R>(func: F) -> R
where
F: FnOnce() -> R,
{
get_timing_protection().protect(func)
}
pub async fn protect_timing_async<F, Fut, R>(func: F) -> R
where
F: FnOnce() -> Fut,
Fut: Future<Output = R>,
{
get_timing_protection().protect_async(func).await
}
pub fn protect_timing_with_timing<F, R>(func: F) -> (R, u64)
where
F: FnOnce() -> R,
{
get_timing_protection().protect_with_timing(func)
}
pub async fn protect_timing_with_timing_async<F, Fut, R>(func: F) -> (R, u64)
where
F: FnOnce() -> Fut,
Fut: Future<Output = R>,
{
get_timing_protection()
.protect_with_timing_async(func)
.await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_timing_protection_defaults() {
let protection = TimingProtection::default();
assert!(protection.enabled);
assert_eq!(protection.target_duration_ns, 1_000);
}
#[test]
fn test_timing_protection_strict() {
let protection = TimingProtection::strict();
assert!(protection.enabled);
assert_eq!(protection.target_duration_ns, 5_000);
}
#[test]
fn test_timing_protection_permissive() {
let protection = TimingProtection::permissive();
assert!(!protection.enabled);
assert_eq!(protection.target_duration_ns, 0);
}
#[test]
fn test_timing_protection_balanced() {
let protection = TimingProtection::balanced();
assert!(protection.enabled);
assert_eq!(protection.target_duration_ns, 1_000);
}
#[test]
fn test_protect() {
let protection = TimingProtection::new();
let result = protection.protect(|| 42);
assert_eq!(result, 42);
}
#[test]
fn test_protect_with_timing() {
let protection = TimingProtection::new();
let (result, elapsed) = protection.protect_with_timing(|| 42);
assert_eq!(result, 42);
assert!(elapsed > 0);
}
#[test]
fn test_global_timing_protection() {
let result = protect_timing(|| 42);
assert_eq!(result, 42);
}
#[test]
fn test_global_timing_protection_with_timing() {
let (result, elapsed) = protect_timing_with_timing(|| 42);
assert_eq!(result, 42);
assert!(elapsed > 0);
}
#[test]
fn test_global_timing_protection_config() {
let config = get_timing_protection();
assert_eq!(config, TimingProtection::default());
let new_config = TimingProtection::strict();
set_timing_protection(new_config);
let _result = protect_timing(|| 42);
let (_result, _elapsed) = protect_timing_with_timing(|| 42);
}
}