Skip to main content

dag_executor/
context.rs

1//! Shared execution context passed to every task.
2//!
3//! A [`Context`] is created once per run and shared (via `Arc`) across all
4//! tasks. It provides three things tasks commonly need:
5//!
6//! * a typed key/value **blackboard** for passing data between tasks,
7//! * **cooperative cancellation** (e.g. for graceful shutdown),
8//! * a lightweight **event bus** used by [`crate::tasks::EventDrivenTask`].
9
10use crate::utils::Config;
11use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::sync::atomic::{AtomicBool, Ordering};
14use std::sync::Arc;
15use tokio::sync::Notify;
16
17/// A cooperative cancellation token.
18///
19/// Cloning yields another handle to the *same* underlying flag, so cancelling
20/// any clone cancels them all. Tasks should poll [`CancelToken::is_cancelled`]
21/// at safe points or `await` [`CancelToken::cancelled`].
22#[derive(Clone, Default)]
23pub struct CancelToken {
24    inner: Arc<CancelInner>,
25}
26
27#[derive(Default)]
28struct CancelInner {
29    cancelled: AtomicBool,
30    notify: Notify,
31}
32
33impl CancelToken {
34    /// Create a fresh, un-cancelled token.
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Signal cancellation and wake all waiters.
40    pub fn cancel(&self) {
41        self.inner.cancelled.store(true, Ordering::SeqCst);
42        self.inner.notify.notify_waiters();
43    }
44
45    /// Whether cancellation has been requested.
46    pub fn is_cancelled(&self) -> bool {
47        self.inner.cancelled.load(Ordering::SeqCst)
48    }
49
50    /// Resolve once cancellation is requested. Returns immediately if already
51    /// cancelled.
52    pub async fn cancelled(&self) {
53        if self.is_cancelled() {
54            return;
55        }
56        // Re-check after registering for the notification to avoid a lost wakeup.
57        let notified = self.inner.notify.notified();
58        if self.is_cancelled() {
59            return;
60        }
61        notified.await;
62    }
63}
64
65/// Per-run shared state handed to tasks.
66pub struct Context {
67    /// Immutable run configuration.
68    pub config: Arc<Config>,
69    /// Unique id for this run.
70    pub run_id: String,
71    blackboard: RwLock<HashMap<String, serde_json::Value>>,
72    events: RwLock<HashMap<String, Arc<Notify>>>,
73    cancel: CancelToken,
74}
75
76impl Context {
77    /// Build a context for a new run.
78    pub fn new(config: Arc<Config>) -> Self {
79        Context {
80            config,
81            run_id: uuid::Uuid::new_v4().to_string(),
82            blackboard: RwLock::new(HashMap::new()),
83            events: RwLock::new(HashMap::new()),
84            cancel: CancelToken::new(),
85        }
86    }
87
88    /// A context with default configuration — handy in tests.
89    pub fn for_tests() -> Arc<Self> {
90        Arc::new(Context::new(Arc::new(Config::default())))
91    }
92
93    /// The cancellation token for this run.
94    pub fn cancel_token(&self) -> &CancelToken {
95        &self.cancel
96    }
97
98    /// Request cancellation of the whole run.
99    pub fn cancel(&self) {
100        self.cancel.cancel();
101    }
102
103    /// Whether the run has been cancelled.
104    pub fn is_cancelled(&self) -> bool {
105        self.cancel.is_cancelled()
106    }
107
108    // ----- blackboard -----
109
110    /// Store a value on the blackboard, overwriting any existing entry.
111    pub fn set(&self, key: impl Into<String>, value: serde_json::Value) {
112        self.blackboard.write().insert(key.into(), value);
113    }
114
115    /// Fetch (and clone) a value from the blackboard.
116    pub fn get(&self, key: &str) -> Option<serde_json::Value> {
117        self.blackboard.read().get(key).cloned()
118    }
119
120    /// Deserialize a blackboard value into a concrete type.
121    pub fn get_as<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
122        self.get(key).and_then(|v| serde_json::from_value(v).ok())
123    }
124
125    // ----- event bus -----
126
127    fn event_handle(&self, name: &str) -> Arc<Notify> {
128        if let Some(n) = self.events.read().get(name) {
129            return n.clone();
130        }
131        let mut w = self.events.write();
132        w.entry(name.to_string())
133            .or_insert_with(|| Arc::new(Notify::new()))
134            .clone()
135    }
136
137    /// Emit a named event, waking any tasks waiting on it.
138    pub fn emit(&self, name: &str) {
139        self.event_handle(name).notify_waiters();
140    }
141
142    /// Wait for a named event to be emitted.
143    ///
144    /// Resolves early (returning `false`) if the run is cancelled first;
145    /// returns `true` when the event actually fires.
146    pub async fn wait_for(&self, name: &str) -> bool {
147        let notify = self.event_handle(name);
148        let notified = notify.notified();
149        tokio::select! {
150            _ = notified => true,
151            _ = self.cancel.cancelled() => false,
152        }
153    }
154}