1use std::sync::Arc;
4use std::sync::atomic::{AtomicBool, Ordering};
5
6use aion_core::{ActivityId, Payload, WorkflowId};
7use tokio::sync::{Notify, mpsc};
8
9use crate::error::WorkerError;
10
11#[derive(Clone, Debug)]
13pub struct ActivityContext {
14 workflow_id: Option<WorkflowId>,
15 activity_id: ActivityId,
16 attempt: u32,
17 cancellation: Arc<CancellationState>,
18 heartbeat_sender: Option<mpsc::UnboundedSender<HeartbeatRequest>>,
19}
20
21#[derive(Clone, Debug)]
23pub struct ActivityCancellationHandle {
24 cancellation: Arc<CancellationState>,
25}
26
27#[derive(Clone, Debug, PartialEq, Eq)]
29pub struct HeartbeatRequest {
30 pub workflow_id: WorkflowId,
32 pub activity_id: ActivityId,
34 pub detail: Option<Payload>,
36}
37
38#[derive(Debug)]
39struct CancellationState {
40 cancelled: AtomicBool,
41 notify: Notify,
42}
43
44impl ActivityContext {
45 #[must_use]
47 pub fn new(activity_id: ActivityId, attempt: u32) -> (Self, ActivityCancellationHandle) {
48 Self::with_heartbeat_sender(activity_id, attempt, None)
49 }
50
51 #[must_use]
53 pub const fn activity_id(&self) -> &ActivityId {
54 &self.activity_id
55 }
56
57 #[must_use]
59 pub const fn attempt(&self) -> u32 {
60 self.attempt
61 }
62
63 pub fn heartbeat(&self, detail: Option<Payload>) -> Result<(), WorkerError> {
73 if let Some(sender) = &self.heartbeat_sender {
74 let workflow_id = self.workflow_id.clone().ok_or_else(|| {
75 WorkerError::registration(HeartbeatMissingWorkflow {
76 activity_id: self.activity_id.clone(),
77 })
78 })?;
79 sender
80 .send(HeartbeatRequest {
81 workflow_id,
82 activity_id: self.activity_id.clone(),
83 detail,
84 })
85 .map_err(|source| WorkerError::registration(HeartbeatSeamClosed { source }))?;
86 }
87 Ok(())
88 }
89
90 #[must_use]
92 pub fn is_cancelled(&self) -> bool {
93 self.cancellation.cancelled.load(Ordering::Acquire)
94 }
95
96 pub async fn cancelled(&self) {
98 while !self.is_cancelled() {
99 self.cancellation.notify.notified().await;
100 }
101 }
102
103 pub(crate) fn with_heartbeat_sender(
104 activity_id: ActivityId,
105 attempt: u32,
106 heartbeat_sender: Option<mpsc::UnboundedSender<HeartbeatRequest>>,
107 ) -> (Self, ActivityCancellationHandle) {
108 Self::for_workflow(None, activity_id, attempt, heartbeat_sender)
109 }
110
111 pub(crate) fn for_workflow(
112 workflow_id: Option<WorkflowId>,
113 activity_id: ActivityId,
114 attempt: u32,
115 heartbeat_sender: Option<mpsc::UnboundedSender<HeartbeatRequest>>,
116 ) -> (Self, ActivityCancellationHandle) {
117 let cancellation = Arc::new(CancellationState {
118 cancelled: AtomicBool::new(false),
119 notify: Notify::new(),
120 });
121 let context = Self {
122 workflow_id,
123 activity_id,
124 attempt,
125 cancellation: Arc::clone(&cancellation),
126 heartbeat_sender,
127 };
128 let handle = ActivityCancellationHandle { cancellation };
129 (context, handle)
130 }
131}
132
133impl ActivityCancellationHandle {
134 pub fn cancel(&self) {
136 let was_cancelled = self.cancellation.cancelled.swap(true, Ordering::AcqRel);
137 if !was_cancelled {
138 self.cancellation.notify.notify_waiters();
139 }
140 }
141}
142
143#[derive(Debug, thiserror::Error)]
144#[error("activity heartbeat seam is closed: {source}")]
145struct HeartbeatSeamClosed {
146 source: mpsc::error::SendError<HeartbeatRequest>,
147}
148
149#[derive(Debug, thiserror::Error)]
150#[error("activity {activity_id} heartbeat is missing workflow id")]
151struct HeartbeatMissingWorkflow {
152 activity_id: ActivityId,
153}
154
155#[cfg(test)]
156mod tests {
157 use std::time::Duration;
158
159 use aion_core::ActivityId;
160
161 use super::ActivityContext;
162
163 #[tokio::test]
164 async fn context_exposes_identity_attempt_and_cancellation_signal() {
165 let activity_id = ActivityId::from_sequence_position(42);
166 let (context, cancellation) = ActivityContext::new(activity_id.clone(), 3);
167
168 assert_eq!(context.activity_id(), &activity_id);
169 assert_eq!(context.attempt(), 3);
170 assert!(!context.is_cancelled());
171
172 cancellation.cancel();
173
174 assert!(context.is_cancelled());
175 let cancelled = tokio::time::timeout(Duration::from_millis(50), context.cancelled()).await;
176 assert!(cancelled.is_ok());
177 }
178}