durable_execution_sdk_testing/cloud_runner.rs
1//! Cloud test runner for testing deployed Lambda functions.
2//!
3//! This module provides the `CloudDurableTestRunner` for testing durable functions
4//! deployed to AWS Lambda, enabling integration testing against real AWS infrastructure.
5//!
6//! # Examples
7//!
8//! ```ignore
9//! use durable_execution_sdk_testing::{
10//! CloudDurableTestRunner, CloudTestRunnerConfig, ExecutionStatus,
11//! };
12//!
13//! #[tokio::test]
14//! async fn test_deployed_workflow() {
15//! let runner = CloudDurableTestRunner::<String>::new("my-function-name")
16//! .await
17//! .unwrap();
18//!
19//! let result = runner.run("input".to_string()).await.unwrap();
20//! assert_eq!(result.get_status(), ExecutionStatus::Succeeded);
21//! }
22//! ```
23
24use std::collections::HashMap;
25use std::marker::PhantomData;
26use std::sync::Arc;
27use std::time::{Duration, Instant};
28
29use aws_sdk_lambda::Client as LambdaClient;
30use serde::de::DeserializeOwned;
31use serde::Serialize;
32use tokio::sync::RwLock;
33
34use crate::error::TestError;
35use crate::history_poller::{HistoryApiClient, HistoryPage, HistoryPoller};
36use crate::operation::{CallbackSender, DurableOperation};
37use crate::operation_handle::{OperationHandle, OperationMatcher};
38use crate::test_result::TestResult;
39use crate::types::{ExecutionStatus, TestResultError};
40use durable_execution_sdk::{
41 DurableServiceClient, LambdaDurableServiceClient, Operation, OperationStatus, OperationType,
42};
43
44/// Configuration for the cloud test runner.
45///
46/// Controls polling behavior and timeouts when testing deployed Lambda functions.
47///
48/// # Examples
49///
50/// ```
51/// use durable_execution_sdk_testing::CloudTestRunnerConfig;
52/// use std::time::Duration;
53///
54/// let config = CloudTestRunnerConfig {
55/// poll_interval: Duration::from_millis(500),
56/// timeout: Duration::from_secs(60),
57/// };
58/// ```
59#[derive(Debug, Clone)]
60pub struct CloudTestRunnerConfig {
61 /// Polling interval when waiting for execution completion.
62 ///
63 /// Default: 1000ms (1 second)
64 pub poll_interval: Duration,
65
66 /// Maximum wait time for execution completion.
67 ///
68 /// Default: 300 seconds (5 minutes)
69 pub timeout: Duration,
70}
71
72impl Default for CloudTestRunnerConfig {
73 fn default() -> Self {
74 Self {
75 poll_interval: Duration::from_millis(1000),
76 timeout: Duration::from_secs(300),
77 }
78 }
79}
80
81impl CloudTestRunnerConfig {
82 /// Creates a new configuration with default values.
83 pub fn new() -> Self {
84 Self::default()
85 }
86
87 /// Sets the polling interval.
88 ///
89 /// # Arguments
90 ///
91 /// * `interval` - The interval between status polls
92 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
93 self.poll_interval = interval;
94 self
95 }
96
97 /// Sets the timeout.
98 ///
99 /// # Arguments
100 ///
101 /// * `timeout` - The maximum time to wait for execution completion
102 pub fn with_timeout(mut self, timeout: Duration) -> Self {
103 self.timeout = timeout;
104 self
105 }
106}
107
108/// Internal storage for operations captured from cloud execution.
109#[derive(Debug, Default)]
110struct OperationStorage {
111 /// All operations in execution order
112 operations: Vec<Operation>,
113 /// Map from operation ID to index in operations vec
114 operations_by_id: HashMap<String, usize>,
115 /// Map from operation name to indices in operations vec
116 operations_by_name: HashMap<String, Vec<usize>>,
117}
118
119impl OperationStorage {
120 fn new() -> Self {
121 Self::default()
122 }
123
124 #[allow(dead_code)]
125 fn add_operation(&mut self, operation: Operation) {
126 let index = self.operations.len();
127 let id = operation.operation_id.clone();
128 let name = operation.name.clone();
129
130 self.operations.push(operation);
131 self.operations_by_id.insert(id, index);
132
133 if let Some(name) = name {
134 self.operations_by_name.entry(name).or_default().push(index);
135 }
136 }
137
138 /// If an operation with the same `operation_id` already exists, update it
139 /// in place; otherwise append it via `add_operation`.
140 #[allow(dead_code)]
141 pub(crate) fn add_or_update(&mut self, operation: Operation) {
142 if let Some(&idx) = self.operations_by_id.get(&operation.operation_id) {
143 self.operations[idx] = operation;
144 } else {
145 self.add_operation(operation);
146 }
147 }
148
149 fn get_by_id(&self, id: &str) -> Option<&Operation> {
150 self.operations_by_id
151 .get(id)
152 .and_then(|&idx| self.operations.get(idx))
153 }
154
155 fn get_by_name(&self, name: &str) -> Option<&Operation> {
156 self.operations_by_name
157 .get(name)
158 .and_then(|indices| indices.first())
159 .and_then(|&idx| self.operations.get(idx))
160 }
161
162 fn get_by_name_and_index(&self, name: &str, index: usize) -> Option<&Operation> {
163 self.operations_by_name
164 .get(name)
165 .and_then(|indices| indices.get(index))
166 .and_then(|&idx| self.operations.get(idx))
167 }
168
169 fn get_by_index(&self, index: usize) -> Option<&Operation> {
170 self.operations.get(index)
171 }
172
173 fn get_all(&self) -> &[Operation] {
174 &self.operations
175 }
176
177 fn clear(&mut self) {
178 self.operations.clear();
179 self.operations_by_id.clear();
180 self.operations_by_name.clear();
181 }
182}
183
184/// Real implementation of [`HistoryApiClient`] that calls the durable execution
185/// state API via a [`LambdaDurableServiceClient`].
186///
187/// Wraps a `LambdaClient` and creates an internal service client to make
188/// signed HTTP calls to the `GetDurableExecutionHistory` API endpoint.
189pub struct LambdaHistoryApiClient {
190 service_client: LambdaDurableServiceClient,
191}
192
193impl LambdaHistoryApiClient {
194 /// Creates a new `LambdaHistoryApiClient` from an AWS SDK config.
195 ///
196 /// The config is used to construct a `LambdaDurableServiceClient` that
197 /// makes signed HTTP calls to the durable execution state API.
198 ///
199 /// # Arguments
200 ///
201 /// * `aws_config` - The AWS SDK configuration (same one used to create the `LambdaClient`)
202 pub fn from_aws_config(aws_config: &aws_config::SdkConfig) -> Result<Self, TestError> {
203 Ok(Self {
204 service_client: LambdaDurableServiceClient::from_aws_config(aws_config).map_err(
205 |e| TestError::aws_error(format!("Failed to create service client: {}", e)),
206 )?,
207 })
208 }
209
210 /// Creates a new `LambdaHistoryApiClient` from an existing `LambdaDurableServiceClient`.
211 ///
212 /// Useful when a service client is already available.
213 pub fn from_service_client(service_client: LambdaDurableServiceClient) -> Self {
214 Self { service_client }
215 }
216
217 /// Maps an [`OperationStatus`] to an [`ExecutionStatus`] for terminal detection.
218 fn map_terminal_status(status: &OperationStatus) -> Option<ExecutionStatus> {
219 match status {
220 OperationStatus::Succeeded => Some(ExecutionStatus::Succeeded),
221 OperationStatus::Failed => Some(ExecutionStatus::Failed),
222 OperationStatus::Cancelled => Some(ExecutionStatus::Cancelled),
223 OperationStatus::TimedOut => Some(ExecutionStatus::TimedOut),
224 _ => None,
225 }
226 }
227}
228
229#[async_trait::async_trait]
230impl HistoryApiClient for LambdaHistoryApiClient {
231 /// Retrieves a single page of execution history by calling the durable execution state API.
232 ///
233 /// Detects terminal state by examining EXECUTION-type operations: when an execution
234 /// operation has a terminal status (Succeeded, Failed, Cancelled, TimedOut), the page
235 /// is marked as terminal with the corresponding status, result, and error.
236 async fn get_history(&self, arn: &str, marker: Option<&str>) -> Result<HistoryPage, TestError> {
237 let marker_str = marker.unwrap_or("");
238
239 let response = self
240 .service_client
241 .get_operations(arn, marker_str)
242 .await
243 .map_err(|e| {
244 TestError::aws_error(format!("GetDurableExecutionHistory failed: {}", e))
245 })?;
246
247 // Detect terminal state from EXECUTION-type operations
248 let mut is_terminal = false;
249 let mut terminal_status = None;
250 let mut terminal_result = None;
251 let mut terminal_error = None;
252
253 for op in &response.operations {
254 if op.operation_type == OperationType::Execution {
255 if let Some(exec_status) = Self::map_terminal_status(&op.status) {
256 is_terminal = true;
257 terminal_status = Some(exec_status);
258 terminal_result = op.result.clone();
259 if let Some(ref err) = op.error {
260 terminal_error =
261 Some(TestResultError::new(&err.error_type, &err.error_message));
262 }
263 break;
264 }
265 }
266 }
267
268 Ok(HistoryPage {
269 events: Vec::new(), // The state API returns operations, not raw history events
270 operations: response.operations,
271 next_marker: response.next_marker,
272 is_terminal,
273 terminal_status,
274 terminal_result,
275 terminal_error,
276 })
277 }
278}
279
280/// Sends callback signals (success, failure, heartbeat) to a durable execution
281/// via the AWS Lambda API.
282///
283/// This bridges the [`CallbackSender`] trait (used by [`OperationHandle`]) with
284/// the Lambda durable execution callback APIs, enabling handles to send callback
285/// responses during cloud test execution.
286pub(crate) struct CloudCallbackSender {
287 /// The AWS Lambda client used to send callback API calls
288 client: LambdaClient,
289 /// The ARN of the durable execution to send callbacks to
290 _durable_execution_arn: String,
291}
292
293impl CloudCallbackSender {
294 /// Creates a new `CloudCallbackSender`.
295 ///
296 /// # Arguments
297 ///
298 /// * `client` - The AWS Lambda client
299 /// * `durable_execution_arn` - The ARN of the durable execution
300 pub fn new(client: LambdaClient, durable_execution_arn: String) -> Self {
301 Self {
302 client,
303 _durable_execution_arn: durable_execution_arn,
304 }
305 }
306}
307
308#[async_trait::async_trait]
309impl CallbackSender for CloudCallbackSender {
310 /// Sends a success response for a callback operation.
311 ///
312 /// Calls the `SendDurableExecutionCallbackSuccess` Lambda API with the
313 /// callback ID and result payload.
314 ///
315 /// # Arguments
316 ///
317 /// * `callback_id` - The unique callback identifier
318 /// * `result` - The success result payload (serialized as bytes)
319 async fn send_success(&self, callback_id: &str, result: &str) -> Result<(), TestError> {
320 self.client
321 .send_durable_execution_callback_success()
322 .callback_id(callback_id)
323 .result(aws_sdk_lambda::primitives::Blob::new(result.as_bytes()))
324 .send()
325 .await
326 .map_err(|e| {
327 TestError::aws_error(format!(
328 "SendDurableExecutionCallbackSuccess failed for callback '{}': {}",
329 callback_id, e
330 ))
331 })?;
332 Ok(())
333 }
334
335 /// Sends a failure response for a callback operation.
336 ///
337 /// Calls the `SendDurableExecutionCallbackFailure` Lambda API with the
338 /// callback ID and error details.
339 ///
340 /// # Arguments
341 ///
342 /// * `callback_id` - The unique callback identifier
343 /// * `error` - The error information to send
344 async fn send_failure(
345 &self,
346 callback_id: &str,
347 error: &TestResultError,
348 ) -> Result<(), TestError> {
349 let error_object = aws_sdk_lambda::types::ErrorObject::builder()
350 .set_error_type(error.error_type.clone())
351 .set_error_message(error.error_message.clone())
352 .build();
353
354 self.client
355 .send_durable_execution_callback_failure()
356 .callback_id(callback_id)
357 .error(error_object)
358 .send()
359 .await
360 .map_err(|e| {
361 TestError::aws_error(format!(
362 "SendDurableExecutionCallbackFailure failed for callback '{}': {}",
363 callback_id, e
364 ))
365 })?;
366 Ok(())
367 }
368
369 /// Sends a heartbeat for a callback operation.
370 ///
371 /// Calls the `SendDurableExecutionCallbackHeartbeat` Lambda API with the
372 /// callback ID to keep the callback active.
373 ///
374 /// # Arguments
375 ///
376 /// * `callback_id` - The unique callback identifier
377 async fn send_heartbeat(&self, callback_id: &str) -> Result<(), TestError> {
378 self.client
379 .send_durable_execution_callback_heartbeat()
380 .callback_id(callback_id)
381 .send()
382 .await
383 .map_err(|e| {
384 TestError::aws_error(format!(
385 "SendDurableExecutionCallbackHeartbeat failed for callback '{}': {}",
386 callback_id, e
387 ))
388 })?;
389 Ok(())
390 }
391}
392
393/// Cloud test runner for testing deployed Lambda functions.
394///
395/// Invokes deployed Lambda functions and polls for execution completion,
396/// enabling integration testing against real AWS infrastructure.
397///
398/// # Type Parameters
399///
400/// * `O` - The output type (must be deserializable)
401///
402/// # Examples
403///
404/// ```ignore
405/// use durable_execution_sdk_testing::CloudDurableTestRunner;
406///
407/// // Create runner with default AWS config
408/// let runner = CloudDurableTestRunner::<String>::new("my-function")
409/// .await
410/// .unwrap();
411///
412/// // Run test
413/// let result = runner.run("input".to_string()).await.unwrap();
414/// println!("Status: {:?}", result.get_status());
415/// ```
416pub struct CloudDurableTestRunner<O>
417where
418 O: DeserializeOwned + Send,
419{
420 /// The Lambda function name or ARN
421 function_name: String,
422 /// The AWS Lambda client
423 lambda_client: LambdaClient,
424 /// The AWS SDK configuration (stored for creating service clients during run)
425 aws_config: Option<aws_config::SdkConfig>,
426 /// Configuration for polling and timeouts
427 config: CloudTestRunnerConfig,
428 /// Storage for captured operations
429 operation_storage: OperationStorage,
430 /// Pre-registered operation handles for lazy population during execution
431 handles: Vec<OperationHandle>,
432 /// Shared operations list for child operation enumeration across handles
433 all_operations: Arc<RwLock<Vec<Operation>>>,
434 /// Phantom data for the output type
435 _phantom: PhantomData<O>,
436}
437
438impl<O> std::fmt::Debug for CloudDurableTestRunner<O>
439where
440 O: DeserializeOwned + Send,
441{
442 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
443 f.debug_struct("CloudDurableTestRunner")
444 .field("function_name", &self.function_name)
445 .field("config", &self.config)
446 .field("operation_count", &self.operation_storage.operations.len())
447 .field("handle_count", &self.handles.len())
448 .finish()
449 }
450}
451
452impl<O> CloudDurableTestRunner<O>
453where
454 O: DeserializeOwned + Send,
455{
456 /// Creates a new cloud test runner for the given Lambda function.
457 ///
458 /// This constructor uses the default AWS configuration, which loads
459 /// credentials from environment variables, AWS config files, or IAM roles.
460 ///
461 /// # Arguments
462 ///
463 /// * `function_name` - The Lambda function name or ARN
464 ///
465 /// # Returns
466 ///
467 /// A new `CloudDurableTestRunner` configured with default settings.
468 /// # Examples
469 ///
470 /// ```ignore
471 /// use durable_execution_sdk_testing::CloudDurableTestRunner;
472 ///
473 /// let runner = CloudDurableTestRunner::<String>::new("my-function")
474 /// .await
475 /// .unwrap();
476 /// ```
477 pub async fn new(function_name: impl Into<String>) -> Result<Self, TestError> {
478 let aws_cfg = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
479 let lambda_client = LambdaClient::new(&aws_cfg);
480
481 Ok(Self {
482 function_name: function_name.into(),
483 lambda_client,
484 aws_config: Some(aws_cfg),
485 config: CloudTestRunnerConfig::default(),
486 operation_storage: OperationStorage::new(),
487 handles: Vec::new(),
488 all_operations: Arc::new(RwLock::new(Vec::new())),
489 _phantom: PhantomData,
490 })
491 }
492
493 /// Creates a new cloud test runner with a custom Lambda client.
494 ///
495 /// This constructor allows using a pre-configured Lambda client,
496 /// useful for testing with custom credentials or endpoints.
497 ///
498 /// # Arguments
499 ///
500 /// * `function_name` - The Lambda function name or ARN
501 /// * `client` - A pre-configured Lambda client
502 ///
503 /// # Returns
504 ///
505 /// A new `CloudDurableTestRunner` using the provided client.
506 /// # Examples
507 ///
508 /// ```ignore
509 /// use durable_execution_sdk_testing::CloudDurableTestRunner;
510 /// use aws_sdk_lambda::Client as LambdaClient;
511 ///
512 /// let config = aws_config::load_defaults(aws_config::BehaviorVersion::latest()).await;
513 /// let custom_client = LambdaClient::new(&config);
514 ///
515 /// let runner = CloudDurableTestRunner::<String>::with_client(
516 /// "my-function",
517 /// custom_client,
518 /// );
519 /// ```
520 pub fn with_client(function_name: impl Into<String>, client: LambdaClient) -> Self {
521 Self {
522 function_name: function_name.into(),
523 lambda_client: client,
524 aws_config: None,
525 config: CloudTestRunnerConfig::default(),
526 operation_storage: OperationStorage::new(),
527 handles: Vec::new(),
528 all_operations: Arc::new(RwLock::new(Vec::new())),
529 _phantom: PhantomData,
530 }
531 }
532
533 /// Configures the test runner with custom settings.
534 ///
535 /// # Arguments
536 ///
537 /// * `config` - The configuration to use
538 ///
539 /// # Returns
540 ///
541 /// The runner with updated configuration.
542 /// # Examples
543 ///
544 /// ```ignore
545 /// use durable_execution_sdk_testing::{CloudDurableTestRunner, CloudTestRunnerConfig};
546 /// use std::time::Duration;
547 ///
548 /// let runner = CloudDurableTestRunner::<String>::new("my-function")
549 /// .await
550 /// .unwrap()
551 /// .with_config(CloudTestRunnerConfig {
552 /// poll_interval: Duration::from_millis(500),
553 /// timeout: Duration::from_secs(60),
554 /// });
555 /// ```
556 pub fn with_config(mut self, config: CloudTestRunnerConfig) -> Self {
557 self.config = config;
558 self
559 }
560
561 /// Returns the function name.
562 pub fn function_name(&self) -> &str {
563 &self.function_name
564 }
565
566 /// Returns the current configuration.
567 pub fn config(&self) -> &CloudTestRunnerConfig {
568 &self.config
569 }
570
571 /// Returns a reference to the Lambda client.
572 pub fn lambda_client(&self) -> &LambdaClient {
573 &self.lambda_client
574 }
575}
576
577impl<O> CloudDurableTestRunner<O>
578where
579 O: DeserializeOwned + Send,
580{
581 /// Runs the durable function and polls for execution completion.
582 ///
583 /// This method invokes the Lambda function, then polls the
584 /// `GetDurableExecutionHistory` API until the execution reaches a terminal
585 /// state or the configured timeout elapses. During polling, operations are
586 /// stored in `OperationStorage` and waiting `OperationHandle` instances are
587 /// notified.
588 ///
589 /// # Arguments
590 ///
591 /// * `payload` - The input payload to send to the Lambda function
592 ///
593 /// # Returns
594 ///
595 /// A `TestResult` reflecting the full execution outcome, including all
596 /// operations and history events collected during polling.
597 /// # Examples
598 ///
599 /// ```ignore
600 /// use durable_execution_sdk_testing::CloudDurableTestRunner;
601 ///
602 /// let mut runner = CloudDurableTestRunner::<String>::new("my-function")
603 /// .await
604 /// .unwrap();
605 ///
606 /// let result = runner.run("input").await.unwrap();
607 /// println!("Status: {:?}", result.get_status());
608 /// ```
609 pub async fn run<I>(&mut self, payload: I) -> Result<TestResult<O>, TestError>
610 where
611 I: Serialize + Send,
612 {
613 // Requirement 3.4: Clear storage at start of each run
614 self.operation_storage.clear();
615
616 // Requirement 8.1, 8.4: Invoke Lambda and extract ARN
617 let arn = self.invoke_lambda(&payload).await?;
618
619 // Requirement 1.1: Create HistoryPoller with the ARN
620 let history_client = self.create_history_client()?;
621 let mut poller = HistoryPoller::new(history_client, arn.clone(), self.config.poll_interval);
622
623 // Requirement 6.4: Create CloudCallbackSender and configure handles
624 let callback_sender: Arc<dyn CallbackSender> = Arc::new(CloudCallbackSender::new(
625 self.lambda_client.clone(),
626 arn.clone(),
627 ));
628 for handle in &self.handles {
629 let mut sender = handle.callback_sender.write().await;
630 *sender = Some(callback_sender.clone());
631 }
632
633 // Requirement 7.2: Use configured timeout
634 let deadline = Instant::now() + self.config.timeout;
635 let mut all_events = Vec::new();
636
637 loop {
638 // Requirement 1.4: Check timeout
639 if Instant::now() >= deadline {
640 return Ok(TestResult::with_status(
641 ExecutionStatus::TimedOut,
642 self.operation_storage.get_all().to_vec(),
643 ));
644 }
645
646 // Requirement 7.1: Wait poll_interval between cycles
647 tokio::time::sleep(self.config.poll_interval).await;
648
649 let poll_result = poller.poll_once().await?;
650
651 // Requirement 3.1: Populate OperationStorage (deduplicated)
652 for op in &poll_result.operations {
653 self.operation_storage.add_or_update(op.clone());
654 }
655
656 // Requirement 5.5: Notify waiting OperationHandles
657 self.notify_handles().await;
658
659 // Requirement 9.1: Collect history events
660 all_events.extend(poll_result.events);
661
662 // Requirement 1.2, 1.5: Check terminal state
663 if let Some(terminal) = poll_result.terminal {
664 let mut result = match terminal.status {
665 ExecutionStatus::Succeeded => {
666 // Requirement 8.2: Parse result from terminal event
667 let output: O =
668 serde_json::from_str(terminal.result.as_deref().unwrap_or("null"))?;
669 TestResult::success(output, self.operation_storage.get_all().to_vec())
670 }
671 status => {
672 // Requirement 8.3: Parse error from terminal event
673 let mut r = TestResult::with_status(
674 status,
675 self.operation_storage.get_all().to_vec(),
676 );
677 if let Some(err) = terminal.error {
678 r.set_error(err);
679 }
680 r
681 }
682 };
683 // Requirement 9.1, 9.2, 9.3: Include all history events
684 result.set_history_events(all_events);
685 return Ok(result);
686 }
687 }
688 }
689
690 /// Invokes the Lambda function and extracts the `DurableExecutionArn` from the response.
691 async fn invoke_lambda<I: Serialize>(&self, payload: &I) -> Result<String, TestError> {
692 let payload_json = serde_json::to_vec(payload)?;
693
694 let invoke_result = self
695 .lambda_client
696 .invoke()
697 .function_name(&self.function_name)
698 .payload(aws_sdk_lambda::primitives::Blob::new(payload_json))
699 .send()
700 .await
701 .map_err(|e| TestError::aws_error(format!("Lambda invoke failed: {}", e)))?;
702
703 // Check for function error
704 if let Some(function_error) = invoke_result.function_error() {
705 let error_payload = invoke_result
706 .payload()
707 .map(|p| String::from_utf8_lossy(p.as_ref()).to_string())
708 .unwrap_or_else(|| "Unknown error".to_string());
709
710 return Err(TestError::aws_error(format!(
711 "Lambda function error ({}): {}",
712 function_error, error_payload
713 )));
714 }
715
716 // Parse the response to extract DurableExecutionArn
717 let response_payload = invoke_result
718 .payload()
719 .ok_or_else(|| TestError::aws_error("No response payload from Lambda"))?;
720
721 let response_str = String::from_utf8_lossy(response_payload.as_ref());
722
723 // Try to parse as JSON and extract the ARN
724 let response_json: serde_json::Value = serde_json::from_str(&response_str)
725 .map_err(|e| TestError::aws_error(format!("Failed to parse Lambda response: {}", e)))?;
726
727 let arn = response_json
728 .get("DurableExecutionArn")
729 .or_else(|| response_json.get("durableExecutionArn"))
730 .and_then(|v| v.as_str())
731 .ok_or_else(|| {
732 TestError::aws_error(format!(
733 "Lambda response missing DurableExecutionArn: {}",
734 response_str
735 ))
736 })?;
737
738 Ok(arn.to_string())
739 }
740
741 /// Creates a `LambdaHistoryApiClient` from the stored AWS config.
742 fn create_history_client(&self) -> Result<LambdaHistoryApiClient, TestError> {
743 match &self.aws_config {
744 Some(cfg) => LambdaHistoryApiClient::from_aws_config(cfg),
745 None => {
746 // Fallback: create a service client from a default config.
747 // This path is used when with_client() was called without an SdkConfig.
748 let service_client = LambdaDurableServiceClient::from_aws_config(
749 &aws_config::SdkConfig::builder().build(),
750 )
751 .map_err(|e| {
752 TestError::aws_error(format!("Failed to create service client: {}", e))
753 })?;
754 Ok(LambdaHistoryApiClient::from_service_client(service_client))
755 }
756 }
757 }
758}
759
760impl<O> CloudDurableTestRunner<O>
761where
762 O: DeserializeOwned + Send,
763{
764 // =========================================================================
765 // Operation Handle Methods (Requirements 5.1, 5.2, 5.3, 5.4)
766 // =========================================================================
767
768 /// Returns a lazy `OperationHandle` that populates with the first operation
769 /// matching the given name during execution.
770 ///
771 /// # Arguments
772 ///
773 /// * `name` - The operation name to match against
774 /// # Examples
775 ///
776 /// ```ignore
777 /// let handle = runner.get_operation_handle("my-callback");
778 /// // handle is unpopulated until run() executes and produces a matching operation
779 /// ```
780 pub fn get_operation_handle(&mut self, name: &str) -> OperationHandle {
781 let handle = OperationHandle::new(
782 OperationMatcher::ByName(name.to_string()),
783 self.all_operations.clone(),
784 );
785 self.handles.push(handle.clone());
786 handle
787 }
788
789 /// Returns a lazy `OperationHandle` that populates with the operation
790 /// at the given execution order index.
791 ///
792 /// # Arguments
793 ///
794 /// * `index` - The zero-based execution order index
795 /// # Examples
796 ///
797 /// ```ignore
798 /// let handle = runner.get_operation_handle_by_index(0);
799 /// // handle populates with the first operation created during execution
800 /// ```
801 pub fn get_operation_handle_by_index(&mut self, index: usize) -> OperationHandle {
802 let handle = OperationHandle::new(
803 OperationMatcher::ByIndex(index),
804 self.all_operations.clone(),
805 );
806 self.handles.push(handle.clone());
807 handle
808 }
809
810 /// Returns a lazy `OperationHandle` that populates with the nth operation
811 /// matching the given name during execution.
812 ///
813 /// # Arguments
814 ///
815 /// * `name` - The operation name to match against
816 /// * `index` - The zero-based index among operations with that name
817 /// # Examples
818 ///
819 /// ```ignore
820 /// let handle = runner.get_operation_handle_by_name_and_index("process", 1);
821 /// // handle populates with the second "process" operation during execution
822 /// ```
823 pub fn get_operation_handle_by_name_and_index(
824 &mut self,
825 name: &str,
826 index: usize,
827 ) -> OperationHandle {
828 let handle = OperationHandle::new(
829 OperationMatcher::ByNameAndIndex(name.to_string(), index),
830 self.all_operations.clone(),
831 );
832 self.handles.push(handle.clone());
833 handle
834 }
835
836 /// Returns a lazy `OperationHandle` that populates with the operation
837 /// matching the given unique ID.
838 ///
839 /// # Arguments
840 ///
841 /// * `id` - The unique operation ID to match against
842 /// # Examples
843 ///
844 /// ```ignore
845 /// let handle = runner.get_operation_handle_by_id("op-abc-123");
846 /// // handle populates with the operation whose ID matches during execution
847 /// ```
848 pub fn get_operation_handle_by_id(&mut self, id: &str) -> OperationHandle {
849 let handle = OperationHandle::new(
850 OperationMatcher::ById(id.to_string()),
851 self.all_operations.clone(),
852 );
853 self.handles.push(handle.clone());
854 handle
855 }
856
857 /// Notifies all registered operation handles with matching operation data
858 /// from the operation storage, and updates the shared `all_operations` list.
859 ///
860 /// All matched operations are collected first, then the shared `all_operations`
861 /// list and individual handles are updated together to ensure consistency.
862 pub(crate) async fn notify_handles(&self) {
863 // Phase 1: Collect all matched operations (no async locks held)
864 let matched: Vec<_> = self
865 .handles
866 .iter()
867 .map(|handle| {
868 let matched_op = match &handle.matcher {
869 OperationMatcher::ByName(name) => {
870 self.operation_storage.get_by_name(name).cloned()
871 }
872 OperationMatcher::ByIndex(idx) => {
873 self.operation_storage.get_by_index(*idx).cloned()
874 }
875 OperationMatcher::ById(id) => self.operation_storage.get_by_id(id).cloned(),
876 OperationMatcher::ByNameAndIndex(name, idx) => self
877 .operation_storage
878 .get_by_name_and_index(name, *idx)
879 .cloned(),
880 };
881 (handle, matched_op)
882 })
883 .collect();
884
885 // Phase 2: Update shared all_operations first, then handles atomically
886 let mut all_ops = self.all_operations.write().await;
887 *all_ops = self.operation_storage.get_all().to_vec();
888 drop(all_ops);
889
890 // Phase 3: Update handles and send notifications
891 for (handle, matched_op) in matched {
892 if let Some(op) = matched_op {
893 let status = op.status;
894 let mut inner = handle.inner.write().await;
895 *inner = Some(op);
896 drop(inner);
897 let _ = handle.status_tx.send(Some(status));
898 }
899 }
900 }
901}
902
903impl<O> CloudDurableTestRunner<O>
904where
905 O: DeserializeOwned + Send,
906{
907 // =========================================================================
908 // Operation Lookup Methods (Requirements 4.1, 4.2, 4.3, 4.4)
909 // =========================================================================
910
911 /// Gets the first operation with the given name.
912 ///
913 /// # Arguments
914 ///
915 /// * `name` - The operation name to search for
916 ///
917 /// # Returns
918 ///
919 /// A `DurableOperation` wrapping the first operation with that name,
920 /// or `None` if no operation with that name exists.
921 /// # Examples
922 ///
923 /// ```ignore
924 /// use durable_execution_sdk_testing::CloudDurableTestRunner;
925 ///
926 /// let mut runner = CloudDurableTestRunner::<String>::new("my-function")
927 /// .await
928 /// .unwrap();
929 /// let _ = runner.run("input".to_string()).await.unwrap();
930 ///
931 /// if let Some(op) = runner.get_operation("process_data") {
932 /// println!("Found operation: {:?}", op.get_status());
933 /// }
934 /// ```
935 pub fn get_operation(&self, name: &str) -> Option<DurableOperation> {
936 let all_ops = self.cached_all_operations();
937 self.operation_storage
938 .get_by_name(name)
939 .cloned()
940 .map(|op| DurableOperation::new(op).with_operations(Arc::clone(&all_ops)))
941 }
942
943 /// Gets an operation by its index in the execution order.
944 ///
945 /// # Arguments
946 ///
947 /// * `index` - The zero-based index of the operation
948 ///
949 /// # Returns
950 ///
951 /// A `DurableOperation` at that index, or `None` if the index is out of bounds.
952 /// # Examples
953 ///
954 /// ```ignore
955 /// use durable_execution_sdk_testing::CloudDurableTestRunner;
956 ///
957 /// let mut runner = CloudDurableTestRunner::<String>::new("my-function")
958 /// .await
959 /// .unwrap();
960 /// let _ = runner.run("input".to_string()).await.unwrap();
961 ///
962 /// // Get the first operation
963 /// if let Some(op) = runner.get_operation_by_index(0) {
964 /// println!("First operation: {:?}", op.get_type());
965 /// }
966 /// ```
967 pub fn get_operation_by_index(&self, index: usize) -> Option<DurableOperation> {
968 let all_ops = self.cached_all_operations();
969 self.operation_storage
970 .get_by_index(index)
971 .cloned()
972 .map(|op| DurableOperation::new(op).with_operations(Arc::clone(&all_ops)))
973 }
974
975 /// Gets an operation by name and occurrence index.
976 ///
977 /// This is useful when multiple operations have the same name and you need
978 /// to access a specific occurrence.
979 ///
980 /// # Arguments
981 ///
982 /// * `name` - The operation name to search for
983 /// * `index` - The zero-based index among operations with that name
984 ///
985 /// # Returns
986 ///
987 /// A `DurableOperation` at that name/index combination, or `None` if not found.
988 /// # Examples
989 ///
990 /// ```ignore
991 /// use durable_execution_sdk_testing::CloudDurableTestRunner;
992 ///
993 /// let mut runner = CloudDurableTestRunner::<String>::new("my-function")
994 /// .await
995 /// .unwrap();
996 /// let _ = runner.run("input".to_string()).await.unwrap();
997 ///
998 /// // Get the second "process" operation
999 /// if let Some(op) = runner.get_operation_by_name_and_index("process", 1) {
1000 /// println!("Second process operation: {:?}", op.get_status());
1001 /// }
1002 /// ```
1003 pub fn get_operation_by_name_and_index(
1004 &self,
1005 name: &str,
1006 index: usize,
1007 ) -> Option<DurableOperation> {
1008 let all_ops = self.cached_all_operations();
1009 self.operation_storage
1010 .get_by_name_and_index(name, index)
1011 .cloned()
1012 .map(|op| DurableOperation::new(op).with_operations(Arc::clone(&all_ops)))
1013 }
1014
1015 /// Gets an operation by its unique ID.
1016 ///
1017 /// # Arguments
1018 ///
1019 /// * `id` - The unique operation ID
1020 ///
1021 /// # Returns
1022 ///
1023 /// A `DurableOperation` with that ID, or `None` if no operation with that ID exists.
1024 /// # Examples
1025 ///
1026 /// ```ignore
1027 /// use durable_execution_sdk_testing::CloudDurableTestRunner;
1028 ///
1029 /// let mut runner = CloudDurableTestRunner::<String>::new("my-function")
1030 /// .await
1031 /// .unwrap();
1032 /// let _ = runner.run("input".to_string()).await.unwrap();
1033 ///
1034 /// if let Some(op) = runner.get_operation_by_id("op-123") {
1035 /// println!("Found operation: {:?}", op.get_name());
1036 /// }
1037 /// ```
1038 pub fn get_operation_by_id(&self, id: &str) -> Option<DurableOperation> {
1039 let all_ops = self.cached_all_operations();
1040 self.operation_storage
1041 .get_by_id(id)
1042 .cloned()
1043 .map(|op| DurableOperation::new(op).with_operations(Arc::clone(&all_ops)))
1044 }
1045
1046 /// Gets all captured operations.
1047 ///
1048 /// # Returns
1049 ///
1050 /// A vector of all operations in execution order.
1051 ///
1052 /// # Examples
1053 ///
1054 /// ```ignore
1055 /// use durable_execution_sdk_testing::CloudDurableTestRunner;
1056 ///
1057 /// let mut runner = CloudDurableTestRunner::<String>::new("my-function")
1058 /// .await
1059 /// .unwrap();
1060 /// let _ = runner.run("input".to_string()).await.unwrap();
1061 ///
1062 /// let all_ops = runner.get_all_operations();
1063 /// println!("Total operations: {}", all_ops.len());
1064 /// ```
1065 pub fn get_all_operations(&self) -> Vec<DurableOperation> {
1066 let all_ops = self.cached_all_operations();
1067 self.operation_storage
1068 .get_all()
1069 .iter()
1070 .cloned()
1071 .map(|op| DurableOperation::new(op).with_operations(Arc::clone(&all_ops)))
1072 .collect()
1073 }
1074
1075 /// Creates a shared snapshot of all operations, avoiding repeated Vec clones
1076 /// across multiple lookup calls.
1077 fn cached_all_operations(&self) -> Arc<Vec<Operation>> {
1078 Arc::new(self.operation_storage.get_all().to_vec())
1079 }
1080
1081 /// Returns the number of captured operations.
1082 pub fn operation_count(&self) -> usize {
1083 self.operation_storage.operations.len()
1084 }
1085
1086 /// Clears all captured operations.
1087 ///
1088 /// This is useful when reusing the runner for multiple test runs.
1089 pub fn clear_operations(&mut self) {
1090 self.operation_storage.clear();
1091 }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096 use super::*;
1097
1098 #[test]
1099 fn test_config_default() {
1100 let config = CloudTestRunnerConfig::default();
1101 assert_eq!(config.poll_interval, Duration::from_millis(1000));
1102 assert_eq!(config.timeout, Duration::from_secs(300));
1103 }
1104
1105 #[test]
1106 fn test_config_builder() {
1107 let config = CloudTestRunnerConfig::new()
1108 .with_poll_interval(Duration::from_millis(500))
1109 .with_timeout(Duration::from_secs(60));
1110
1111 assert_eq!(config.poll_interval, Duration::from_millis(500));
1112 assert_eq!(config.timeout, Duration::from_secs(60));
1113 }
1114
1115 #[test]
1116 fn test_operation_storage() {
1117 let mut storage = OperationStorage::new();
1118
1119 // Add operations
1120 let mut op1 = Operation::new("op-001", durable_execution_sdk::OperationType::Step);
1121 op1.name = Some("step1".to_string());
1122 storage.add_operation(op1);
1123
1124 let mut op2 = Operation::new("op-002", durable_execution_sdk::OperationType::Wait);
1125 op2.name = Some("wait1".to_string());
1126 storage.add_operation(op2);
1127
1128 let mut op3 = Operation::new("op-003", durable_execution_sdk::OperationType::Step);
1129 op3.name = Some("step1".to_string()); // Same name as op1
1130 storage.add_operation(op3);
1131
1132 // Test get_by_id
1133 assert!(storage.get_by_id("op-001").is_some());
1134 assert!(storage.get_by_id("op-002").is_some());
1135 assert!(storage.get_by_id("nonexistent").is_none());
1136
1137 // Test get_by_name (returns first)
1138 let first_step = storage.get_by_name("step1").unwrap();
1139 assert_eq!(first_step.operation_id, "op-001");
1140
1141 // Test get_by_name_and_index
1142 let second_step = storage.get_by_name_and_index("step1", 1).unwrap();
1143 assert_eq!(second_step.operation_id, "op-003");
1144
1145 // Test get_by_index
1146 let first_op = storage.get_by_index(0).unwrap();
1147 assert_eq!(first_op.operation_id, "op-001");
1148
1149 // Test get_all
1150 assert_eq!(storage.get_all().len(), 3);
1151
1152 // Test clear
1153 storage.clear();
1154 assert_eq!(storage.get_all().len(), 0);
1155 }
1156}