durable_execution_sdk_testing/checkpoint_server/
callback_manager.rs1use std::collections::HashMap;
7use std::time::{Duration, Instant};
8
9use durable_execution_sdk::ErrorObject;
10
11use crate::error::TestError;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum CompleteCallbackStatus {
16 Success,
18 Failure,
20 TimedOut,
22}
23
24#[derive(Debug, Clone)]
26pub struct CallbackState {
27 pub callback_id: String,
29 pub timeout: Option<Duration>,
31 pub registered_at: Instant,
33 pub last_heartbeat: Instant,
35 pub completion_status: Option<CompleteCallbackStatus>,
37 pub result: Option<String>,
39 pub error: Option<ErrorObject>,
41}
42
43impl CallbackState {
44 pub fn new(callback_id: String, timeout: Option<Duration>) -> Self {
46 let now = Instant::now();
47 Self {
48 callback_id,
49 timeout,
50 registered_at: now,
51 last_heartbeat: now,
52 completion_status: None,
53 result: None,
54 error: None,
55 }
56 }
57
58 pub fn is_completed(&self) -> bool {
60 self.completion_status.is_some()
61 }
62
63 pub fn is_timed_out(&self) -> bool {
65 if let Some(timeout) = self.timeout {
66 self.last_heartbeat.elapsed() > timeout
67 } else {
68 false
69 }
70 }
71}
72
73#[derive(Debug, Default)]
75pub struct CallbackManager {
76 execution_id: String,
78 callbacks: HashMap<String, CallbackState>,
80}
81
82impl CallbackManager {
83 pub fn new(execution_id: &str) -> Self {
85 Self {
86 execution_id: execution_id.to_string(),
87 callbacks: HashMap::new(),
88 }
89 }
90
91 pub fn register_callback(
93 &mut self,
94 callback_id: &str,
95 timeout: Option<Duration>,
96 ) -> Result<(), TestError> {
97 if self.callbacks.contains_key(callback_id) {
98 return Err(TestError::CallbackAlreadyCompleted(format!(
99 "Callback {} already registered",
100 callback_id
101 )));
102 }
103
104 let state = CallbackState::new(callback_id.to_string(), timeout);
105 self.callbacks.insert(callback_id.to_string(), state);
106 Ok(())
107 }
108
109 pub fn send_success(&mut self, callback_id: &str, result: &str) -> Result<(), TestError> {
111 let state = self
112 .callbacks
113 .get_mut(callback_id)
114 .ok_or_else(|| TestError::CallbackNotFound(callback_id.to_string()))?;
115
116 if state.is_completed() {
117 return Err(TestError::CallbackAlreadyCompleted(callback_id.to_string()));
118 }
119
120 state.completion_status = Some(CompleteCallbackStatus::Success);
121 state.result = Some(result.to_string());
122 Ok(())
123 }
124
125 pub fn send_failure(
127 &mut self,
128 callback_id: &str,
129 error: &ErrorObject,
130 ) -> Result<(), TestError> {
131 let state = self
132 .callbacks
133 .get_mut(callback_id)
134 .ok_or_else(|| TestError::CallbackNotFound(callback_id.to_string()))?;
135
136 if state.is_completed() {
137 return Err(TestError::CallbackAlreadyCompleted(callback_id.to_string()));
138 }
139
140 state.completion_status = Some(CompleteCallbackStatus::Failure);
141 state.error = Some(error.clone());
142 Ok(())
143 }
144
145 pub fn send_heartbeat(&mut self, callback_id: &str) -> Result<(), TestError> {
147 let state = self
148 .callbacks
149 .get_mut(callback_id)
150 .ok_or_else(|| TestError::CallbackNotFound(callback_id.to_string()))?;
151
152 if state.is_completed() {
153 return Err(TestError::CallbackAlreadyCompleted(callback_id.to_string()));
154 }
155
156 state.last_heartbeat = Instant::now();
157 Ok(())
158 }
159
160 pub fn check_timeouts(&mut self) -> Vec<String> {
163 let mut timed_out = Vec::new();
164
165 for (id, state) in self.callbacks.iter_mut() {
166 if !state.is_completed() && state.is_timed_out() {
167 state.completion_status = Some(CompleteCallbackStatus::TimedOut);
168 timed_out.push(id.clone());
169 }
170 }
171
172 timed_out
173 }
174
175 pub fn get_callback_status(&self, callback_id: &str) -> Option<CompleteCallbackStatus> {
177 self.callbacks
178 .get(callback_id)
179 .and_then(|s| s.completion_status)
180 }
181
182 pub fn get_callback_state(&self, callback_id: &str) -> Option<&CallbackState> {
184 self.callbacks.get(callback_id)
185 }
186
187 pub fn execution_id(&self) -> &str {
189 &self.execution_id
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196
197 #[test]
198 fn test_register_callback() {
199 let mut manager = CallbackManager::new("exec-1");
200 manager.register_callback("cb-1", None).unwrap();
201
202 let state = manager.get_callback_state("cb-1").unwrap();
203 assert_eq!(state.callback_id, "cb-1");
204 assert!(!state.is_completed());
205 }
206
207 #[test]
208 fn test_register_duplicate_callback_fails() {
209 let mut manager = CallbackManager::new("exec-1");
210 manager.register_callback("cb-1", None).unwrap();
211
212 let result = manager.register_callback("cb-1", None);
213 assert!(result.is_err());
214 }
215
216 #[test]
217 fn test_send_success() {
218 let mut manager = CallbackManager::new("exec-1");
219 manager.register_callback("cb-1", None).unwrap();
220 manager.send_success("cb-1", r#"{"result": "ok"}"#).unwrap();
221
222 let state = manager.get_callback_state("cb-1").unwrap();
223 assert!(state.is_completed());
224 assert_eq!(
225 state.completion_status,
226 Some(CompleteCallbackStatus::Success)
227 );
228 assert_eq!(state.result, Some(r#"{"result": "ok"}"#.to_string()));
229 }
230
231 #[test]
232 fn test_send_failure() {
233 let mut manager = CallbackManager::new("exec-1");
234 manager.register_callback("cb-1", None).unwrap();
235
236 let error = ErrorObject::new("TestError", "Something went wrong");
237 manager.send_failure("cb-1", &error).unwrap();
238
239 let state = manager.get_callback_state("cb-1").unwrap();
240 assert!(state.is_completed());
241 assert_eq!(
242 state.completion_status,
243 Some(CompleteCallbackStatus::Failure)
244 );
245 assert!(state.error.is_some());
246 }
247
248 #[test]
249 fn test_send_heartbeat() {
250 let mut manager = CallbackManager::new("exec-1");
251 manager
252 .register_callback("cb-1", Some(Duration::from_secs(60)))
253 .unwrap();
254
255 std::thread::sleep(Duration::from_millis(10));
257
258 let before = manager.get_callback_state("cb-1").unwrap().last_heartbeat;
259 manager.send_heartbeat("cb-1").unwrap();
260 let after = manager.get_callback_state("cb-1").unwrap().last_heartbeat;
261
262 assert!(after > before);
263 }
264
265 #[test]
266 fn test_double_complete_fails() {
267 let mut manager = CallbackManager::new("exec-1");
268 manager.register_callback("cb-1", None).unwrap();
269 manager.send_success("cb-1", "result").unwrap();
270
271 let result = manager.send_success("cb-1", "another result");
272 assert!(result.is_err());
273 }
274
275 #[test]
276 fn test_callback_not_found() {
277 let mut manager = CallbackManager::new("exec-1");
278 let result = manager.send_success("nonexistent", "result");
279 assert!(matches!(result, Err(TestError::CallbackNotFound(_))));
280 }
281}