1use core::{num::NonZeroU32, time::Duration};
33
34use spin::Mutex;
35
36use super::LimitError;
37
38#[cfg(test)]
39use std::sync::{Mutex as StdMutex, MutexGuard as StdMutexGuard};
40
41#[derive(Clone, Copy, Debug, PartialEq, Eq)]
50pub struct ExecutionTimerConfig {
51 pub limit: Duration,
53 pub check_interval: NonZeroU32,
55}
56
57#[derive(Debug)]
59pub struct ExecutionTimer {
60 config: Option<ExecutionTimerConfig>,
61 start: Option<Duration>,
62 accumulated_units: u32,
63 last_elapsed: Duration,
64}
65
66pub trait TimeSource: Send + Sync {
68 fn now(&self) -> Option<Duration>;
70}
71
72#[cfg(feature = "std")]
73#[derive(Debug)]
74struct StdTimeSource;
75
76#[cfg(feature = "std")]
77impl StdTimeSource {
78 const fn new() -> Self {
79 Self
80 }
81}
82
83#[cfg(feature = "std")]
84impl TimeSource for StdTimeSource {
85 fn now(&self) -> Option<Duration> {
86 use std::sync::OnceLock;
87
88 static ANCHOR: OnceLock<std::time::Instant> = OnceLock::new();
89 let anchor = ANCHOR.get_or_init(std::time::Instant::now);
90 Some(anchor.elapsed())
91 }
92}
93
94#[cfg(feature = "std")]
95static STD_TIME_SOURCE: StdTimeSource = StdTimeSource::new();
96
97#[cfg(any(test, not(feature = "std")))]
98static TIME_SOURCE_OVERRIDE: Mutex<Option<&'static dyn TimeSource>> = Mutex::new(None);
99
100#[cfg(any(test, not(feature = "std")))]
101#[derive(Debug, Clone, Copy, PartialEq, Eq)]
102pub enum TimeSourceRegistrationError {
103 AlreadySet,
104}
105
106#[cfg(any(test, not(feature = "std")))]
107impl core::fmt::Display for TimeSourceRegistrationError {
108 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
109 match self {
110 Self::AlreadySet => f.write_str("time source already configured"),
111 }
112 }
113}
114
115#[cfg(any(test, not(feature = "std")))]
116impl core::error::Error for TimeSourceRegistrationError {}
117
118static FALLBACK_EXECUTION_TIMER_CONFIG: Mutex<Option<ExecutionTimerConfig>> = Mutex::new(None);
119
120#[cfg(test)]
121static LIMITS_TEST_LOCK: StdMutex<()> = StdMutex::new(());
122
123#[cfg(test)]
124pub fn acquire_limits_test_lock() -> StdMutexGuard<'static, ()> {
125 LIMITS_TEST_LOCK
126 .lock()
127 .unwrap_or_else(|poisoned| poisoned.into_inner())
128}
129
130pub fn monotonic_now() -> Option<Duration> {
132 #[cfg(any(test, not(feature = "std")))]
133 if let Some(source) = {
136 let guard = TIME_SOURCE_OVERRIDE.lock();
137 *guard
138 } {
139 if let Some(duration) = source.now() {
140 return Some(duration);
141 }
142 }
143
144 #[cfg(feature = "std")]
145 {
146 STD_TIME_SOURCE.now()
147 }
148
149 #[cfg(not(feature = "std"))]
150 {
151 None
152 }
153}
154
155#[cfg(any(test, not(feature = "std")))]
156pub fn set_time_source(source: &'static dyn TimeSource) -> Result<(), TimeSourceRegistrationError> {
157 let mut slot = TIME_SOURCE_OVERRIDE.lock();
158 if slot.is_some() {
159 Err(TimeSourceRegistrationError::AlreadySet)
160 } else {
161 *slot = Some(source);
162 Ok(())
163 }
164}
165
166pub fn set_fallback_execution_timer_config(config: Option<ExecutionTimerConfig>) {
185 *FALLBACK_EXECUTION_TIMER_CONFIG.lock() = config;
186}
187
188pub fn fallback_execution_timer_config() -> Option<ExecutionTimerConfig> {
199 let guard = FALLBACK_EXECUTION_TIMER_CONFIG.lock();
200 guard.as_ref().copied()
201}
202
203impl ExecutionTimer {
204 pub const fn new(config: Option<ExecutionTimerConfig>) -> Self {
206 Self {
207 config,
208 start: None,
209 accumulated_units: 0,
210 last_elapsed: Duration::ZERO,
211 }
212 }
213
214 pub const fn reset(&mut self) {
216 self.start = None;
217 self.accumulated_units = 0;
218 self.last_elapsed = Duration::ZERO;
219 }
220
221 pub const fn start(&mut self, now: Duration) {
223 self.start = Some(now);
224 self.accumulated_units = 0;
225 self.last_elapsed = Duration::ZERO;
226 }
227
228 pub const fn config(&self) -> Option<ExecutionTimerConfig> {
230 self.config
231 }
232
233 pub const fn limit(&self) -> Option<Duration> {
235 match self.config {
236 Some(config) => Some(config.limit),
237 None => None,
238 }
239 }
240
241 pub const fn last_elapsed(&self) -> Duration {
243 self.last_elapsed
244 }
245
246 pub fn tick(&mut self, work_units: u32, now: Duration) -> Result<(), LimitError> {
248 let Some(config) = self.config else {
249 return Ok(());
250 };
251 self.accumulated_units = self.accumulated_units.saturating_add(work_units);
252 if self.accumulated_units < config.check_interval.get() {
253 return Ok(());
254 }
255
256 let interval = config.check_interval.get();
258 self.accumulated_units %= interval;
259 self.check_now(now)
260 }
261
262 pub fn check_now(&mut self, now: Duration) -> Result<(), LimitError> {
264 let Some(config) = self.config else {
265 return Ok(());
266 };
267 let Some(start) = self.start else {
268 return Ok(());
269 };
270
271 let elapsed = now.checked_sub(start).unwrap_or(Duration::ZERO);
272 self.last_elapsed = elapsed;
273 if elapsed > config.limit {
274 return Err(LimitError::TimeLimitExceeded {
275 elapsed,
276 limit: config.limit,
277 });
278 }
279 Ok(())
280 }
281
282 pub fn elapsed(&self, now: Duration) -> Option<Duration> {
284 let start = self.start?;
285 Some(now.checked_sub(start).unwrap_or(Duration::ZERO))
286 }
287
288 pub const fn resume_from_elapsed(&mut self, now: Duration, elapsed: Duration) {
291 if self.config.is_none() {
292 return;
293 }
294
295 self.start = Some(now.saturating_sub(elapsed));
296 self.last_elapsed = elapsed;
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use core::{
304 num::NonZeroU32,
305 sync::atomic::{AtomicU64, Ordering},
306 time::Duration,
307 };
308
309 fn nz(value: u32) -> NonZeroU32 {
310 NonZeroU32::new(value).unwrap_or(NonZeroU32::MIN)
311 }
312
313 #[test]
314 fn tick_defers_checks_until_interval_is_reached() {
315 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
316 limit: Duration::from_millis(100),
317 check_interval: nz(4),
318 }));
319
320 timer.start(Duration::from_millis(0));
321
322 for step in 1..4 {
323 let now = Duration::from_millis((step * 10) as u64);
324 let result = timer.tick(1, now);
325 assert_eq!(result, Ok(()), "tick before reaching interval must succeed");
326 assert_eq!(timer.last_elapsed(), Duration::ZERO);
327 }
328
329 let result = timer.tick(1, Duration::from_millis(40));
330 assert_eq!(result, Ok(()), "tick at interval boundary must succeed");
331 assert_eq!(timer.last_elapsed(), Duration::from_millis(40));
332 }
333
334 #[test]
335 fn check_now_reports_limit_exceeded() {
336 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
337 limit: Duration::from_millis(25),
338 check_interval: nz(1),
339 }));
340
341 timer.start(Duration::from_millis(0));
342 assert_eq!(
343 timer.tick(1, Duration::from_millis(10)),
344 Ok(()),
345 "tick before limit breach must succeed"
346 );
347
348 let result = timer.check_now(Duration::from_millis(30));
349 assert!(matches!(&result, Err(LimitError::TimeLimitExceeded { .. })));
350
351 if let Err(LimitError::TimeLimitExceeded { elapsed, limit }) = result {
352 assert!(elapsed > limit);
353 assert_eq!(limit, Duration::from_millis(25));
354 }
355 }
356
357 #[test]
358 fn tick_reports_limit_exceeded() {
359 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
360 limit: Duration::from_millis(30),
361 check_interval: nz(2),
362 }));
363
364 timer.start(Duration::from_millis(0));
365 assert_eq!(
366 timer.tick(1, Duration::from_millis(10)),
367 Ok(()),
368 "initial tick must succeed"
369 );
370
371 let result = timer.tick(1, Duration::from_millis(35));
372 assert!(matches!(&result, Err(LimitError::TimeLimitExceeded { .. })));
373
374 if let Err(LimitError::TimeLimitExceeded { elapsed, limit }) = result {
375 assert!(elapsed > limit);
376 assert_eq!(limit, Duration::from_millis(30));
377 assert_eq!(timer.last_elapsed(), elapsed);
378 }
379 }
380
381 #[test]
382 fn tick_before_start_is_noop() {
383 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
384 limit: Duration::from_secs(1),
385 check_interval: nz(1),
386 }));
387
388 let result = timer.tick(1, Duration::from_millis(100));
389 assert_eq!(result, Ok(()), "tick before start should be ignored");
390 assert_eq!(timer.last_elapsed(), Duration::ZERO);
391 assert!(timer.elapsed(Duration::from_millis(200)).is_none());
392 }
393
394 #[test]
395 fn check_now_allows_elapsed_equal_to_limit() {
396 let mut timer = ExecutionTimer::new(Some(ExecutionTimerConfig {
397 limit: Duration::from_millis(50),
398 check_interval: nz(1),
399 }));
400
401 timer.start(Duration::from_millis(0));
402 assert_eq!(
403 timer.tick(1, Duration::from_millis(30)),
404 Ok(()),
405 "tick prior to equality check must succeed"
406 );
407 let result = timer.check_now(Duration::from_millis(50));
408 assert_eq!(result, Ok(()), "elapsed equal to limit must not fail");
409 assert_eq!(timer.last_elapsed(), Duration::from_millis(50));
410 }
411
412 #[test]
413 fn tick_is_noop_when_limit_disabled() {
414 let mut timer = ExecutionTimer::new(None);
415
416 timer.start(Duration::from_millis(0));
417
418 for step in 0..8 {
419 let now = Duration::from_millis((step + 1) as u64);
420 assert_eq!(
421 timer.tick(1, now),
422 Ok(()),
423 "ticks with disabled limit must succeed"
424 );
425 }
426
427 assert_eq!(timer.last_elapsed(), Duration::ZERO);
428 }
429
430 #[test]
431 fn check_now_is_noop_before_start() {
432 let mut timer = ExecutionTimer::new(None);
433 let result = timer.check_now(Duration::from_secs(1));
434 assert_eq!(result, Ok(()), "check before start must be ignored");
435 assert!(timer.elapsed(Duration::from_secs(2)).is_none());
436 }
437
438 #[test]
439 fn elapsed_reports_offset_from_start() {
440 let mut timer = ExecutionTimer::new(None);
441 timer.start(Duration::from_millis(5));
442 let elapsed = timer.elapsed(Duration::from_millis(20));
443 assert_eq!(elapsed, Some(Duration::from_millis(15)));
444 }
445
446 #[test]
447 fn monotonic_now_uses_override_when_present() {
448 static TEST_TIME: AtomicU64 = AtomicU64::new(0);
449
450 struct TestSource;
451
452 impl TimeSource for TestSource {
453 fn now(&self) -> Option<Duration> {
454 Some(Duration::from_nanos(TEST_TIME.load(Ordering::Relaxed)))
455 }
456 }
457
458 static SOURCE: TestSource = TestSource;
459
460 let _suite_guard = super::acquire_limits_test_lock();
461
462 let mut slot = super::TIME_SOURCE_OVERRIDE.lock();
463 let previous = (*slot).replace(&SOURCE);
464 drop(slot);
465
466 TEST_TIME.store(123_000_000, Ordering::Relaxed);
467 assert_eq!(monotonic_now(), Some(Duration::from_nanos(123_000_000)));
468
469 let mut slot = super::TIME_SOURCE_OVERRIDE.lock();
470 *slot = previous;
471 }
472}