Skip to main content

durable_execution_sdk_testing/checkpoint_server/
worker_manager.rs

1//! Checkpoint worker manager for managing the checkpoint server thread.
2//!
3//! This module implements the CheckpointWorkerManager which manages the lifecycle
4//! of the checkpoint server thread, matching the Node.js SDK's CheckpointWorkerManager.
5
6use std::sync::{Arc, Mutex, OnceLock};
7use std::thread::{self, JoinHandle};
8
9use tokio::sync::{mpsc, oneshot};
10use uuid::Uuid;
11
12use durable_execution_sdk::{
13    CheckpointResponse, DurableError, DurableServiceClient, ErrorObject, GetOperationsResponse,
14    OperationUpdate,
15};
16
17use super::execution_manager::ExecutionManager;
18use super::types::{
19    ApiType, CheckpointWorkerParams, CompleteInvocationRequest, SendCallbackFailureRequest,
20    SendCallbackHeartbeatRequest, SendCallbackSuccessRequest, StartDurableExecutionRequest,
21    StartInvocationRequest, WorkerApiRequest, WorkerApiResponse, WorkerCommand, WorkerCommandType,
22    WorkerResponse, WorkerResponseType,
23};
24use crate::error::TestError;
25
26/// Internal state for the checkpoint server.
27struct CheckpointServerState {
28    execution_manager: ExecutionManager,
29    params: CheckpointWorkerParams,
30}
31
32/// Message sent to the worker thread.
33enum WorkerMessage {
34    /// A command to process
35    Command(WorkerCommand, oneshot::Sender<WorkerResponse>),
36    /// Shutdown signal
37    Shutdown,
38}
39
40/// Manages the checkpoint server thread lifecycle.
41pub struct CheckpointWorkerManager {
42    /// Sender for commands to the worker
43    command_tx: mpsc::Sender<WorkerMessage>,
44    /// Handle to the worker thread (for joining on shutdown)
45    worker_handle: Option<JoinHandle<()>>,
46    /// Configuration parameters
47    params: CheckpointWorkerParams,
48}
49
50/// Global singleton instance
51static INSTANCE: OnceLock<Mutex<Option<Arc<CheckpointWorkerManager>>>> = OnceLock::new();
52
53impl CheckpointWorkerManager {
54    /// Get or create the singleton instance.
55    ///
56    /// If the singleton already exists and `params` is `Some` with values that differ
57    /// from the existing instance's configuration, an error is returned. Pass `None`
58    /// to accept whatever configuration the existing instance has.
59    pub fn get_instance(
60        params: Option<CheckpointWorkerParams>,
61    ) -> Result<Arc<CheckpointWorkerManager>, TestError> {
62        let mutex = INSTANCE.get_or_init(|| Mutex::new(None));
63        let mut guard = mutex.lock().map_err(|e| {
64            TestError::CheckpointServerError(format!("Failed to lock instance: {}", e))
65        })?;
66
67        if let Some(ref instance) = *guard {
68            // Validate that requested params match the existing instance
69            if let Some(ref requested) = params {
70                if *requested != instance.params {
71                    return Err(TestError::CheckpointServerError(format!(
72                        "CheckpointWorkerManager singleton already exists with different params. \
73                         Existing: {:?}, Requested: {:?}. \
74                         Call reset_instance_for_testing() first to reconfigure.",
75                        instance.params, requested
76                    )));
77                }
78            }
79            return Ok(Arc::clone(instance));
80        }
81
82        let params = params.unwrap_or_default();
83        let manager = Self::new(params)?;
84        let arc = Arc::new(manager);
85        *guard = Some(Arc::clone(&arc));
86        Ok(arc)
87    }
88
89    /// Reset the singleton instance (for testing).
90    ///
91    /// Note: This method should be called from a non-async context or at the start
92    /// of a test before any async operations. If called from within an async context,
93    /// it will skip the graceful shutdown and just clear the instance.
94    pub fn reset_instance_for_testing() {
95        if let Some(mutex) = INSTANCE.get() {
96            if let Ok(mut guard) = mutex.lock() {
97                if let Some(instance) = guard.take() {
98                    // Try to shutdown gracefully, but don't panic if we can't
99                    if let Ok(manager) = Arc::try_unwrap(instance) {
100                        // Check if we're in an async context by trying to get the current runtime
101                        // If we are, we can't use shutdown_sync, so just drop the manager
102                        if tokio::runtime::Handle::try_current().is_err() {
103                            // Not in an async context, safe to use shutdown_sync
104                            let _ = manager.shutdown_sync();
105                        }
106                        // If in async context, the manager will be dropped and the thread
107                        // will eventually terminate when the channel is closed
108                    }
109                }
110            }
111        }
112    }
113
114    /// Create a new checkpoint worker manager.
115    fn new(params: CheckpointWorkerParams) -> Result<Self, TestError> {
116        let (command_tx, command_rx) = mpsc::channel::<WorkerMessage>(100);
117
118        let worker_params = params.clone();
119        let worker_handle = thread::spawn(move || {
120            Self::run_worker(command_rx, worker_params);
121        });
122
123        Ok(Self {
124            command_tx,
125            worker_handle: Some(worker_handle),
126            params,
127        })
128    }
129
130    /// Run the worker thread.
131    fn run_worker(mut command_rx: mpsc::Receiver<WorkerMessage>, params: CheckpointWorkerParams) {
132        // Create a tokio runtime for the worker thread
133        let rt = tokio::runtime::Builder::new_current_thread()
134            .enable_all()
135            .build()
136            .expect("Failed to create tokio runtime for worker");
137
138        rt.block_on(async {
139            let mut state = CheckpointServerState {
140                execution_manager: ExecutionManager::new(),
141                params,
142            };
143
144            while let Some(message) = command_rx.recv().await {
145                match message {
146                    WorkerMessage::Command(command, response_tx) => {
147                        let response = Self::process_command(&mut state, command).await;
148                        let _ = response_tx.send(response);
149                    }
150                    WorkerMessage::Shutdown => {
151                        break;
152                    }
153                }
154            }
155        });
156    }
157
158    /// Process a command and return a response.
159    async fn process_command(
160        state: &mut CheckpointServerState,
161        command: WorkerCommand,
162    ) -> WorkerResponse {
163        // Simulate checkpoint delay if configured
164        if let Some(delay_ms) = state.params.checkpoint_delay {
165            tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
166        }
167
168        let api_response = match command.data.api_type {
169            ApiType::StartDurableExecution => Self::handle_start_execution(state, &command.data),
170            ApiType::StartInvocation => Self::handle_start_invocation(state, &command.data),
171            ApiType::CompleteInvocation => Self::handle_complete_invocation(state, &command.data),
172            ApiType::CheckpointDurableExecutionState => {
173                Self::handle_checkpoint(state, &command.data)
174            }
175            ApiType::GetDurableExecutionState => Self::handle_get_state(state, &command.data),
176            ApiType::SendDurableExecutionCallbackSuccess => {
177                Self::handle_callback_success(state, &command.data)
178            }
179            ApiType::SendDurableExecutionCallbackFailure => {
180                Self::handle_callback_failure(state, &command.data)
181            }
182            ApiType::SendDurableExecutionCallbackHeartbeat => {
183                Self::handle_callback_heartbeat(state, &command.data)
184            }
185            ApiType::UpdateCheckpointData => {
186                Self::handle_update_checkpoint_data(state, &command.data)
187            }
188            ApiType::GetNodeJsHistoryEvents => {
189                Self::handle_get_nodejs_history_events(state, &command.data)
190            }
191            _ => WorkerApiResponse::error(
192                command.data.api_type,
193                command.data.request_id.clone(),
194                format!("Unsupported API type: {:?}", command.data.api_type),
195            ),
196        };
197
198        WorkerResponse {
199            response_type: WorkerResponseType::ApiResponse,
200            data: api_response,
201        }
202    }
203
204    /// Handle StartDurableExecution request.
205    fn handle_start_execution(
206        state: &mut CheckpointServerState,
207        request: &WorkerApiRequest,
208    ) -> WorkerApiResponse {
209        let parsed: Result<StartDurableExecutionRequest, _> =
210            serde_json::from_str(&request.payload);
211
212        match parsed {
213            Ok(req) => {
214                let result = state.execution_manager.start_execution_from_request(req);
215                match serde_json::to_string(&result) {
216                    Ok(payload) => WorkerApiResponse::success(
217                        request.api_type,
218                        request.request_id.clone(),
219                        payload,
220                    ),
221                    Err(e) => WorkerApiResponse::error(
222                        request.api_type,
223                        request.request_id.clone(),
224                        format!("Failed to serialize response: {}", e),
225                    ),
226                }
227            }
228            Err(e) => WorkerApiResponse::error(
229                request.api_type,
230                request.request_id.clone(),
231                format!("Failed to parse request: {}", e),
232            ),
233        }
234    }
235
236    /// Handle StartInvocation request.
237    fn handle_start_invocation(
238        state: &mut CheckpointServerState,
239        request: &WorkerApiRequest,
240    ) -> WorkerApiResponse {
241        let parsed: Result<StartInvocationRequest, _> = serde_json::from_str(&request.payload);
242
243        match parsed {
244            Ok(req) => match state.execution_manager.start_invocation(req) {
245                Ok(result) => match serde_json::to_string(&result) {
246                    Ok(payload) => WorkerApiResponse::success(
247                        request.api_type,
248                        request.request_id.clone(),
249                        payload,
250                    ),
251                    Err(e) => WorkerApiResponse::error(
252                        request.api_type,
253                        request.request_id.clone(),
254                        format!("Failed to serialize response: {}", e),
255                    ),
256                },
257                Err(e) => WorkerApiResponse::error(
258                    request.api_type,
259                    request.request_id.clone(),
260                    format!("{}", e),
261                ),
262            },
263            Err(e) => WorkerApiResponse::error(
264                request.api_type,
265                request.request_id.clone(),
266                format!("Failed to parse request: {}", e),
267            ),
268        }
269    }
270
271    /// Handle CompleteInvocation request.
272    fn handle_complete_invocation(
273        state: &mut CheckpointServerState,
274        request: &WorkerApiRequest,
275    ) -> WorkerApiResponse {
276        let parsed: Result<CompleteInvocationRequest, _> = serde_json::from_str(&request.payload);
277
278        match parsed {
279            Ok(req) => match state.execution_manager.complete_invocation(req) {
280                Ok(result) => match serde_json::to_string(&result) {
281                    Ok(payload) => WorkerApiResponse::success(
282                        request.api_type,
283                        request.request_id.clone(),
284                        payload,
285                    ),
286                    Err(e) => WorkerApiResponse::error(
287                        request.api_type,
288                        request.request_id.clone(),
289                        format!("Failed to serialize response: {}", e),
290                    ),
291                },
292                Err(e) => WorkerApiResponse::error(
293                    request.api_type,
294                    request.request_id.clone(),
295                    format!("{}", e),
296                ),
297            },
298            Err(e) => WorkerApiResponse::error(
299                request.api_type,
300                request.request_id.clone(),
301                format!("Failed to parse request: {}", e),
302            ),
303        }
304    }
305
306    /// Handle CheckpointDurableExecutionState request.
307    fn handle_checkpoint(
308        state: &mut CheckpointServerState,
309        request: &WorkerApiRequest,
310    ) -> WorkerApiResponse {
311        let parsed: Result<CheckpointDurableExecutionStateRequest, _> =
312            serde_json::from_str(&request.payload);
313
314        match parsed {
315            Ok(req) => {
316                // Get checkpoint manager by token
317                match state
318                    .execution_manager
319                    .get_checkpoints_by_token_mut(&req.checkpoint_token)
320                {
321                    Ok(Some(checkpoint_manager)) => {
322                        match checkpoint_manager.process_checkpoint(req.operations) {
323                            Ok(dirty_ops) => {
324                                // Build response matching SDK's CheckpointResponse structure
325                                let response = serde_json::json!({
326                                    "CheckpointToken": req.checkpoint_token,
327                                    "NewExecutionState": {
328                                        "Operations": dirty_ops
329                                    }
330                                });
331                                match serde_json::to_string(&response) {
332                                    Ok(payload) => WorkerApiResponse::success(
333                                        request.api_type,
334                                        request.request_id.clone(),
335                                        payload,
336                                    ),
337                                    Err(e) => WorkerApiResponse::error(
338                                        request.api_type,
339                                        request.request_id.clone(),
340                                        format!("Failed to serialize response: {}", e),
341                                    ),
342                                }
343                            }
344                            Err(e) => WorkerApiResponse::error(
345                                request.api_type,
346                                request.request_id.clone(),
347                                format!("Checkpoint processing failed: {}", e),
348                            ),
349                        }
350                    }
351                    Ok(None) => WorkerApiResponse::error(
352                        request.api_type,
353                        request.request_id.clone(),
354                        "Execution not found for checkpoint token".to_string(),
355                    ),
356                    Err(e) => WorkerApiResponse::error(
357                        request.api_type,
358                        request.request_id.clone(),
359                        format!("Invalid checkpoint token: {}", e),
360                    ),
361                }
362            }
363            Err(e) => WorkerApiResponse::error(
364                request.api_type,
365                request.request_id.clone(),
366                format!("Failed to parse request: {}", e),
367            ),
368        }
369    }
370
371    /// Handle GetDurableExecutionState request.
372    fn handle_get_state(
373        state: &mut CheckpointServerState,
374        request: &WorkerApiRequest,
375    ) -> WorkerApiResponse {
376        let parsed: Result<GetDurableExecutionStateRequest, _> =
377            serde_json::from_str(&request.payload);
378
379        match parsed {
380            Ok(req) => {
381                // Extract execution ID from ARN (simplified - just use the ARN as ID for testing)
382                let execution_id = &req.durable_execution_arn;
383                match state
384                    .execution_manager
385                    .get_checkpoints_by_execution(execution_id)
386                {
387                    Some(checkpoint_manager) => {
388                        let operations = checkpoint_manager.get_state();
389                        let response = GetOperationsResponse {
390                            operations,
391                            next_marker: None,
392                        };
393                        match serde_json::to_string(&response) {
394                            Ok(payload) => WorkerApiResponse::success(
395                                request.api_type,
396                                request.request_id.clone(),
397                                payload,
398                            ),
399                            Err(e) => WorkerApiResponse::error(
400                                request.api_type,
401                                request.request_id.clone(),
402                                format!("Failed to serialize response: {}", e),
403                            ),
404                        }
405                    }
406                    None => WorkerApiResponse::error(
407                        request.api_type,
408                        request.request_id.clone(),
409                        format!("Execution not found: {}", execution_id),
410                    ),
411                }
412            }
413            Err(e) => WorkerApiResponse::error(
414                request.api_type,
415                request.request_id.clone(),
416                format!("Failed to parse request: {}", e),
417            ),
418        }
419    }
420
421    /// Handle SendDurableExecutionCallbackSuccess request.
422    fn handle_callback_success(
423        state: &mut CheckpointServerState,
424        request: &WorkerApiRequest,
425    ) -> WorkerApiResponse {
426        let parsed: Result<SendCallbackSuccessRequest, _> = serde_json::from_str(&request.payload);
427
428        match parsed {
429            Ok(req) => {
430                // Collect execution IDs first to avoid borrow conflict
431                let execution_ids: Vec<String> = state
432                    .execution_manager
433                    .get_execution_ids()
434                    .into_iter()
435                    .cloned()
436                    .collect();
437
438                // Find the execution containing this callback
439                for execution_id in execution_ids {
440                    if let Some(checkpoint_manager) = state
441                        .execution_manager
442                        .get_checkpoints_by_execution_mut(&execution_id)
443                    {
444                        if checkpoint_manager
445                            .callback_manager()
446                            .get_callback_state(&req.callback_id)
447                            .is_some()
448                        {
449                            match checkpoint_manager
450                                .callback_manager_mut()
451                                .send_success(&req.callback_id, &req.result)
452                            {
453                                Ok(()) => {
454                                    // Also update the callback operation status to Succeeded
455                                    // so the orchestrator can detect completion and re-invoke.
456                                    checkpoint_manager.complete_callback_operation(
457                                        &req.callback_id,
458                                        Some(req.result.clone()),
459                                        None,
460                                    );
461                                    return WorkerApiResponse::success(
462                                        request.api_type,
463                                        request.request_id.clone(),
464                                        "{}".to_string(),
465                                    );
466                                }
467                                Err(e) => {
468                                    return WorkerApiResponse::error(
469                                        request.api_type,
470                                        request.request_id.clone(),
471                                        format!("{}", e),
472                                    );
473                                }
474                            }
475                        }
476                    }
477                }
478                WorkerApiResponse::error(
479                    request.api_type,
480                    request.request_id.clone(),
481                    format!("Callback not found: {}", req.callback_id),
482                )
483            }
484            Err(e) => WorkerApiResponse::error(
485                request.api_type,
486                request.request_id.clone(),
487                format!("Failed to parse request: {}", e),
488            ),
489        }
490    }
491
492    /// Handle SendDurableExecutionCallbackFailure request.
493    fn handle_callback_failure(
494        state: &mut CheckpointServerState,
495        request: &WorkerApiRequest,
496    ) -> WorkerApiResponse {
497        let parsed: Result<SendCallbackFailureRequest, _> = serde_json::from_str(&request.payload);
498
499        match parsed {
500            Ok(req) => {
501                // Collect execution IDs first to avoid borrow conflict
502                let execution_ids: Vec<String> = state
503                    .execution_manager
504                    .get_execution_ids()
505                    .into_iter()
506                    .cloned()
507                    .collect();
508
509                // Find the execution containing this callback
510                for execution_id in execution_ids {
511                    if let Some(checkpoint_manager) = state
512                        .execution_manager
513                        .get_checkpoints_by_execution_mut(&execution_id)
514                    {
515                        if checkpoint_manager
516                            .callback_manager()
517                            .get_callback_state(&req.callback_id)
518                            .is_some()
519                        {
520                            match checkpoint_manager
521                                .callback_manager_mut()
522                                .send_failure(&req.callback_id, &req.error)
523                            {
524                                Ok(()) => {
525                                    // Also update the callback operation status to Failed
526                                    // so the orchestrator can detect completion and re-invoke.
527                                    checkpoint_manager.complete_callback_operation(
528                                        &req.callback_id,
529                                        None,
530                                        Some(req.error.clone()),
531                                    );
532                                    return WorkerApiResponse::success(
533                                        request.api_type,
534                                        request.request_id.clone(),
535                                        "{}".to_string(),
536                                    );
537                                }
538                                Err(e) => {
539                                    return WorkerApiResponse::error(
540                                        request.api_type,
541                                        request.request_id.clone(),
542                                        format!("{}", e),
543                                    );
544                                }
545                            }
546                        }
547                    }
548                }
549                WorkerApiResponse::error(
550                    request.api_type,
551                    request.request_id.clone(),
552                    format!("Callback not found: {}", req.callback_id),
553                )
554            }
555            Err(e) => WorkerApiResponse::error(
556                request.api_type,
557                request.request_id.clone(),
558                format!("Failed to parse request: {}", e),
559            ),
560        }
561    }
562
563    /// Handle SendDurableExecutionCallbackHeartbeat request.
564    fn handle_callback_heartbeat(
565        state: &mut CheckpointServerState,
566        request: &WorkerApiRequest,
567    ) -> WorkerApiResponse {
568        let parsed: Result<SendCallbackHeartbeatRequest, _> =
569            serde_json::from_str(&request.payload);
570
571        match parsed {
572            Ok(req) => {
573                // Collect execution IDs first to avoid borrow conflict
574                let execution_ids: Vec<String> = state
575                    .execution_manager
576                    .get_execution_ids()
577                    .into_iter()
578                    .cloned()
579                    .collect();
580
581                // Find the execution containing this callback
582                for execution_id in execution_ids {
583                    if let Some(checkpoint_manager) = state
584                        .execution_manager
585                        .get_checkpoints_by_execution_mut(&execution_id)
586                    {
587                        if checkpoint_manager
588                            .callback_manager()
589                            .get_callback_state(&req.callback_id)
590                            .is_some()
591                        {
592                            match checkpoint_manager
593                                .callback_manager_mut()
594                                .send_heartbeat(&req.callback_id)
595                            {
596                                Ok(()) => {
597                                    return WorkerApiResponse::success(
598                                        request.api_type,
599                                        request.request_id.clone(),
600                                        "{}".to_string(),
601                                    );
602                                }
603                                Err(e) => {
604                                    return WorkerApiResponse::error(
605                                        request.api_type,
606                                        request.request_id.clone(),
607                                        format!("{}", e),
608                                    );
609                                }
610                            }
611                        }
612                    }
613                }
614                WorkerApiResponse::error(
615                    request.api_type,
616                    request.request_id.clone(),
617                    format!("Callback not found: {}", req.callback_id),
618                )
619            }
620            Err(e) => WorkerApiResponse::error(
621                request.api_type,
622                request.request_id.clone(),
623                format!("Failed to parse request: {}", e),
624            ),
625        }
626    }
627
628    /// Handle UpdateCheckpointData request.
629    ///
630    /// This method updates the state of a specific operation in the checkpoint server.
631    /// It's used by the orchestrator to mark wait operations as SUCCEEDED after time
632    /// has been advanced in time-skipping mode.
633    fn handle_update_checkpoint_data(
634        state: &mut CheckpointServerState,
635        request: &WorkerApiRequest,
636    ) -> WorkerApiResponse {
637        let parsed: Result<super::types::UpdateCheckpointDataRequest, _> =
638            serde_json::from_str(&request.payload);
639
640        match parsed {
641            Ok(req) => {
642                // Get checkpoint manager by execution ID
643                match state
644                    .execution_manager
645                    .get_checkpoints_by_execution_mut(&req.execution_id)
646                {
647                    Some(checkpoint_manager) => {
648                        // Update the operation data
649                        checkpoint_manager
650                            .update_operation_data(&req.operation_id, req.operation_data);
651                        WorkerApiResponse::success(
652                            request.api_type,
653                            request.request_id.clone(),
654                            "{}".to_string(),
655                        )
656                    }
657                    None => WorkerApiResponse::error(
658                        request.api_type,
659                        request.request_id.clone(),
660                        format!("Execution not found: {}", req.execution_id),
661                    ),
662                }
663            }
664            Err(e) => WorkerApiResponse::error(
665                request.api_type,
666                request.request_id.clone(),
667                format!("Failed to parse request: {}", e),
668            ),
669        }
670    }
671
672    /// Handle GetNodeJsHistoryEvents request.
673    ///
674    /// This method retrieves the Node.js-compatible history events for an execution.
675    /// These events match the Node.js SDK's event history format exactly,
676    /// enabling cross-SDK history comparison.
677    fn handle_get_nodejs_history_events(
678        state: &mut CheckpointServerState,
679        request: &WorkerApiRequest,
680    ) -> WorkerApiResponse {
681        let parsed: Result<super::types::GetNodeJsHistoryEventsRequest, _> =
682            serde_json::from_str(&request.payload);
683
684        match parsed {
685            Ok(req) => {
686                // Get checkpoint manager by execution ID
687                match state
688                    .execution_manager
689                    .get_checkpoints_by_execution(&req.execution_id)
690                {
691                    Some(checkpoint_manager) => {
692                        // Get Node.js history events
693                        let events = checkpoint_manager.get_nodejs_history_events();
694                        match serde_json::to_string(&events) {
695                            Ok(payload) => WorkerApiResponse::success(
696                                request.api_type,
697                                request.request_id.clone(),
698                                payload,
699                            ),
700                            Err(e) => WorkerApiResponse::error(
701                                request.api_type,
702                                request.request_id.clone(),
703                                format!("Failed to serialize response: {}", e),
704                            ),
705                        }
706                    }
707                    None => WorkerApiResponse::error(
708                        request.api_type,
709                        request.request_id.clone(),
710                        format!("Execution not found: {}", req.execution_id),
711                    ),
712                }
713            }
714            Err(e) => WorkerApiResponse::error(
715                request.api_type,
716                request.request_id.clone(),
717                format!("Failed to parse request: {}", e),
718            ),
719        }
720    }
721
722    /// Send an API request to the checkpoint server and wait for response.
723    pub async fn send_api_request(
724        &self,
725        api_type: ApiType,
726        payload: String,
727    ) -> Result<WorkerApiResponse, TestError> {
728        let request_id = Uuid::new_v4().to_string();
729        let command = WorkerCommand {
730            command_type: WorkerCommandType::ApiRequest,
731            data: WorkerApiRequest {
732                api_type,
733                request_id: request_id.clone(),
734                payload,
735            },
736        };
737
738        let (response_tx, response_rx) = oneshot::channel();
739
740        self.command_tx
741            .send(WorkerMessage::Command(command, response_tx))
742            .await
743            .map_err(|e| {
744                TestError::CheckpointCommunicationError(format!("Failed to send command: {}", e))
745            })?;
746
747        let response = response_rx.await.map_err(|e| {
748            TestError::CheckpointCommunicationError(format!("Failed to receive response: {}", e))
749        })?;
750
751        Ok(response.data)
752    }
753
754    /// Gracefully shut down the checkpoint server.
755    pub async fn shutdown(mut self) -> Result<(), TestError> {
756        // Send shutdown signal
757        let _ = self.command_tx.send(WorkerMessage::Shutdown).await;
758
759        // Wait for worker thread to finish
760        if let Some(handle) = self.worker_handle.take() {
761            handle.join().map_err(|_| {
762                TestError::CheckpointServerError("Worker thread panicked".to_string())
763            })?;
764        }
765
766        Ok(())
767    }
768
769    /// Synchronous shutdown (for use in Drop or non-async contexts).
770    pub fn shutdown_sync(mut self) -> Result<(), TestError> {
771        // Create a runtime to send the shutdown signal
772        if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
773            .enable_all()
774            .build()
775        {
776            let _ = rt.block_on(self.command_tx.send(WorkerMessage::Shutdown));
777        }
778
779        // Wait for worker thread to finish
780        if let Some(handle) = self.worker_handle.take() {
781            handle.join().map_err(|_| {
782                TestError::CheckpointServerError("Worker thread panicked".to_string())
783            })?;
784        }
785
786        Ok(())
787    }
788
789    /// Get the configuration parameters.
790    pub fn params(&self) -> &CheckpointWorkerParams {
791        &self.params
792    }
793}
794
795// Additional types needed for checkpoint requests
796use super::types::{CheckpointDurableExecutionStateRequest, GetDurableExecutionStateRequest};
797
798// Implement DurableServiceClient trait for CheckpointWorkerManager
799use async_trait::async_trait;
800
801#[async_trait]
802impl DurableServiceClient for CheckpointWorkerManager {
803    async fn checkpoint(
804        &self,
805        durable_execution_arn: &str,
806        checkpoint_token: &str,
807        operations: Vec<OperationUpdate>,
808    ) -> Result<CheckpointResponse, DurableError> {
809        let request = CheckpointDurableExecutionStateRequest {
810            durable_execution_arn: durable_execution_arn.to_string(),
811            checkpoint_token: checkpoint_token.to_string(),
812            operations,
813        };
814
815        let payload = serde_json::to_string(&request)
816            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
817
818        let response = self
819            .send_api_request(ApiType::CheckpointDurableExecutionState, payload)
820            .await
821            .map_err(|e| {
822                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
823            })?;
824
825        if let Some(error) = response.error {
826            return Err(DurableError::checkpoint_retriable(error));
827        }
828
829        let payload = response
830            .payload
831            .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
832
833        serde_json::from_str(&payload)
834            .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
835    }
836
837    async fn get_operations(
838        &self,
839        durable_execution_arn: &str,
840        _next_marker: &str,
841    ) -> Result<GetOperationsResponse, DurableError> {
842        let request = GetDurableExecutionStateRequest {
843            durable_execution_arn: durable_execution_arn.to_string(),
844        };
845
846        let payload = serde_json::to_string(&request)
847            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
848
849        let response = self
850            .send_api_request(ApiType::GetDurableExecutionState, payload)
851            .await
852            .map_err(|e| {
853                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
854            })?;
855
856        if let Some(error) = response.error {
857            return Err(DurableError::checkpoint_retriable(error));
858        }
859
860        let payload = response
861            .payload
862            .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
863
864        serde_json::from_str(&payload)
865            .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
866    }
867}
868
869// Callback methods (not part of DurableServiceClient trait)
870impl CheckpointWorkerManager {
871    /// Send a callback success response.
872    pub async fn send_callback_success(
873        &self,
874        callback_id: &str,
875        result: &str,
876    ) -> Result<(), DurableError> {
877        let request = SendCallbackSuccessRequest {
878            callback_id: callback_id.to_string(),
879            result: result.to_string(),
880        };
881
882        let payload = serde_json::to_string(&request)
883            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
884
885        let response = self
886            .send_api_request(ApiType::SendDurableExecutionCallbackSuccess, payload)
887            .await
888            .map_err(|e| {
889                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
890            })?;
891
892        if let Some(error) = response.error {
893            return Err(DurableError::execution(error));
894        }
895
896        Ok(())
897    }
898
899    /// Send a callback failure response.
900    pub async fn send_callback_failure(
901        &self,
902        callback_id: &str,
903        error: &ErrorObject,
904    ) -> Result<(), DurableError> {
905        let request = SendCallbackFailureRequest {
906            callback_id: callback_id.to_string(),
907            error: error.clone(),
908        };
909
910        let payload = serde_json::to_string(&request)
911            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
912
913        let response = self
914            .send_api_request(ApiType::SendDurableExecutionCallbackFailure, payload)
915            .await
916            .map_err(|e| {
917                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
918            })?;
919
920        if let Some(error) = response.error {
921            return Err(DurableError::execution(error));
922        }
923
924        Ok(())
925    }
926
927    /// Send a callback heartbeat.
928    pub async fn send_callback_heartbeat(&self, callback_id: &str) -> Result<(), DurableError> {
929        let request = SendCallbackHeartbeatRequest {
930            callback_id: callback_id.to_string(),
931        };
932
933        let payload = serde_json::to_string(&request)
934            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
935
936        let response = self
937            .send_api_request(ApiType::SendDurableExecutionCallbackHeartbeat, payload)
938            .await
939            .map_err(|e| {
940                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
941            })?;
942
943        if let Some(error) = response.error {
944            return Err(DurableError::execution(error));
945        }
946
947        Ok(())
948    }
949
950    /// Get Node.js-compatible history events for an execution.
951    ///
952    /// Returns events in the Node.js SDK compatible format, suitable for
953    /// cross-SDK history comparison. These events use PascalCase field names
954    /// and include detailed event-specific information.
955    ///
956    /// # Arguments
957    ///
958    /// * `execution_id` - The execution ID to get history events for
959    ///
960    /// # Returns
961    ///
962    /// A vector of Node.js-compatible history events.
963    pub async fn get_nodejs_history_events(
964        &self,
965        execution_id: &str,
966    ) -> Result<Vec<super::nodejs_event_types::NodeJsHistoryEvent>, DurableError> {
967        use super::types::GetNodeJsHistoryEventsRequest;
968
969        let request = GetNodeJsHistoryEventsRequest {
970            execution_id: execution_id.to_string(),
971        };
972
973        let payload = serde_json::to_string(&request)
974            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
975
976        let response = self
977            .send_api_request(ApiType::GetNodeJsHistoryEvents, payload)
978            .await
979            .map_err(|e| {
980                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
981            })?;
982
983        if let Some(error) = response.error {
984            return Err(DurableError::checkpoint_retriable(error));
985        }
986
987        let payload = response
988            .payload
989            .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
990
991        serde_json::from_str(&payload)
992            .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
993    }
994}
995
996#[cfg(test)]
997mod tests {
998    use super::*;
999    use crate::checkpoint_server::InvocationResult;
1000
1001    #[tokio::test]
1002    async fn test_get_instance() {
1003        CheckpointWorkerManager::reset_instance_for_testing();
1004
1005        let instance = CheckpointWorkerManager::get_instance(None).unwrap();
1006        assert!(instance.params().checkpoint_delay.is_none());
1007
1008        // Getting instance again should return the same one
1009        let instance2 = CheckpointWorkerManager::get_instance(None).unwrap();
1010        assert!(Arc::ptr_eq(&instance, &instance2));
1011
1012        CheckpointWorkerManager::reset_instance_for_testing();
1013    }
1014
1015    #[tokio::test]
1016    async fn test_start_execution() {
1017        CheckpointWorkerManager::reset_instance_for_testing();
1018
1019        let manager = CheckpointWorkerManager::get_instance(None).unwrap();
1020
1021        let request = StartDurableExecutionRequest {
1022            invocation_id: "inv-1".to_string(),
1023            payload: Some(r#"{"test": true}"#.to_string()),
1024        };
1025
1026        let payload = serde_json::to_string(&request).unwrap();
1027        let response = manager
1028            .send_api_request(ApiType::StartDurableExecution, payload)
1029            .await
1030            .unwrap();
1031
1032        assert!(!response.is_error());
1033        assert!(response.payload.is_some());
1034
1035        let result: InvocationResult = serde_json::from_str(&response.payload.unwrap()).unwrap();
1036        assert!(!result.execution_id.is_empty());
1037        assert_eq!(result.invocation_id, "inv-1");
1038
1039        CheckpointWorkerManager::reset_instance_for_testing();
1040    }
1041
1042    #[tokio::test]
1043    async fn test_checkpoint_workflow() {
1044        CheckpointWorkerManager::reset_instance_for_testing();
1045
1046        let manager = CheckpointWorkerManager::get_instance(None).unwrap();
1047
1048        // Start execution
1049        let start_request = StartDurableExecutionRequest {
1050            invocation_id: "inv-1".to_string(),
1051            payload: Some("{}".to_string()),
1052        };
1053        let payload = serde_json::to_string(&start_request).unwrap();
1054        let response = manager
1055            .send_api_request(ApiType::StartDurableExecution, payload)
1056            .await
1057            .unwrap();
1058
1059        let result: InvocationResult = serde_json::from_str(&response.payload.unwrap()).unwrap();
1060
1061        // Send checkpoint using OperationUpdate::start helper
1062        let checkpoint_request = CheckpointDurableExecutionStateRequest {
1063            durable_execution_arn: result.execution_id.clone(),
1064            checkpoint_token: result.checkpoint_token.clone(),
1065            operations: vec![OperationUpdate::start(
1066                "op-1",
1067                durable_execution_sdk::OperationType::Step,
1068            )
1069            .with_name("test-step")],
1070        };
1071
1072        let payload = serde_json::to_string(&checkpoint_request).unwrap();
1073        let response = manager
1074            .send_api_request(ApiType::CheckpointDurableExecutionState, payload)
1075            .await
1076            .unwrap();
1077
1078        assert!(!response.is_error());
1079
1080        CheckpointWorkerManager::reset_instance_for_testing();
1081    }
1082}