Skip to main content

durable_execution_sdk_testing/checkpoint_server/
callback_manager.rs

1//! Callback manager for managing callback lifecycle.
2//!
3//! This module implements the CallbackManager which manages callback lifecycle
4//! including timeouts and heartbeats, matching the Node.js SDK's callback manager.
5
6use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use durable_execution_sdk::ErrorObject;
10
11use crate::error::TestError;
12
13/// Status of a completed callback.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CompleteCallbackStatus {
16    /// Callback completed successfully
17    Success,
18    /// Callback completed with failure
19    Failure,
20    /// Callback timed out
21    TimedOut,
22}
23
24/// Internal state of a callback.
25#[derive(Debug, Clone)]
26pub struct CallbackState {
27    /// The callback ID
28    pub callback_id: String,
29    /// Optional timeout duration
30    pub timeout: Option<Duration>,
31    /// When the callback was registered
32    pub registered_at: Instant,
33    /// Last heartbeat time
34    pub last_heartbeat: Instant,
35    /// Completion status (None if still pending)
36    pub completion_status: Option<CompleteCallbackStatus>,
37    /// Result if completed successfully
38    pub result: Option<String>,
39    /// Error if completed with failure
40    pub error: Option<ErrorObject>,
41}
42
43impl CallbackState {
44    /// Create a new callback state.
45    pub fn new(callback_id: String, timeout: Option<Duration>) -> Self {
46        let now = Instant::now();
47        Self {
48            callback_id,
49            timeout,
50            registered_at: now,
51            last_heartbeat: now,
52            completion_status: None,
53            result: None,
54            error: None,
55        }
56    }
57
58    /// Check if this callback is completed.
59    pub fn is_completed(&self) -> bool {
60        self.completion_status.is_some()
61    }
62
63    /// Check if this callback has timed out.
64    pub fn is_timed_out(&self) -> bool {
65        if let Some(timeout) = self.timeout {
66            self.last_heartbeat.elapsed() > timeout
67        } else {
68            false
69        }
70    }
71}
72
73/// Manages callback lifecycle including timeouts and heartbeats.
74#[derive(Debug, Default)]
75pub struct CallbackManager {
76    /// The execution ID this manager belongs to
77    execution_id: String,
78    /// Map of callback ID to callback state
79    callbacks: HashMap<String, CallbackState>,
80}
81
82impl CallbackManager {
83    /// Create a new callback manager.
84    pub fn new(execution_id: &str) -> Self {
85        Self {
86            execution_id: execution_id.to_string(),
87            callbacks: HashMap::new(),
88        }
89    }
90
91    /// Register a new callback.
92    pub fn register_callback(
93        &mut self,
94        callback_id: &str,
95        timeout: Option<Duration>,
96    ) -> Result<(), TestError> {
97        if self.callbacks.contains_key(callback_id) {
98            return Err(TestError::CallbackAlreadyCompleted(format!(
99                "Callback {} already registered",
100                callback_id
101            )));
102        }
103
104        let state = CallbackState::new(callback_id.to_string(), timeout);
105        self.callbacks.insert(callback_id.to_string(), state);
106        Ok(())
107    }
108
109    /// Send callback success.
110    pub fn send_success(&mut self, callback_id: &str, result: &str) -> Result<(), TestError> {
111        let state = self
112            .callbacks
113            .get_mut(callback_id)
114            .ok_or_else(|| TestError::CallbackNotFound(callback_id.to_string()))?;
115
116        if state.is_completed() {
117            return Err(TestError::CallbackAlreadyCompleted(callback_id.to_string()));
118        }
119
120        state.completion_status = Some(CompleteCallbackStatus::Success);
121        state.result = Some(result.to_string());
122        Ok(())
123    }
124
125    /// Send callback failure.
126    pub fn send_failure(
127        &mut self,
128        callback_id: &str,
129        error: &ErrorObject,
130    ) -> Result<(), TestError> {
131        let state = self
132            .callbacks
133            .get_mut(callback_id)
134            .ok_or_else(|| TestError::CallbackNotFound(callback_id.to_string()))?;
135
136        if state.is_completed() {
137            return Err(TestError::CallbackAlreadyCompleted(callback_id.to_string()));
138        }
139
140        state.completion_status = Some(CompleteCallbackStatus::Failure);
141        state.error = Some(error.clone());
142        Ok(())
143    }
144
145    /// Send callback heartbeat.
146    pub fn send_heartbeat(&mut self, callback_id: &str) -> Result<(), TestError> {
147        let state = self
148            .callbacks
149            .get_mut(callback_id)
150            .ok_or_else(|| TestError::CallbackNotFound(callback_id.to_string()))?;
151
152        if state.is_completed() {
153            return Err(TestError::CallbackAlreadyCompleted(callback_id.to_string()));
154        }
155
156        state.last_heartbeat = Instant::now();
157        Ok(())
158    }
159
160    /// Check for timed out callbacks and mark them as timed out.
161    /// Returns the IDs of callbacks that timed out.
162    pub fn check_timeouts(&mut self) -> Vec<String> {
163        let mut timed_out = Vec::new();
164
165        for (id, state) in self.callbacks.iter_mut() {
166            if !state.is_completed() && state.is_timed_out() {
167                state.completion_status = Some(CompleteCallbackStatus::TimedOut);
168                timed_out.push(id.clone());
169            }
170        }
171
172        timed_out
173    }
174
175    /// Get callback status.
176    pub fn get_callback_status(&self, callback_id: &str) -> Option<CompleteCallbackStatus> {
177        self.callbacks
178            .get(callback_id)
179            .and_then(|s| s.completion_status)
180    }
181
182    /// Get callback state.
183    pub fn get_callback_state(&self, callback_id: &str) -> Option<&CallbackState> {
184        self.callbacks.get(callback_id)
185    }
186
187    /// Get the execution ID.
188    pub fn execution_id(&self) -> &str {
189        &self.execution_id
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196
197    #[test]
198    fn test_register_callback() {
199        let mut manager = CallbackManager::new("exec-1");
200        manager.register_callback("cb-1", None).unwrap();
201
202        let state = manager.get_callback_state("cb-1").unwrap();
203        assert_eq!(state.callback_id, "cb-1");
204        assert!(!state.is_completed());
205    }
206
207    #[test]
208    fn test_register_duplicate_callback_fails() {
209        let mut manager = CallbackManager::new("exec-1");
210        manager.register_callback("cb-1", None).unwrap();
211
212        let result = manager.register_callback("cb-1", None);
213        assert!(result.is_err());
214    }
215
216    #[test]
217    fn test_send_success() {
218        let mut manager = CallbackManager::new("exec-1");
219        manager.register_callback("cb-1", None).unwrap();
220        manager.send_success("cb-1", r#"{"result": "ok"}"#).unwrap();
221
222        let state = manager.get_callback_state("cb-1").unwrap();
223        assert!(state.is_completed());
224        assert_eq!(
225            state.completion_status,
226            Some(CompleteCallbackStatus::Success)
227        );
228        assert_eq!(state.result, Some(r#"{"result": "ok"}"#.to_string()));
229    }
230
231    #[test]
232    fn test_send_failure() {
233        let mut manager = CallbackManager::new("exec-1");
234        manager.register_callback("cb-1", None).unwrap();
235
236        let error = ErrorObject::new("TestError", "Something went wrong");
237        manager.send_failure("cb-1", &error).unwrap();
238
239        let state = manager.get_callback_state("cb-1").unwrap();
240        assert!(state.is_completed());
241        assert_eq!(
242            state.completion_status,
243            Some(CompleteCallbackStatus::Failure)
244        );
245        assert!(state.error.is_some());
246    }
247
248    #[test]
249    fn test_send_heartbeat() {
250        let mut manager = CallbackManager::new("exec-1");
251        manager
252            .register_callback("cb-1", Some(Duration::from_secs(60)))
253            .unwrap();
254
255        // Wait a tiny bit
256        std::thread::sleep(Duration::from_millis(10));
257
258        let before = manager.get_callback_state("cb-1").unwrap().last_heartbeat;
259        manager.send_heartbeat("cb-1").unwrap();
260        let after = manager.get_callback_state("cb-1").unwrap().last_heartbeat;
261
262        assert!(after > before);
263    }
264
265    #[test]
266    fn test_double_complete_fails() {
267        let mut manager = CallbackManager::new("exec-1");
268        manager.register_callback("cb-1", None).unwrap();
269        manager.send_success("cb-1", "result").unwrap();
270
271        let result = manager.send_success("cb-1", "another result");
272        assert!(result.is_err());
273    }
274
275    #[test]
276    fn test_callback_not_found() {
277        let mut manager = CallbackManager::new("exec-1");
278        let result = manager.send_success("nonexistent", "result");
279        assert!(matches!(result, Err(TestError::CallbackNotFound(_))));
280    }
281}