1use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::Instant;
6
7use aion_core::{ActivityId, WorkflowId};
8
9use crate::context::HeartbeatRequest;
10use crate::error::WorkerError;
11use crate::protocol::WorkerSession;
12
13#[derive(Clone, Debug, Default)]
20pub struct HeartbeatBookkeeper {
21 inner: Arc<Mutex<HashMap<ActivityExecutionKey, Option<Instant>>>>,
22}
23
24impl HeartbeatBookkeeper {
25 pub fn register(&self, key: ActivityExecutionKey) -> Result<(), WorkerError> {
32 let mut last_heartbeats = self.lock_last_heartbeats()?;
33 last_heartbeats.entry(key).or_insert(None);
34 Ok(())
35 }
36
37 pub fn remove(&self, key: &ActivityExecutionKey) -> Result<(), WorkerError> {
43 let mut last_heartbeats = self.lock_last_heartbeats()?;
44 last_heartbeats.remove(key);
45 Ok(())
46 }
47
48 #[must_use]
51 pub fn last_heartbeat(&self, key: &ActivityExecutionKey) -> Option<Instant> {
52 match self.inner.lock() {
53 Ok(last_heartbeats) => last_heartbeats.get(key).copied().flatten(),
54 Err(poisoned) => poisoned.into_inner().get(key).copied().flatten(),
55 }
56 }
57
58 fn record_sent(&self, key: ActivityExecutionKey, sent_at: Instant) -> Result<(), WorkerError> {
59 let mut last_heartbeats = self.lock_last_heartbeats()?;
60 last_heartbeats.insert(key, Some(sent_at));
61 Ok(())
62 }
63
64 fn lock_last_heartbeats(
65 &self,
66 ) -> Result<
67 std::sync::MutexGuard<'_, HashMap<ActivityExecutionKey, Option<Instant>>>,
68 WorkerError,
69 > {
70 self.inner
71 .lock()
72 .map_err(|_| WorkerError::registration(HeartbeatBookkeeperPoisoned))
73 }
74}
75
76pub async fn send_heartbeat<S>(
84 session: &mut S,
85 bookkeeper: &HeartbeatBookkeeper,
86 request: HeartbeatRequest,
87) -> Result<(), WorkerError>
88where
89 S: WorkerSession,
90{
91 let key = ActivityExecutionKey::new(request.workflow_id.clone(), request.activity_id.clone());
92 session
93 .send_heartbeat(request.workflow_id, request.activity_id, request.detail)
94 .await?;
95 bookkeeper.record_sent(key, Instant::now())
96}
97
98#[derive(Debug, thiserror::Error)]
99#[error("heartbeat bookkeeper mutex was poisoned")]
100struct HeartbeatBookkeeperPoisoned;
101
102#[derive(Clone, Debug, PartialEq, Eq, Hash)]
104pub struct ActivityExecutionKey {
105 pub workflow_id: WorkflowId,
107 pub activity_id: ActivityId,
109}
110
111impl ActivityExecutionKey {
112 #[must_use]
114 pub const fn new(workflow_id: WorkflowId, activity_id: ActivityId) -> Self {
115 Self {
116 workflow_id,
117 activity_id,
118 }
119 }
120}
121
122#[cfg(test)]
123mod tests {
124 use std::collections::BTreeSet;
125 use std::time::Duration;
126
127 use aion_core::{ActivityError, ActivityId, ContentType, Payload, WorkflowId};
128 use async_trait::async_trait;
129 use futures::stream;
130
131 use super::{ActivityExecutionKey, HeartbeatBookkeeper, send_heartbeat};
132 use crate::WorkerConfig;
133 use crate::context::HeartbeatRequest;
134 use crate::error::WorkerError;
135 use crate::protocol::{WorkerSession, WorkerTaskStream, validate_activity_handlers};
136
137 #[derive(Debug, thiserror::Error)]
138 #[error("heartbeat timestamp was not recorded")]
139 struct MissingHeartbeatTimestamp;
140
141 #[derive(Default)]
142 struct FakeSession {
143 heartbeats: Vec<RecordedHeartbeat>,
144 }
145
146 #[derive(Clone, Debug, PartialEq, Eq)]
147 struct RecordedHeartbeat {
148 workflow_id: WorkflowId,
149 activity_id: ActivityId,
150 detail: Option<Payload>,
151 }
152
153 #[async_trait]
154 impl WorkerSession for FakeSession {
155 async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
156 drop(config.clone());
157 Ok(())
158 }
159
160 async fn register(
161 &mut self,
162 activity_types: Vec<String>,
163 available_handlers: &BTreeSet<String>,
164 ) -> Result<(), WorkerError> {
165 validate_activity_handlers(&activity_types, available_handlers)
166 }
167
168 fn receive_tasks(&mut self) -> WorkerTaskStream {
169 Box::pin(stream::empty())
170 }
171
172 async fn report_result(
173 &mut self,
174 workflow_id: WorkflowId,
175 activity_id: ActivityId,
176 result: Payload,
177 ) -> Result<(), WorkerError> {
178 drop((workflow_id, activity_id, result));
179 Ok(())
180 }
181
182 async fn report_failure(
183 &mut self,
184 workflow_id: WorkflowId,
185 activity_id: ActivityId,
186 failure: ActivityError,
187 ) -> Result<(), WorkerError> {
188 drop((workflow_id, activity_id, failure));
189 Ok(())
190 }
191
192 async fn send_heartbeat(
193 &mut self,
194 workflow_id: WorkflowId,
195 activity_id: ActivityId,
196 progress: Option<Payload>,
197 ) -> Result<(), WorkerError> {
198 self.heartbeats.push(RecordedHeartbeat {
199 workflow_id,
200 activity_id,
201 detail: progress,
202 });
203 Ok(())
204 }
205 }
206
207 #[tokio::test]
208 async fn sends_explicit_heartbeats_and_preserves_detail() -> Result<(), WorkerError> {
209 let workflow_id = WorkflowId::new_v4();
210 let activity_id = ActivityId::from_sequence_position(7);
211 let detail = Payload::new(ContentType::Json, br#"{"progress":1}"#.to_vec());
212 let bookkeeper = HeartbeatBookkeeper::default();
213 let mut session = FakeSession::default();
214
215 send_heartbeat(
216 &mut session,
217 &bookkeeper,
218 HeartbeatRequest {
219 workflow_id: workflow_id.clone(),
220 activity_id: activity_id.clone(),
221 detail: Some(detail.clone()),
222 },
223 )
224 .await?;
225 send_heartbeat(
226 &mut session,
227 &bookkeeper,
228 HeartbeatRequest {
229 workflow_id: workflow_id.clone(),
230 activity_id: activity_id.clone(),
231 detail: Some(detail.clone()),
232 },
233 )
234 .await?;
235
236 assert_eq!(
237 session.heartbeats,
238 vec![
239 RecordedHeartbeat {
240 workflow_id: workflow_id.clone(),
241 activity_id: activity_id.clone(),
242 detail: Some(detail.clone()),
243 },
244 RecordedHeartbeat {
245 workflow_id,
246 activity_id,
247 detail: Some(detail.clone()),
248 },
249 ]
250 );
251 assert_eq!(detail.content_type(), &ContentType::Json);
252 Ok(())
253 }
254
255 #[tokio::test]
256 async fn last_heartbeat_timestamp_advances_on_each_send() -> Result<(), WorkerError> {
257 let workflow_id = WorkflowId::new_v4();
258 let activity_id = ActivityId::from_sequence_position(8);
259 let key = ActivityExecutionKey::new(workflow_id.clone(), activity_id.clone());
260 let bookkeeper = HeartbeatBookkeeper::default();
261 let mut session = FakeSession::default();
262
263 send_heartbeat(
264 &mut session,
265 &bookkeeper,
266 HeartbeatRequest {
267 workflow_id: workflow_id.clone(),
268 activity_id: activity_id.clone(),
269 detail: None,
270 },
271 )
272 .await?;
273 let first = bookkeeper.last_heartbeat(&key);
274 tokio::time::sleep(Duration::from_millis(1)).await;
275 send_heartbeat(
276 &mut session,
277 &bookkeeper,
278 HeartbeatRequest {
279 workflow_id,
280 activity_id: activity_id.clone(),
281 detail: None,
282 },
283 )
284 .await?;
285 let second = bookkeeper.last_heartbeat(&key);
286
287 let (Some(first), Some(second)) = (first, second) else {
288 return Err(WorkerError::decode(MissingHeartbeatTimestamp));
289 };
290 assert!(second > first);
291 Ok(())
292 }
293
294 #[tokio::test]
295 async fn colliding_sequence_positions_track_per_workflow() -> Result<(), WorkerError> {
296 let activity_id = ActivityId::from_sequence_position(3);
297 let workflow_a = WorkflowId::new_v4();
298 let workflow_b = WorkflowId::new_v4();
299 let key_a = ActivityExecutionKey::new(workflow_a.clone(), activity_id.clone());
300 let key_b = ActivityExecutionKey::new(workflow_b.clone(), activity_id.clone());
301 let bookkeeper = HeartbeatBookkeeper::default();
302 let mut session = FakeSession::default();
303
304 bookkeeper.register(key_a.clone())?;
305 bookkeeper.register(key_b.clone())?;
306
307 send_heartbeat(
309 &mut session,
310 &bookkeeper,
311 HeartbeatRequest {
312 workflow_id: workflow_a,
313 activity_id: activity_id.clone(),
314 detail: None,
315 },
316 )
317 .await?;
318 assert!(bookkeeper.last_heartbeat(&key_a).is_some());
319 assert!(bookkeeper.last_heartbeat(&key_b).is_none());
320
321 send_heartbeat(
322 &mut session,
323 &bookkeeper,
324 HeartbeatRequest {
325 workflow_id: workflow_b,
326 activity_id,
327 detail: None,
328 },
329 )
330 .await?;
331 let b_before_a_completes = bookkeeper.last_heartbeat(&key_b);
332 let Some(b_before_a_completes) = b_before_a_completes else {
333 return Err(WorkerError::decode(MissingHeartbeatTimestamp));
334 };
335
336 bookkeeper.remove(&key_a)?;
338 assert!(bookkeeper.last_heartbeat(&key_a).is_none());
339 assert_eq!(
340 bookkeeper.last_heartbeat(&key_b),
341 Some(b_before_a_completes)
342 );
343 Ok(())
344 }
345}