Skip to main content

aion_worker/protocol/
heartbeat.rs

1//! heartbeat frame send + heartbeat-timeout bookkeeping
2
3use std::collections::HashMap;
4use std::sync::{Arc, Mutex};
5use std::time::Instant;
6
7use aion_core::{ActivityId, WorkflowId};
8
9use crate::context::HeartbeatRequest;
10use crate::error::WorkerError;
11use crate::protocol::WorkerSession;
12
13/// In-memory liveness view for explicitly emitted activity heartbeats.
14///
15/// This bookkeeper is observability-only. It records the last successful local
16/// send time for in-flight activities, but the SDK never enforces heartbeat
17/// timeouts or fails activities for missing heartbeats; timeout ownership stays
18/// with the engine.
19#[derive(Clone, Debug, Default)]
20pub struct HeartbeatBookkeeper {
21    inner: Arc<Mutex<HashMap<ActivityExecutionKey, Option<Instant>>>>,
22}
23
24impl HeartbeatBookkeeper {
25    /// Marks an activity execution as in flight without recording a heartbeat
26    /// yet.
27    ///
28    /// # Errors
29    ///
30    /// Returns [`WorkerError`] if the in-memory bookkeeping mutex is poisoned.
31    pub fn register(&self, key: ActivityExecutionKey) -> Result<(), WorkerError> {
32        let mut last_heartbeats = self.lock_last_heartbeats()?;
33        last_heartbeats.entry(key).or_insert(None);
34        Ok(())
35    }
36
37    /// Removes bookkeeping for a completed activity execution.
38    ///
39    /// # Errors
40    ///
41    /// Returns [`WorkerError`] if the in-memory bookkeeping mutex is poisoned.
42    pub fn remove(&self, key: &ActivityExecutionKey) -> Result<(), WorkerError> {
43        let mut last_heartbeats = self.lock_last_heartbeats()?;
44        last_heartbeats.remove(key);
45        Ok(())
46    }
47
48    /// Returns the last successful local heartbeat send instant for an
49    /// activity execution.
50    #[must_use]
51    pub fn last_heartbeat(&self, key: &ActivityExecutionKey) -> Option<Instant> {
52        match self.inner.lock() {
53            Ok(last_heartbeats) => last_heartbeats.get(key).copied().flatten(),
54            Err(poisoned) => poisoned.into_inner().get(key).copied().flatten(),
55        }
56    }
57
58    fn record_sent(&self, key: ActivityExecutionKey, sent_at: Instant) -> Result<(), WorkerError> {
59        let mut last_heartbeats = self.lock_last_heartbeats()?;
60        last_heartbeats.insert(key, Some(sent_at));
61        Ok(())
62    }
63
64    fn lock_last_heartbeats(
65        &self,
66    ) -> Result<
67        std::sync::MutexGuard<'_, HashMap<ActivityExecutionKey, Option<Instant>>>,
68        WorkerError,
69    > {
70        self.inner
71            .lock()
72            .map_err(|_| WorkerError::registration(HeartbeatBookkeeperPoisoned))
73    }
74}
75
76/// Sends one explicit heartbeat request and updates local liveness bookkeeping
77/// after the transport accepts the frame.
78///
79/// # Errors
80///
81/// Returns [`WorkerError`] when the session send fails or bookkeeping cannot be
82/// updated.
83pub async fn send_heartbeat<S>(
84    session: &mut S,
85    bookkeeper: &HeartbeatBookkeeper,
86    request: HeartbeatRequest,
87) -> Result<(), WorkerError>
88where
89    S: WorkerSession,
90{
91    let key = ActivityExecutionKey::new(request.workflow_id.clone(), request.activity_id.clone());
92    session
93        .send_heartbeat(request.workflow_id, request.activity_id, request.detail)
94        .await?;
95    bookkeeper.record_sent(key, Instant::now())
96}
97
98#[derive(Debug, thiserror::Error)]
99#[error("heartbeat bookkeeper mutex was poisoned")]
100struct HeartbeatBookkeeperPoisoned;
101
102/// Key identifying one in-flight activity execution.
103#[derive(Clone, Debug, PartialEq, Eq, Hash)]
104pub struct ActivityExecutionKey {
105    /// Owning workflow id.
106    pub workflow_id: WorkflowId,
107    /// Activity id within the workflow.
108    pub activity_id: ActivityId,
109}
110
111impl ActivityExecutionKey {
112    /// Creates a key for an in-flight activity execution.
113    #[must_use]
114    pub const fn new(workflow_id: WorkflowId, activity_id: ActivityId) -> Self {
115        Self {
116            workflow_id,
117            activity_id,
118        }
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use std::collections::BTreeSet;
125    use std::time::Duration;
126
127    use aion_core::{ActivityError, ActivityId, ContentType, Payload, WorkflowId};
128    use async_trait::async_trait;
129    use futures::stream;
130
131    use super::{ActivityExecutionKey, HeartbeatBookkeeper, send_heartbeat};
132    use crate::WorkerConfig;
133    use crate::context::HeartbeatRequest;
134    use crate::error::WorkerError;
135    use crate::protocol::{WorkerSession, WorkerTaskStream, validate_activity_handlers};
136
137    #[derive(Debug, thiserror::Error)]
138    #[error("heartbeat timestamp was not recorded")]
139    struct MissingHeartbeatTimestamp;
140
141    #[derive(Default)]
142    struct FakeSession {
143        heartbeats: Vec<RecordedHeartbeat>,
144    }
145
146    #[derive(Clone, Debug, PartialEq, Eq)]
147    struct RecordedHeartbeat {
148        workflow_id: WorkflowId,
149        activity_id: ActivityId,
150        detail: Option<Payload>,
151    }
152
153    #[async_trait]
154    impl WorkerSession for FakeSession {
155        async fn handshake(&mut self, config: &WorkerConfig) -> Result<(), WorkerError> {
156            drop(config.clone());
157            Ok(())
158        }
159
160        async fn register(
161            &mut self,
162            activity_types: Vec<String>,
163            available_handlers: &BTreeSet<String>,
164        ) -> Result<(), WorkerError> {
165            validate_activity_handlers(&activity_types, available_handlers)
166        }
167
168        fn receive_tasks(&mut self) -> WorkerTaskStream {
169            Box::pin(stream::empty())
170        }
171
172        async fn report_result(
173            &mut self,
174            workflow_id: WorkflowId,
175            activity_id: ActivityId,
176            result: Payload,
177        ) -> Result<(), WorkerError> {
178            drop((workflow_id, activity_id, result));
179            Ok(())
180        }
181
182        async fn report_failure(
183            &mut self,
184            workflow_id: WorkflowId,
185            activity_id: ActivityId,
186            failure: ActivityError,
187        ) -> Result<(), WorkerError> {
188            drop((workflow_id, activity_id, failure));
189            Ok(())
190        }
191
192        async fn send_heartbeat(
193            &mut self,
194            workflow_id: WorkflowId,
195            activity_id: ActivityId,
196            progress: Option<Payload>,
197        ) -> Result<(), WorkerError> {
198            self.heartbeats.push(RecordedHeartbeat {
199                workflow_id,
200                activity_id,
201                detail: progress,
202            });
203            Ok(())
204        }
205    }
206
207    #[tokio::test]
208    async fn sends_explicit_heartbeats_and_preserves_detail() -> Result<(), WorkerError> {
209        let workflow_id = WorkflowId::new_v4();
210        let activity_id = ActivityId::from_sequence_position(7);
211        let detail = Payload::new(ContentType::Json, br#"{"progress":1}"#.to_vec());
212        let bookkeeper = HeartbeatBookkeeper::default();
213        let mut session = FakeSession::default();
214
215        send_heartbeat(
216            &mut session,
217            &bookkeeper,
218            HeartbeatRequest {
219                workflow_id: workflow_id.clone(),
220                activity_id: activity_id.clone(),
221                detail: Some(detail.clone()),
222            },
223        )
224        .await?;
225        send_heartbeat(
226            &mut session,
227            &bookkeeper,
228            HeartbeatRequest {
229                workflow_id: workflow_id.clone(),
230                activity_id: activity_id.clone(),
231                detail: Some(detail.clone()),
232            },
233        )
234        .await?;
235
236        assert_eq!(
237            session.heartbeats,
238            vec![
239                RecordedHeartbeat {
240                    workflow_id: workflow_id.clone(),
241                    activity_id: activity_id.clone(),
242                    detail: Some(detail.clone()),
243                },
244                RecordedHeartbeat {
245                    workflow_id,
246                    activity_id,
247                    detail: Some(detail.clone()),
248                },
249            ]
250        );
251        assert_eq!(detail.content_type(), &ContentType::Json);
252        Ok(())
253    }
254
255    #[tokio::test]
256    async fn last_heartbeat_timestamp_advances_on_each_send() -> Result<(), WorkerError> {
257        let workflow_id = WorkflowId::new_v4();
258        let activity_id = ActivityId::from_sequence_position(8);
259        let key = ActivityExecutionKey::new(workflow_id.clone(), activity_id.clone());
260        let bookkeeper = HeartbeatBookkeeper::default();
261        let mut session = FakeSession::default();
262
263        send_heartbeat(
264            &mut session,
265            &bookkeeper,
266            HeartbeatRequest {
267                workflow_id: workflow_id.clone(),
268                activity_id: activity_id.clone(),
269                detail: None,
270            },
271        )
272        .await?;
273        let first = bookkeeper.last_heartbeat(&key);
274        tokio::time::sleep(Duration::from_millis(1)).await;
275        send_heartbeat(
276            &mut session,
277            &bookkeeper,
278            HeartbeatRequest {
279                workflow_id,
280                activity_id: activity_id.clone(),
281                detail: None,
282            },
283        )
284        .await?;
285        let second = bookkeeper.last_heartbeat(&key);
286
287        let (Some(first), Some(second)) = (first, second) else {
288            return Err(WorkerError::decode(MissingHeartbeatTimestamp));
289        };
290        assert!(second > first);
291        Ok(())
292    }
293
294    #[tokio::test]
295    async fn colliding_sequence_positions_track_per_workflow() -> Result<(), WorkerError> {
296        let activity_id = ActivityId::from_sequence_position(3);
297        let workflow_a = WorkflowId::new_v4();
298        let workflow_b = WorkflowId::new_v4();
299        let key_a = ActivityExecutionKey::new(workflow_a.clone(), activity_id.clone());
300        let key_b = ActivityExecutionKey::new(workflow_b.clone(), activity_id.clone());
301        let bookkeeper = HeartbeatBookkeeper::default();
302        let mut session = FakeSession::default();
303
304        bookkeeper.register(key_a.clone())?;
305        bookkeeper.register(key_b.clone())?;
306
307        // record_sent for workflow A never touches workflow B's timestamp.
308        send_heartbeat(
309            &mut session,
310            &bookkeeper,
311            HeartbeatRequest {
312                workflow_id: workflow_a,
313                activity_id: activity_id.clone(),
314                detail: None,
315            },
316        )
317        .await?;
318        assert!(bookkeeper.last_heartbeat(&key_a).is_some());
319        assert!(bookkeeper.last_heartbeat(&key_b).is_none());
320
321        send_heartbeat(
322            &mut session,
323            &bookkeeper,
324            HeartbeatRequest {
325                workflow_id: workflow_b,
326                activity_id,
327                detail: None,
328            },
329        )
330        .await?;
331        let b_before_a_completes = bookkeeper.last_heartbeat(&key_b);
332        let Some(b_before_a_completes) = b_before_a_completes else {
333            return Err(WorkerError::decode(MissingHeartbeatTimestamp));
334        };
335
336        // Completing workflow A's activity leaves workflow B's entry intact.
337        bookkeeper.remove(&key_a)?;
338        assert!(bookkeeper.last_heartbeat(&key_a).is_none());
339        assert_eq!(
340            bookkeeper.last_heartbeat(&key_b),
341            Some(b_before_a_completes)
342        );
343        Ok(())
344    }
345}