commonware_runtime/utils/
mod.rs1use futures::task::ArcWake;
4use std::{
5 any::Any,
6 collections::HashSet,
7 future::Future,
8 pin::Pin,
9 sync::{Arc, Condvar, Mutex},
10 task::{Context, Poll},
11};
12
13commonware_macros::stability_mod!(BETA, pub mod buffer);
14pub mod signal;
15
16mod handle;
17pub use handle::Handle;
18#[commonware_macros::stability(ALPHA)]
19pub(crate) use handle::Panicked;
20pub(crate) use handle::{Aborter, MetricHandle, Panicker};
21
22mod cell;
23pub use cell::Cell as ContextCell;
24
25pub(crate) mod supervision;
26
27#[derive(Copy, Clone, Debug)]
29pub enum Execution {
30 Dedicated,
32 Shared(bool),
35}
36
37impl Default for Execution {
38 fn default() -> Self {
39 Self::Shared(false)
40 }
41}
42
43pub async fn reschedule() {
45 struct Reschedule {
46 yielded: bool,
47 }
48
49 impl Future for Reschedule {
50 type Output = ();
51
52 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
53 if self.yielded {
54 Poll::Ready(())
55 } else {
56 self.yielded = true;
57 cx.waker().wake_by_ref();
58 Poll::Pending
59 }
60 }
61 }
62
63 Reschedule { yielded: false }.await
64}
65
66fn extract_panic_message(err: &(dyn Any + Send)) -> String {
67 err.downcast_ref::<&str>().map_or_else(
68 || {
69 err.downcast_ref::<String>()
70 .map_or_else(|| format!("{err:?}"), |s| s.clone())
71 },
72 |s| s.to_string(),
73 )
74}
75
76pub struct RwLock<T>(async_lock::RwLock<T>);
102
103pub type RwLockReadGuard<'a, T> = async_lock::RwLockReadGuard<'a, T>;
105
106pub type RwLockWriteGuard<'a, T> = async_lock::RwLockWriteGuard<'a, T>;
108
109impl<T> RwLock<T> {
110 #[inline]
112 pub const fn new(value: T) -> Self {
113 Self(async_lock::RwLock::new(value))
114 }
115
116 #[inline]
118 pub async fn read(&self) -> RwLockReadGuard<'_, T> {
119 self.0.read().await
120 }
121
122 #[inline]
124 pub async fn write(&self) -> RwLockWriteGuard<'_, T> {
125 self.0.write().await
126 }
127
128 #[inline]
130 pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
131 self.0.try_read()
132 }
133
134 #[inline]
136 pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
137 self.0.try_write()
138 }
139
140 #[inline]
142 pub fn get_mut(&mut self) -> &mut T {
143 self.0.get_mut()
144 }
145
146 #[inline]
148 pub fn into_inner(self) -> T {
149 self.0.into_inner()
150 }
151}
152
153pub struct Blocker {
155 state: Mutex<bool>,
157 cv: Condvar,
159}
160
161impl Blocker {
162 pub fn new() -> Arc<Self> {
164 Arc::new(Self {
165 state: Mutex::new(false),
166 cv: Condvar::new(),
167 })
168 }
169
170 pub fn wait(&self) {
172 let mut signaled = self.state.lock().unwrap();
174 while !*signaled {
175 signaled = self.cv.wait(signaled).unwrap();
176 }
177
178 *signaled = false;
180 }
181}
182
183impl ArcWake for Blocker {
184 fn wake_by_ref(arc_self: &Arc<Self>) {
185 {
187 let mut signaled = arc_self.state.lock().unwrap();
188 *signaled = true;
189 }
190
191 arc_self.cv.notify_one();
193 }
194}
195
196#[cfg(any(test, feature = "test-utils"))]
197pub fn count_running_tasks(metrics: &impl crate::Metrics, prefix: &str) -> usize {
237 let encoded = metrics.encode();
238 encoded
239 .lines()
240 .filter(|line| {
241 line.starts_with("runtime_tasks_running{")
242 && line.contains("kind=\"Task\"")
243 && line.trim_end().ends_with(" 1")
244 && line
245 .split("name=\"")
246 .nth(1)
247 .is_some_and(|s| s.split('"').next().unwrap_or("").starts_with(prefix))
248 })
249 .count()
250}
251
252pub fn validate_label(label: &str) {
259 let mut chars = label.chars();
260 assert!(
261 chars.next().is_some_and(|c| c.is_ascii_alphabetic()),
262 "label must start with [a-zA-Z]: {label}"
263 );
264 assert!(
265 chars.all(|c| c.is_ascii_alphanumeric() || c == '_'),
266 "label must only contain [a-zA-Z0-9_]: {label}"
267 );
268}
269
270pub fn add_attribute(
274 attributes: &mut Vec<(String, String)>,
275 key: &str,
276 value: impl std::fmt::Display,
277) -> bool {
278 let key_string = key.to_string();
279 let value_string = value.to_string();
280
281 match attributes.binary_search_by(|(k, _)| k.cmp(&key_string)) {
282 Ok(pos) => {
283 attributes[pos].1 = value_string;
284 false
285 }
286 Err(pos) => {
287 attributes.insert(pos, (key_string, value_string));
288 true
289 }
290 }
291}
292
293pub struct MetricEncoder {
302 output: String,
303 line_buffer: String,
304 seen_help: HashSet<String>,
305 seen_type: HashSet<String>,
306}
307
308impl MetricEncoder {
309 pub fn new() -> Self {
310 Self {
311 output: String::new(),
312 line_buffer: String::new(),
313 seen_help: HashSet::new(),
314 seen_type: HashSet::new(),
315 }
316 }
317
318 pub fn into_string(mut self) -> String {
319 if !self.line_buffer.is_empty() {
320 self.flush_line();
321 }
322 self.output
323 }
324
325 fn flush_line(&mut self) {
326 let line = &self.line_buffer;
327 let should_write = if let Some(rest) = line.strip_prefix("# HELP ") {
328 let metric_name = rest.split_whitespace().next().unwrap_or("");
329 self.seen_help.insert(metric_name.to_string())
330 } else if let Some(rest) = line.strip_prefix("# TYPE ") {
331 let metric_name = rest.split_whitespace().next().unwrap_or("");
332 self.seen_type.insert(metric_name.to_string())
333 } else {
334 true
335 };
336 if should_write {
337 self.output.push_str(line);
338 self.output.push('\n');
339 }
340 self.line_buffer.clear();
341 }
342}
343
344impl Default for MetricEncoder {
345 fn default() -> Self {
346 Self::new()
347 }
348}
349
350impl std::fmt::Write for MetricEncoder {
351 fn write_str(&mut self, s: &str) -> std::fmt::Result {
352 let mut remaining = s;
353 while let Some(pos) = remaining.find('\n') {
354 self.line_buffer.push_str(&remaining[..pos]);
355 self.flush_line();
356 remaining = &remaining[pos + 1..];
357 }
358 self.line_buffer.push_str(remaining);
359 Ok(())
360 }
361}
362
363#[cfg(test)]
364mod tests {
365 use super::*;
366 use crate::{deterministic, Metrics, Runner};
367 use commonware_macros::test_traced;
368 use futures::task::waker;
369 use prometheus_client::metrics::counter::Counter;
370 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
371
372 fn encode_dedup(input: &str) -> String {
373 use std::fmt::Write;
374 let mut encoder = MetricEncoder::new();
375 encoder.write_str(input).unwrap();
376 encoder.into_string()
377 }
378
379 #[test]
380 fn test_metric_encoder_empty() {
381 assert_eq!(encode_dedup(""), "");
382 assert_eq!(encode_dedup("# EOF\n"), "# EOF\n");
383 }
384
385 #[test]
386 fn test_metric_encoder_no_duplicates() {
387 let input = r#"# HELP foo_total A counter.
388# TYPE foo_total counter
389foo_total 1
390# HELP bar_gauge A gauge.
391# TYPE bar_gauge gauge
392bar_gauge 42
393# EOF
394"#;
395 let output = encode_dedup(input);
396 assert_eq!(output, input);
397 }
398
399 #[test]
400 fn test_metric_encoder_with_duplicates() {
401 let input = r#"# HELP votes_total vote count.
402# TYPE votes_total counter
403votes_total{epoch="e5"} 1
404# HELP votes_total vote count.
405# TYPE votes_total counter
406votes_total{epoch="e6"} 2
407# EOF
408"#;
409 let expected = r#"# HELP votes_total vote count.
410# TYPE votes_total counter
411votes_total{epoch="e5"} 1
412votes_total{epoch="e6"} 2
413# EOF
414"#;
415 let output = encode_dedup(input);
416 assert_eq!(output, expected);
417 }
418
419 #[test]
420 fn test_metric_encoder_multiple_metrics() {
421 let input = r#"# HELP a_total First.
422# TYPE a_total counter
423a_total{tag="x"} 1
424# HELP b_total Second.
425# TYPE b_total counter
426b_total 5
427# HELP a_total First.
428# TYPE a_total counter
429a_total{tag="y"} 2
430# EOF
431"#;
432 let expected = r#"# HELP a_total First.
433# TYPE a_total counter
434a_total{tag="x"} 1
435# HELP b_total Second.
436# TYPE b_total counter
437b_total 5
438a_total{tag="y"} 2
439# EOF
440"#;
441 let output = encode_dedup(input);
442 assert_eq!(output, expected);
443 }
444
445 #[test]
446 fn test_metric_encoder_preserves_order() {
447 let input = r#"# HELP z First alphabetically last.
448# TYPE z counter
449z_total 1
450# HELP a Last alphabetically first.
451# TYPE a counter
452a_total 2
453# EOF
454"#;
455 let output = encode_dedup(input);
456 assert_eq!(output, input);
457 }
458
459 #[test_traced]
460 fn test_rwlock() {
461 let executor = deterministic::Runner::default();
462 executor.start(|_| async move {
463 let lock = RwLock::new(100);
465
466 let r1 = lock.read().await;
468 let r2 = lock.read().await;
469 assert_eq!(*r1 + *r2, 200);
470
471 drop((r1, r2)); let mut w = lock.write().await;
474 *w += 1;
475
476 assert_eq!(*w, 101);
478 });
479 }
480
481 #[test]
482 fn test_blocker_waits_until_wake() {
483 let blocker = Blocker::new();
484 let started = Arc::new(AtomicBool::new(false));
485 let completed = Arc::new(AtomicBool::new(false));
486
487 let thread_blocker = blocker.clone();
488 let thread_started = started.clone();
489 let thread_completed = completed.clone();
490 let handle = std::thread::spawn(move || {
491 thread_started.store(true, Ordering::SeqCst);
492 thread_blocker.wait();
493 thread_completed.store(true, Ordering::SeqCst);
494 });
495
496 while !started.load(Ordering::SeqCst) {
497 std::thread::yield_now();
498 }
499
500 assert!(!completed.load(Ordering::SeqCst));
501 waker(blocker).wake();
502 handle.join().unwrap();
503 assert!(completed.load(Ordering::SeqCst));
504 }
505
506 #[test]
507 fn test_blocker_handles_pre_wake() {
508 let blocker = Blocker::new();
509 waker(blocker.clone()).wake();
510
511 let completed = Arc::new(AtomicBool::new(false));
512 let thread_blocker = blocker;
513 let thread_completed = completed.clone();
514 std::thread::spawn(move || {
515 thread_blocker.wait();
516 thread_completed.store(true, Ordering::SeqCst);
517 })
518 .join()
519 .unwrap();
520
521 assert!(completed.load(Ordering::SeqCst));
522 }
523
524 #[test]
525 fn test_blocker_reusable_across_signals() {
526 let blocker = Blocker::new();
527 let completed = Arc::new(AtomicUsize::new(0));
528
529 let thread_blocker = blocker.clone();
530 let thread_completed = completed.clone();
531 let handle = std::thread::spawn(move || {
532 for _ in 0..2 {
533 thread_blocker.wait();
534 thread_completed.fetch_add(1, Ordering::SeqCst);
535 }
536 });
537
538 for expected in 1..=2 {
539 waker(blocker.clone()).wake();
540 while completed.load(Ordering::SeqCst) < expected {
541 std::thread::yield_now();
542 }
543 }
544
545 handle.join().unwrap();
546 assert_eq!(completed.load(Ordering::SeqCst), 2);
547 }
548
549 #[test_traced]
550 fn test_count_running_tasks() {
551 use crate::{Metrics, Runner, Spawner};
552 use futures::future;
553
554 let executor = deterministic::Runner::default();
555 executor.start(|context| async move {
556 assert_eq!(
558 count_running_tasks(&context, "worker"),
559 0,
560 "no worker tasks initially"
561 );
562
563 let worker_ctx = context.with_label("worker");
565 let handle1 = worker_ctx.clone().spawn(|_| async move {
566 future::pending::<()>().await;
567 });
568
569 let count = count_running_tasks(&context, "worker");
571 assert_eq!(count, 1, "worker task should be running");
572
573 assert_eq!(
575 count_running_tasks(&context, "other"),
576 0,
577 "no tasks with 'other' prefix"
578 );
579
580 let handle2 = worker_ctx.with_label("child").spawn(|_| async move {
582 future::pending::<()>().await;
583 });
584
585 let count = count_running_tasks(&context, "worker");
587 assert_eq!(count, 2, "both worker and worker_child should be counted");
588
589 handle1.abort();
591 let _ = handle1.await;
592
593 let count = count_running_tasks(&context, "worker");
595 assert_eq!(count, 1, "only worker_child should remain");
596
597 handle2.abort();
599 let _ = handle2.await;
600
601 assert_eq!(
603 count_running_tasks(&context, "worker"),
604 0,
605 "all worker tasks should be stopped"
606 );
607 });
608 }
609
610 #[test_traced]
611 fn test_no_duplicate_metrics() {
612 let executor = deterministic::Runner::default();
613 executor.start(|context| async move {
614 let c1 = Counter::<u64>::default();
616 context.with_label("a").register("test", "help", c1);
617 let c2 = Counter::<u64>::default();
618 context.with_label("b").register("test", "help", c2);
619 });
620 }
622
623 #[test]
624 #[should_panic(expected = "duplicate metric:")]
625 fn test_duplicate_metrics_panics() {
626 let executor = deterministic::Runner::default();
627 executor.start(|context| async move {
628 let c1 = Counter::<u64>::default();
630 context.with_label("a").register("test", "help", c1);
631 let c2 = Counter::<u64>::default();
632 context.with_label("a").register("test", "help", c2);
633 });
634 }
635}