Skip to main content

durable_execution_sdk_testing/checkpoint_server/
worker_manager.rs

1//! Checkpoint worker manager for managing the checkpoint server thread.
2//!
3//! This module implements the CheckpointWorkerManager which manages the lifecycle
4//! of the checkpoint server thread, matching the Node.js SDK's CheckpointWorkerManager.
5
6use std::sync::{Arc, Mutex, OnceLock};
7use std::thread::{self, JoinHandle};
8
9use tokio::sync::{mpsc, oneshot};
10use uuid::Uuid;
11
12use durable_execution_sdk::{
13    CheckpointResponse, DurableError, DurableServiceClient, ErrorObject, GetOperationsResponse,
14    OperationUpdate,
15};
16
17use super::execution_manager::ExecutionManager;
18use super::types::{
19    ApiType, CheckpointWorkerParams, CompleteInvocationRequest, SendCallbackFailureRequest,
20    SendCallbackHeartbeatRequest, SendCallbackSuccessRequest, StartDurableExecutionRequest,
21    StartInvocationRequest, WorkerApiRequest, WorkerApiResponse, WorkerCommand, WorkerCommandType,
22    WorkerResponse, WorkerResponseType,
23};
24use crate::error::TestError;
25
26/// Internal state for the checkpoint server.
27struct CheckpointServerState {
28    execution_manager: ExecutionManager,
29    params: CheckpointWorkerParams,
30}
31
32/// Message sent to the worker thread.
33enum WorkerMessage {
34    /// A command to process
35    Command(WorkerCommand, oneshot::Sender<WorkerResponse>),
36    /// Shutdown signal
37    Shutdown,
38}
39
40/// Manages the checkpoint server thread lifecycle.
41pub struct CheckpointWorkerManager {
42    /// Sender for commands to the worker
43    command_tx: mpsc::Sender<WorkerMessage>,
44    /// Handle to the worker thread (for joining on shutdown)
45    worker_handle: Option<JoinHandle<()>>,
46    /// Configuration parameters
47    params: CheckpointWorkerParams,
48}
49
50/// Global singleton instance
51static INSTANCE: OnceLock<Mutex<Option<Arc<CheckpointWorkerManager>>>> = OnceLock::new();
52
53impl CheckpointWorkerManager {
54    /// Get or create the singleton instance.
55    pub fn get_instance(
56        params: Option<CheckpointWorkerParams>,
57    ) -> Result<Arc<CheckpointWorkerManager>, TestError> {
58        let mutex = INSTANCE.get_or_init(|| Mutex::new(None));
59        let mut guard = mutex.lock().map_err(|e| {
60            TestError::CheckpointServerError(format!("Failed to lock instance: {}", e))
61        })?;
62
63        if let Some(ref instance) = *guard {
64            return Ok(Arc::clone(instance));
65        }
66
67        let params = params.unwrap_or_default();
68        let manager = Self::new(params)?;
69        let arc = Arc::new(manager);
70        *guard = Some(Arc::clone(&arc));
71        Ok(arc)
72    }
73
74    /// Reset the singleton instance (for testing).
75    ///
76    /// Note: This method should be called from a non-async context or at the start
77    /// of a test before any async operations. If called from within an async context,
78    /// it will skip the graceful shutdown and just clear the instance.
79    pub fn reset_instance_for_testing() {
80        if let Some(mutex) = INSTANCE.get() {
81            if let Ok(mut guard) = mutex.lock() {
82                if let Some(instance) = guard.take() {
83                    // Try to shutdown gracefully, but don't panic if we can't
84                    if let Ok(manager) = Arc::try_unwrap(instance) {
85                        // Check if we're in an async context by trying to get the current runtime
86                        // If we are, we can't use shutdown_sync, so just drop the manager
87                        if tokio::runtime::Handle::try_current().is_err() {
88                            // Not in an async context, safe to use shutdown_sync
89                            let _ = manager.shutdown_sync();
90                        }
91                        // If in async context, the manager will be dropped and the thread
92                        // will eventually terminate when the channel is closed
93                    }
94                }
95            }
96        }
97    }
98
99    /// Create a new checkpoint worker manager.
100    fn new(params: CheckpointWorkerParams) -> Result<Self, TestError> {
101        let (command_tx, command_rx) = mpsc::channel::<WorkerMessage>(100);
102
103        let worker_params = params.clone();
104        let worker_handle = thread::spawn(move || {
105            Self::run_worker(command_rx, worker_params);
106        });
107
108        Ok(Self {
109            command_tx,
110            worker_handle: Some(worker_handle),
111            params,
112        })
113    }
114
115    /// Run the worker thread.
116    fn run_worker(mut command_rx: mpsc::Receiver<WorkerMessage>, params: CheckpointWorkerParams) {
117        // Create a tokio runtime for the worker thread
118        let rt = tokio::runtime::Builder::new_current_thread()
119            .enable_all()
120            .build()
121            .expect("Failed to create tokio runtime for worker");
122
123        rt.block_on(async {
124            let mut state = CheckpointServerState {
125                execution_manager: ExecutionManager::new(),
126                params,
127            };
128
129            while let Some(message) = command_rx.recv().await {
130                match message {
131                    WorkerMessage::Command(command, response_tx) => {
132                        let response = Self::process_command(&mut state, command).await;
133                        let _ = response_tx.send(response);
134                    }
135                    WorkerMessage::Shutdown => {
136                        break;
137                    }
138                }
139            }
140        });
141    }
142
143    /// Process a command and return a response.
144    async fn process_command(
145        state: &mut CheckpointServerState,
146        command: WorkerCommand,
147    ) -> WorkerResponse {
148        // Simulate checkpoint delay if configured
149        if let Some(delay_ms) = state.params.checkpoint_delay {
150            tokio::time::sleep(tokio::time::Duration::from_millis(delay_ms)).await;
151        }
152
153        let api_response = match command.data.api_type {
154            ApiType::StartDurableExecution => Self::handle_start_execution(state, &command.data),
155            ApiType::StartInvocation => Self::handle_start_invocation(state, &command.data),
156            ApiType::CompleteInvocation => Self::handle_complete_invocation(state, &command.data),
157            ApiType::CheckpointDurableExecutionState => {
158                Self::handle_checkpoint(state, &command.data)
159            }
160            ApiType::GetDurableExecutionState => Self::handle_get_state(state, &command.data),
161            ApiType::SendDurableExecutionCallbackSuccess => {
162                Self::handle_callback_success(state, &command.data)
163            }
164            ApiType::SendDurableExecutionCallbackFailure => {
165                Self::handle_callback_failure(state, &command.data)
166            }
167            ApiType::SendDurableExecutionCallbackHeartbeat => {
168                Self::handle_callback_heartbeat(state, &command.data)
169            }
170            ApiType::UpdateCheckpointData => {
171                Self::handle_update_checkpoint_data(state, &command.data)
172            }
173            ApiType::GetNodeJsHistoryEvents => {
174                Self::handle_get_nodejs_history_events(state, &command.data)
175            }
176            _ => WorkerApiResponse::error(
177                command.data.api_type,
178                command.data.request_id.clone(),
179                format!("Unsupported API type: {:?}", command.data.api_type),
180            ),
181        };
182
183        WorkerResponse {
184            response_type: WorkerResponseType::ApiResponse,
185            data: api_response,
186        }
187    }
188
189    /// Handle StartDurableExecution request.
190    fn handle_start_execution(
191        state: &mut CheckpointServerState,
192        request: &WorkerApiRequest,
193    ) -> WorkerApiResponse {
194        let parsed: Result<StartDurableExecutionRequest, _> =
195            serde_json::from_str(&request.payload);
196
197        match parsed {
198            Ok(req) => {
199                let result = state.execution_manager.start_execution_from_request(req);
200                match serde_json::to_string(&result) {
201                    Ok(payload) => WorkerApiResponse::success(
202                        request.api_type,
203                        request.request_id.clone(),
204                        payload,
205                    ),
206                    Err(e) => WorkerApiResponse::error(
207                        request.api_type,
208                        request.request_id.clone(),
209                        format!("Failed to serialize response: {}", e),
210                    ),
211                }
212            }
213            Err(e) => WorkerApiResponse::error(
214                request.api_type,
215                request.request_id.clone(),
216                format!("Failed to parse request: {}", e),
217            ),
218        }
219    }
220
221    /// Handle StartInvocation request.
222    fn handle_start_invocation(
223        state: &mut CheckpointServerState,
224        request: &WorkerApiRequest,
225    ) -> WorkerApiResponse {
226        let parsed: Result<StartInvocationRequest, _> = serde_json::from_str(&request.payload);
227
228        match parsed {
229            Ok(req) => match state.execution_manager.start_invocation(req) {
230                Ok(result) => match serde_json::to_string(&result) {
231                    Ok(payload) => WorkerApiResponse::success(
232                        request.api_type,
233                        request.request_id.clone(),
234                        payload,
235                    ),
236                    Err(e) => WorkerApiResponse::error(
237                        request.api_type,
238                        request.request_id.clone(),
239                        format!("Failed to serialize response: {}", e),
240                    ),
241                },
242                Err(e) => WorkerApiResponse::error(
243                    request.api_type,
244                    request.request_id.clone(),
245                    format!("{}", e),
246                ),
247            },
248            Err(e) => WorkerApiResponse::error(
249                request.api_type,
250                request.request_id.clone(),
251                format!("Failed to parse request: {}", e),
252            ),
253        }
254    }
255
256    /// Handle CompleteInvocation request.
257    fn handle_complete_invocation(
258        state: &mut CheckpointServerState,
259        request: &WorkerApiRequest,
260    ) -> WorkerApiResponse {
261        let parsed: Result<CompleteInvocationRequest, _> = serde_json::from_str(&request.payload);
262
263        match parsed {
264            Ok(req) => match state.execution_manager.complete_invocation(req) {
265                Ok(result) => match serde_json::to_string(&result) {
266                    Ok(payload) => WorkerApiResponse::success(
267                        request.api_type,
268                        request.request_id.clone(),
269                        payload,
270                    ),
271                    Err(e) => WorkerApiResponse::error(
272                        request.api_type,
273                        request.request_id.clone(),
274                        format!("Failed to serialize response: {}", e),
275                    ),
276                },
277                Err(e) => WorkerApiResponse::error(
278                    request.api_type,
279                    request.request_id.clone(),
280                    format!("{}", e),
281                ),
282            },
283            Err(e) => WorkerApiResponse::error(
284                request.api_type,
285                request.request_id.clone(),
286                format!("Failed to parse request: {}", e),
287            ),
288        }
289    }
290
291    /// Handle CheckpointDurableExecutionState request.
292    fn handle_checkpoint(
293        state: &mut CheckpointServerState,
294        request: &WorkerApiRequest,
295    ) -> WorkerApiResponse {
296        let parsed: Result<CheckpointDurableExecutionStateRequest, _> =
297            serde_json::from_str(&request.payload);
298
299        match parsed {
300            Ok(req) => {
301                // Get checkpoint manager by token
302                match state
303                    .execution_manager
304                    .get_checkpoints_by_token_mut(&req.checkpoint_token)
305                {
306                    Ok(Some(checkpoint_manager)) => {
307                        match checkpoint_manager.process_checkpoint(req.operations) {
308                            Ok(dirty_ops) => {
309                                // Build response matching SDK's CheckpointResponse structure
310                                let response = serde_json::json!({
311                                    "CheckpointToken": req.checkpoint_token,
312                                    "NewExecutionState": {
313                                        "Operations": dirty_ops
314                                    }
315                                });
316                                match serde_json::to_string(&response) {
317                                    Ok(payload) => WorkerApiResponse::success(
318                                        request.api_type,
319                                        request.request_id.clone(),
320                                        payload,
321                                    ),
322                                    Err(e) => WorkerApiResponse::error(
323                                        request.api_type,
324                                        request.request_id.clone(),
325                                        format!("Failed to serialize response: {}", e),
326                                    ),
327                                }
328                            }
329                            Err(e) => WorkerApiResponse::error(
330                                request.api_type,
331                                request.request_id.clone(),
332                                format!("Checkpoint processing failed: {}", e),
333                            ),
334                        }
335                    }
336                    Ok(None) => WorkerApiResponse::error(
337                        request.api_type,
338                        request.request_id.clone(),
339                        "Execution not found for checkpoint token".to_string(),
340                    ),
341                    Err(e) => WorkerApiResponse::error(
342                        request.api_type,
343                        request.request_id.clone(),
344                        format!("Invalid checkpoint token: {}", e),
345                    ),
346                }
347            }
348            Err(e) => WorkerApiResponse::error(
349                request.api_type,
350                request.request_id.clone(),
351                format!("Failed to parse request: {}", e),
352            ),
353        }
354    }
355
356    /// Handle GetDurableExecutionState request.
357    fn handle_get_state(
358        state: &mut CheckpointServerState,
359        request: &WorkerApiRequest,
360    ) -> WorkerApiResponse {
361        let parsed: Result<GetDurableExecutionStateRequest, _> =
362            serde_json::from_str(&request.payload);
363
364        match parsed {
365            Ok(req) => {
366                // Extract execution ID from ARN (simplified - just use the ARN as ID for testing)
367                let execution_id = &req.durable_execution_arn;
368                match state
369                    .execution_manager
370                    .get_checkpoints_by_execution(execution_id)
371                {
372                    Some(checkpoint_manager) => {
373                        let operations = checkpoint_manager.get_state();
374                        let response = GetOperationsResponse {
375                            operations,
376                            next_marker: None,
377                        };
378                        match serde_json::to_string(&response) {
379                            Ok(payload) => WorkerApiResponse::success(
380                                request.api_type,
381                                request.request_id.clone(),
382                                payload,
383                            ),
384                            Err(e) => WorkerApiResponse::error(
385                                request.api_type,
386                                request.request_id.clone(),
387                                format!("Failed to serialize response: {}", e),
388                            ),
389                        }
390                    }
391                    None => WorkerApiResponse::error(
392                        request.api_type,
393                        request.request_id.clone(),
394                        format!("Execution not found: {}", execution_id),
395                    ),
396                }
397            }
398            Err(e) => WorkerApiResponse::error(
399                request.api_type,
400                request.request_id.clone(),
401                format!("Failed to parse request: {}", e),
402            ),
403        }
404    }
405
406    /// Handle SendDurableExecutionCallbackSuccess request.
407    fn handle_callback_success(
408        state: &mut CheckpointServerState,
409        request: &WorkerApiRequest,
410    ) -> WorkerApiResponse {
411        let parsed: Result<SendCallbackSuccessRequest, _> = serde_json::from_str(&request.payload);
412
413        match parsed {
414            Ok(req) => {
415                // Collect execution IDs first to avoid borrow conflict
416                let execution_ids: Vec<String> = state
417                    .execution_manager
418                    .get_execution_ids()
419                    .into_iter()
420                    .cloned()
421                    .collect();
422
423                // Find the execution containing this callback
424                for execution_id in execution_ids {
425                    if let Some(checkpoint_manager) = state
426                        .execution_manager
427                        .get_checkpoints_by_execution_mut(&execution_id)
428                    {
429                        if checkpoint_manager
430                            .callback_manager()
431                            .get_callback_state(&req.callback_id)
432                            .is_some()
433                        {
434                            match checkpoint_manager
435                                .callback_manager_mut()
436                                .send_success(&req.callback_id, &req.result)
437                            {
438                                Ok(()) => {
439                                    // Also update the callback operation status to Succeeded
440                                    // so the orchestrator can detect completion and re-invoke.
441                                    checkpoint_manager.complete_callback_operation(
442                                        &req.callback_id,
443                                        Some(req.result.clone()),
444                                        None,
445                                    );
446                                    return WorkerApiResponse::success(
447                                        request.api_type,
448                                        request.request_id.clone(),
449                                        "{}".to_string(),
450                                    );
451                                }
452                                Err(e) => {
453                                    return WorkerApiResponse::error(
454                                        request.api_type,
455                                        request.request_id.clone(),
456                                        format!("{}", e),
457                                    );
458                                }
459                            }
460                        }
461                    }
462                }
463                WorkerApiResponse::error(
464                    request.api_type,
465                    request.request_id.clone(),
466                    format!("Callback not found: {}", req.callback_id),
467                )
468            }
469            Err(e) => WorkerApiResponse::error(
470                request.api_type,
471                request.request_id.clone(),
472                format!("Failed to parse request: {}", e),
473            ),
474        }
475    }
476
477    /// Handle SendDurableExecutionCallbackFailure request.
478    fn handle_callback_failure(
479        state: &mut CheckpointServerState,
480        request: &WorkerApiRequest,
481    ) -> WorkerApiResponse {
482        let parsed: Result<SendCallbackFailureRequest, _> = serde_json::from_str(&request.payload);
483
484        match parsed {
485            Ok(req) => {
486                // Collect execution IDs first to avoid borrow conflict
487                let execution_ids: Vec<String> = state
488                    .execution_manager
489                    .get_execution_ids()
490                    .into_iter()
491                    .cloned()
492                    .collect();
493
494                // Find the execution containing this callback
495                for execution_id in execution_ids {
496                    if let Some(checkpoint_manager) = state
497                        .execution_manager
498                        .get_checkpoints_by_execution_mut(&execution_id)
499                    {
500                        if checkpoint_manager
501                            .callback_manager()
502                            .get_callback_state(&req.callback_id)
503                            .is_some()
504                        {
505                            match checkpoint_manager
506                                .callback_manager_mut()
507                                .send_failure(&req.callback_id, &req.error)
508                            {
509                                Ok(()) => {
510                                    // Also update the callback operation status to Failed
511                                    // so the orchestrator can detect completion and re-invoke.
512                                    checkpoint_manager.complete_callback_operation(
513                                        &req.callback_id,
514                                        None,
515                                        Some(req.error.clone()),
516                                    );
517                                    return WorkerApiResponse::success(
518                                        request.api_type,
519                                        request.request_id.clone(),
520                                        "{}".to_string(),
521                                    );
522                                }
523                                Err(e) => {
524                                    return WorkerApiResponse::error(
525                                        request.api_type,
526                                        request.request_id.clone(),
527                                        format!("{}", e),
528                                    );
529                                }
530                            }
531                        }
532                    }
533                }
534                WorkerApiResponse::error(
535                    request.api_type,
536                    request.request_id.clone(),
537                    format!("Callback not found: {}", req.callback_id),
538                )
539            }
540            Err(e) => WorkerApiResponse::error(
541                request.api_type,
542                request.request_id.clone(),
543                format!("Failed to parse request: {}", e),
544            ),
545        }
546    }
547
548    /// Handle SendDurableExecutionCallbackHeartbeat request.
549    fn handle_callback_heartbeat(
550        state: &mut CheckpointServerState,
551        request: &WorkerApiRequest,
552    ) -> WorkerApiResponse {
553        let parsed: Result<SendCallbackHeartbeatRequest, _> =
554            serde_json::from_str(&request.payload);
555
556        match parsed {
557            Ok(req) => {
558                // Collect execution IDs first to avoid borrow conflict
559                let execution_ids: Vec<String> = state
560                    .execution_manager
561                    .get_execution_ids()
562                    .into_iter()
563                    .cloned()
564                    .collect();
565
566                // Find the execution containing this callback
567                for execution_id in execution_ids {
568                    if let Some(checkpoint_manager) = state
569                        .execution_manager
570                        .get_checkpoints_by_execution_mut(&execution_id)
571                    {
572                        if checkpoint_manager
573                            .callback_manager()
574                            .get_callback_state(&req.callback_id)
575                            .is_some()
576                        {
577                            match checkpoint_manager
578                                .callback_manager_mut()
579                                .send_heartbeat(&req.callback_id)
580                            {
581                                Ok(()) => {
582                                    return WorkerApiResponse::success(
583                                        request.api_type,
584                                        request.request_id.clone(),
585                                        "{}".to_string(),
586                                    );
587                                }
588                                Err(e) => {
589                                    return WorkerApiResponse::error(
590                                        request.api_type,
591                                        request.request_id.clone(),
592                                        format!("{}", e),
593                                    );
594                                }
595                            }
596                        }
597                    }
598                }
599                WorkerApiResponse::error(
600                    request.api_type,
601                    request.request_id.clone(),
602                    format!("Callback not found: {}", req.callback_id),
603                )
604            }
605            Err(e) => WorkerApiResponse::error(
606                request.api_type,
607                request.request_id.clone(),
608                format!("Failed to parse request: {}", e),
609            ),
610        }
611    }
612
613    /// Handle UpdateCheckpointData request.
614    ///
615    /// This method updates the state of a specific operation in the checkpoint server.
616    /// It's used by the orchestrator to mark wait operations as SUCCEEDED after time
617    /// has been advanced in time-skipping mode.
618    fn handle_update_checkpoint_data(
619        state: &mut CheckpointServerState,
620        request: &WorkerApiRequest,
621    ) -> WorkerApiResponse {
622        let parsed: Result<super::types::UpdateCheckpointDataRequest, _> =
623            serde_json::from_str(&request.payload);
624
625        match parsed {
626            Ok(req) => {
627                // Get checkpoint manager by execution ID
628                match state
629                    .execution_manager
630                    .get_checkpoints_by_execution_mut(&req.execution_id)
631                {
632                    Some(checkpoint_manager) => {
633                        // Update the operation data
634                        checkpoint_manager
635                            .update_operation_data(&req.operation_id, req.operation_data);
636                        WorkerApiResponse::success(
637                            request.api_type,
638                            request.request_id.clone(),
639                            "{}".to_string(),
640                        )
641                    }
642                    None => WorkerApiResponse::error(
643                        request.api_type,
644                        request.request_id.clone(),
645                        format!("Execution not found: {}", req.execution_id),
646                    ),
647                }
648            }
649            Err(e) => WorkerApiResponse::error(
650                request.api_type,
651                request.request_id.clone(),
652                format!("Failed to parse request: {}", e),
653            ),
654        }
655    }
656
657    /// Handle GetNodeJsHistoryEvents request.
658    ///
659    /// This method retrieves the Node.js-compatible history events for an execution.
660    /// These events match the Node.js SDK's event history format exactly,
661    /// enabling cross-SDK history comparison.
662    fn handle_get_nodejs_history_events(
663        state: &mut CheckpointServerState,
664        request: &WorkerApiRequest,
665    ) -> WorkerApiResponse {
666        let parsed: Result<super::types::GetNodeJsHistoryEventsRequest, _> =
667            serde_json::from_str(&request.payload);
668
669        match parsed {
670            Ok(req) => {
671                // Get checkpoint manager by execution ID
672                match state
673                    .execution_manager
674                    .get_checkpoints_by_execution(&req.execution_id)
675                {
676                    Some(checkpoint_manager) => {
677                        // Get Node.js history events
678                        let events = checkpoint_manager.get_nodejs_history_events();
679                        match serde_json::to_string(&events) {
680                            Ok(payload) => WorkerApiResponse::success(
681                                request.api_type,
682                                request.request_id.clone(),
683                                payload,
684                            ),
685                            Err(e) => WorkerApiResponse::error(
686                                request.api_type,
687                                request.request_id.clone(),
688                                format!("Failed to serialize response: {}", e),
689                            ),
690                        }
691                    }
692                    None => WorkerApiResponse::error(
693                        request.api_type,
694                        request.request_id.clone(),
695                        format!("Execution not found: {}", req.execution_id),
696                    ),
697                }
698            }
699            Err(e) => WorkerApiResponse::error(
700                request.api_type,
701                request.request_id.clone(),
702                format!("Failed to parse request: {}", e),
703            ),
704        }
705    }
706
707    /// Send an API request to the checkpoint server and wait for response.
708    pub async fn send_api_request(
709        &self,
710        api_type: ApiType,
711        payload: String,
712    ) -> Result<WorkerApiResponse, TestError> {
713        let request_id = Uuid::new_v4().to_string();
714        let command = WorkerCommand {
715            command_type: WorkerCommandType::ApiRequest,
716            data: WorkerApiRequest {
717                api_type,
718                request_id: request_id.clone(),
719                payload,
720            },
721        };
722
723        let (response_tx, response_rx) = oneshot::channel();
724
725        self.command_tx
726            .send(WorkerMessage::Command(command, response_tx))
727            .await
728            .map_err(|e| {
729                TestError::CheckpointCommunicationError(format!("Failed to send command: {}", e))
730            })?;
731
732        let response = response_rx.await.map_err(|e| {
733            TestError::CheckpointCommunicationError(format!("Failed to receive response: {}", e))
734        })?;
735
736        Ok(response.data)
737    }
738
739    /// Gracefully shut down the checkpoint server.
740    pub async fn shutdown(mut self) -> Result<(), TestError> {
741        // Send shutdown signal
742        let _ = self.command_tx.send(WorkerMessage::Shutdown).await;
743
744        // Wait for worker thread to finish
745        if let Some(handle) = self.worker_handle.take() {
746            handle.join().map_err(|_| {
747                TestError::CheckpointServerError("Worker thread panicked".to_string())
748            })?;
749        }
750
751        Ok(())
752    }
753
754    /// Synchronous shutdown (for use in Drop or non-async contexts).
755    pub fn shutdown_sync(mut self) -> Result<(), TestError> {
756        // Create a runtime to send the shutdown signal
757        if let Ok(rt) = tokio::runtime::Builder::new_current_thread()
758            .enable_all()
759            .build()
760        {
761            let _ = rt.block_on(self.command_tx.send(WorkerMessage::Shutdown));
762        }
763
764        // Wait for worker thread to finish
765        if let Some(handle) = self.worker_handle.take() {
766            handle.join().map_err(|_| {
767                TestError::CheckpointServerError("Worker thread panicked".to_string())
768            })?;
769        }
770
771        Ok(())
772    }
773
774    /// Get the configuration parameters.
775    pub fn params(&self) -> &CheckpointWorkerParams {
776        &self.params
777    }
778}
779
780// Additional types needed for checkpoint requests
781use super::types::{CheckpointDurableExecutionStateRequest, GetDurableExecutionStateRequest};
782
783// Implement DurableServiceClient trait for CheckpointWorkerManager
784use async_trait::async_trait;
785
786#[async_trait]
787impl DurableServiceClient for CheckpointWorkerManager {
788    async fn checkpoint(
789        &self,
790        durable_execution_arn: &str,
791        checkpoint_token: &str,
792        operations: Vec<OperationUpdate>,
793    ) -> Result<CheckpointResponse, DurableError> {
794        let request = CheckpointDurableExecutionStateRequest {
795            durable_execution_arn: durable_execution_arn.to_string(),
796            checkpoint_token: checkpoint_token.to_string(),
797            operations,
798        };
799
800        let payload = serde_json::to_string(&request)
801            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
802
803        let response = self
804            .send_api_request(ApiType::CheckpointDurableExecutionState, payload)
805            .await
806            .map_err(|e| {
807                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
808            })?;
809
810        if let Some(error) = response.error {
811            return Err(DurableError::checkpoint_retriable(error));
812        }
813
814        let payload = response
815            .payload
816            .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
817
818        serde_json::from_str(&payload)
819            .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
820    }
821
822    async fn get_operations(
823        &self,
824        durable_execution_arn: &str,
825        _next_marker: &str,
826    ) -> Result<GetOperationsResponse, DurableError> {
827        let request = GetDurableExecutionStateRequest {
828            durable_execution_arn: durable_execution_arn.to_string(),
829        };
830
831        let payload = serde_json::to_string(&request)
832            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
833
834        let response = self
835            .send_api_request(ApiType::GetDurableExecutionState, payload)
836            .await
837            .map_err(|e| {
838                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
839            })?;
840
841        if let Some(error) = response.error {
842            return Err(DurableError::checkpoint_retriable(error));
843        }
844
845        let payload = response
846            .payload
847            .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
848
849        serde_json::from_str(&payload)
850            .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
851    }
852}
853
854// Callback methods (not part of DurableServiceClient trait)
855impl CheckpointWorkerManager {
856    /// Send a callback success response.
857    pub async fn send_callback_success(
858        &self,
859        callback_id: &str,
860        result: &str,
861    ) -> Result<(), DurableError> {
862        let request = SendCallbackSuccessRequest {
863            callback_id: callback_id.to_string(),
864            result: result.to_string(),
865        };
866
867        let payload = serde_json::to_string(&request)
868            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
869
870        let response = self
871            .send_api_request(ApiType::SendDurableExecutionCallbackSuccess, payload)
872            .await
873            .map_err(|e| {
874                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
875            })?;
876
877        if let Some(error) = response.error {
878            return Err(DurableError::execution(error));
879        }
880
881        Ok(())
882    }
883
884    /// Send a callback failure response.
885    pub async fn send_callback_failure(
886        &self,
887        callback_id: &str,
888        error: &ErrorObject,
889    ) -> Result<(), DurableError> {
890        let request = SendCallbackFailureRequest {
891            callback_id: callback_id.to_string(),
892            error: error.clone(),
893        };
894
895        let payload = serde_json::to_string(&request)
896            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
897
898        let response = self
899            .send_api_request(ApiType::SendDurableExecutionCallbackFailure, payload)
900            .await
901            .map_err(|e| {
902                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
903            })?;
904
905        if let Some(error) = response.error {
906            return Err(DurableError::execution(error));
907        }
908
909        Ok(())
910    }
911
912    /// Send a callback heartbeat.
913    pub async fn send_callback_heartbeat(&self, callback_id: &str) -> Result<(), DurableError> {
914        let request = SendCallbackHeartbeatRequest {
915            callback_id: callback_id.to_string(),
916        };
917
918        let payload = serde_json::to_string(&request)
919            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
920
921        let response = self
922            .send_api_request(ApiType::SendDurableExecutionCallbackHeartbeat, payload)
923            .await
924            .map_err(|e| {
925                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
926            })?;
927
928        if let Some(error) = response.error {
929            return Err(DurableError::execution(error));
930        }
931
932        Ok(())
933    }
934
935    /// Get Node.js-compatible history events for an execution.
936    ///
937    /// Returns events in the Node.js SDK compatible format, suitable for
938    /// cross-SDK history comparison. These events use PascalCase field names
939    /// and include detailed event-specific information.
940    ///
941    /// # Arguments
942    ///
943    /// * `execution_id` - The execution ID to get history events for
944    ///
945    /// # Returns
946    ///
947    /// A vector of Node.js-compatible history events.
948    pub async fn get_nodejs_history_events(
949        &self,
950        execution_id: &str,
951    ) -> Result<Vec<super::nodejs_event_types::NodeJsHistoryEvent>, DurableError> {
952        use super::types::GetNodeJsHistoryEventsRequest;
953
954        let request = GetNodeJsHistoryEventsRequest {
955            execution_id: execution_id.to_string(),
956        };
957
958        let payload = serde_json::to_string(&request)
959            .map_err(|e| DurableError::validation(format!("Failed to serialize request: {}", e)))?;
960
961        let response = self
962            .send_api_request(ApiType::GetNodeJsHistoryEvents, payload)
963            .await
964            .map_err(|e| {
965                DurableError::checkpoint_retriable(format!("Communication error: {}", e))
966            })?;
967
968        if let Some(error) = response.error {
969            return Err(DurableError::checkpoint_retriable(error));
970        }
971
972        let payload = response
973            .payload
974            .ok_or_else(|| DurableError::checkpoint_retriable("Empty response payload"))?;
975
976        serde_json::from_str(&payload)
977            .map_err(|e| DurableError::validation(format!("Failed to parse response: {}", e)))
978    }
979}
980
981#[cfg(test)]
982mod tests {
983    use super::*;
984    use crate::checkpoint_server::InvocationResult;
985
986    #[tokio::test]
987    async fn test_get_instance() {
988        CheckpointWorkerManager::reset_instance_for_testing();
989
990        let instance = CheckpointWorkerManager::get_instance(None).unwrap();
991        assert!(instance.params().checkpoint_delay.is_none());
992
993        // Getting instance again should return the same one
994        let instance2 = CheckpointWorkerManager::get_instance(None).unwrap();
995        assert!(Arc::ptr_eq(&instance, &instance2));
996
997        CheckpointWorkerManager::reset_instance_for_testing();
998    }
999
1000    #[tokio::test]
1001    async fn test_start_execution() {
1002        CheckpointWorkerManager::reset_instance_for_testing();
1003
1004        let manager = CheckpointWorkerManager::get_instance(None).unwrap();
1005
1006        let request = StartDurableExecutionRequest {
1007            invocation_id: "inv-1".to_string(),
1008            payload: Some(r#"{"test": true}"#.to_string()),
1009        };
1010
1011        let payload = serde_json::to_string(&request).unwrap();
1012        let response = manager
1013            .send_api_request(ApiType::StartDurableExecution, payload)
1014            .await
1015            .unwrap();
1016
1017        assert!(!response.is_error());
1018        assert!(response.payload.is_some());
1019
1020        let result: InvocationResult = serde_json::from_str(&response.payload.unwrap()).unwrap();
1021        assert!(!result.execution_id.is_empty());
1022        assert_eq!(result.invocation_id, "inv-1");
1023
1024        CheckpointWorkerManager::reset_instance_for_testing();
1025    }
1026
1027    #[tokio::test]
1028    async fn test_checkpoint_workflow() {
1029        CheckpointWorkerManager::reset_instance_for_testing();
1030
1031        let manager = CheckpointWorkerManager::get_instance(None).unwrap();
1032
1033        // Start execution
1034        let start_request = StartDurableExecutionRequest {
1035            invocation_id: "inv-1".to_string(),
1036            payload: Some("{}".to_string()),
1037        };
1038        let payload = serde_json::to_string(&start_request).unwrap();
1039        let response = manager
1040            .send_api_request(ApiType::StartDurableExecution, payload)
1041            .await
1042            .unwrap();
1043
1044        let result: InvocationResult = serde_json::from_str(&response.payload.unwrap()).unwrap();
1045
1046        // Send checkpoint using OperationUpdate::start helper
1047        let checkpoint_request = CheckpointDurableExecutionStateRequest {
1048            durable_execution_arn: result.execution_id.clone(),
1049            checkpoint_token: result.checkpoint_token.clone(),
1050            operations: vec![OperationUpdate::start(
1051                "op-1",
1052                durable_execution_sdk::OperationType::Step,
1053            )
1054            .with_name("test-step")],
1055        };
1056
1057        let payload = serde_json::to_string(&checkpoint_request).unwrap();
1058        let response = manager
1059            .send_api_request(ApiType::CheckpointDurableExecutionState, payload)
1060            .await
1061            .unwrap();
1062
1063        assert!(!response.is_error());
1064
1065        CheckpointWorkerManager::reset_instance_for_testing();
1066    }
1067}