commonware_runtime/utils/
mod.rs

1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::{Runner, Spawner};
5#[cfg(test)]
6use futures::stream::{FuturesUnordered, StreamExt};
7use futures::task::ArcWake;
8use std::{
9    any::Any,
10    future::Future,
11    pin::Pin,
12    sync::{Arc, Condvar, Mutex},
13    task::{Context, Poll},
14};
15
16pub mod buffer;
17pub mod signal;
18
19mod handle;
20pub use handle::Handle;
21pub(crate) use handle::{Aborter, MetricHandle, Panicked, Panicker};
22
23mod cell;
24pub use cell::Cell as ContextCell;
25
26pub(crate) mod supervision;
27
28/// The execution mode of a task.
29#[derive(Copy, Clone, Debug)]
30pub enum Execution {
31    /// Task runs on a dedicated thread.
32    Dedicated,
33    /// Task runs on the shared executor. `true` marks short blocking work that should
34    /// use the runtime's blocking-friendly pool.
35    Shared(bool),
36}
37
38impl Default for Execution {
39    fn default() -> Self {
40        Self::Shared(false)
41    }
42}
43
44/// Yield control back to the runtime.
45pub async fn reschedule() {
46    struct Reschedule {
47        yielded: bool,
48    }
49
50    impl Future for Reschedule {
51        type Output = ();
52
53        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
54            if self.yielded {
55                Poll::Ready(())
56            } else {
57                self.yielded = true;
58                cx.waker().wake_by_ref();
59                Poll::Pending
60            }
61        }
62    }
63
64    Reschedule { yielded: false }.await
65}
66
67fn extract_panic_message(err: &(dyn Any + Send)) -> String {
68    err.downcast_ref::<&str>().map_or_else(
69        || {
70            err.downcast_ref::<String>()
71                .map_or_else(|| format!("{err:?}"), |s| s.clone())
72        },
73        |s| s.to_string(),
74    )
75}
76
77/// Async reader–writer lock.
78///
79/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
80/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
81///
82/// Usage:
83/// ```rust
84/// use commonware_runtime::{Spawner, Runner, deterministic, RwLock};
85///
86/// let executor = deterministic::Runner::default();
87/// executor.start(|context| async move {
88///     // Create a new RwLock
89///     let lock = RwLock::new(2);
90///
91///     // many concurrent readers
92///     let r1 = lock.read().await;
93///     let r2 = lock.read().await;
94///     assert_eq!(*r1 + *r2, 4);
95///
96///     // exclusive writer
97///     drop((r1, r2));
98///     let mut w = lock.write().await;
99///     *w += 1;
100/// });
101/// ```
102pub struct RwLock<T>(async_lock::RwLock<T>);
103
104/// Shared guard returned by [RwLock::read].
105pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
106
107/// Exclusive guard returned by [RwLock::write].
108pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
109
110impl<T> RwLock<T> {
111    /// Create a new lock.
112    #[inline]
113    pub const fn new(value: T) -> Self {
114        Self(async_lock::RwLock::new(value))
115    }
116
117    /// Acquire a shared read guard.
118    #[inline]
119    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
120        self.0.read().await
121    }
122
123    /// Acquire an exclusive write guard.
124    #[inline]
125    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
126        self.0.write().await
127    }
128
129    /// Try to get a read guard without waiting.
130    #[inline]
131    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
132        self.0.try_read()
133    }
134
135    /// Try to get a write guard without waiting.
136    #[inline]
137    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
138        self.0.try_write()
139    }
140
141    /// Get mutable access without locking (requires `&mut self`).
142    #[inline]
143    pub fn get_mut(&mut self) -> &mut T {
144        self.0.get_mut()
145    }
146
147    /// Consume the lock, returning the inner value.
148    #[inline]
149    pub fn into_inner(self) -> T {
150        self.0.into_inner()
151    }
152}
153
154/// Synchronization primitive that enables a thread to block until a waker delivers a signal.
155pub struct Blocker {
156    /// Tracks whether a wake-up signal has been delivered (even if wait has not started yet).
157    state: Mutex<bool>,
158    /// Condvar used to park and resume the thread when the signal flips to true.
159    cv: Condvar,
160}
161
162impl Blocker {
163    /// Create a new [Blocker].
164    pub fn new() -> Arc<Self> {
165        Arc::new(Self {
166            state: Mutex::new(false),
167            cv: Condvar::new(),
168        })
169    }
170
171    /// Block the current thread until a waker delivers a signal.
172    pub fn wait(&self) {
173        // Use a loop to tolerate spurious wake-ups and only proceed once a real signal arrives.
174        let mut signaled = self.state.lock().unwrap();
175        while !*signaled {
176            signaled = self.cv.wait(signaled).unwrap();
177        }
178
179        // Reset the flag so subsequent waits park again until the next wake signal.
180        *signaled = false;
181    }
182}
183
184impl ArcWake for Blocker {
185    fn wake_by_ref(arc_self: &Arc<Self>) {
186        // Mark as signaled (and release lock before notifying).
187        {
188            let mut signaled = arc_self.state.lock().unwrap();
189            *signaled = true;
190        }
191
192        // Notify a single waiter so the blocked thread re-checks the flag.
193        arc_self.cv.notify_one();
194    }
195}
196
197#[cfg(any(test, feature = "test-utils"))]
198/// Count the number of running tasks whose name starts with the given prefix.
199///
200/// This function encodes metrics and counts tasks that are currently running
201/// (have a value of 1) and whose name starts with the specified prefix.
202///
203/// This is useful for verifying that all child tasks under a given label hierarchy
204/// have been properly shut down.
205///
206/// # Example
207///
208/// ```rust
209/// use commonware_runtime::{Clock, Metrics, Runner, Spawner, deterministic};
210/// use commonware_runtime::utils::count_running_tasks;
211/// use std::time::Duration;
212///
213/// let executor = deterministic::Runner::default();
214/// executor.start(|context| async move {
215///     // Spawn a task under a labeled context
216///     let handle = context.with_label("worker").spawn(|ctx| async move {
217///         ctx.sleep(Duration::from_secs(100)).await;
218///     });
219///
220///     // Allow the task to start
221///     context.sleep(Duration::from_millis(10)).await;
222///
223///     // Count running tasks with "worker" prefix
224///     let count = count_running_tasks(&context, "worker");
225///     assert!(count > 0, "worker task should be running");
226///
227///     // Abort the task
228///     handle.abort();
229///     let _ = handle.await;
230///     context.sleep(Duration::from_millis(10)).await;
231///
232///     // Verify task is stopped
233///     let count = count_running_tasks(&context, "worker");
234///     assert_eq!(count, 0, "worker task should be stopped");
235/// });
236/// ```
237pub fn count_running_tasks(metrics: &impl crate::Metrics, prefix: &str) -> usize {
238    let encoded = metrics.encode();
239    encoded
240        .lines()
241        .filter(|line| {
242            line.starts_with("runtime_tasks_running{")
243                && line.contains("kind=\"Task\"")
244                && line.trim_end().ends_with(" 1")
245                && line
246                    .split("name=\"")
247                    .nth(1)
248                    .is_some_and(|s| s.split('"').next().unwrap_or("").starts_with(prefix))
249        })
250        .count()
251}
252
253/// Validates that a label matches Prometheus metric name format: `[a-zA-Z][a-zA-Z0-9_]*`.
254///
255/// # Panics
256///
257/// Panics if the label is empty, starts with a non-alphabetic character,
258/// or contains characters other than `[a-zA-Z0-9_]`.
259pub fn validate_label(label: &str) {
260    let mut chars = label.chars();
261    assert!(
262        chars.next().is_some_and(|c| c.is_ascii_alphabetic()),
263        "label must start with [a-zA-Z]: {label}"
264    );
265    assert!(
266        chars.all(|c| c.is_ascii_alphanumeric() || c == '_'),
267        "label must only contain [a-zA-Z0-9_]: {label}"
268    );
269}
270
271#[cfg(test)]
272async fn task(i: usize) -> usize {
273    for _ in 0..5 {
274        reschedule().await;
275    }
276    i
277}
278
279#[cfg(test)]
280pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
281    runner.start(|context| async move {
282        // Randomly schedule tasks
283        let mut handles = FuturesUnordered::new();
284        for i in 0..=tasks - 1 {
285            handles.push(context.clone().spawn(move |_| task(i)));
286        }
287
288        // Collect output order
289        let mut outputs = Vec::new();
290        while let Some(result) = handles.next().await {
291            outputs.push(result.unwrap());
292        }
293        assert_eq!(outputs.len(), tasks);
294        (context.auditor().state(), outputs)
295    })
296}
297
298#[cfg(test)]
299mod tests {
300    use super::*;
301    use crate::deterministic;
302    use commonware_macros::test_traced;
303    use futures::task::waker;
304    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
305
306    #[test_traced]
307    fn test_rwlock() {
308        let executor = deterministic::Runner::default();
309        executor.start(|_| async move {
310            // Create a new RwLock
311            let lock = RwLock::new(100);
312
313            // many concurrent readers
314            let r1 = lock.read().await;
315            let r2 = lock.read().await;
316            assert_eq!(*r1 + *r2, 200);
317
318            // exclusive writer
319            drop((r1, r2)); // all readers must go away
320            let mut w = lock.write().await;
321            *w += 1;
322
323            // Check the value
324            assert_eq!(*w, 101);
325        });
326    }
327
328    #[test]
329    fn test_blocker_waits_until_wake() {
330        let blocker = Blocker::new();
331        let started = Arc::new(AtomicBool::new(false));
332        let completed = Arc::new(AtomicBool::new(false));
333
334        let thread_blocker = blocker.clone();
335        let thread_started = started.clone();
336        let thread_completed = completed.clone();
337        let handle = std::thread::spawn(move || {
338            thread_started.store(true, Ordering::SeqCst);
339            thread_blocker.wait();
340            thread_completed.store(true, Ordering::SeqCst);
341        });
342
343        while !started.load(Ordering::SeqCst) {
344            std::thread::yield_now();
345        }
346
347        assert!(!completed.load(Ordering::SeqCst));
348        waker(blocker).wake();
349        handle.join().unwrap();
350        assert!(completed.load(Ordering::SeqCst));
351    }
352
353    #[test]
354    fn test_blocker_handles_pre_wake() {
355        let blocker = Blocker::new();
356        waker(blocker.clone()).wake();
357
358        let completed = Arc::new(AtomicBool::new(false));
359        let thread_blocker = blocker;
360        let thread_completed = completed.clone();
361        std::thread::spawn(move || {
362            thread_blocker.wait();
363            thread_completed.store(true, Ordering::SeqCst);
364        })
365        .join()
366        .unwrap();
367
368        assert!(completed.load(Ordering::SeqCst));
369    }
370
371    #[test]
372    fn test_blocker_reusable_across_signals() {
373        let blocker = Blocker::new();
374        let completed = Arc::new(AtomicUsize::new(0));
375
376        let thread_blocker = blocker.clone();
377        let thread_completed = completed.clone();
378        let handle = std::thread::spawn(move || {
379            for _ in 0..2 {
380                thread_blocker.wait();
381                thread_completed.fetch_add(1, Ordering::SeqCst);
382            }
383        });
384
385        for expected in 1..=2 {
386            waker(blocker.clone()).wake();
387            while completed.load(Ordering::SeqCst) < expected {
388                std::thread::yield_now();
389            }
390        }
391
392        handle.join().unwrap();
393        assert_eq!(completed.load(Ordering::SeqCst), 2);
394    }
395
396    #[test_traced]
397    fn test_count_running_tasks() {
398        use crate::{Metrics, Runner, Spawner};
399        use futures::future;
400
401        let executor = deterministic::Runner::default();
402        executor.start(|context| async move {
403            // Initially no tasks with "worker" prefix
404            assert_eq!(
405                count_running_tasks(&context, "worker"),
406                0,
407                "no worker tasks initially"
408            );
409
410            // Spawn a task under a labeled context that stays running
411            let worker_ctx = context.with_label("worker");
412            let handle1 = worker_ctx.clone().spawn(|_| async move {
413                future::pending::<()>().await;
414            });
415
416            // Count running tasks with "worker" prefix
417            let count = count_running_tasks(&context, "worker");
418            assert_eq!(count, 1, "worker task should be running");
419
420            // Non-matching prefix should return 0
421            assert_eq!(
422                count_running_tasks(&context, "other"),
423                0,
424                "no tasks with 'other' prefix"
425            );
426
427            // Spawn a nested task (worker_child)
428            let handle2 = worker_ctx.with_label("child").spawn(|_| async move {
429                future::pending::<()>().await;
430            });
431
432            // Count should include both parent and nested tasks
433            let count = count_running_tasks(&context, "worker");
434            assert_eq!(count, 2, "both worker and worker_child should be counted");
435
436            // Abort parent task
437            handle1.abort();
438            let _ = handle1.await;
439
440            // Only nested task remains
441            let count = count_running_tasks(&context, "worker");
442            assert_eq!(count, 1, "only worker_child should remain");
443
444            // Abort nested task
445            handle2.abort();
446            let _ = handle2.await;
447
448            // All tasks stopped
449            assert_eq!(
450                count_running_tasks(&context, "worker"),
451                0,
452                "all worker tasks should be stopped"
453            );
454        });
455    }
456}