commonware_runtime/utils.rs
1//! Utility functions for interacting with any runtime.
2
3#[cfg(test)]
4use crate::Runner;
5use crate::{Error, Metrics, Spawner};
6#[cfg(test)]
7use futures::stream::{FuturesUnordered, StreamExt};
8use futures::{
9 channel::oneshot,
10 future::Shared,
11 stream::{AbortHandle, Abortable},
12 FutureExt,
13};
14use prometheus_client::metrics::gauge::Gauge;
15use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
16use std::{
17 any::Any,
18 future::Future,
19 panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
20 pin::Pin,
21 sync::{Arc, Once},
22 task::{Context, Poll},
23};
24use tracing::error;
25
26/// Yield control back to the runtime.
27pub async fn reschedule() {
28 struct Reschedule {
29 yielded: bool,
30 }
31
32 impl Future for Reschedule {
33 type Output = ();
34
35 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
36 if self.yielded {
37 Poll::Ready(())
38 } else {
39 self.yielded = true;
40 cx.waker().wake_by_ref();
41 Poll::Pending
42 }
43 }
44 }
45
46 Reschedule { yielded: false }.await
47}
48
49fn extract_panic_message(err: &(dyn Any + Send)) -> String {
50 if let Some(s) = err.downcast_ref::<&str>() {
51 s.to_string()
52 } else if let Some(s) = err.downcast_ref::<String>() {
53 s.clone()
54 } else {
55 format!("{:?}", err)
56 }
57}
58
59/// Handle to a spawned task.
60pub struct Handle<T>
61where
62 T: Send + 'static,
63{
64 aborter: Option<AbortHandle>,
65 receiver: oneshot::Receiver<Result<T, Error>>,
66
67 running: Gauge,
68 once: Arc<Once>,
69}
70
71impl<T> Handle<T>
72where
73 T: Send + 'static,
74{
75 pub(crate) fn init<F>(
76 f: F,
77 running: Gauge,
78 catch_panic: bool,
79 ) -> (impl Future<Output = ()>, Self)
80 where
81 F: Future<Output = T> + Send + 'static,
82 {
83 // Increment running counter
84 running.inc();
85
86 // Initialize channels to handle result/abort
87 let once = Arc::new(Once::new());
88 let (sender, receiver) = oneshot::channel();
89 let (aborter, abort_registration) = AbortHandle::new_pair();
90
91 // Wrap the future to handle panics
92 let wrapped = {
93 let once = once.clone();
94 let running = running.clone();
95 async move {
96 // Run future
97 let result = AssertUnwindSafe(f).catch_unwind().await;
98
99 // Decrement running counter
100 once.call_once(|| {
101 running.dec();
102 });
103
104 // Handle result
105 let result = match result {
106 Ok(result) => Ok(result),
107 Err(err) => {
108 if !catch_panic {
109 resume_unwind(err);
110 }
111 let err = extract_panic_message(&*err);
112 error!(?err, "task panicked");
113 Err(Error::Exited)
114 }
115 };
116 let _ = sender.send(result);
117 }
118 };
119
120 // Make the future abortable
121 let abortable = Abortable::new(wrapped, abort_registration);
122 (
123 abortable.map(|_| ()),
124 Self {
125 aborter: Some(aborter),
126 receiver,
127
128 running,
129 once,
130 },
131 )
132 }
133
134 pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
135 where
136 F: FnOnce() -> T + Send + 'static,
137 {
138 // Increment the running tasks gauge
139 running.inc();
140
141 // Initialize channel to handle result
142 let once = Arc::new(Once::new());
143 let (sender, receiver) = oneshot::channel();
144
145 // Wrap the closure with panic handling
146 let f = {
147 let once = once.clone();
148 let running = running.clone();
149 move || {
150 // Run blocking task
151 let result = catch_unwind(AssertUnwindSafe(f));
152
153 // Decrement running counter
154 once.call_once(|| {
155 running.dec();
156 });
157
158 // Handle result
159 let result = match result {
160 Ok(value) => Ok(value),
161 Err(err) => {
162 if !catch_panic {
163 resume_unwind(err);
164 }
165 let err = extract_panic_message(&*err);
166 error!(?err, "blocking task panicked");
167 Err(Error::Exited)
168 }
169 };
170 let _ = sender.send(result);
171 }
172 };
173
174 // Return the task and handle
175 (
176 f,
177 Self {
178 aborter: None,
179 receiver,
180
181 running,
182 once,
183 },
184 )
185 }
186
187 /// Abort the task (if not blocking).
188 pub fn abort(&self) {
189 // Get aborter and abort
190 let Some(aborter) = &self.aborter else {
191 return;
192 };
193 aborter.abort();
194
195 // Decrement running counter
196 self.once.call_once(|| {
197 self.running.dec();
198 });
199 }
200}
201
202impl<T> Future for Handle<T>
203where
204 T: Send + 'static,
205{
206 type Output = Result<T, Error>;
207
208 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
209 match Pin::new(&mut self.receiver).poll(cx) {
210 Poll::Ready(Ok(Ok(value))) => {
211 self.once.call_once(|| {
212 self.running.dec();
213 });
214 Poll::Ready(Ok(value))
215 }
216 Poll::Ready(Ok(Err(err))) => {
217 self.once.call_once(|| {
218 self.running.dec();
219 });
220 Poll::Ready(Err(err))
221 }
222 Poll::Ready(Err(_)) => {
223 self.once.call_once(|| {
224 self.running.dec();
225 });
226 Poll::Ready(Err(Error::Closed))
227 }
228 Poll::Pending => Poll::Pending,
229 }
230 }
231}
232
233/// A one-time broadcast that can be awaited by many tasks. It is often used for
234/// coordinating shutdown across many tasks.
235///
236/// To minimize the overhead of tracking outstanding signals (which only return once),
237/// it is recommended to wait on a reference to it (i.e. `&mut signal`) instead of
238/// cloning it multiple times in a given task (i.e. in each iteration of a loop).
239pub type Signal = Shared<oneshot::Receiver<i32>>;
240
241/// Coordinates a one-time signal across many tasks.
242///
243/// # Example
244///
245/// ## Basic Usage
246///
247/// ```rust
248/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic};
249///
250/// let executor = deterministic::Runner::default();
251/// executor.start(|context| async move {
252/// // Setup signaler and get future
253/// let (mut signaler, signal) = Signaler::new();
254///
255/// // Signal shutdown
256/// signaler.signal(2);
257///
258/// // Wait for shutdown in task
259/// let sig = signal.await.unwrap();
260/// println!("Received signal: {}", sig);
261/// });
262/// ```
263///
264/// ## Advanced Usage
265///
266/// While `Futures::Shared` is efficient, there is still meaningful overhead
267/// to cloning it (i.e. in each iteration of a loop). To avoid
268/// a performance regression from introducing `Signaler`, it is recommended
269/// to wait on a reference to `Signal` (i.e. `&mut signal`).
270///
271/// ```rust
272/// use commonware_macros::select;
273/// use commonware_runtime::{Clock, Spawner, Runner, Signaler, deterministic, Metrics};
274/// use futures::channel::oneshot;
275/// use std::time::Duration;
276///
277/// let executor = deterministic::Runner::default();
278/// executor.start(|context| async move {
279/// // Setup signaler and get future
280/// let (mut signaler, mut signal) = Signaler::new();
281///
282/// // Loop on the signal until resolved
283/// let (tx, rx) = oneshot::channel();
284/// context.with_label("waiter").spawn(|context| async move {
285/// loop {
286/// // Wait for signal or sleep
287/// select! {
288/// sig = &mut signal => {
289/// println!("Received signal: {}", sig.unwrap());
290/// break;
291/// },
292/// _ = context.sleep(Duration::from_secs(1)) => {},
293/// };
294/// }
295/// let _ = tx.send(());
296/// });
297///
298/// // Send signal
299/// signaler.signal(9);
300///
301/// // Wait for task
302/// rx.await.expect("shutdown signaled");
303/// });
304/// ```
305pub struct Signaler {
306 tx: Option<oneshot::Sender<i32>>,
307}
308
309impl Signaler {
310 /// Create a new `Signaler`.
311 ///
312 /// Returns a `Signaler` and a `Signal` that will resolve when `signal` is called.
313 pub fn new() -> (Self, Signal) {
314 let (tx, rx) = oneshot::channel();
315 (Self { tx: Some(tx) }, rx.shared())
316 }
317
318 /// Resolve the `Signal` for all waiters (if not already resolved).
319 pub fn signal(&mut self, value: i32) {
320 if let Some(stop_tx) = self.tx.take() {
321 let _ = stop_tx.send(value);
322 }
323 }
324}
325
326/// Creates a [rayon]-compatible thread pool with [Spawner::spawn_blocking].
327///
328/// # Arguments
329/// - `context`: The runtime context implementing the [Spawner] trait.
330/// - `concurrency`: The number of tasks to execute concurrently in the pool.
331///
332/// # Returns
333/// A `Result` containing the configured [rayon::ThreadPool] or a [rayon::ThreadPoolBuildError] if the pool cannot be built.
334pub fn create_pool<S: Spawner + Metrics>(
335 context: S,
336 concurrency: usize,
337) -> Result<ThreadPool, ThreadPoolBuildError> {
338 ThreadPoolBuilder::new()
339 .num_threads(concurrency)
340 .spawn_handler(move |thread| {
341 context
342 .with_label("rayon-thread")
343 .spawn_blocking(move || thread.run());
344 Ok(())
345 })
346 .build()
347}
348
349/// Async reader–writer lock.
350///
351/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
352/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
353///
354/// Usage:
355/// ```rust
356/// use commonware_runtime::{Spawner, Runner, Signaler, deterministic, RwLock};
357///
358/// let executor = deterministic::Runner::default();
359/// executor.start(|context| async move {
360/// // Create a new RwLock
361/// let lock = RwLock::new(2);
362///
363/// // many concurrent readers
364/// let r1 = lock.read().await;
365/// let r2 = lock.read().await;
366/// assert_eq!(*r1 + *r2, 4);
367///
368/// // exclusive writer
369/// drop((r1, r2));
370/// let mut w = lock.write().await;
371/// *w += 1;
372/// });
373/// ```
374pub struct RwLock<T>(async_lock::RwLock<T>);
375
376/// Shared guard returned by [`RwLock::read`].
377pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
378
379/// Exclusive guard returned by [`RwLock::write`].
380pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
381
382impl<T> RwLock<T> {
383 /// Create a new lock.
384 #[inline]
385 pub const fn new(value: T) -> Self {
386 Self(async_lock::RwLock::new(value))
387 }
388
389 /// Acquire a shared read guard.
390 #[inline]
391 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
392 self.0.read().await
393 }
394
395 /// Acquire an exclusive write guard.
396 #[inline]
397 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
398 self.0.write().await
399 }
400
401 /// Try to get a read guard without waiting.
402 #[inline]
403 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
404 self.0.try_read()
405 }
406
407 /// Try to get a write guard without waiting.
408 #[inline]
409 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
410 self.0.try_write()
411 }
412
413 /// Get mutable access without locking (requires `&mut self`).
414 #[inline]
415 pub fn get_mut(&mut self) -> &mut T {
416 self.0.get_mut()
417 }
418
419 /// Consume the lock, returning the inner value.
420 #[inline]
421 pub fn into_inner(self) -> T {
422 self.0.into_inner()
423 }
424}
425
426#[cfg(test)]
427async fn task(i: usize) -> usize {
428 for _ in 0..5 {
429 reschedule().await;
430 }
431 i
432}
433
434#[cfg(test)]
435pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
436 runner.start(|context| async move {
437 // Randomly schedule tasks
438 let mut handles = FuturesUnordered::new();
439 for i in 0..=tasks - 1 {
440 handles.push(context.clone().spawn(move |_| task(i)));
441 }
442
443 // Collect output order
444 let mut outputs = Vec::new();
445 while let Some(result) = handles.next().await {
446 outputs.push(result.unwrap());
447 }
448 assert_eq!(outputs.len(), tasks);
449 (context.auditor().state(), outputs)
450 })
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::{deterministic, tokio, Metrics};
457 use commonware_macros::test_traced;
458 use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
459
460 #[test_traced]
461 fn test_create_pool() {
462 let executor = tokio::Runner::default();
463 executor.start(|context| async move {
464 // Create a thread pool with 4 threads
465 let pool = create_pool(context.with_label("pool"), 4).unwrap();
466
467 // Create a vector of numbers
468 let v: Vec<_> = (0..10000).collect();
469
470 // Use the thread pool to sum the numbers
471 pool.install(|| {
472 assert_eq!(v.par_iter().sum::<i32>(), 10000 * 9999 / 2);
473 });
474 });
475 }
476
477 #[test_traced]
478 fn test_rwlock() {
479 let executor = deterministic::Runner::default();
480 executor.start(|_| async move {
481 // Create a new RwLock
482 let lock = RwLock::new(100);
483
484 // many concurrent readers
485 let r1 = lock.read().await;
486 let r2 = lock.read().await;
487 assert_eq!(*r1 + *r2, 200);
488
489 // exclusive writer
490 drop((r1, r2)); // all readers must go away
491 let mut w = lock.write().await;
492 *w += 1;
493
494 // Check the value
495 assert_eq!(*w, 101);
496 });
497 }
498}