1use std::sync::{Arc, Mutex, OnceLock};
7use std::thread::{self, JoinHandle};
8
9use tokio::sync::{mpsc, oneshot};
10use uuid::Uuid;
11
12use durable_execution_sdk::{
13 CheckpointResponse, DurableError, DurableServiceClient, ErrorObject, GetOperationsResponse,
14 OperationUpdate,
15};
16
17use super::execution_manager::ExecutionManager;
18use super::types::{
19 ApiType, CheckpointWorkerParams, CompleteInvocationRequest, SendCallbackFailureRequest,
20 SendCallbackHeartbeatRequest, SendCallbackSuccessRequest, StartDurableExecutionRequest,
21 StartInvocationRequest, WorkerApiRequest, WorkerApiResponse, WorkerCommand, WorkerCommandType,
22 WorkerResponse, WorkerResponseType,
23};
24use crate::error::TestError;
25
26struct CheckpointServerState {
28 execution_manager: ExecutionManager,
29 params: CheckpointWorkerParams,
30}
31
32enum WorkerMessage {
34 Command(WorkerCommand, oneshot::Sender<WorkerResponse>),
36 Shutdown,
38}
39
40pub struct CheckpointWorkerManager {
42 command_tx: mpsc::Sender<WorkerMessage>,
44 worker_handle: Option<JoinHandle<()>>,
46 params: CheckpointWorkerParams,
48}
49
50static INSTANCE: OnceLock<Mutex<Option<Arc<CheckpointWorkerManager>>>> = OnceLock::new();
52
53impl CheckpointWorkerManager {
54 pub fn get_instance(
56 params: Option<CheckpointWorkerParams>,
57 ) -> Result<Arc<CheckpointWorkerManager>, TestError> {
58 let mutex = INSTANCE.get_or_init(|| Mutex::new(None));
59 let mut guard = mutex.lock().map_err(|e| {
60 TestError::CheckpointServerError(format!("Failed to lock instance: {}", e))
61 })?;
62
63 if let Some(ref instance) = *guard {
64 return Ok(Arc::clone(instance));
65 }
66
67 let params = params.unwrap_or_default();
68 let manager = Self::new(params)?;
69 let arc = Arc::new(manager);
70 *guard = Some(Arc::clone(&arc));
71 Ok(arc)
72 }
73
74 pub fn reset_instance_for_testing() {
80 if let Some(mutex) = INSTANCE.get() {
81 if let Ok(mut guard) = mutex.lock() {
82 if let Some(instance) = guard.take() {
83 if let Ok(manager) = Arc::try_unwrap(instance) {
85 if tokio::runtime::Handle::try_current().is_err() {
88 let _ = manager.shutdown_sync();
90 }
91 }
94 }
95 }
96 }
97 }
98
99 fn new(params: CheckpointWorkerParams) -> Result<Self, TestError> {
101 let (command_tx, command_rx) = mpsc::channel::<WorkerMessage>(100);
102
103 let worker_params = params.clone();
104 let worker_handle = thread::spawn(move || {
105 Self::run_worker(command_rx, worker_params);
106 });
107
108 Ok(Self {
109 command_tx,
110 worker_handle: Some(worker_handle),
111 params,
112 })
113 }
114
115 fn run_worker(mut command_rx: mpsc::Receiver<WorkerMessage>, params: CheckpointWorkerParams) {
117 let rt = tokio::runtime::Builder::new_current_thread()
119 .enable_all()
120 .build()
121 .expect("Failed to create tokio runtime for worker");
122
123 rt.block_on(async {
124 let mut state = CheckpointServerState {
125 execution_manager: ExecutionManager::new(),
126 params,
127 };
128
129 while let Some(message) = command_rx.recv().await {
130 match message {
131 WorkerMessage::Command(command, response_tx) => {
132 let response = Self::process_command(&mut state, command).await;
133 let _ = response_tx.send(response);
134 }
135 WorkerMessage::Shutdown => {
136 break;
137 }
138 }
139 }
140 });
141 }
142
143 async fn process_command(
145 state: &mut CheckpointServerState,
146 command: WorkerCommand,
147 ) -> WorkerResponse {
148 if let Some(delay_ms) = state.params.checkpoint_delay {
150 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
151 }
152
153 let api_response = match command.data.api_type {
154 ApiType::StartDurableExecution => Self::handle_start_execution(state, &command.data),
155 ApiType::StartInvocation => Self::handle_start_invocation(state, &command.data),
156 ApiType::CompleteInvocation => Self::handle_complete_invocation(state, &command.data),
157 ApiType::CheckpointDurableExecutionState => {
158 Self::handle_checkpoint(state, &command.data)
159 }
160 ApiType::GetDurableExecutionState => Self::handle_get_state(state, &command.data),
161 ApiType::SendDurableExecutionCallbackSuccess => {
162 Self::handle_callback_success(state, &command.data)
163 }
164 ApiType::SendDurableExecutionCallbackFailure => {
165 Self::handle_callback_failure(state, &command.data)
166 }
167 ApiType::SendDurableExecutionCallbackHeartbeat => {
168 Self::handle_callback_heartbeat(state, &command.data)
169 }
170 ApiType::UpdateCheckpointData => {
171 Self::handle_update_checkpoint_data(state, &command.data)
172 }
173 ApiType::GetNodeJsHistoryEvents => {
174 Self::handle_get_nodejs_history_events(state, &command.data)
175 }
176 _ => WorkerApiResponse::error(
177 command.data.api_type,
178 command.data.request_id.clone(),
179 format!("Unsupported API type: {:?}", command.data.api_type),
180 ),
181 };
182
183 WorkerResponse {
184 response_type: WorkerResponseType::ApiResponse,
185 data: api_response,
186 }
187 }
188
189 fn handle_start_execution(
191 state: &mut CheckpointServerState,
192 request: &WorkerApiRequest,
193 ) -> WorkerApiResponse {
194 let parsed: Result<StartDurableExecutionRequest, _> =
195 serde_json::from_str(&request.payload);
196
197 match parsed {
198 Ok(req) => {
199 let result = state.execution_manager.start_execution_from_request(req);
200 match serde_json::to_string(&result) {
201 Ok(payload) => WorkerApiResponse::success(
202 request.api_type,
203 request.request_id.clone(),
204 payload,
205 ),
206 Err(e) => WorkerApiResponse::error(
207 request.api_type,
208 request.request_id.clone(),
209 format!("Failed to serialize response: {}", e),
210 ),
211 }
212 }
213 Err(e) => WorkerApiResponse::error(
214 request.api_type,
215 request.request_id.clone(),
216 format!("Failed to parse request: {}", e),
217 ),
218 }
219 }
220
221 fn handle_start_invocation(
223 state: &mut CheckpointServerState,
224 request: &WorkerApiRequest,
225 ) -> WorkerApiResponse {
226 let parsed: Result<StartInvocationRequest, _> = serde_json::from_str(&request.payload);
227
228 match parsed {
229 Ok(req) => match state.execution_manager.start_invocation(req) {
230 Ok(result) => match serde_json::to_string(&result) {
231 Ok(payload) => WorkerApiResponse::success(
232 request.api_type,
233 request.request_id.clone(),
234 payload,
235 ),
236 Err(e) => WorkerApiResponse::error(
237 request.api_type,
238 request.request_id.clone(),
239 format!("Failed to serialize response: {}", e),
240 ),
241 },
242 Err(e) => WorkerApiResponse::error(
243 request.api_type,
244 request.request_id.clone(),
245 format!("{}", e),
246 ),
247 },
248 Err(e) => WorkerApiResponse::error(
249 request.api_type,
250 request.request_id.clone(),
251 format!("Failed to parse request: {}", e),
252 ),
253 }
254 }
255
256 fn handle_complete_invocation(
258 state: &mut CheckpointServerState,
259 request: &WorkerApiRequest,
260 ) -> WorkerApiResponse {
261 let parsed: Result<CompleteInvocationRequest, _> = serde_json::from_str(&request.payload);
262
263 match parsed {
264 Ok(req) => match state.execution_manager.complete_invocation(req) {
265 Ok(result) => match serde_json::to_string(&result) {
266 Ok(payload) => WorkerApiResponse::success(
267 request.api_type,
268 request.request_id.clone(),
269 payload,
270 ),
271 Err(e) => WorkerApiResponse::error(
272 request.api_type,
273 request.request_id.clone(),
274 format!("Failed to serialize response: {}", e),
275 ),
276 },
277 Err(e) => WorkerApiResponse::error(
278 request.api_type,
279 request.request_id.clone(),
280 format!("{}", e),
281 ),
282 },
283 Err(e) => WorkerApiResponse::error(
284 request.api_type,
285 request.request_id.clone(),
286 format!("Failed to parse request: {}", e),
287 ),
288 }
289 }
290
291 fn handle_checkpoint(
293 state: &mut CheckpointServerState,
294 request: &WorkerApiRequest,
295 ) -> WorkerApiResponse {
296 let parsed: Result<CheckpointDurableExecutionStateRequest, _> =
297 serde_json::from_str(&request.payload);
298
299 match parsed {
300 Ok(req) => {
301 match state
303 .execution_manager
304 .get_checkpoints_by_token_mut(&req.checkpoint_token)
305 {
306 Ok(Some(checkpoint_manager)) => {
307 match checkpoint_manager.process_checkpoint(req.operations) {
308 Ok(dirty_ops) => {
309 let response = serde_json::json!({
311 "CheckpointToken": req.checkpoint_token,
312 "NewExecutionState": {
313 "Operations": dirty_ops
314 }
315 });
316 match serde_json::to_string(&response) {
317 Ok(payload) => WorkerApiResponse::success(
318 request.api_type,
319 request.request_id.clone(),
320 payload,
321 ),
322 Err(e) => WorkerApiResponse::error(
323 request.api_type,
324 request.request_id.clone(),
325 format!("Failed to serialize response: {}", e),
326 ),
327 }
328 }
329 Err(e) => WorkerApiResponse::error(
330 request.api_type,
331 request.request_id.clone(),
332 format!("Checkpoint processing failed: {}", e),
333 ),
334 }
335 }
336 Ok(None) => WorkerApiResponse::error(
337 request.api_type,
338 request.request_id.clone(),
339 "Execution not found for checkpoint token".to_string(),
340 ),
341 Err(e) => WorkerApiResponse::error(
342 request.api_type,
343 request.request_id.clone(),
344 format!("Invalid checkpoint token: {}", e),
345 ),
346 }
347 }
348 Err(e) => WorkerApiResponse::error(
349 request.api_type,
350 request.request_id.clone(),
351 format!("Failed to parse request: {}", e),
352 ),
353 }
354 }
355
356 fn handle_get_state(
358 state: &mut CheckpointServerState,
359 request: &WorkerApiRequest,
360 ) -> WorkerApiResponse {
361 let parsed: Result<GetDurableExecutionStateRequest, _> =
362 serde_json::from_str(&request.payload);
363
364 match parsed {
365 Ok(req) => {
366 let execution_id = &req.durable_execution_arn;
368 match state
369 .execution_manager
370 .get_checkpoints_by_execution(execution_id)
371 {
372 Some(checkpoint_manager) => {
373 let operations = checkpoint_manager.get_state();
374 let response = GetOperationsResponse {
375 operations,
376 next_marker: None,
377 };
378 match serde_json::to_string(&response) {
379 Ok(payload) => WorkerApiResponse::success(
380 request.api_type,
381 request.request_id.clone(),
382 payload,
383 ),
384 Err(e) => WorkerApiResponse::error(
385 request.api_type,
386 request.request_id.clone(),
387 format!("Failed to serialize response: {}", e),
388 ),
389 }
390 }
391 None => WorkerApiResponse::error(
392 request.api_type,
393 request.request_id.clone(),
394 format!("Execution not found: {}", execution_id),
395 ),
396 }
397 }
398 Err(e) => WorkerApiResponse::error(
399 request.api_type,
400 request.request_id.clone(),
401 format!("Failed to parse request: {}", e),
402 ),
403 }
404 }
405
406 fn handle_callback_success(
408 state: &mut CheckpointServerState,
409 request: &WorkerApiRequest,
410 ) -> WorkerApiResponse {
411 let parsed: Result<SendCallbackSuccessRequest, _> = serde_json::from_str(&request.payload);
412
413 match parsed {
414 Ok(req) => {
415 let execution_ids: Vec<String> = state
417 .execution_manager
418 .get_execution_ids()
419 .into_iter()
420 .cloned()
421 .collect();
422
423 for execution_id in execution_ids {
425 if let Some(checkpoint_manager) = state
426 .execution_manager
427 .get_checkpoints_by_execution_mut(&execution_id)
428 {
429 if checkpoint_manager
430 .callback_manager()
431 .get_callback_state(&req.callback_id)
432 .is_some()
433 {
434 match checkpoint_manager
435 .callback_manager_mut()
436 .send_success(&req.callback_id, &req.result)
437 {
438 Ok(()) => {
439 checkpoint_manager.complete_callback_operation(
442 &req.callback_id,
443 Some(req.result.clone()),
444 None,
445 );
446 return WorkerApiResponse::success(
447 request.api_type,
448 request.request_id.clone(),
449 "{}".to_string(),
450 );
451 }
452 Err(e) => {
453 return WorkerApiResponse::error(
454 request.api_type,
455 request.request_id.clone(),
456 format!("{}", e),
457 );
458 }
459 }
460 }
461 }
462 }
463 WorkerApiResponse::error(
464 request.api_type,
465 request.request_id.clone(),
466 format!("Callback not found: {}", req.callback_id),
467 )
468 }
469 Err(e) => WorkerApiResponse::error(
470 request.api_type,
471 request.request_id.clone(),
472 format!("Failed to parse request: {}", e),
473 ),
474 }
475 }
476
477 fn handle_callback_failure(
479 state: &mut CheckpointServerState,
480 request: &WorkerApiRequest,
481 ) -> WorkerApiResponse {
482 let parsed: Result<SendCallbackFailureRequest, _> = serde_json::from_str(&request.payload);
483
484 match parsed {
485 Ok(req) => {
486 let execution_ids: Vec<String> = state
488 .execution_manager
489 .get_execution_ids()
490 .into_iter()
491 .cloned()
492 .collect();
493
494 for execution_id in execution_ids {
496 if let Some(checkpoint_manager) = state
497 .execution_manager
498 .get_checkpoints_by_execution_mut(&execution_id)
499 {
500 if checkpoint_manager
501 .callback_manager()
502 .get_callback_state(&req.callback_id)
503 .is_some()
504 {
505 match checkpoint_manager
506 .callback_manager_mut()
507 .send_failure(&req.callback_id, &req.error)
508 {
509 Ok(()) => {
510 checkpoint_manager.complete_callback_operation(
513 &req.callback_id,
514 None,
515 Some(req.error.clone()),
516 );
517 return WorkerApiResponse::success(
518 request.api_type,
519 request.request_id.clone(),
520 "{}".to_string(),
521 );
522 }
523 Err(e) => {
524 return WorkerApiResponse::error(
525 request.api_type,
526 request.request_id.clone(),
527 format!("{}", e),
528 );
529 }
530 }
531 }
532 }
533 }
534 WorkerApiResponse::error(
535 request.api_type,
536 request.request_id.clone(),
537 format!("Callback not found: {}", req.callback_id),
538 )
539 }
540 Err(e) => WorkerApiResponse::error(
541 request.api_type,
542 request.request_id.clone(),
543 format!("Failed to parse request: {}", e),
544 ),
545 }
546 }
547
548 fn handle_callback_heartbeat(
550 state: &mut CheckpointServerState,
551 request: &WorkerApiRequest,
552 ) -> WorkerApiResponse {
553 let parsed: Result<SendCallbackHeartbeatRequest, _> =
554 serde_json::from_str(&request.payload);
555
556 match parsed {
557 Ok(req) => {
558 let execution_ids: Vec<String> = state
560 .execution_manager
561 .get_execution_ids()
562 .into_iter()
563 .cloned()
564 .collect();
565
566 for execution_id in execution_ids {
568 if let Some(checkpoint_manager) = state
569 .execution_manager
570 .get_checkpoints_by_execution_mut(&execution_id)
571 {
572 if checkpoint_manager
573 .callback_manager()
574 .get_callback_state(&req.callback_id)
575 .is_some()
576 {
577 match checkpoint_manager
578 .callback_manager_mut()
579 .send_heartbeat(&req.callback_id)
580 {
581 Ok(()) => {
582 return WorkerApiResponse::success(
583 request.api_type,
584 request.request_id.clone(),
585 "{}".to_string(),
586 );
587 }
588 Err(e) => {
589 return WorkerApiResponse::error(
590 request.api_type,
591 request.request_id.clone(),
592 format!("{}", e),
593 );
594 }
595 }
596 }
597 }
598 }
599 WorkerApiResponse::error(
600 request.api_type,
601 request.request_id.clone(),
602 format!("Callback not found: {}", req.callback_id),
603 )
604 }
605 Err(e) => WorkerApiResponse::error(
606 request.api_type,
607 request.request_id.clone(),
608 format!("Failed to parse request: {}", e),
609 ),
610 }
611 }
612
613 fn handle_update_checkpoint_data(
619 state: &mut CheckpointServerState,
620 request: &WorkerApiRequest,
621 ) -> WorkerApiResponse {
622 let parsed: Result<super::types::UpdateCheckpointDataRequest, _> =
623 serde_json::from_str(&request.payload);
624
625 match parsed {
626 Ok(req) => {
627 match state
629 .execution_manager
630 .get_checkpoints_by_execution_mut(&req.execution_id)
631 {
632 Some(checkpoint_manager) => {
633 checkpoint_manager
635 .update_operation_data(&req.operation_id, req.operation_data);
636 WorkerApiResponse::success(
637 request.api_type,
638 request.request_id.clone(),
639 "{}".to_string(),
640 )
641 }
642 None => WorkerApiResponse::error(
643 request.api_type,
644 request.request_id.clone(),
645 format!("Execution not found: {}", req.execution_id),
646 ),
647 }
648 }
649 Err(e) => WorkerApiResponse::error(
650 request.api_type,
651 request.request_id.clone(),
652 format!("Failed to parse request: {}", e),
653 ),
654 }
655 }
656
657 fn handle_get_nodejs_history_events(
663 state: &mut CheckpointServerState,
664 request: &WorkerApiRequest,
665 ) -> WorkerApiResponse {
666 let parsed: Result<super::types::GetNodeJsHistoryEventsRequest, _> =
667 serde_json::from_str(&request.payload);
668
669 match parsed {
670 Ok(req) => {
671 match state
673 .execution_manager
674 .get_checkpoints_by_execution(&req.execution_id)
675 {
676 Some(checkpoint_manager) => {
677 let events = checkpoint_manager.get_nodejs_history_events();
679 match serde_json::to_string(&events) {
680 Ok(payload) => WorkerApiResponse::success(
681 request.api_type,
682 request.request_id.clone(),
683 payload,
684 ),
685 Err(e) => WorkerApiResponse::error(
686 request.api_type,
687 request.request_id.clone(),
688 format!("Failed to serialize response: {}", e),
689 ),
690 }
691 }
692 None => WorkerApiResponse::error(
693 request.api_type,
694 request.request_id.clone(),
695 format!("Execution not found: {}", req.execution_id),
696 ),
697 }
698 }
699 Err(e) => WorkerApiResponse::error(
700 request.api_type,
701 request.request_id.clone(),
702 format!("Failed to parse request: {}", e),
703 ),
704 }
705 }
706
707 pub async fn send_api_request(
709 &self,
710 api_type: ApiType,
711 payload: String,
712 ) -> Result<WorkerApiResponse, TestError> {
713 let request_id = Uuid::new_v4().to_string();
714 let command = WorkerCommand {
715 command_type: WorkerCommandType::ApiRequest,
716 data: WorkerApiRequest {
717 api_type,
718 request_id: request_id.clone(),
719 payload,
720 },
721 };
722
723 let (response_tx, response_rx) = oneshot::channel();
724
725 self.command_tx
726 .send(WorkerMessage::Command(command, response_tx))
727 .await
728 .map_err(|e| {
729 TestError::CheckpointCommunicationError(format!("Failed to send command: {}", e))
730 })?;
731
732 let response = response_rx.await.map_err(|e| {
733 TestError::CheckpointCommunicationError(format!("Failed to receive response: {}", e))
734 })?;
735
736 Ok(response.data)
737 }
738
739 pub async fn shutdown(mut self) -> Result<(), TestError> {
741 let _ = self.command_tx.send(WorkerMessage::Shutdown).await;
743
744 if let Some(handle) = self.worker_handle.take() {
746 handle.join().map_err(|_| {
747 TestError::CheckpointServerError("Worker thread panicked".to_string())
748 })?;
749 }
750
751 Ok(())
752 }
753
754 pub fn shutdown_sync(mut self) -> Result<(), TestError> {
756 if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
758 .enable_all()
759 .build()
760 {
761 let _ = rt.block_on(self.command_tx.send(WorkerMessage::Shutdown));
762 }
763
764 if let Some(handle) = self.worker_handle.take() {
766 handle.join().map_err(|_| {
767 TestError::CheckpointServerError("Worker thread panicked".to_string())
768 })?;
769 }
770
771 Ok(())
772 }
773
774 pub fn params(&self) -> &CheckpointWorkerParams {
776 &self.params
777 }
778}
779
780use super::types::{CheckpointDurableExecutionStateRequest, GetDurableExecutionStateRequest};
782
783use async_trait::async_trait;
785
786#[async_trait]
787impl DurableServiceClient for CheckpointWorkerManager {
788 async fn checkpoint(
789 &self,
790 durable_execution_arn: &str,
791 checkpoint_token: &str,
792 operations: Vec<OperationUpdate>,
793 ) -> Result<CheckpointResponse, DurableError> {
794 let request = CheckpointDurableExecutionStateRequest {
795 durable_execution_arn: durable_execution_arn.to_string(),
796 checkpoint_token: checkpoint_token.to_string(),
797 operations,
798 };
799
800 let payload = serde_json::to_string(&request)
801 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
802
803 let response = self
804 .send_api_request(ApiType::CheckpointDurableExecutionState, payload)
805 .await
806 .map_err(|e| {
807 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
808 })?;
809
810 if let Some(error) = response.error {
811 return Err(DurableError::checkpoint_retriable(error));
812 }
813
814 let payload = response
815 .payload
816 .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
817
818 serde_json::from_str(&payload)
819 .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
820 }
821
822 async fn get_operations(
823 &self,
824 durable_execution_arn: &str,
825 _next_marker: &str,
826 ) -> Result<GetOperationsResponse, DurableError> {
827 let request = GetDurableExecutionStateRequest {
828 durable_execution_arn: durable_execution_arn.to_string(),
829 };
830
831 let payload = serde_json::to_string(&request)
832 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
833
834 let response = self
835 .send_api_request(ApiType::GetDurableExecutionState, payload)
836 .await
837 .map_err(|e| {
838 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
839 })?;
840
841 if let Some(error) = response.error {
842 return Err(DurableError::checkpoint_retriable(error));
843 }
844
845 let payload = response
846 .payload
847 .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
848
849 serde_json::from_str(&payload)
850 .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
851 }
852}
853
854impl CheckpointWorkerManager {
856 pub async fn send_callback_success(
858 &self,
859 callback_id: &str,
860 result: &str,
861 ) -> Result<(), DurableError> {
862 let request = SendCallbackSuccessRequest {
863 callback_id: callback_id.to_string(),
864 result: result.to_string(),
865 };
866
867 let payload = serde_json::to_string(&request)
868 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
869
870 let response = self
871 .send_api_request(ApiType::SendDurableExecutionCallbackSuccess, payload)
872 .await
873 .map_err(|e| {
874 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
875 })?;
876
877 if let Some(error) = response.error {
878 return Err(DurableError::execution(error));
879 }
880
881 Ok(())
882 }
883
884 pub async fn send_callback_failure(
886 &self,
887 callback_id: &str,
888 error: &ErrorObject,
889 ) -> Result<(), DurableError> {
890 let request = SendCallbackFailureRequest {
891 callback_id: callback_id.to_string(),
892 error: error.clone(),
893 };
894
895 let payload = serde_json::to_string(&request)
896 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
897
898 let response = self
899 .send_api_request(ApiType::SendDurableExecutionCallbackFailure, payload)
900 .await
901 .map_err(|e| {
902 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
903 })?;
904
905 if let Some(error) = response.error {
906 return Err(DurableError::execution(error));
907 }
908
909 Ok(())
910 }
911
912 pub async fn send_callback_heartbeat(&self, callback_id: &str) -> Result<(), DurableError> {
914 let request = SendCallbackHeartbeatRequest {
915 callback_id: callback_id.to_string(),
916 };
917
918 let payload = serde_json::to_string(&request)
919 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
920
921 let response = self
922 .send_api_request(ApiType::SendDurableExecutionCallbackHeartbeat, payload)
923 .await
924 .map_err(|e| {
925 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
926 })?;
927
928 if let Some(error) = response.error {
929 return Err(DurableError::execution(error));
930 }
931
932 Ok(())
933 }
934
935 pub async fn get_nodejs_history_events(
949 &self,
950 execution_id: &str,
951 ) -> Result<Vec<super::nodejs_event_types::NodeJsHistoryEvent>, DurableError> {
952 use super::types::GetNodeJsHistoryEventsRequest;
953
954 let request = GetNodeJsHistoryEventsRequest {
955 execution_id: execution_id.to_string(),
956 };
957
958 let payload = serde_json::to_string(&request)
959 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
960
961 let response = self
962 .send_api_request(ApiType::GetNodeJsHistoryEvents, payload)
963 .await
964 .map_err(|e| {
965 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
966 })?;
967
968 if let Some(error) = response.error {
969 return Err(DurableError::checkpoint_retriable(error));
970 }
971
972 let payload = response
973 .payload
974 .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
975
976 serde_json::from_str(&payload)
977 .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
978 }
979}
980
981#[cfg(test)]
982mod tests {
983 use super::*;
984 use crate::checkpoint_server::InvocationResult;
985
986 #[tokio::test]
987 async fn test_get_instance() {
988 CheckpointWorkerManager::reset_instance_for_testing();
989
990 let instance = CheckpointWorkerManager::get_instance(None).unwrap();
991 assert!(instance.params().checkpoint_delay.is_none());
992
993 let instance2 = CheckpointWorkerManager::get_instance(None).unwrap();
995 assert!(Arc::ptr_eq(&instance, &instance2));
996
997 CheckpointWorkerManager::reset_instance_for_testing();
998 }
999
1000 #[tokio::test]
1001 async fn test_start_execution() {
1002 CheckpointWorkerManager::reset_instance_for_testing();
1003
1004 let manager = CheckpointWorkerManager::get_instance(None).unwrap();
1005
1006 let request = StartDurableExecutionRequest {
1007 invocation_id: "inv-1".to_string(),
1008 payload: Some(r#"{"test": true}"#.to_string()),
1009 };
1010
1011 let payload = serde_json::to_string(&request).unwrap();
1012 let response = manager
1013 .send_api_request(ApiType::StartDurableExecution, payload)
1014 .await
1015 .unwrap();
1016
1017 assert!(!response.is_error());
1018 assert!(response.payload.is_some());
1019
1020 let result: InvocationResult = serde_json::from_str(&response.payload.unwrap()).unwrap();
1021 assert!(!result.execution_id.is_empty());
1022 assert_eq!(result.invocation_id, "inv-1");
1023
1024 CheckpointWorkerManager::reset_instance_for_testing();
1025 }
1026
1027 #[tokio::test]
1028 async fn test_checkpoint_workflow() {
1029 CheckpointWorkerManager::reset_instance_for_testing();
1030
1031 let manager = CheckpointWorkerManager::get_instance(None).unwrap();
1032
1033 let start_request = StartDurableExecutionRequest {
1035 invocation_id: "inv-1".to_string(),
1036 payload: Some("{}".to_string()),
1037 };
1038 let payload = serde_json::to_string(&start_request).unwrap();
1039 let response = manager
1040 .send_api_request(ApiType::StartDurableExecution, payload)
1041 .await
1042 .unwrap();
1043
1044 let result: InvocationResult = serde_json::from_str(&response.payload.unwrap()).unwrap();
1045
1046 let checkpoint_request = CheckpointDurableExecutionStateRequest {
1048 durable_execution_arn: result.execution_id.clone(),
1049 checkpoint_token: result.checkpoint_token.clone(),
1050 operations: vec![OperationUpdate::start(
1051 "op-1",
1052 durable_execution_sdk::OperationType::Step,
1053 )
1054 .with_name("test-step")],
1055 };
1056
1057 let payload = serde_json::to_string(&checkpoint_request).unwrap();
1058 let response = manager
1059 .send_api_request(ApiType::CheckpointDurableExecutionState, payload)
1060 .await
1061 .unwrap();
1062
1063 assert!(!response.is_error());
1064
1065 CheckpointWorkerManager::reset_instance_for_testing();
1066 }
1067}