Skip to main content

commonware_runtime/utils/
mod.rs

1//! Utility functions for interacting with any runtime.
2
3use futures::task::ArcWake;
4use std::{
5    any::Any,
6    collections::HashSet,
7    future::Future,
8    pin::Pin,
9    sync::{Arc, Condvar, Mutex},
10    task::{Context, Poll},
11};
12
13commonware_macros::stability_mod!(BETA, pub mod buffer);
14pub mod signal;
15
16mod handle;
17pub use handle::Handle;
18#[commonware_macros::stability(ALPHA)]
19pub(crate) use handle::Panicked;
20pub(crate) use handle::{Aborter, MetricHandle, Panicker};
21
22mod cell;
23pub use cell::Cell as ContextCell;
24
25pub(crate) mod supervision;
26
27/// The execution mode of a task.
28#[derive(Copy, Clone, Debug)]
29pub enum Execution {
30    /// Task runs on a dedicated thread.
31    Dedicated,
32    /// Task runs on the shared executor. `true` marks short blocking work that should
33    /// use the runtime's blocking-friendly pool.
34    Shared(bool),
35}
36
37impl Default for Execution {
38    fn default() -> Self {
39        Self::Shared(false)
40    }
41}
42
43/// Yield control back to the runtime.
44pub async fn reschedule() {
45    struct Reschedule {
46        yielded: bool,
47    }
48
49    impl Future for Reschedule {
50        type Output = ();
51
52        fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
53            if self.yielded {
54                Poll::Ready(())
55            } else {
56                self.yielded = true;
57                cx.waker().wake_by_ref();
58                Poll::Pending
59            }
60        }
61    }
62
63    Reschedule { yielded: false }.await
64}
65
66fn extract_panic_message(err: &(dyn Any + Send)) -> String {
67    err.downcast_ref::<&str>().map_or_else(
68        || {
69            err.downcast_ref::<String>()
70                .map_or_else(|| format!("{err:?}"), |s| s.clone())
71        },
72        |s| s.to_string(),
73    )
74}
75
76/// Async reader–writer lock.
77///
78/// Powered by [async_lock::RwLock], `RwLock` provides both fair writer acquisition
79/// and `try_read` / `try_write` without waiting (without any runtime-specific dependencies).
80///
81/// Usage:
82/// ```rust
83/// use commonware_runtime::{Spawner, Runner, deterministic, RwLock};
84///
85/// let executor = deterministic::Runner::default();
86/// executor.start(|context| async move {
87///     // Create a new RwLock
88///     let lock = RwLock::new(2);
89///
90///     // many concurrent readers
91///     let r1 = lock.read().await;
92///     let r2 = lock.read().await;
93///     assert_eq!(*r1 + *r2, 4);
94///
95///     // exclusive writer
96///     drop((r1, r2));
97///     let mut w = lock.write().await;
98///     *w += 1;
99/// });
100/// ```
101pub struct RwLock<T>(async_lock::RwLock<T>);
102
103/// Shared guard returned by [RwLock::read].
104pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
105
106/// Exclusive guard returned by [RwLock::write].
107pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
108
109impl<T> RwLock<T> {
110    /// Create a new lock.
111    #[inline]
112    pub const fn new(value: T) -> Self {
113        Self(async_lock::RwLock::new(value))
114    }
115
116    /// Acquire a shared read guard.
117    #[inline]
118    pub async fn read(&self) -> RwLockReadGuard<'_, T> {
119        self.0.read().await
120    }
121
122    /// Acquire an exclusive write guard.
123    #[inline]
124    pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
125        self.0.write().await
126    }
127
128    /// Try to get a read guard without waiting.
129    #[inline]
130    pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
131        self.0.try_read()
132    }
133
134    /// Try to get a write guard without waiting.
135    #[inline]
136    pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
137        self.0.try_write()
138    }
139
140    /// Get mutable access without locking (requires `&mut self`).
141    #[inline]
142    pub fn get_mut(&mut self) -> &mut T {
143        self.0.get_mut()
144    }
145
146    /// Consume the lock, returning the inner value.
147    #[inline]
148    pub fn into_inner(self) -> T {
149        self.0.into_inner()
150    }
151}
152
153/// Synchronization primitive that enables a thread to block until a waker delivers a signal.
154pub struct Blocker {
155    /// Tracks whether a wake-up signal has been delivered (even if wait has not started yet).
156    state: Mutex<bool>,
157    /// Condvar used to park and resume the thread when the signal flips to true.
158    cv: Condvar,
159}
160
161impl Blocker {
162    /// Create a new [Blocker].
163    pub fn new() -> Arc<Self> {
164        Arc::new(Self {
165            state: Mutex::new(false),
166            cv: Condvar::new(),
167        })
168    }
169
170    /// Block the current thread until a waker delivers a signal.
171    pub fn wait(&self) {
172        // Use a loop to tolerate spurious wake-ups and only proceed once a real signal arrives.
173        let mut signaled = self.state.lock().unwrap();
174        while !*signaled {
175            signaled = self.cv.wait(signaled).unwrap();
176        }
177
178        // Reset the flag so subsequent waits park again until the next wake signal.
179        *signaled = false;
180    }
181}
182
183impl ArcWake for Blocker {
184    fn wake_by_ref(arc_self: &Arc<Self>) {
185        // Mark as signaled (and release lock before notifying).
186        {
187            let mut signaled = arc_self.state.lock().unwrap();
188            *signaled = true;
189        }
190
191        // Notify a single waiter so the blocked thread re-checks the flag.
192        arc_self.cv.notify_one();
193    }
194}
195
196#[cfg(any(test, feature = "test-utils"))]
197/// Count the number of running tasks whose name starts with the given prefix.
198///
199/// This function encodes metrics and counts tasks that are currently running
200/// (have a value of 1) and whose name starts with the specified prefix.
201///
202/// This is useful for verifying that all child tasks under a given label hierarchy
203/// have been properly shut down.
204///
205/// # Example
206///
207/// ```rust
208/// use commonware_runtime::{Clock, Metrics, Runner, Spawner, deterministic};
209/// use commonware_runtime::utils::count_running_tasks;
210/// use std::time::Duration;
211///
212/// let executor = deterministic::Runner::default();
213/// executor.start(|context| async move {
214///     // Spawn a task under a labeled context
215///     let handle = context.with_label("worker").spawn(|ctx| async move {
216///         ctx.sleep(Duration::from_secs(100)).await;
217///     });
218///
219///     // Allow the task to start
220///     context.sleep(Duration::from_millis(10)).await;
221///
222///     // Count running tasks with "worker" prefix
223///     let count = count_running_tasks(&context, "worker");
224///     assert!(count > 0, "worker task should be running");
225///
226///     // Abort the task
227///     handle.abort();
228///     let _ = handle.await;
229///     context.sleep(Duration::from_millis(10)).await;
230///
231///     // Verify task is stopped
232///     let count = count_running_tasks(&context, "worker");
233///     assert_eq!(count, 0, "worker task should be stopped");
234/// });
235/// ```
236pub fn count_running_tasks(metrics: &impl crate::Metrics, prefix: &str) -> usize {
237    let encoded = metrics.encode();
238    encoded
239        .lines()
240        .filter(|line| {
241            line.starts_with("runtime_tasks_running{")
242                && line.contains("kind=\"Task\"")
243                && line.trim_end().ends_with(" 1")
244                && line
245                    .split("name=\"")
246                    .nth(1)
247                    .is_some_and(|s| s.split('"').next().unwrap_or("").starts_with(prefix))
248        })
249        .count()
250}
251
252/// Validates that a label matches Prometheus metric name format: `[a-zA-Z][a-zA-Z0-9_]*`.
253///
254/// # Panics
255///
256/// Panics if the label is empty, starts with a non-alphabetic character,
257/// or contains characters other than `[a-zA-Z0-9_]`.
258pub fn validate_label(label: &str) {
259    let mut chars = label.chars();
260    assert!(
261        chars.next().is_some_and(|c| c.is_ascii_alphabetic()),
262        "label must start with [a-zA-Z]: {label}"
263    );
264    assert!(
265        chars.all(|c| c.is_ascii_alphanumeric() || c == '_'),
266        "label must only contain [a-zA-Z0-9_]: {label}"
267    );
268}
269
270/// Add an attribute to a sorted attribute list, maintaining sorted order via binary search.
271///
272/// Returns `true` if the key was new, `false` if it was a duplicate (value overwritten).
273pub fn add_attribute(
274    attributes: &mut Vec<(String, String)>,
275    key: &str,
276    value: impl std::fmt::Display,
277) -> bool {
278    let key_string = key.to_string();
279    let value_string = value.to_string();
280
281    match attributes.binary_search_by(|(k, _)| k.cmp(&key_string)) {
282        Ok(pos) => {
283            attributes[pos].1 = value_string;
284            false
285        }
286        Err(pos) => {
287            attributes.insert(pos, (key_string, value_string));
288            true
289        }
290    }
291}
292
293/// A writer that deduplicates HELP and TYPE metadata lines during Prometheus encoding.
294///
295/// When the same metric is registered multiple times with different attribute values
296/// (via `sub_registry_with_label`), prometheus_client outputs duplicate HELP/TYPE
297/// lines. This writer filters them in a single pass to produce canonical Prometheus format.
298///
299/// Uses "first wins" semantics: keeps the first HELP/TYPE description encountered
300/// for each metric name and discards subsequent duplicates.
301pub struct MetricEncoder {
302    output: String,
303    line_buffer: String,
304    seen_help: HashSet<String>,
305    seen_type: HashSet<String>,
306}
307
308impl MetricEncoder {
309    pub fn new() -> Self {
310        Self {
311            output: String::new(),
312            line_buffer: String::new(),
313            seen_help: HashSet::new(),
314            seen_type: HashSet::new(),
315        }
316    }
317
318    pub fn into_string(mut self) -> String {
319        if !self.line_buffer.is_empty() {
320            self.flush_line();
321        }
322        self.output
323    }
324
325    fn flush_line(&mut self) {
326        let line = &self.line_buffer;
327        let should_write = if let Some(rest) = line.strip_prefix("# HELP ") {
328            let metric_name = rest.split_whitespace().next().unwrap_or("");
329            self.seen_help.insert(metric_name.to_string())
330        } else if let Some(rest) = line.strip_prefix("# TYPE ") {
331            let metric_name = rest.split_whitespace().next().unwrap_or("");
332            self.seen_type.insert(metric_name.to_string())
333        } else {
334            true
335        };
336        if should_write {
337            self.output.push_str(line);
338            self.output.push('\n');
339        }
340        self.line_buffer.clear();
341    }
342}
343
344impl Default for MetricEncoder {
345    fn default() -> Self {
346        Self::new()
347    }
348}
349
350impl std::fmt::Write for MetricEncoder {
351    fn write_str(&mut self, s: &str) -> std::fmt::Result {
352        let mut remaining = s;
353        while let Some(pos) = remaining.find('\n') {
354            self.line_buffer.push_str(&remaining[..pos]);
355            self.flush_line();
356            remaining = &remaining[pos + 1..];
357        }
358        self.line_buffer.push_str(remaining);
359        Ok(())
360    }
361}
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366    use crate::{deterministic, Metrics, Runner};
367    use commonware_macros::test_traced;
368    use futures::task::waker;
369    use prometheus_client::metrics::counter::Counter;
370    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
371
372    fn encode_dedup(input: &str) -> String {
373        use std::fmt::Write;
374        let mut encoder = MetricEncoder::new();
375        encoder.write_str(input).unwrap();
376        encoder.into_string()
377    }
378
379    #[test]
380    fn test_metric_encoder_empty() {
381        assert_eq!(encode_dedup(""), "");
382        assert_eq!(encode_dedup("# EOF\n"), "# EOF\n");
383    }
384
385    #[test]
386    fn test_metric_encoder_no_duplicates() {
387        let input = r#"# HELP foo_total A counter.
388# TYPE foo_total counter
389foo_total 1
390# HELP bar_gauge A gauge.
391# TYPE bar_gauge gauge
392bar_gauge 42
393# EOF
394"#;
395        let output = encode_dedup(input);
396        assert_eq!(output, input);
397    }
398
399    #[test]
400    fn test_metric_encoder_with_duplicates() {
401        let input = r#"# HELP votes_total vote count.
402# TYPE votes_total counter
403votes_total{epoch="e5"} 1
404# HELP votes_total vote count.
405# TYPE votes_total counter
406votes_total{epoch="e6"} 2
407# EOF
408"#;
409        let expected = r#"# HELP votes_total vote count.
410# TYPE votes_total counter
411votes_total{epoch="e5"} 1
412votes_total{epoch="e6"} 2
413# EOF
414"#;
415        let output = encode_dedup(input);
416        assert_eq!(output, expected);
417    }
418
419    #[test]
420    fn test_metric_encoder_multiple_metrics() {
421        let input = r#"# HELP a_total First.
422# TYPE a_total counter
423a_total{tag="x"} 1
424# HELP b_total Second.
425# TYPE b_total counter
426b_total 5
427# HELP a_total First.
428# TYPE a_total counter
429a_total{tag="y"} 2
430# EOF
431"#;
432        let expected = r#"# HELP a_total First.
433# TYPE a_total counter
434a_total{tag="x"} 1
435# HELP b_total Second.
436# TYPE b_total counter
437b_total 5
438a_total{tag="y"} 2
439# EOF
440"#;
441        let output = encode_dedup(input);
442        assert_eq!(output, expected);
443    }
444
445    #[test]
446    fn test_metric_encoder_preserves_order() {
447        let input = r#"# HELP z First alphabetically last.
448# TYPE z counter
449z_total 1
450# HELP a Last alphabetically first.
451# TYPE a counter
452a_total 2
453# EOF
454"#;
455        let output = encode_dedup(input);
456        assert_eq!(output, input);
457    }
458
459    #[test_traced]
460    fn test_rwlock() {
461        let executor = deterministic::Runner::default();
462        executor.start(|_| async move {
463            // Create a new RwLock
464            let lock = RwLock::new(100);
465
466            // many concurrent readers
467            let r1 = lock.read().await;
468            let r2 = lock.read().await;
469            assert_eq!(*r1 + *r2, 200);
470
471            // exclusive writer
472            drop((r1, r2)); // all readers must go away
473            let mut w = lock.write().await;
474            *w += 1;
475
476            // Check the value
477            assert_eq!(*w, 101);
478        });
479    }
480
481    #[test]
482    fn test_blocker_waits_until_wake() {
483        let blocker = Blocker::new();
484        let started = Arc::new(AtomicBool::new(false));
485        let completed = Arc::new(AtomicBool::new(false));
486
487        let thread_blocker = blocker.clone();
488        let thread_started = started.clone();
489        let thread_completed = completed.clone();
490        let handle = std::thread::spawn(move || {
491            thread_started.store(true, Ordering::SeqCst);
492            thread_blocker.wait();
493            thread_completed.store(true, Ordering::SeqCst);
494        });
495
496        while !started.load(Ordering::SeqCst) {
497            std::thread::yield_now();
498        }
499
500        assert!(!completed.load(Ordering::SeqCst));
501        waker(blocker).wake();
502        handle.join().unwrap();
503        assert!(completed.load(Ordering::SeqCst));
504    }
505
506    #[test]
507    fn test_blocker_handles_pre_wake() {
508        let blocker = Blocker::new();
509        waker(blocker.clone()).wake();
510
511        let completed = Arc::new(AtomicBool::new(false));
512        let thread_blocker = blocker;
513        let thread_completed = completed.clone();
514        std::thread::spawn(move || {
515            thread_blocker.wait();
516            thread_completed.store(true, Ordering::SeqCst);
517        })
518        .join()
519        .unwrap();
520
521        assert!(completed.load(Ordering::SeqCst));
522    }
523
524    #[test]
525    fn test_blocker_reusable_across_signals() {
526        let blocker = Blocker::new();
527        let completed = Arc::new(AtomicUsize::new(0));
528
529        let thread_blocker = blocker.clone();
530        let thread_completed = completed.clone();
531        let handle = std::thread::spawn(move || {
532            for _ in 0..2 {
533                thread_blocker.wait();
534                thread_completed.fetch_add(1, Ordering::SeqCst);
535            }
536        });
537
538        for expected in 1..=2 {
539            waker(blocker.clone()).wake();
540            while completed.load(Ordering::SeqCst) < expected {
541                std::thread::yield_now();
542            }
543        }
544
545        handle.join().unwrap();
546        assert_eq!(completed.load(Ordering::SeqCst), 2);
547    }
548
549    #[test_traced]
550    fn test_count_running_tasks() {
551        use crate::{Metrics, Runner, Spawner};
552        use futures::future;
553
554        let executor = deterministic::Runner::default();
555        executor.start(|context| async move {
556            // Initially no tasks with "worker" prefix
557            assert_eq!(
558                count_running_tasks(&context, "worker"),
559                0,
560                "no worker tasks initially"
561            );
562
563            // Spawn a task under a labeled context that stays running
564            let worker_ctx = context.with_label("worker");
565            let handle1 = worker_ctx.clone().spawn(|_| async move {
566                future::pending::<()>().await;
567            });
568
569            // Count running tasks with "worker" prefix
570            let count = count_running_tasks(&context, "worker");
571            assert_eq!(count, 1, "worker task should be running");
572
573            // Non-matching prefix should return 0
574            assert_eq!(
575                count_running_tasks(&context, "other"),
576                0,
577                "no tasks with 'other' prefix"
578            );
579
580            // Spawn a nested task (worker_child)
581            let handle2 = worker_ctx.with_label("child").spawn(|_| async move {
582                future::pending::<()>().await;
583            });
584
585            // Count should include both parent and nested tasks
586            let count = count_running_tasks(&context, "worker");
587            assert_eq!(count, 2, "both worker and worker_child should be counted");
588
589            // Abort parent task
590            handle1.abort();
591            let _ = handle1.await;
592
593            // Only nested task remains
594            let count = count_running_tasks(&context, "worker");
595            assert_eq!(count, 1, "only worker_child should remain");
596
597            // Abort nested task
598            handle2.abort();
599            let _ = handle2.await;
600
601            // All tasks stopped
602            assert_eq!(
603                count_running_tasks(&context, "worker"),
604                0,
605                "all worker tasks should be stopped"
606            );
607        });
608    }
609
610    #[test_traced]
611    fn test_no_duplicate_metrics() {
612        let executor = deterministic::Runner::default();
613        executor.start(|context| async move {
614            // Register metrics under different labels (no duplicates)
615            let c1 = Counter::<u64>::default();
616            context.with_label("a").register("test", "help", c1);
617            let c2 = Counter::<u64>::default();
618            context.with_label("b").register("test", "help", c2);
619        });
620        // Test passes if runtime doesn't panic on shutdown
621    }
622
623    #[test]
624    #[should_panic(expected = "duplicate metric:")]
625    fn test_duplicate_metrics_panics() {
626        let executor = deterministic::Runner::default();
627        executor.start(|context| async move {
628            // Register metrics with the same label, causing duplicates
629            let c1 = Counter::<u64>::default();
630            context.with_label("a").register("test", "help", c1);
631            let c2 = Counter::<u64>::default();
632            context.with_label("a").register("test", "help", c2);
633        });
634    }
635}