commonware_runtime/utils/
mod.rs1#[cfg(test)]
4use crate::{Runner, Spawner};
5#[cfg(test)]
6use futures::stream::{FuturesUnordered, StreamExt};
7use futures::task::ArcWake;
8use std::{
9 any::Any,
10 future::Future,
11 pin::Pin,
12 sync::{Arc, Condvar, Mutex},
13 task::{Context, Poll},
14};
15
16pub mod buffer;
17pub mod signal;
18
19mod handle;
20pub use handle::Handle;
21pub(crate) use handle::{Aborter, MetricHandle, Panicked, Panicker};
22
23mod cell;
24pub use cell::Cell as ContextCell;
25
26pub(crate) mod supervision;
27
28#[derive(Copy, Clone, Debug)]
30pub enum Execution {
31 Dedicated,
33 Shared(bool),
36}
37
38impl Default for Execution {
39 fn default() -> Self {
40 Self::Shared(false)
41 }
42}
43
44pub async fn reschedule() {
46 struct Reschedule {
47 yielded: bool,
48 }
49
50 impl Future for Reschedule {
51 type Output = ();
52
53 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
54 if self.yielded {
55 Poll::Ready(())
56 } else {
57 self.yielded = true;
58 cx.waker().wake_by_ref();
59 Poll::Pending
60 }
61 }
62 }
63
64 Reschedule { yielded: false }.await
65}
66
67fn extract_panic_message(err: &(dyn Any + Send)) -> String {
68 err.downcast_ref::<&str>().map_or_else(
69 || {
70 err.downcast_ref::<String>()
71 .map_or_else(|| format!("{err:?}"), |s| s.clone())
72 },
73 |s| s.to_string(),
74 )
75}
76
77pub struct RwLock<T>(async_lock::RwLock<T>);
103
104pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
106
107pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
109
110impl<T> RwLock<T> {
111 #[inline]
113 pub const fn new(value: T) -> Self {
114 Self(async_lock::RwLock::new(value))
115 }
116
117 #[inline]
119 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
120 self.0.read().await
121 }
122
123 #[inline]
125 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
126 self.0.write().await
127 }
128
129 #[inline]
131 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
132 self.0.try_read()
133 }
134
135 #[inline]
137 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
138 self.0.try_write()
139 }
140
141 #[inline]
143 pub fn get_mut(&mut self) -> &mut T {
144 self.0.get_mut()
145 }
146
147 #[inline]
149 pub fn into_inner(self) -> T {
150 self.0.into_inner()
151 }
152}
153
154pub struct Blocker {
156 state: Mutex<bool>,
158 cv: Condvar,
160}
161
162impl Blocker {
163 pub fn new() -> Arc<Self> {
165 Arc::new(Self {
166 state: Mutex::new(false),
167 cv: Condvar::new(),
168 })
169 }
170
171 pub fn wait(&self) {
173 let mut signaled = self.state.lock().unwrap();
175 while !*signaled {
176 signaled = self.cv.wait(signaled).unwrap();
177 }
178
179 *signaled = false;
181 }
182}
183
184impl ArcWake for Blocker {
185 fn wake_by_ref(arc_self: &Arc<Self>) {
186 {
188 let mut signaled = arc_self.state.lock().unwrap();
189 *signaled = true;
190 }
191
192 arc_self.cv.notify_one();
194 }
195}
196
197#[cfg(any(test, feature = "test-utils"))]
198pub fn count_running_tasks(metrics: &impl crate::Metrics, prefix: &str) -> usize {
238 let encoded = metrics.encode();
239 encoded
240 .lines()
241 .filter(|line| {
242 line.starts_with("runtime_tasks_running{")
243 && line.contains("kind=\"Task\"")
244 && line.trim_end().ends_with(" 1")
245 && line
246 .split("name=\"")
247 .nth(1)
248 .is_some_and(|s| s.split('"').next().unwrap_or("").starts_with(prefix))
249 })
250 .count()
251}
252
253pub fn validate_label(label: &str) {
260 let mut chars = label.chars();
261 assert!(
262 chars.next().is_some_and(|c| c.is_ascii_alphabetic()),
263 "label must start with [a-zA-Z]: {label}"
264 );
265 assert!(
266 chars.all(|c| c.is_ascii_alphanumeric() || c == '_'),
267 "label must only contain [a-zA-Z0-9_]: {label}"
268 );
269}
270
271#[cfg(test)]
272async fn task(i: usize) -> usize {
273 for _ in 0..5 {
274 reschedule().await;
275 }
276 i
277}
278
279#[cfg(test)]
280pub fn run_tasks(tasks: usize, runner: crate::deterministic::Runner) -> (String, Vec<usize>) {
281 runner.start(|context| async move {
282 let mut handles = FuturesUnordered::new();
284 for i in 0..=tasks - 1 {
285 handles.push(context.clone().spawn(move |_| task(i)));
286 }
287
288 let mut outputs = Vec::new();
290 while let Some(result) = handles.next().await {
291 outputs.push(result.unwrap());
292 }
293 assert_eq!(outputs.len(), tasks);
294 (context.auditor().state(), outputs)
295 })
296}
297
298#[cfg(test)]
299mod tests {
300 use super::*;
301 use crate::deterministic;
302 use commonware_macros::test_traced;
303 use futures::task::waker;
304 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
305
306 #[test_traced]
307 fn test_rwlock() {
308 let executor = deterministic::Runner::default();
309 executor.start(|_| async move {
310 let lock = RwLock::new(100);
312
313 let r1 = lock.read().await;
315 let r2 = lock.read().await;
316 assert_eq!(*r1 + *r2, 200);
317
318 drop((r1, r2)); let mut w = lock.write().await;
321 *w += 1;
322
323 assert_eq!(*w, 101);
325 });
326 }
327
328 #[test]
329 fn test_blocker_waits_until_wake() {
330 let blocker = Blocker::new();
331 let started = Arc::new(AtomicBool::new(false));
332 let completed = Arc::new(AtomicBool::new(false));
333
334 let thread_blocker = blocker.clone();
335 let thread_started = started.clone();
336 let thread_completed = completed.clone();
337 let handle = std::thread::spawn(move || {
338 thread_started.store(true, Ordering::SeqCst);
339 thread_blocker.wait();
340 thread_completed.store(true, Ordering::SeqCst);
341 });
342
343 while !started.load(Ordering::SeqCst) {
344 std::thread::yield_now();
345 }
346
347 assert!(!completed.load(Ordering::SeqCst));
348 waker(blocker).wake();
349 handle.join().unwrap();
350 assert!(completed.load(Ordering::SeqCst));
351 }
352
353 #[test]
354 fn test_blocker_handles_pre_wake() {
355 let blocker = Blocker::new();
356 waker(blocker.clone()).wake();
357
358 let completed = Arc::new(AtomicBool::new(false));
359 let thread_blocker = blocker;
360 let thread_completed = completed.clone();
361 std::thread::spawn(move || {
362 thread_blocker.wait();
363 thread_completed.store(true, Ordering::SeqCst);
364 })
365 .join()
366 .unwrap();
367
368 assert!(completed.load(Ordering::SeqCst));
369 }
370
371 #[test]
372 fn test_blocker_reusable_across_signals() {
373 let blocker = Blocker::new();
374 let completed = Arc::new(AtomicUsize::new(0));
375
376 let thread_blocker = blocker.clone();
377 let thread_completed = completed.clone();
378 let handle = std::thread::spawn(move || {
379 for _ in 0..2 {
380 thread_blocker.wait();
381 thread_completed.fetch_add(1, Ordering::SeqCst);
382 }
383 });
384
385 for expected in 1..=2 {
386 waker(blocker.clone()).wake();
387 while completed.load(Ordering::SeqCst) < expected {
388 std::thread::yield_now();
389 }
390 }
391
392 handle.join().unwrap();
393 assert_eq!(completed.load(Ordering::SeqCst), 2);
394 }
395
396 #[test_traced]
397 fn test_count_running_tasks() {
398 use crate::{Metrics, Runner, Spawner};
399 use futures::future;
400
401 let executor = deterministic::Runner::default();
402 executor.start(|context| async move {
403 assert_eq!(
405 count_running_tasks(&context, "worker"),
406 0,
407 "no worker tasks initially"
408 );
409
410 let worker_ctx = context.with_label("worker");
412 let handle1 = worker_ctx.clone().spawn(|_| async move {
413 future::pending::<()>().await;
414 });
415
416 let count = count_running_tasks(&context, "worker");
418 assert_eq!(count, 1, "worker task should be running");
419
420 assert_eq!(
422 count_running_tasks(&context, "other"),
423 0,
424 "no tasks with 'other' prefix"
425 );
426
427 let handle2 = worker_ctx.with_label("child").spawn(|_| async move {
429 future::pending::<()>().await;
430 });
431
432 let count = count_running_tasks(&context, "worker");
434 assert_eq!(count, 2, "both worker and worker_child should be counted");
435
436 handle1.abort();
438 let _ = handle1.await;
439
440 let count = count_running_tasks(&context, "worker");
442 assert_eq!(count, 1, "only worker_child should remain");
443
444 handle2.abort();
446 let _ = handle2.await;
447
448 assert_eq!(
450 count_running_tasks(&context, "worker"),
451 0,
452 "all worker tasks should be stopped"
453 );
454 });
455 }
456}