1use std::sync::{Arc, Mutex, OnceLock};
7use std::thread::{self, JoinHandle};
8
9use tokio::sync::{mpsc, oneshot};
10use uuid::Uuid;
11
12use durable_execution_sdk::{
13 CheckpointResponse, DurableError, DurableServiceClient, ErrorObject, GetOperationsResponse,
14 OperationUpdate,
15};
16
17use super::execution_manager::ExecutionManager;
18use super::types::{
19 ApiType, CheckpointWorkerParams, CompleteInvocationRequest, SendCallbackFailureRequest,
20 SendCallbackHeartbeatRequest, SendCallbackSuccessRequest, StartDurableExecutionRequest,
21 StartInvocationRequest, WorkerApiRequest, WorkerApiResponse, WorkerCommand, WorkerCommandType,
22 WorkerResponse, WorkerResponseType,
23};
24use crate::error::TestError;
25
26struct CheckpointServerState {
28 execution_manager: ExecutionManager,
29 params: CheckpointWorkerParams,
30}
31
32enum WorkerMessage {
34 Command(WorkerCommand, oneshot::Sender<WorkerResponse>),
36 Shutdown,
38}
39
40pub struct CheckpointWorkerManager {
42 command_tx: mpsc::Sender<WorkerMessage>,
44 worker_handle: Option<JoinHandle<()>>,
46 params: CheckpointWorkerParams,
48}
49
50static INSTANCE: OnceLock<Mutex<Option<Arc<CheckpointWorkerManager>>>> = OnceLock::new();
52
53impl CheckpointWorkerManager {
54 pub fn get_instance(
60 params: Option<CheckpointWorkerParams>,
61 ) -> Result<Arc<CheckpointWorkerManager>, TestError> {
62 let mutex = INSTANCE.get_or_init(|| Mutex::new(None));
63 let mut guard = mutex.lock().map_err(|e| {
64 TestError::CheckpointServerError(format!("Failed to lock instance: {}", e))
65 })?;
66
67 if let Some(ref instance) = *guard {
68 if let Some(ref requested) = params {
70 if *requested != instance.params {
71 return Err(TestError::CheckpointServerError(format!(
72 "CheckpointWorkerManager singleton already exists with different params. \
73 Existing: {:?}, Requested: {:?}. \
74 Call reset_instance_for_testing() first to reconfigure.",
75 instance.params, requested
76 )));
77 }
78 }
79 return Ok(Arc::clone(instance));
80 }
81
82 let params = params.unwrap_or_default();
83 let manager = Self::new(params)?;
84 let arc = Arc::new(manager);
85 *guard = Some(Arc::clone(&arc));
86 Ok(arc)
87 }
88
89 pub fn reset_instance_for_testing() {
95 if let Some(mutex) = INSTANCE.get() {
96 if let Ok(mut guard) = mutex.lock() {
97 if let Some(instance) = guard.take() {
98 if let Ok(manager) = Arc::try_unwrap(instance) {
100 if tokio::runtime::Handle::try_current().is_err() {
103 let _ = manager.shutdown_sync();
105 }
106 }
109 }
110 }
111 }
112 }
113
114 fn new(params: CheckpointWorkerParams) -> Result<Self, TestError> {
116 let (command_tx, command_rx) = mpsc::channel::<WorkerMessage>(100);
117
118 let worker_params = params.clone();
119 let worker_handle = thread::spawn(move || {
120 Self::run_worker(command_rx, worker_params);
121 });
122
123 Ok(Self {
124 command_tx,
125 worker_handle: Some(worker_handle),
126 params,
127 })
128 }
129
130 fn run_worker(mut command_rx: mpsc::Receiver<WorkerMessage>, params: CheckpointWorkerParams) {
132 let rt = tokio::runtime::Builder::new_current_thread()
134 .enable_all()
135 .build()
136 .expect("Failed to create tokio runtime for worker");
137
138 rt.block_on(async {
139 let mut state = CheckpointServerState {
140 execution_manager: ExecutionManager::new(),
141 params,
142 };
143
144 while let Some(message) = command_rx.recv().await {
145 match message {
146 WorkerMessage::Command(command, response_tx) => {
147 let response = Self::process_command(&mut state, command).await;
148 let _ = response_tx.send(response);
149 }
150 WorkerMessage::Shutdown => {
151 break;
152 }
153 }
154 }
155 });
156 }
157
158 async fn process_command(
160 state: &mut CheckpointServerState,
161 command: WorkerCommand,
162 ) -> WorkerResponse {
163 if let Some(delay_ms) = state.params.checkpoint_delay {
165 tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
166 }
167
168 let api_response = match command.data.api_type {
169 ApiType::StartDurableExecution => Self::handle_start_execution(state, &command.data),
170 ApiType::StartInvocation => Self::handle_start_invocation(state, &command.data),
171 ApiType::CompleteInvocation => Self::handle_complete_invocation(state, &command.data),
172 ApiType::CheckpointDurableExecutionState => {
173 Self::handle_checkpoint(state, &command.data)
174 }
175 ApiType::GetDurableExecutionState => Self::handle_get_state(state, &command.data),
176 ApiType::SendDurableExecutionCallbackSuccess => {
177 Self::handle_callback_success(state, &command.data)
178 }
179 ApiType::SendDurableExecutionCallbackFailure => {
180 Self::handle_callback_failure(state, &command.data)
181 }
182 ApiType::SendDurableExecutionCallbackHeartbeat => {
183 Self::handle_callback_heartbeat(state, &command.data)
184 }
185 ApiType::UpdateCheckpointData => {
186 Self::handle_update_checkpoint_data(state, &command.data)
187 }
188 ApiType::GetNodeJsHistoryEvents => {
189 Self::handle_get_nodejs_history_events(state, &command.data)
190 }
191 _ => WorkerApiResponse::error(
192 command.data.api_type,
193 command.data.request_id.clone(),
194 format!("Unsupported API type: {:?}", command.data.api_type),
195 ),
196 };
197
198 WorkerResponse {
199 response_type: WorkerResponseType::ApiResponse,
200 data: api_response,
201 }
202 }
203
204 fn handle_start_execution(
206 state: &mut CheckpointServerState,
207 request: &WorkerApiRequest,
208 ) -> WorkerApiResponse {
209 let parsed: Result<StartDurableExecutionRequest, _> =
210 serde_json::from_str(&request.payload);
211
212 match parsed {
213 Ok(req) => {
214 let result = state.execution_manager.start_execution_from_request(req);
215 match serde_json::to_string(&result) {
216 Ok(payload) => WorkerApiResponse::success(
217 request.api_type,
218 request.request_id.clone(),
219 payload,
220 ),
221 Err(e) => WorkerApiResponse::error(
222 request.api_type,
223 request.request_id.clone(),
224 format!("Failed to serialize response: {}", e),
225 ),
226 }
227 }
228 Err(e) => WorkerApiResponse::error(
229 request.api_type,
230 request.request_id.clone(),
231 format!("Failed to parse request: {}", e),
232 ),
233 }
234 }
235
236 fn handle_start_invocation(
238 state: &mut CheckpointServerState,
239 request: &WorkerApiRequest,
240 ) -> WorkerApiResponse {
241 let parsed: Result<StartInvocationRequest, _> = serde_json::from_str(&request.payload);
242
243 match parsed {
244 Ok(req) => match state.execution_manager.start_invocation(req) {
245 Ok(result) => match serde_json::to_string(&result) {
246 Ok(payload) => WorkerApiResponse::success(
247 request.api_type,
248 request.request_id.clone(),
249 payload,
250 ),
251 Err(e) => WorkerApiResponse::error(
252 request.api_type,
253 request.request_id.clone(),
254 format!("Failed to serialize response: {}", e),
255 ),
256 },
257 Err(e) => WorkerApiResponse::error(
258 request.api_type,
259 request.request_id.clone(),
260 format!("{}", e),
261 ),
262 },
263 Err(e) => WorkerApiResponse::error(
264 request.api_type,
265 request.request_id.clone(),
266 format!("Failed to parse request: {}", e),
267 ),
268 }
269 }
270
271 fn handle_complete_invocation(
273 state: &mut CheckpointServerState,
274 request: &WorkerApiRequest,
275 ) -> WorkerApiResponse {
276 let parsed: Result<CompleteInvocationRequest, _> = serde_json::from_str(&request.payload);
277
278 match parsed {
279 Ok(req) => match state.execution_manager.complete_invocation(req) {
280 Ok(result) => match serde_json::to_string(&result) {
281 Ok(payload) => WorkerApiResponse::success(
282 request.api_type,
283 request.request_id.clone(),
284 payload,
285 ),
286 Err(e) => WorkerApiResponse::error(
287 request.api_type,
288 request.request_id.clone(),
289 format!("Failed to serialize response: {}", e),
290 ),
291 },
292 Err(e) => WorkerApiResponse::error(
293 request.api_type,
294 request.request_id.clone(),
295 format!("{}", e),
296 ),
297 },
298 Err(e) => WorkerApiResponse::error(
299 request.api_type,
300 request.request_id.clone(),
301 format!("Failed to parse request: {}", e),
302 ),
303 }
304 }
305
306 fn handle_checkpoint(
308 state: &mut CheckpointServerState,
309 request: &WorkerApiRequest,
310 ) -> WorkerApiResponse {
311 let parsed: Result<CheckpointDurableExecutionStateRequest, _> =
312 serde_json::from_str(&request.payload);
313
314 match parsed {
315 Ok(req) => {
316 match state
318 .execution_manager
319 .get_checkpoints_by_token_mut(&req.checkpoint_token)
320 {
321 Ok(Some(checkpoint_manager)) => {
322 match checkpoint_manager.process_checkpoint(req.operations) {
323 Ok(dirty_ops) => {
324 let response = serde_json::json!({
326 "CheckpointToken": req.checkpoint_token,
327 "NewExecutionState": {
328 "Operations": dirty_ops
329 }
330 });
331 match serde_json::to_string(&response) {
332 Ok(payload) => WorkerApiResponse::success(
333 request.api_type,
334 request.request_id.clone(),
335 payload,
336 ),
337 Err(e) => WorkerApiResponse::error(
338 request.api_type,
339 request.request_id.clone(),
340 format!("Failed to serialize response: {}", e),
341 ),
342 }
343 }
344 Err(e) => WorkerApiResponse::error(
345 request.api_type,
346 request.request_id.clone(),
347 format!("Checkpoint processing failed: {}", e),
348 ),
349 }
350 }
351 Ok(None) => WorkerApiResponse::error(
352 request.api_type,
353 request.request_id.clone(),
354 "Execution not found for checkpoint token".to_string(),
355 ),
356 Err(e) => WorkerApiResponse::error(
357 request.api_type,
358 request.request_id.clone(),
359 format!("Invalid checkpoint token: {}", e),
360 ),
361 }
362 }
363 Err(e) => WorkerApiResponse::error(
364 request.api_type,
365 request.request_id.clone(),
366 format!("Failed to parse request: {}", e),
367 ),
368 }
369 }
370
371 fn handle_get_state(
373 state: &mut CheckpointServerState,
374 request: &WorkerApiRequest,
375 ) -> WorkerApiResponse {
376 let parsed: Result<GetDurableExecutionStateRequest, _> =
377 serde_json::from_str(&request.payload);
378
379 match parsed {
380 Ok(req) => {
381 let execution_id = &req.durable_execution_arn;
383 match state
384 .execution_manager
385 .get_checkpoints_by_execution(execution_id)
386 {
387 Some(checkpoint_manager) => {
388 let operations = checkpoint_manager.get_state();
389 let response = GetOperationsResponse {
390 operations,
391 next_marker: None,
392 };
393 match serde_json::to_string(&response) {
394 Ok(payload) => WorkerApiResponse::success(
395 request.api_type,
396 request.request_id.clone(),
397 payload,
398 ),
399 Err(e) => WorkerApiResponse::error(
400 request.api_type,
401 request.request_id.clone(),
402 format!("Failed to serialize response: {}", e),
403 ),
404 }
405 }
406 None => WorkerApiResponse::error(
407 request.api_type,
408 request.request_id.clone(),
409 format!("Execution not found: {}", execution_id),
410 ),
411 }
412 }
413 Err(e) => WorkerApiResponse::error(
414 request.api_type,
415 request.request_id.clone(),
416 format!("Failed to parse request: {}", e),
417 ),
418 }
419 }
420
421 fn handle_callback_success(
423 state: &mut CheckpointServerState,
424 request: &WorkerApiRequest,
425 ) -> WorkerApiResponse {
426 let parsed: Result<SendCallbackSuccessRequest, _> = serde_json::from_str(&request.payload);
427
428 match parsed {
429 Ok(req) => {
430 let execution_ids: Vec<String> = state
432 .execution_manager
433 .get_execution_ids()
434 .into_iter()
435 .cloned()
436 .collect();
437
438 for execution_id in execution_ids {
440 if let Some(checkpoint_manager) = state
441 .execution_manager
442 .get_checkpoints_by_execution_mut(&execution_id)
443 {
444 if checkpoint_manager
445 .callback_manager()
446 .get_callback_state(&req.callback_id)
447 .is_some()
448 {
449 match checkpoint_manager
450 .callback_manager_mut()
451 .send_success(&req.callback_id, &req.result)
452 {
453 Ok(()) => {
454 checkpoint_manager.complete_callback_operation(
457 &req.callback_id,
458 Some(req.result.clone()),
459 None,
460 );
461 return WorkerApiResponse::success(
462 request.api_type,
463 request.request_id.clone(),
464 "{}".to_string(),
465 );
466 }
467 Err(e) => {
468 return WorkerApiResponse::error(
469 request.api_type,
470 request.request_id.clone(),
471 format!("{}", e),
472 );
473 }
474 }
475 }
476 }
477 }
478 WorkerApiResponse::error(
479 request.api_type,
480 request.request_id.clone(),
481 format!("Callback not found: {}", req.callback_id),
482 )
483 }
484 Err(e) => WorkerApiResponse::error(
485 request.api_type,
486 request.request_id.clone(),
487 format!("Failed to parse request: {}", e),
488 ),
489 }
490 }
491
492 fn handle_callback_failure(
494 state: &mut CheckpointServerState,
495 request: &WorkerApiRequest,
496 ) -> WorkerApiResponse {
497 let parsed: Result<SendCallbackFailureRequest, _> = serde_json::from_str(&request.payload);
498
499 match parsed {
500 Ok(req) => {
501 let execution_ids: Vec<String> = state
503 .execution_manager
504 .get_execution_ids()
505 .into_iter()
506 .cloned()
507 .collect();
508
509 for execution_id in execution_ids {
511 if let Some(checkpoint_manager) = state
512 .execution_manager
513 .get_checkpoints_by_execution_mut(&execution_id)
514 {
515 if checkpoint_manager
516 .callback_manager()
517 .get_callback_state(&req.callback_id)
518 .is_some()
519 {
520 match checkpoint_manager
521 .callback_manager_mut()
522 .send_failure(&req.callback_id, &req.error)
523 {
524 Ok(()) => {
525 checkpoint_manager.complete_callback_operation(
528 &req.callback_id,
529 None,
530 Some(req.error.clone()),
531 );
532 return WorkerApiResponse::success(
533 request.api_type,
534 request.request_id.clone(),
535 "{}".to_string(),
536 );
537 }
538 Err(e) => {
539 return WorkerApiResponse::error(
540 request.api_type,
541 request.request_id.clone(),
542 format!("{}", e),
543 );
544 }
545 }
546 }
547 }
548 }
549 WorkerApiResponse::error(
550 request.api_type,
551 request.request_id.clone(),
552 format!("Callback not found: {}", req.callback_id),
553 )
554 }
555 Err(e) => WorkerApiResponse::error(
556 request.api_type,
557 request.request_id.clone(),
558 format!("Failed to parse request: {}", e),
559 ),
560 }
561 }
562
563 fn handle_callback_heartbeat(
565 state: &mut CheckpointServerState,
566 request: &WorkerApiRequest,
567 ) -> WorkerApiResponse {
568 let parsed: Result<SendCallbackHeartbeatRequest, _> =
569 serde_json::from_str(&request.payload);
570
571 match parsed {
572 Ok(req) => {
573 let execution_ids: Vec<String> = state
575 .execution_manager
576 .get_execution_ids()
577 .into_iter()
578 .cloned()
579 .collect();
580
581 for execution_id in execution_ids {
583 if let Some(checkpoint_manager) = state
584 .execution_manager
585 .get_checkpoints_by_execution_mut(&execution_id)
586 {
587 if checkpoint_manager
588 .callback_manager()
589 .get_callback_state(&req.callback_id)
590 .is_some()
591 {
592 match checkpoint_manager
593 .callback_manager_mut()
594 .send_heartbeat(&req.callback_id)
595 {
596 Ok(()) => {
597 return WorkerApiResponse::success(
598 request.api_type,
599 request.request_id.clone(),
600 "{}".to_string(),
601 );
602 }
603 Err(e) => {
604 return WorkerApiResponse::error(
605 request.api_type,
606 request.request_id.clone(),
607 format!("{}", e),
608 );
609 }
610 }
611 }
612 }
613 }
614 WorkerApiResponse::error(
615 request.api_type,
616 request.request_id.clone(),
617 format!("Callback not found: {}", req.callback_id),
618 )
619 }
620 Err(e) => WorkerApiResponse::error(
621 request.api_type,
622 request.request_id.clone(),
623 format!("Failed to parse request: {}", e),
624 ),
625 }
626 }
627
628 fn handle_update_checkpoint_data(
634 state: &mut CheckpointServerState,
635 request: &WorkerApiRequest,
636 ) -> WorkerApiResponse {
637 let parsed: Result<super::types::UpdateCheckpointDataRequest, _> =
638 serde_json::from_str(&request.payload);
639
640 match parsed {
641 Ok(req) => {
642 match state
644 .execution_manager
645 .get_checkpoints_by_execution_mut(&req.execution_id)
646 {
647 Some(checkpoint_manager) => {
648 checkpoint_manager
650 .update_operation_data(&req.operation_id, req.operation_data);
651 WorkerApiResponse::success(
652 request.api_type,
653 request.request_id.clone(),
654 "{}".to_string(),
655 )
656 }
657 None => WorkerApiResponse::error(
658 request.api_type,
659 request.request_id.clone(),
660 format!("Execution not found: {}", req.execution_id),
661 ),
662 }
663 }
664 Err(e) => WorkerApiResponse::error(
665 request.api_type,
666 request.request_id.clone(),
667 format!("Failed to parse request: {}", e),
668 ),
669 }
670 }
671
672 fn handle_get_nodejs_history_events(
678 state: &mut CheckpointServerState,
679 request: &WorkerApiRequest,
680 ) -> WorkerApiResponse {
681 let parsed: Result<super::types::GetNodeJsHistoryEventsRequest, _> =
682 serde_json::from_str(&request.payload);
683
684 match parsed {
685 Ok(req) => {
686 match state
688 .execution_manager
689 .get_checkpoints_by_execution(&req.execution_id)
690 {
691 Some(checkpoint_manager) => {
692 let events = checkpoint_manager.get_nodejs_history_events();
694 match serde_json::to_string(&events) {
695 Ok(payload) => WorkerApiResponse::success(
696 request.api_type,
697 request.request_id.clone(),
698 payload,
699 ),
700 Err(e) => WorkerApiResponse::error(
701 request.api_type,
702 request.request_id.clone(),
703 format!("Failed to serialize response: {}", e),
704 ),
705 }
706 }
707 None => WorkerApiResponse::error(
708 request.api_type,
709 request.request_id.clone(),
710 format!("Execution not found: {}", req.execution_id),
711 ),
712 }
713 }
714 Err(e) => WorkerApiResponse::error(
715 request.api_type,
716 request.request_id.clone(),
717 format!("Failed to parse request: {}", e),
718 ),
719 }
720 }
721
722 pub async fn send_api_request(
724 &self,
725 api_type: ApiType,
726 payload: String,
727 ) -> Result<WorkerApiResponse, TestError> {
728 let request_id = Uuid::new_v4().to_string();
729 let command = WorkerCommand {
730 command_type: WorkerCommandType::ApiRequest,
731 data: WorkerApiRequest {
732 api_type,
733 request_id: request_id.clone(),
734 payload,
735 },
736 };
737
738 let (response_tx, response_rx) = oneshot::channel();
739
740 self.command_tx
741 .send(WorkerMessage::Command(command, response_tx))
742 .await
743 .map_err(|e| {
744 TestError::CheckpointCommunicationError(format!("Failed to send command: {}", e))
745 })?;
746
747 let response = response_rx.await.map_err(|e| {
748 TestError::CheckpointCommunicationError(format!("Failed to receive response: {}", e))
749 })?;
750
751 Ok(response.data)
752 }
753
754 pub async fn shutdown(mut self) -> Result<(), TestError> {
756 let _ = self.command_tx.send(WorkerMessage::Shutdown).await;
758
759 if let Some(handle) = self.worker_handle.take() {
761 handle.join().map_err(|_| {
762 TestError::CheckpointServerError("Worker thread panicked".to_string())
763 })?;
764 }
765
766 Ok(())
767 }
768
769 pub fn shutdown_sync(mut self) -> Result<(), TestError> {
771 if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
773 .enable_all()
774 .build()
775 {
776 let _ = rt.block_on(self.command_tx.send(WorkerMessage::Shutdown));
777 }
778
779 if let Some(handle) = self.worker_handle.take() {
781 handle.join().map_err(|_| {
782 TestError::CheckpointServerError("Worker thread panicked".to_string())
783 })?;
784 }
785
786 Ok(())
787 }
788
789 pub fn params(&self) -> &CheckpointWorkerParams {
791 &self.params
792 }
793}
794
795use super::types::{CheckpointDurableExecutionStateRequest, GetDurableExecutionStateRequest};
797
798use async_trait::async_trait;
800
801#[async_trait]
802impl DurableServiceClient for CheckpointWorkerManager {
803 async fn checkpoint(
804 &self,
805 durable_execution_arn: &str,
806 checkpoint_token: &str,
807 operations: Vec<OperationUpdate>,
808 ) -> Result<CheckpointResponse, DurableError> {
809 let request = CheckpointDurableExecutionStateRequest {
810 durable_execution_arn: durable_execution_arn.to_string(),
811 checkpoint_token: checkpoint_token.to_string(),
812 operations,
813 };
814
815 let payload = serde_json::to_string(&request)
816 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
817
818 let response = self
819 .send_api_request(ApiType::CheckpointDurableExecutionState, payload)
820 .await
821 .map_err(|e| {
822 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
823 })?;
824
825 if let Some(error) = response.error {
826 return Err(DurableError::checkpoint_retriable(error));
827 }
828
829 let payload = response
830 .payload
831 .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
832
833 serde_json::from_str(&payload)
834 .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
835 }
836
837 async fn get_operations(
838 &self,
839 durable_execution_arn: &str,
840 _next_marker: &str,
841 ) -> Result<GetOperationsResponse, DurableError> {
842 let request = GetDurableExecutionStateRequest {
843 durable_execution_arn: durable_execution_arn.to_string(),
844 };
845
846 let payload = serde_json::to_string(&request)
847 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
848
849 let response = self
850 .send_api_request(ApiType::GetDurableExecutionState, payload)
851 .await
852 .map_err(|e| {
853 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
854 })?;
855
856 if let Some(error) = response.error {
857 return Err(DurableError::checkpoint_retriable(error));
858 }
859
860 let payload = response
861 .payload
862 .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
863
864 serde_json::from_str(&payload)
865 .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
866 }
867}
868
869impl CheckpointWorkerManager {
871 pub async fn send_callback_success(
873 &self,
874 callback_id: &str,
875 result: &str,
876 ) -> Result<(), DurableError> {
877 let request = SendCallbackSuccessRequest {
878 callback_id: callback_id.to_string(),
879 result: result.to_string(),
880 };
881
882 let payload = serde_json::to_string(&request)
883 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
884
885 let response = self
886 .send_api_request(ApiType::SendDurableExecutionCallbackSuccess, payload)
887 .await
888 .map_err(|e| {
889 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
890 })?;
891
892 if let Some(error) = response.error {
893 return Err(DurableError::execution(error));
894 }
895
896 Ok(())
897 }
898
899 pub async fn send_callback_failure(
901 &self,
902 callback_id: &str,
903 error: &ErrorObject,
904 ) -> Result<(), DurableError> {
905 let request = SendCallbackFailureRequest {
906 callback_id: callback_id.to_string(),
907 error: error.clone(),
908 };
909
910 let payload = serde_json::to_string(&request)
911 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
912
913 let response = self
914 .send_api_request(ApiType::SendDurableExecutionCallbackFailure, payload)
915 .await
916 .map_err(|e| {
917 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
918 })?;
919
920 if let Some(error) = response.error {
921 return Err(DurableError::execution(error));
922 }
923
924 Ok(())
925 }
926
927 pub async fn send_callback_heartbeat(&self, callback_id: &str) -> Result<(), DurableError> {
929 let request = SendCallbackHeartbeatRequest {
930 callback_id: callback_id.to_string(),
931 };
932
933 let payload = serde_json::to_string(&request)
934 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
935
936 let response = self
937 .send_api_request(ApiType::SendDurableExecutionCallbackHeartbeat, payload)
938 .await
939 .map_err(|e| {
940 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
941 })?;
942
943 if let Some(error) = response.error {
944 return Err(DurableError::execution(error));
945 }
946
947 Ok(())
948 }
949
950 pub async fn get_nodejs_history_events(
964 &self,
965 execution_id: &str,
966 ) -> Result<Vec<super::nodejs_event_types::NodeJsHistoryEvent>, DurableError> {
967 use super::types::GetNodeJsHistoryEventsRequest;
968
969 let request = GetNodeJsHistoryEventsRequest {
970 execution_id: execution_id.to_string(),
971 };
972
973 let payload = serde_json::to_string(&request)
974 .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
975
976 let response = self
977 .send_api_request(ApiType::GetNodeJsHistoryEvents, payload)
978 .await
979 .map_err(|e| {
980 DurableError::checkpoint_retriable(format!("Communication error: {}", e))
981 })?;
982
983 if let Some(error) = response.error {
984 return Err(DurableError::checkpoint_retriable(error));
985 }
986
987 let payload = response
988 .payload
989 .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
990
991 serde_json::from_str(&payload)
992 .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
993 }
994}
995
996#[cfg(test)]
997mod tests {
998 use super::*;
999 use crate::checkpoint_server::InvocationResult;
1000
1001 #[tokio::test]
1002 async fn test_get_instance() {
1003 CheckpointWorkerManager::reset_instance_for_testing();
1004
1005 let instance = CheckpointWorkerManager::get_instance(None).unwrap();
1006 assert!(instance.params().checkpoint_delay.is_none());
1007
1008 let instance2 = CheckpointWorkerManager::get_instance(None).unwrap();
1010 assert!(Arc::ptr_eq(&instance, &instance2));
1011
1012 CheckpointWorkerManager::reset_instance_for_testing();
1013 }
1014
1015 #[tokio::test]
1016 async fn test_start_execution() {
1017 CheckpointWorkerManager::reset_instance_for_testing();
1018
1019 let manager = CheckpointWorkerManager::get_instance(None).unwrap();
1020
1021 let request = StartDurableExecutionRequest {
1022 invocation_id: "inv-1".to_string(),
1023 payload: Some(r#"{"test": true}"#.to_string()),
1024 };
1025
1026 let payload = serde_json::to_string(&request).unwrap();
1027 let response = manager
1028 .send_api_request(ApiType::StartDurableExecution, payload)
1029 .await
1030 .unwrap();
1031
1032 assert!(!response.is_error());
1033 assert!(response.payload.is_some());
1034
1035 let result: InvocationResult = serde_json::from_str(&response.payload.unwrap()).unwrap();
1036 assert!(!result.execution_id.is_empty());
1037 assert_eq!(result.invocation_id, "inv-1");
1038
1039 CheckpointWorkerManager::reset_instance_for_testing();
1040 }
1041
1042 #[tokio::test]
1043 async fn test_checkpoint_workflow() {
1044 CheckpointWorkerManager::reset_instance_for_testing();
1045
1046 let manager = CheckpointWorkerManager::get_instance(None).unwrap();
1047
1048 let start_request = StartDurableExecutionRequest {
1050 invocation_id: "inv-1".to_string(),
1051 payload: Some("{}".to_string()),
1052 };
1053 let payload = serde_json::to_string(&start_request).unwrap();
1054 let response = manager
1055 .send_api_request(ApiType::StartDurableExecution, payload)
1056 .await
1057 .unwrap();
1058
1059 let result: InvocationResult = serde_json::from_str(&response.payload.unwrap()).unwrap();
1060
1061 let checkpoint_request = CheckpointDurableExecutionStateRequest {
1063 durable_execution_arn: result.execution_id.clone(),
1064 checkpoint_token: result.checkpoint_token.clone(),
1065 operations: vec![OperationUpdate::start(
1066 "op-1",
1067 durable_execution_sdk::OperationType::Step,
1068 )
1069 .with_name("test-step")],
1070 };
1071
1072 let payload = serde_json::to_string(&checkpoint_request).unwrap();
1073 let response = manager
1074 .send_api_request(ApiType::CheckpointDurableExecutionState, payload)
1075 .await
1076 .unwrap();
1077
1078 assert!(!response.is_error());
1079
1080 CheckpointWorkerManager::reset_instance_for_testing();
1081 }
1082}