Skip to main content

commonware_runtime/utils/
mod.rs

1//! Utility functions for interacting with any runtime.
2
3use commonware_utils::sync::{Condvar, Mutex};
4use futures::task::ArcWake;
5use std::{
6    any::Any,
7    future::Future,
8    pin::Pin,
9    sync::Arc,
10    task::{Context, Poll},
11};
12
13commonware_macros::stability_mod!(BETA, pub mod buffer);
14pub mod signal;
15#[cfg(not(target_arch = "wasm32"))]
16pub(crate) mod thread;
17
18mod handle;
19pub use handle::Handle;
20#[commonware_macros::stability(ALPHA)]
21pub(crate) use handle::Panicked;
22pub(crate) use handle::{Aborter, MetricHandle, Panicker};
23
24mod cell;
25pub use cell::Cell as ContextCell;
26
27pub(crate) mod supervision;
28
29/// The execution mode of a task.
30#[derive(Copy, Clone, Debug)]
31pub enum Execution {
32    /// Task runs on a dedicated thread.
33    Dedicated,
34    /// Task runs on the shared executor. `true` marks short blocking work that should
35    /// use the runtime's blocking-friendly pool.
36    Shared(bool),
37}
38
39impl Default for Execution {
40    fn default() -> Self {
41        Self::Shared(false)
42    }
43}
44
45/// Yield control back to the runtime.
46pub async fn reschedule() {
47    struct Reschedule {
48        yielded: bool,
49    }
50
51    impl Future for Reschedule {
52        type Output = ();
53
54        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
55            if self.yielded {
56                Poll::Ready(())
57            } else {
58                self.yielded = true;
59                cx.waker().wake_by_ref();
60                Poll::Pending
61            }
62        }
63    }
64
65    Reschedule { yielded: false }.await
66}
67
68pub(crate) fn extract_panic_message(err: &(dyn Any + Send)) -> String {
69    err.downcast_ref::<&str>().map_or_else(
70        || {
71            err.downcast_ref::<String>()
72                .map_or_else(|| format!("{err:?}"), |s| s.clone())
73        },
74        |s| s.to_string(),
75    )
76}
77
78/// Synchronization primitive that enables a thread to block until a waker delivers a signal.
79pub struct Blocker {
80    /// Tracks whether a wake-up signal has been delivered (even if wait has not started yet).
81    state: Mutex<bool>,
82    /// Condvar used to park and resume the thread when the signal flips to true.
83    cv: Condvar,
84}
85
86impl Blocker {
87    /// Create a new [Blocker].
88    pub fn new() -> Arc<Self> {
89        Arc::new(Self {
90            state: Mutex::new(false),
91            cv: Condvar::new(),
92        })
93    }
94
95    /// Block the current thread until a waker delivers a signal.
96    pub fn wait(&self) {
97        // Use a loop to tolerate spurious wake-ups and only proceed once a real signal arrives.
98        let mut signaled = self.state.lock();
99        while !*signaled {
100            self.cv.wait(&mut signaled);
101        }
102
103        // Reset the flag so subsequent waits park again until the next wake signal.
104        *signaled = false;
105    }
106}
107
108impl ArcWake for Blocker {
109    fn wake_by_ref(arc_self: &Arc<Self>) {
110        // Mark as signaled (and release lock before notifying).
111        {
112            let mut signaled = arc_self.state.lock();
113            *signaled = true;
114        }
115
116        // Notify a single waiter so the blocked thread re-checks the flag.
117        arc_self.cv.notify_one();
118    }
119}
120
121#[cfg(test)]
122mod tests {
123    use super::*;
124    use futures::task::waker;
125    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
126
127    #[test]
128    fn test_blocker_waits_until_wake() {
129        let blocker = Blocker::new();
130        let started = Arc::new(AtomicBool::new(false));
131        let completed = Arc::new(AtomicBool::new(false));
132
133        let thread_blocker = blocker.clone();
134        let thread_started = started.clone();
135        let thread_completed = completed.clone();
136        let handle = std::thread::spawn(move || {
137            thread_started.store(true, Ordering::SeqCst);
138            thread_blocker.wait();
139            thread_completed.store(true, Ordering::SeqCst);
140        });
141
142        while !started.load(Ordering::SeqCst) {
143            std::thread::yield_now();
144        }
145
146        assert!(!completed.load(Ordering::SeqCst));
147        waker(blocker).wake();
148        handle.join().unwrap();
149        assert!(completed.load(Ordering::SeqCst));
150    }
151
152    #[test]
153    fn test_blocker_handles_pre_wake() {
154        let blocker = Blocker::new();
155        waker(blocker.clone()).wake();
156
157        let completed = Arc::new(AtomicBool::new(false));
158        let thread_blocker = blocker;
159        let thread_completed = completed.clone();
160        std::thread::spawn(move || {
161            thread_blocker.wait();
162            thread_completed.store(true, Ordering::SeqCst);
163        })
164        .join()
165        .unwrap();
166
167        assert!(completed.load(Ordering::SeqCst));
168    }
169
170    #[test]
171    fn test_blocker_reusable_across_signals() {
172        let blocker = Blocker::new();
173        let completed = Arc::new(AtomicUsize::new(0));
174
175        let thread_blocker = blocker.clone();
176        let thread_completed = completed.clone();
177        let handle = std::thread::spawn(move || {
178            for _ in 0..2 {
179                thread_blocker.wait();
180                thread_completed.fetch_add(1, Ordering::SeqCst);
181            }
182        });
183
184        for expected in 1..=2 {
185            waker(blocker.clone()).wake();
186            while completed.load(Ordering::SeqCst) < expected {
187                std::thread::yield_now();
188            }
189        }
190
191        handle.join().unwrap();
192        assert_eq!(completed.load(Ordering::SeqCst), 2);
193    }
194}