lib_q_aead/security/
timing.rs1use core::future::Future;
33use core::sync::atomic::{
34 Ordering,
35 compiler_fence,
36};
37
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub struct TimingProtection {
41 pub enabled: bool,
43 pub target_duration_ns: u64,
50}
51
52impl Default for TimingProtection {
53 fn default() -> Self {
54 Self {
55 enabled: true,
56 target_duration_ns: 1_000, }
58 }
59}
60
61impl TimingProtection {
62 pub fn new() -> Self {
64 Self::default()
65 }
66
67 pub fn strict() -> Self {
69 Self {
70 enabled: true,
71 target_duration_ns: 5_000,
72 }
73 }
74
75 pub fn permissive() -> Self {
77 Self {
78 enabled: false,
79 target_duration_ns: 0,
80 }
81 }
82
83 pub fn balanced() -> Self {
85 Self {
86 enabled: true,
87 target_duration_ns: 1_000,
88 }
89 }
90
91 pub fn protect<F, R>(&self, func: F) -> R
94 where
95 F: FnOnce() -> R,
96 {
97 if !self.enabled {
98 return func();
99 }
100
101 let start = Self::timestamp_ns();
102 let result = func();
103 let result = core::hint::black_box(result);
104 compiler_fence(Ordering::SeqCst);
105
106 Self::spin_until(start, self.target_duration_ns);
107 result
108 }
109
110 pub async fn protect_async<F, Fut, R>(&self, func: F) -> R
112 where
113 F: FnOnce() -> Fut,
114 Fut: Future<Output = R>,
115 {
116 if !self.enabled {
117 return func().await;
118 }
119
120 let start = Self::timestamp_ns();
121 let result = func().await;
122 let result = core::hint::black_box(result);
123 compiler_fence(Ordering::SeqCst);
124
125 Self::spin_until(start, self.target_duration_ns);
126 result
127 }
128
129 pub fn protect_with_timing<F, R>(&self, func: F) -> (R, u64)
137 where
138 F: FnOnce() -> R,
139 {
140 let start = Self::timestamp_ns();
141
142 if !self.enabled {
143 let result = func();
144 let elapsed = Self::timestamp_ns().wrapping_sub(start);
145 return (result, elapsed);
146 }
147
148 let result = func();
149 let result = core::hint::black_box(result);
150 compiler_fence(Ordering::SeqCst);
151
152 Self::spin_until(start, self.target_duration_ns);
153
154 let elapsed = Self::timestamp_ns().wrapping_sub(start);
155 (result, elapsed)
156 }
157
158 pub async fn protect_with_timing_async<F, Fut, R>(&self, func: F) -> (R, u64)
163 where
164 F: FnOnce() -> Fut,
165 Fut: Future<Output = R>,
166 {
167 let start = Self::timestamp_ns();
168
169 if !self.enabled {
170 let result = func().await;
171 let elapsed = Self::timestamp_ns().wrapping_sub(start);
172 return (result, elapsed);
173 }
174
175 let result = func().await;
176 let result = core::hint::black_box(result);
177 compiler_fence(Ordering::SeqCst);
178
179 Self::spin_until(start, self.target_duration_ns);
180
181 let elapsed = Self::timestamp_ns().wrapping_sub(start);
182 (result, elapsed)
183 }
184
185 #[inline]
192 fn timestamp_ns() -> u64 {
193 #[cfg(all(feature = "std", not(target_arch = "wasm32")))]
194 {
195 use std::sync::OnceLock;
196 use std::time::Instant;
197 static EPOCH: OnceLock<Instant> = OnceLock::new();
198 let epoch = EPOCH.get_or_init(Instant::now);
199 epoch.elapsed().as_nanos() as u64
200 }
201
202 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
203 {
204 Self::wasm_performance_now_ns()
205 }
206
207 #[cfg(not(any(
208 all(feature = "std", not(target_arch = "wasm32")),
209 all(target_arch = "wasm32", feature = "wasm"),
210 )))]
211 {
212 Self::monotonic_tick_counter()
213 }
214 }
215
216 #[cfg(all(target_arch = "wasm32", feature = "wasm"))]
221 #[inline]
222 fn wasm_performance_now_ns() -> u64 {
223 use wasm_bindgen::JsCast;
224
225 let global = js_sys::global();
226 let Ok(perf_val) =
227 js_sys::Reflect::get(&global, &wasm_bindgen::JsValue::from_str("performance"))
228 else {
229 return Self::monotonic_tick_counter();
230 };
231 if perf_val.is_null() || perf_val.is_undefined() {
232 return Self::monotonic_tick_counter();
233 }
234 let Ok(perf) = perf_val.dyn_into::<web_sys::Performance>() else {
235 return Self::monotonic_tick_counter();
236 };
237 let ms = perf.now();
238 if !ms.is_finite() || ms < 0.0 {
239 return Self::monotonic_tick_counter();
240 }
241 (ms * 1_000_000.0) as u64
242 }
243
244 #[cfg_attr(all(feature = "std", not(target_arch = "wasm32")), allow(dead_code))]
248 #[inline]
249 fn monotonic_tick_counter() -> u64 {
250 use core::sync::atomic::AtomicU64;
251 static COUNTER: AtomicU64 = AtomicU64::new(0);
252 COUNTER.fetch_add(1, Ordering::SeqCst)
253 }
254
255 #[inline(never)]
260 fn spin_until(start: u64, duration_ns: u64) {
261 while Self::timestamp_ns().wrapping_sub(start) < duration_ns {
262 core::hint::spin_loop();
263 }
264 compiler_fence(Ordering::SeqCst);
265 }
266}
267
268#[cfg(feature = "std")]
273use std::sync::{
274 Arc,
275 RwLock,
276};
277
278#[cfg(feature = "std")]
279static GLOBAL_TIMING_PROTECTION: std::sync::OnceLock<Arc<RwLock<TimingProtection>>> =
280 std::sync::OnceLock::new();
281#[cfg(not(feature = "std"))]
282static GLOBAL_TIMING_PROTECTION: once_cell::sync::Lazy<spin::Mutex<TimingProtection>> =
283 once_cell::sync::Lazy::new(|| spin::Mutex::new(TimingProtection::default()));
284
285pub fn get_timing_protection() -> TimingProtection {
287 #[cfg(feature = "std")]
288 {
289 GLOBAL_TIMING_PROTECTION
290 .get_or_init(|| Arc::new(RwLock::new(TimingProtection::default())))
291 .read()
292 .map(|guard| *guard)
293 .unwrap_or_else(|_| TimingProtection::default())
294 }
295 #[cfg(not(feature = "std"))]
296 {
297 *GLOBAL_TIMING_PROTECTION.lock()
298 }
299}
300
301pub fn set_timing_protection(protection: TimingProtection) {
303 #[cfg(feature = "std")]
304 {
305 if let Some(global_protection) = GLOBAL_TIMING_PROTECTION.get() {
306 if let Ok(mut global) = global_protection.write() {
307 *global = protection;
308 }
309 } else {
310 let _ = GLOBAL_TIMING_PROTECTION.set(Arc::new(RwLock::new(protection)));
311 }
312 }
313 #[cfg(not(feature = "std"))]
314 {
315 *GLOBAL_TIMING_PROTECTION.lock() = protection;
316 }
317}
318
319pub fn protect_timing<F, R>(func: F) -> R
321where
322 F: FnOnce() -> R,
323{
324 get_timing_protection().protect(func)
325}
326
327pub async fn protect_timing_async<F, Fut, R>(func: F) -> R
329where
330 F: FnOnce() -> Fut,
331 Fut: Future<Output = R>,
332{
333 get_timing_protection().protect_async(func).await
334}
335
336pub fn protect_timing_with_timing<F, R>(func: F) -> (R, u64)
340where
341 F: FnOnce() -> R,
342{
343 get_timing_protection().protect_with_timing(func)
344}
345
346pub async fn protect_timing_with_timing_async<F, Fut, R>(func: F) -> (R, u64)
348where
349 F: FnOnce() -> Fut,
350 Fut: Future<Output = R>,
351{
352 get_timing_protection()
353 .protect_with_timing_async(func)
354 .await
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360
361 #[test]
362 fn test_timing_protection_defaults() {
363 let protection = TimingProtection::default();
364 assert!(protection.enabled);
365 assert_eq!(protection.target_duration_ns, 1_000);
366 }
367
368 #[test]
369 fn test_timing_protection_strict() {
370 let protection = TimingProtection::strict();
371 assert!(protection.enabled);
372 assert_eq!(protection.target_duration_ns, 5_000);
373 }
374
375 #[test]
376 fn test_timing_protection_permissive() {
377 let protection = TimingProtection::permissive();
378 assert!(!protection.enabled);
379 assert_eq!(protection.target_duration_ns, 0);
380 }
381
382 #[test]
383 fn test_timing_protection_balanced() {
384 let protection = TimingProtection::balanced();
385 assert!(protection.enabled);
386 assert_eq!(protection.target_duration_ns, 1_000);
387 }
388
389 #[test]
390 fn test_protect() {
391 let protection = TimingProtection::new();
392 let result = protection.protect(|| 42);
393 assert_eq!(result, 42);
394 }
395
396 #[test]
397 fn test_protect_with_timing() {
398 let protection = TimingProtection::new();
399 let (result, elapsed) = protection.protect_with_timing(|| 42);
400 assert_eq!(result, 42);
401 assert!(elapsed > 0);
402 }
403
404 #[test]
405 fn test_global_timing_protection() {
406 let result = protect_timing(|| 42);
407 assert_eq!(result, 42);
408 }
409
410 #[test]
411 fn test_global_timing_protection_with_timing() {
412 let (result, elapsed) = protect_timing_with_timing(|| 42);
413 assert_eq!(result, 42);
414 assert!(elapsed > 0);
415 }
416
417 #[test]
418 fn test_global_timing_protection_config() {
419 let config = get_timing_protection();
420 assert_eq!(config, TimingProtection::default());
421
422 let new_config = TimingProtection::strict();
423 set_timing_protection(new_config);
424
425 let _result = protect_timing(|| 42);
426 let (_result, _elapsed) = protect_timing_with_timing(|| 42);
427 }
428}