Skip to main content

aion_worker/
context.rs

1//! `ActivityContext` heartbeat, cancellation, attempt, and identifier support.
2
3use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5
6use aion_core::{ActivityId, Payload, WorkflowId};
7use tokio::sync::{Notify, mpsc};
8
9use crate::error::WorkerError;
10
11/// Handler-facing context for one activity execution.
12#[derive(Clone, Debug)]
13pub struct ActivityContext {
14    workflow_id: Option<WorkflowId>,
15    activity_id: ActivityId,
16    attempt: u32,
17    cancellation: Arc<CancellationState>,
18    heartbeat_sender: Option<mpsc::UnboundedSender<HeartbeatRequest>>,
19}
20
21/// Internal handle used by the worker runtime to signal cooperative cancellation.
22#[derive(Clone, Debug)]
23pub struct ActivityCancellationHandle {
24    cancellation: Arc<CancellationState>,
25}
26
27/// Heartbeat request emitted by [`ActivityContext::heartbeat`].
28#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct HeartbeatRequest {
30    /// Workflow owning the activity whose progress is being reported.
31    pub workflow_id: WorkflowId,
32    /// Activity whose progress is being reported.
33    pub activity_id: ActivityId,
34    /// Opaque progress detail supplied by the handler.
35    pub detail: Option<Payload>,
36}
37
38#[derive(Debug)]
39struct CancellationState {
40    cancelled: AtomicBool,
41    notify: Notify,
42}
43
44impl ActivityContext {
45    /// Creates a context and the internal handle that can signal cancellation.
46    #[must_use]
47    pub fn new(activity_id: ActivityId, attempt: u32) -> (Self, ActivityCancellationHandle) {
48        Self::with_heartbeat_sender(activity_id, attempt, None)
49    }
50
51    /// Returns this activity's identifier.
52    #[must_use]
53    pub const fn activity_id(&self) -> &ActivityId {
54        &self.activity_id
55    }
56
57    /// Returns this activity's attempt number.
58    #[must_use]
59    pub const fn attempt(&self) -> u32 {
60        self.attempt
61    }
62
63    /// Emits a cooperative heartbeat request for this activity.
64    ///
65    /// Only explicit handler calls enqueue heartbeats. Contexts created without a
66    /// live heartbeat sender remain no-op contexts for isolated unit tests.
67    ///
68    /// # Errors
69    ///
70    /// Returns [`WorkerError`] when an installed heartbeat seam has been closed
71    /// or when the context lacks the workflow id required by the live session.
72    pub fn heartbeat(&self, detail: Option<Payload>) -> Result<(), WorkerError> {
73        if let Some(sender) = &self.heartbeat_sender {
74            let workflow_id = self.workflow_id.clone().ok_or_else(|| {
75                WorkerError::registration(HeartbeatMissingWorkflow {
76                    activity_id: self.activity_id.clone(),
77                })
78            })?;
79            sender
80                .send(HeartbeatRequest {
81                    workflow_id,
82                    activity_id: self.activity_id.clone(),
83                    detail,
84                })
85                .map_err(|source| WorkerError::registration(HeartbeatSeamClosed { source }))?;
86        }
87        Ok(())
88    }
89
90    /// Returns true once cooperative cancellation has been signalled.
91    #[must_use]
92    pub fn is_cancelled(&self) -> bool {
93        self.cancellation.cancelled.load(Ordering::Acquire)
94    }
95
96    /// Resolves when cooperative cancellation is signalled.
97    pub async fn cancelled(&self) {
98        while !self.is_cancelled() {
99            self.cancellation.notify.notified().await;
100        }
101    }
102
103    pub(crate) fn with_heartbeat_sender(
104        activity_id: ActivityId,
105        attempt: u32,
106        heartbeat_sender: Option<mpsc::UnboundedSender<HeartbeatRequest>>,
107    ) -> (Self, ActivityCancellationHandle) {
108        Self::for_workflow(None, activity_id, attempt, heartbeat_sender)
109    }
110
111    pub(crate) fn for_workflow(
112        workflow_id: Option<WorkflowId>,
113        activity_id: ActivityId,
114        attempt: u32,
115        heartbeat_sender: Option<mpsc::UnboundedSender<HeartbeatRequest>>,
116    ) -> (Self, ActivityCancellationHandle) {
117        let cancellation = Arc::new(CancellationState {
118            cancelled: AtomicBool::new(false),
119            notify: Notify::new(),
120        });
121        let context = Self {
122            workflow_id,
123            activity_id,
124            attempt,
125            cancellation: Arc::clone(&cancellation),
126            heartbeat_sender,
127        };
128        let handle = ActivityCancellationHandle { cancellation };
129        (context, handle)
130    }
131}
132
133impl ActivityCancellationHandle {
134    /// Signals cooperative cancellation to the handler-facing context.
135    pub fn cancel(&self) {
136        let was_cancelled = self.cancellation.cancelled.swap(true, Ordering::AcqRel);
137        if !was_cancelled {
138            self.cancellation.notify.notify_waiters();
139        }
140    }
141}
142
143#[derive(Debug, thiserror::Error)]
144#[error("activity heartbeat seam is closed: {source}")]
145struct HeartbeatSeamClosed {
146    source: mpsc::error::SendError<HeartbeatRequest>,
147}
148
149#[derive(Debug, thiserror::Error)]
150#[error("activity {activity_id} heartbeat is missing workflow id")]
151struct HeartbeatMissingWorkflow {
152    activity_id: ActivityId,
153}
154
155#[cfg(test)]
156mod tests {
157    use std::time::Duration;
158
159    use aion_core::ActivityId;
160
161    use super::ActivityContext;
162
163    #[tokio::test]
164    async fn context_exposes_identity_attempt_and_cancellation_signal() {
165        let activity_id = ActivityId::from_sequence_position(42);
166        let (context, cancellation) = ActivityContext::new(activity_id.clone(), 3);
167
168        assert_eq!(context.activity_id(), &activity_id);
169        assert_eq!(context.attempt(), 3);
170        assert!(!context.is_cancelled());
171
172        cancellation.cancel();
173
174        assert!(context.is_cancelled());
175        let cancelled = tokio::time::timeout(Duration::from_millis(50), context.cancelled()).await;
176        assert!(cancelled.is_ok());
177    }
178}