async_spin_sleep/
lib.rs

1//! # Async-spin-sleep
2//!
3//! A dedicated timer driver implementation for easy use of high-precision sleep function in
4//! numerous async/await context.
5//!
6//! ## Features
7//!
8//! - *`system-clock`* (default): Enable use of system clock as a timer source.
9//!
10use std::{
11    pin::Pin,
12    sync::{
13        atomic::{AtomicIsize, Ordering},
14        Arc,
15    },
16    task::{Context, Poll},
17    time::{Duration, Instant},
18};
19
20use crossbeam::channel;
21use driver::{NodeDesc, WakerNode};
22
23/* -------------------------------------------- Init -------------------------------------------- */
24/// Usually ~1ms resolution, thus 4ms should be enough.
25#[cfg(target_os = "linux")]
26const DEFAULT_SCHEDULE_RESOLUTION: Duration = Duration::from_millis(4);
27
28/// In some cases, ~3ms reports unstable behavior. Give more margin here.
29#[cfg(target_os = "macos")]
30const DEFAULT_SCHEDULE_RESOLUTION: Duration = Duration::from_millis(10);
31
32/// Fallback case including Windows. For windows, timer resolution is ~16ms in most cases, however,
33/// gives some margin here to avoid unstable behavior.
34#[cfg(not(any(target_os = "linux", target_os = "macos")))]
35const DEFAULT_SCHEDULE_RESOLUTION: Duration = Duration::from_millis(33);
36
37/// A builder for [`Handle`].
38///
39/// Returned result of [`Builder::build`] or [`Builder::build_d_ary`] is a tuple of [`Handle`] and
40/// its dedicated driver function. The driver function recommended to be spawned as a separate
41/// thread.
42#[derive(Debug, derive_setters::Setters)]
43#[setters(prefix = "with_")]
44#[non_exhaustive]
45pub struct Builder {
46    /// Default scheduling resolution for this driver. Setting this to a lower value may decrease
47    /// CPU usage of the driver, but may also dangerously increase the chance of missing a wakeup
48    /// event due to the OS scheduler.
49    pub schedule_resolution: Duration,
50
51    /// Aborted nodes that are too far from execution may remain in the driver's memory for a long
52    /// time. This value specifies the maximum number of aborted nodes that can be stored in the
53    /// driver's memory. If this value is exceeded, the driver will collect garbage.
54    pub gc_threshold: usize,
55
56    /// Set channel capacity. This value is used to initialize the channel that connects the driver
57    /// and its handles. If the channel is full, the driver will block until the channel is
58    /// available.
59    ///
60    /// When [`None`] is specified, an unbounded channel will be used.
61    #[setters(into)]
62    pub channel_capacity: Option<usize>,
63
64    /// Determines the `number of yields per try_recv`. This option assumes that the `try_recv`
65    /// function contains a relatively heavy routine that is called at spin time, so adjust to an
66    /// appropriate value to optimize performance.
67    ///
68    /// Value will be clamped to at least 1. Does not use `NonZeroUsize`, just only for convenience.
69    pub yields_per_spin: usize,
70
71    /// A shared handle to represent naive number of garbage nodes to collect
72    #[setters(skip)]
73    gc_counter: Arc<AtomicIsize>,
74}
75
76impl Default for Builder {
77    fn default() -> Self {
78        Self {
79            schedule_resolution: DEFAULT_SCHEDULE_RESOLUTION,
80            gc_threshold: 1000,
81            channel_capacity: None,
82            gc_counter: Default::default(),
83            yields_per_spin: 1,
84        }
85    }
86}
87
88impl Builder {
89    /// Build timer driver with optimal configuration.
90    #[must_use = "Never drop the driver instance!"]
91    pub fn build(self) -> (Handle, impl FnOnce()) {
92        self.build_d_ary::<4>()
93    }
94
95    /// Build timer driver with desired D-ary heap configuration.
96    #[must_use = "Never drop the driver instance!"]
97    pub fn build_d_ary<const D: usize>(self) -> (Handle, impl FnOnce()) {
98        let _ = instant::origin(); // Force initialization of the global pivot time
99
100        let (tx, rx) = if let Some(cap) = self.channel_capacity {
101            channel::bounded(cap)
102        } else {
103            channel::unbounded()
104        };
105
106        let handle = Handle { tx: tx.clone(), gc_counter: self.gc_counter.clone() };
107        let driver = move || driver::execute::<D>(self, rx);
108
109        (handle, driver)
110    }
111}
112
113/// A shortcut for `Builder::default().build()`.
114pub fn create() -> (Handle, impl FnOnce()) {
115    Builder::default().build()
116}
117
118/// A shortcut for `Builder::default().build_d_ary<..>()`.
119pub fn create_d_ary<const D: usize>() -> (Handle, impl FnOnce()) {
120    Builder::default().build_d_ary::<D>()
121}
122
123/* ------------------------------------------- Driver ------------------------------------------- */
124mod driver {
125    use std::{
126        sync::{atomic::Ordering, Weak},
127        task::Waker,
128        time::{Duration, Instant},
129    };
130
131    use crossbeam::channel::{self, TryRecvError};
132    use dary_heap::DaryHeap;
133    use educe::Educe;
134
135    use crate::Builder;
136
137    #[derive(Debug)]
138    pub(crate) enum Event {
139        SleepUntil(NodeDesc),
140    }
141
142    pub(crate) fn execute<const D: usize>(this: Builder, rx: channel::Receiver<Event>) {
143        let mut nodes = DaryHeap::<Node, D>::new();
144        let pivot = Instant::now();
145        let to_usec = |x: Instant| x.duration_since(pivot).as_micros() as u64;
146        let resolution_usec = this.schedule_resolution.as_micros() as u64;
147
148        // As each node always increment the `gc_counter` by 1 when dropped, and the worker
149        // decrements the number by 1 when a node is cleared, this value is expected to be a naive
150        // status of 'aborted but not yet handled' node count, i.e. garbage nodes.
151        let gc_counter = this.gc_counter;
152        let yields_per_spin = this.yields_per_spin.max(1);
153
154        'worker: loop {
155            let now_ts = Instant::now();
156            let now = to_usec(now_ts);
157            let mut event = if let Some(node) = nodes.peek() {
158                let remain = node.timeout_usec.saturating_sub(now);
159                if remain > resolution_usec {
160                    let system_sleep_for = remain - resolution_usec;
161                    let timeout = Duration::from_micros(system_sleep_for);
162                    let deadline = now_ts + timeout;
163
164                    let Ok(x) = rx.recv_deadline(deadline).map_err(|e| match e {
165                        channel::RecvTimeoutError::Timeout => (),
166                        channel::RecvTimeoutError::Disconnected => {
167                            // The channel handle was closed while a task was still waiting. As no
168                            // new timer nodes will be added after receiving a `disconnected` state,
169                            // it's safe to pause the current thread until the next node's timeout
170                            // is reached.
171                            std::thread::sleep(deadline.saturating_duration_since(Instant::now()))
172                        }
173                    }) else {
174                        continue;
175                    };
176                    x
177                } else {
178                    let mut yields_counter = 0usize;
179
180                    'busy_wait: loop {
181                        let now = to_usec(Instant::now());
182                        if now >= node.timeout_usec {
183                            let node = nodes.pop().unwrap();
184
185                            if let Some(waker) = node.weak_waker.upgrade() {
186                                waker.value.lock().take().expect("logic error").wake();
187                            }
188
189                            let n_garbage = gc_counter.fetch_sub(1, Ordering::Release);
190                            if n_garbage > this.gc_threshold as isize {
191                                let n_collect = gc(&mut nodes) as _;
192                                gc_counter.fetch_sub(n_collect, Ordering::Release);
193                            }
194
195                            continue 'worker;
196                        } else {
197                            if yields_counter % yields_per_spin == 0 {
198                                match rx.try_recv() {
199                                    Ok(x) => break 'busy_wait x,
200                                    Err(TryRecvError::Disconnected) if nodes.is_empty() => {
201                                        break 'worker
202                                    }
203                                    Err(TryRecvError::Disconnected) | Err(TryRecvError::Empty) => {
204                                        // We still have timer nodes to deal with ..
205                                    }
206                                }
207                            }
208
209                            yields_counter += 1;
210                            std::thread::yield_now();
211                            continue 'busy_wait;
212                        }
213                    }
214                }
215            } else {
216                let Ok(x) = rx.recv() else { break };
217                x
218            };
219
220            if gc_counter.load(Ordering::Acquire) as usize > this.gc_threshold {
221                let n_collect = gc(&mut nodes) as _;
222                gc_counter.fetch_sub(n_collect, Ordering::Release);
223            }
224
225            'flush: loop {
226                match event {
227                    Event::SleepUntil(desc) => nodes
228                        .push(Node { timeout_usec: to_usec(desc.timeout), weak_waker: desc.waker }),
229                };
230
231                event = match rx.try_recv() {
232                    Ok(x) => x,
233                    Err(TryRecvError::Disconnected) if nodes.is_empty() => break 'worker,
234                    Err(TryRecvError::Disconnected) | Err(TryRecvError::Empty) => break 'flush,
235                };
236            }
237        }
238
239        assert!(nodes.is_empty());
240        assert_eq!(gc_counter.load(Ordering::Relaxed), 0);
241    }
242
243    fn gc<const D: usize>(nodes: &mut DaryHeap<Node, D>) -> usize {
244        let fn_retain = |x: &Node| x.weak_waker.upgrade().is_some();
245
246        let prev_len = nodes.len();
247
248        *nodes = {
249            let mut vec = std::mem::take(nodes).into_vec();
250            vec.retain(fn_retain);
251            DaryHeap::from(vec)
252        };
253
254        
255        prev_len - nodes.len()
256    }
257
258    #[derive(Debug, Clone)]
259    pub(crate) struct NodeDesc {
260        pub timeout: Instant,
261        pub waker: Weak<WakerNode>,
262    }
263
264    #[derive(Debug, Clone, Educe)]
265    #[educe(Eq, PartialEq, PartialOrd, Ord)]
266    pub(crate) struct Node {
267        #[educe(PartialOrd(method = "cmp_rev_partial"), Ord(method = "cmp_rev"))]
268        pub timeout_usec: u64,
269
270        #[educe(Eq(ignore), PartialEq(ignore), PartialOrd(ignore), Ord(ignore))]
271        pub weak_waker: Weak<WakerNode>,
272    }
273
274    fn cmp_rev(a: &u64, b: &u64) -> std::cmp::Ordering {
275        b.cmp(a)
276    }
277
278    fn cmp_rev_partial(a: &u64, b: &u64) -> Option<std::cmp::Ordering> {
279        b.partial_cmp(a)
280    }
281
282    #[derive(Debug)]
283    pub(crate) struct WakerNode {
284        //  dsa sads
285        value: parking_lot::Mutex<Option<Waker>>,
286    }
287
288    impl WakerNode {
289        pub fn new(waker: Waker) -> Self {
290            Self { value: parking_lot::Mutex::new(Some(waker)) }
291        }
292
293        pub fn is_expired(&self) -> bool {
294            self.value.lock().is_none()
295        }
296    }
297}
298
299/* ------------------------------------------- Handle ------------------------------------------- */
300/// A handle to the timer driver.
301#[derive(Debug, Clone)]
302pub struct Handle {
303    tx: channel::Sender<driver::Event>,
304    gc_counter: Arc<AtomicIsize>,
305}
306
307impl Handle {
308    /// Returns a future that sleeps for the specified duration.
309    ///
310    /// [`SleepFuture`] returns the duration that overly passed the specified duration.
311    pub fn sleep_for(&self, duration: Duration) -> SleepFuture {
312        self.sleep_until(Instant::now() + duration)
313    }
314
315    /// Returns a future that sleeps until the specified instant.
316    ///
317    /// [`SleepFuture`] returns the duration that overly passed the specified instant.
318    pub fn sleep_until(&self, timeout: Instant) -> SleepFuture {
319        SleepFuture {
320            state: SleepState::Pending(self.tx.clone()),
321            timeout,
322            gc_counter: self.gc_counter.clone(),
323        }
324    }
325
326    /// Create an interval controller which wakes up after specified `interval` duration on
327    /// every call to [`util::Interval::wait`]
328    pub fn interval(&self, interval: Duration) -> util::Interval {
329        util::Interval { handle: self.clone(), wakeup_time: Instant::now() + interval, interval }
330    }
331}
332
333pub mod util {
334    use crate::{instant, Report};
335    use std::time::{Duration, Instant};
336
337    /// Interval controller.
338    #[derive(Debug, Clone)]
339    pub struct Interval {
340        pub(crate) handle: super::Handle,
341        pub(crate) wakeup_time: Instant,
342        pub(crate) interval: Duration,
343    }
344
345    impl Interval {
346        /// Wait until next interval.
347        ///
348        /// This function will return [`SleepResult`] which contains the duration that overly passed
349        /// the specified interval. As it internally aligns to the specified interval, it should not
350        /// be drifted over time, in terms of [`Instant`] clock domain.
351        ///
352        /// - `minimum_interval`: Minimum interval to wait. This prevents burst after long
353        ///   inactivity on `Interval` object.
354        pub async fn tick_with_min_interval(&mut self, minimum_interval: Duration) -> Report {
355            assert!(minimum_interval <= self.interval);
356            let Self { handle, wakeup_time: wakeup, interval } = self;
357
358            let result = handle.sleep_until(*wakeup).await;
359            let now = Instant::now();
360            *wakeup += *interval;
361
362            let minimum_next = now + minimum_interval;
363            if minimum_next > *wakeup {
364                // XXX: We use 128 bit integer to avoid overflow in nanosecond domain.
365                //  This is not a perfect solution, but it should be enough for most cases,
366                //  as 'over-sleep' is relatively rare case thus slow path is not a big deal.
367                let interval_ns = interval.as_nanos();
368                let num_ticks = ((minimum_next - *wakeup).as_nanos() - 1) / interval_ns + 1;
369
370                // Set next wakeup to nearest aligned timestamp.
371                *wakeup += Duration::from_nanos((interval_ns * num_ticks) as _);
372            }
373
374            result
375        }
376
377        /// A shortcut for [`tick_with_min_interval`] with `minimum_interval` set to half.
378        pub async fn tick(&mut self) -> Report {
379            self.tick_with_min_interval(self.interval / 2).await
380        }
381
382        /// Reset interval to the specified duration.
383        pub fn set_interval(&mut self, interval: Duration) {
384            assert!(interval > Duration::default());
385            self.wakeup_time -= self.interval;
386            self.wakeup_time += interval;
387            self.interval = interval;
388        }
389
390        pub fn interval(&self) -> Duration {
391            self.interval
392        }
393
394        pub fn wakeup_time(&self) -> Instant {
395            self.wakeup_time
396        }
397
398        /// This method aligns the subsequent tick to a given interval. Following the alignment, the
399        /// timestamp will conform to the specified interval.
400        ///
401        /// Parameters:
402        /// - `now_since_epoch`: A function yielding the current time since the epoch. It's
403        ///   internally converted to an [`Instant`], hence, should return the most recent
404        ///   timestamp.
405        /// - `align_offset_ns`: This parameter can adjust the alignment timing. For example, an
406        ///   offset of 100us applied to a next tick scheduled at 9000us will push the tick to
407        ///   9100us.
408        /// - `initial_interval_tolerance`: This defines the permissible noise level for the initial
409        ///   interval. If set to zero, the actual sleep duration will exceed the interval,
410        ///   potentially causing a tick to be skipped as the actual sleep duration might be twice
411        ///   the interval. It's advisable to set it to 10% of the interval to prevent the
412        ///   `align_clock` command from disrupting the initial interval.
413        ///
414        /// Note: For example, if `now_since_epoch()` gives 8500us and the interval is 1000us, the
415        /// subsequent tick will be adjusted to 9000us, aligning with the interval.
416        pub fn align_with_clock(
417            &mut self,
418            now_since_epoch: impl FnOnce() -> Duration,
419            interval: Option<Duration>, // If none, reuse the previous interval.
420            initial_interval_tolerance: Option<Duration>, // If none, 10% of the interval.
421            align_offset_ns: i64,
422        ) {
423            let prev_trig = self.wakeup_time - self.interval;
424            let dst_now_ns = now_since_epoch().as_nanos() as i64;
425            let inst_now = Instant::now();
426
427            let interval = interval.unwrap_or(self.interval);
428            let interval_ns = interval.as_nanos() as i64;
429            let interval_tolerance =
430                initial_interval_tolerance.unwrap_or(Duration::from_nanos((interval_ns / 10) as _));
431
432            assert!(interval > Duration::default(), "interval must be larger than zero");
433            assert!(interval_tolerance < interval);
434
435            let ticks_to_align = {
436                let mut val = interval_ns - (dst_now_ns % interval_ns) + align_offset_ns;
437                if val < 0 {
438                    val += (val / interval_ns + 1) * interval_ns;
439                }
440                Duration::from_nanos((val % interval_ns) as _)
441            };
442
443            let mut desired_wake_up = inst_now + ticks_to_align;
444            if desired_wake_up < prev_trig + interval - interval_tolerance {
445                desired_wake_up += interval;
446                debug_assert!(desired_wake_up >= prev_trig + interval - interval_tolerance);
447            }
448
449            self.wakeup_time = desired_wake_up;
450            self.interval = interval;
451        }
452
453        /// Shortcut for [`Self::align_clock`] from now.
454        pub fn align_now(
455            &mut self,
456            interval: Option<Duration>,
457            initial_interval_tolerance: Option<Duration>,
458            align_offset_ns: i64,
459        ) {
460            self.align_with_clock(
461                instant::time_from_epoch,
462                interval,
463                initial_interval_tolerance,
464                align_offset_ns,
465            );
466        }
467
468        /// Shortcut for [`Self::align_clock`] with [`std::time::SystemTime`] as the time source.
469        #[cfg(feature = "system-clock")]
470        pub fn align_with_system_clock(
471            &mut self,
472            interval: Option<Duration>,
473            initial_interval_tolerance: Option<Duration>,
474            align_offset_ns: i64,
475        ) {
476            self.align_with_clock(
477                || {
478                    let now = std::time::SystemTime::now();
479                    now.duration_since(std::time::UNIX_EPOCH).unwrap()
480                },
481                interval,
482                initial_interval_tolerance,
483                align_offset_ns,
484            );
485        }
486    }
487}
488
489mod instant {
490    use std::time::Instant;
491
492    pub(crate) fn origin() -> Instant {
493        lazy_static::lazy_static!(
494            static ref PIVOT: Instant = Instant::now();
495        );
496
497        *PIVOT
498    }
499
500    pub(crate) fn time_from_epoch() -> std::time::Duration {
501        origin().elapsed()
502    }
503}
504
505/* ------------------------------------------- Future ------------------------------------------- */
506#[derive(Debug)]
507#[must_use = "futures do nothing unless you `.await` or poll them"]
508pub struct SleepFuture {
509    gc_counter: Arc<AtomicIsize>,
510    timeout: Instant,
511    state: SleepState,
512}
513
514#[cfg(test)]
515static_assertions::assert_impl_all!(SleepFuture: Send, Sync, Unpin);
516
517#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
518pub enum Report {
519    /// Timer has been correctly requested, and woke up normally. Returned value is overslept
520    /// duration than the requested timeout.
521    Completed(Duration),
522
523    /// Timer has not been requested as the timeout is already expired.
524    ExpiredTimer(Duration),
525
526    /// We woke up a bit earlier than required. It is usually hundreads of nanoseconds.
527    CompletedEarly(Duration),
528}
529
530impl Report {
531    /// Returns how many ticks the timer overslept than required duration.
532    pub fn overslept(&self) -> Duration {
533        match self {
534            Self::Completed(dur) => *dur,
535            Self::ExpiredTimer(dur) => *dur,
536
537            // NOTE: As duration should always be positive value, we can safely return zero here.
538            Self::CompletedEarly(_) => Duration::ZERO,
539        }
540    }
541
542    /// Returns true if the timer woke up earlier than required duration.
543    pub fn is_woke_up_early(&self) -> bool {
544        matches!(self, Self::CompletedEarly(_))
545    }
546
547    /// Trick to make previous code compatible with this version.
548    #[doc(hidden)]
549    pub fn unwrap(self) -> Self {
550        self
551    }
552
553    /// Trick to make previous code compatible with this version.
554    #[doc(hidden)]
555    pub fn ok(self) -> Option<Self> {
556        Some(self)
557    }
558}
559
560#[derive(Debug)]
561enum SleepState {
562    Pending(channel::Sender<driver::Event>),
563    Sleeping(Arc<WakerNode>),
564    Woken,
565}
566
567impl std::future::Future for SleepFuture {
568    type Output = Report;
569
570    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
571        let now = Instant::now();
572
573        if let Some(over) = now.checked_duration_since(self.timeout) {
574            let result = if matches!(self.state, SleepState::Sleeping(_)) {
575                self.state = SleepState::Woken;
576                Report::Completed(over)
577            } else {
578                Report::ExpiredTimer(over)
579            };
580
581            return Poll::Ready(result);
582        }
583
584        if let SleepState::Pending(tx) = &self.state {
585            let waker = Arc::new(WakerNode::new(cx.waker().clone()));
586            let event = driver::Event::SleepUntil(NodeDesc {
587                timeout: self.timeout,
588                waker: Arc::downgrade(&waker),
589            });
590
591            tx.send(event).expect("timer driver instance dropped!");
592            self.state = SleepState::Sleeping(waker);
593        } else if let SleepState::Sleeping(node) = &self.state {
594            // We woke up too early. Check if it is due to broken clock monotonicity.
595            if node.is_expired() {
596                self.state = SleepState::Woken;
597                return Poll::Ready(Report::CompletedEarly(self.timeout - now));
598            } else {
599                // If not, this is a spurious wakeup. We should sleep again.
600                // - XXX: Should we re-register wakeup timer here?
601            }
602        }
603
604        Poll::Pending
605    }
606}
607
608impl Drop for SleepFuture {
609    fn drop(&mut self) {
610        if !matches!(&self.state, SleepState::Pending { .. }) {
611            self.gc_counter.fetch_add(1, Ordering::Release);
612        }
613    }
614}