Skip to main content

durable_execution_sdk/
client.rs

1//! Service client for the AWS Durable Execution SDK.
2//!
3//! This module defines the `DurableServiceClient` trait and provides
4//! a Lambda-based implementation for communicating with the AWS Lambda
5//! durable execution service using the CheckpointDurableExecution and
6//! GetDurableExecutionState REST APIs.
7
8use std::sync::Arc;
9use std::time::SystemTime;
10
11use async_trait::async_trait;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings};
14use aws_sigv4::sign::v4;
15use serde::{Deserialize, Serialize};
16
17use crate::error::{AwsError, DurableError};
18use crate::operation::{Operation, OperationUpdate};
19
20/// Trait for communicating with the durable execution service.
21///
22/// This trait abstracts the communication layer, allowing for different
23/// implementations (e.g., Lambda client, mock client for testing).
24#[async_trait]
25pub trait DurableServiceClient: Send + Sync {
26    /// Sends a batch of checkpoint operations to the service.
27    ///
28    /// # Arguments
29    ///
30    /// * `durable_execution_arn` - The ARN of the durable execution
31    /// * `checkpoint_token` - The token for this checkpoint batch
32    /// * `operations` - The operations to checkpoint
33    ///
34    /// # Returns
35    ///
36    /// A new checkpoint token on success, or an error on failure.
37    async fn checkpoint(
38        &self,
39        durable_execution_arn: &str,
40        checkpoint_token: &str,
41        operations: Vec<OperationUpdate>,
42    ) -> Result<CheckpointResponse, DurableError>;
43
44    /// Retrieves additional operations for pagination.
45    ///
46    /// # Arguments
47    ///
48    /// * `durable_execution_arn` - The ARN of the durable execution
49    /// * `next_marker` - The pagination marker from the previous response
50    ///
51    /// # Returns
52    ///
53    /// A list of operations and an optional next marker for further pagination.
54    async fn get_operations(
55        &self,
56        durable_execution_arn: &str,
57        next_marker: &str,
58    ) -> Result<GetOperationsResponse, DurableError>;
59}
60
61/// Response from a checkpoint operation.
62#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct CheckpointResponse {
64    /// The new checkpoint token to use for subsequent checkpoints
65    #[serde(rename = "CheckpointToken", default)]
66    pub checkpoint_token: String,
67
68    /// The new execution state containing updated operations
69    /// This includes service-generated values like CallbackDetails.CallbackId
70    #[serde(
71        rename = "NewExecutionState",
72        skip_serializing_if = "Option::is_none",
73        default
74    )]
75    pub new_execution_state: Option<NewExecutionState>,
76}
77
78impl CheckpointResponse {
79    /// Creates a new CheckpointResponse with just a checkpoint token.
80    pub fn new(checkpoint_token: impl Into<String>) -> Self {
81        Self {
82            checkpoint_token: checkpoint_token.into(),
83            new_execution_state: None,
84        }
85    }
86}
87
88/// New execution state returned from checkpoint operations.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NewExecutionState {
91    /// The updated operations with service-generated values
92    #[serde(rename = "Operations", default)]
93    pub operations: Vec<Operation>,
94
95    /// Marker for the next page of results, if any
96    #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
97    pub next_marker: Option<String>,
98}
99
100impl NewExecutionState {
101    /// Finds an operation by its ID.
102    ///
103    /// # Arguments
104    ///
105    /// * `operation_id` - The ID of the operation to find
106    ///
107    /// # Returns
108    ///
109    /// A reference to the operation if found, None otherwise.
110    pub fn find_operation(&self, operation_id: &str) -> Option<&Operation> {
111        self.operations
112            .iter()
113            .find(|op| op.operation_id == operation_id)
114    }
115}
116
117/// Response from a get_operations call.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct GetOperationsResponse {
120    /// The retrieved operations
121    #[serde(rename = "Operations")]
122    pub operations: Vec<Operation>,
123
124    /// Marker for the next page of results, if any
125    #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
126    pub next_marker: Option<String>,
127}
128
129/// Configuration for the Lambda durable service client.
130#[derive(Debug, Clone)]
131pub struct LambdaClientConfig {
132    /// AWS region for the Lambda service
133    pub region: String,
134    /// Optional custom endpoint URL (for testing)
135    pub endpoint_url: Option<String>,
136}
137
138impl Default for LambdaClientConfig {
139    fn default() -> Self {
140        Self {
141            region: "us-east-1".to_string(),
142            endpoint_url: None,
143        }
144    }
145}
146
147impl LambdaClientConfig {
148    /// Creates a new LambdaClientConfig with the specified region.
149    pub fn with_region(region: impl Into<String>) -> Self {
150        Self {
151            region: region.into(),
152            endpoint_url: None,
153        }
154    }
155
156    /// Creates a new LambdaClientConfig from AWS SDK config.
157    pub fn from_aws_config(config: &aws_config::SdkConfig) -> Self {
158        Self {
159            region: config
160                .region()
161                .map(|r| r.to_string())
162                .unwrap_or_else(|| "us-east-1".to_string()),
163            endpoint_url: None,
164        }
165    }
166}
167
168/// Lambda-based implementation of the DurableServiceClient.
169///
170/// This client uses the AWS Lambda REST APIs (CheckpointDurableExecution and
171/// GetDurableExecutionState) to communicate with the durable execution service.
172pub struct LambdaDurableServiceClient {
173    /// HTTP client for making requests
174    http_client: reqwest::Client,
175    /// AWS credentials provider
176    credentials_provider: Arc<dyn ProvideCredentials>,
177    /// Configuration for the client
178    config: LambdaClientConfig,
179}
180
181impl LambdaDurableServiceClient {
182    /// Creates a new LambdaDurableServiceClient from AWS config.
183    ///
184    /// # Errors
185    ///
186    /// Returns `DurableError::Configuration` if no credentials provider is configured.
187    pub async fn from_env() -> Result<Self, DurableError> {
188        let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
189            .load()
190            .await;
191        Self::from_aws_config(&aws_config)
192    }
193
194    /// Creates a new LambdaDurableServiceClient from AWS SDK config.
195    ///
196    /// # Errors
197    ///
198    /// Returns `DurableError::Configuration` if no credentials provider is configured.
199    pub fn from_aws_config(aws_config: &aws_config::SdkConfig) -> Result<Self, DurableError> {
200        let credentials_provider = aws_config
201            .credentials_provider()
202            .ok_or_else(|| DurableError::Configuration {
203                message: "No credentials provider configured in AWS SDK config".to_string(),
204            })?
205            .clone();
206
207        Ok(Self {
208            http_client: reqwest::Client::new(),
209            credentials_provider: Arc::from(credentials_provider),
210            config: LambdaClientConfig::from_aws_config(aws_config),
211        })
212    }
213
214    /// Creates a new LambdaDurableServiceClient from AWS SDK config with a custom
215    /// user-agent string appended to HTTP requests for SDK identification.
216    ///
217    /// # Errors
218    ///
219    /// Returns `DurableError::Configuration` if no credentials provider is configured.
220    pub fn from_aws_config_with_user_agent(
221        aws_config: &aws_config::SdkConfig,
222        sdk_name: &str,
223        sdk_version: &str,
224    ) -> Result<Self, DurableError> {
225        let credentials_provider = aws_config
226            .credentials_provider()
227            .ok_or_else(|| DurableError::Configuration {
228                message: "No credentials provider configured in AWS SDK config".to_string(),
229            })?
230            .clone();
231
232        let user_agent = format!("{}/{}", sdk_name, sdk_version);
233        let mut headers = reqwest::header::HeaderMap::new();
234        if let Ok(value) = reqwest::header::HeaderValue::from_str(&user_agent) {
235            headers.insert(reqwest::header::USER_AGENT, value);
236        }
237
238        let http_client = reqwest::Client::builder()
239            .default_headers(headers)
240            .build()
241            .map_err(|e| DurableError::Configuration {
242                message: format!("Failed to build HTTP client: {}", e),
243            })?;
244
245        Ok(Self {
246            http_client,
247            credentials_provider: Arc::from(credentials_provider),
248            config: LambdaClientConfig::from_aws_config(aws_config),
249        })
250    }
251
252    /// Creates a new LambdaDurableServiceClient with custom configuration.
253    pub fn with_config(
254        credentials_provider: Arc<dyn ProvideCredentials>,
255        config: LambdaClientConfig,
256    ) -> Self {
257        Self {
258            http_client: reqwest::Client::new(),
259            credentials_provider,
260            config,
261        }
262    }
263
264    /// Returns the Lambda service endpoint URL.
265    fn endpoint_url(&self) -> String {
266        self.config
267            .endpoint_url
268            .clone()
269            .unwrap_or_else(|| format!("https://lambda.{}.amazonaws.com", self.config.region))
270    }
271
272    /// Signs an HTTP request using AWS SigV4 and returns the signed headers.
273    async fn sign_request(
274        &self,
275        method: &str,
276        uri: &str,
277        body: &[u8],
278    ) -> Result<Vec<(String, String)>, DurableError> {
279        let credentials = self
280            .credentials_provider
281            .provide_credentials()
282            .await
283            .map_err(|e| DurableError::Checkpoint {
284                message: format!("Failed to get AWS credentials: {}", e),
285                is_retriable: true,
286                aws_error: None,
287            })?;
288
289        let identity = credentials.into();
290        let signing_settings = SigningSettings::default();
291        let signing_params = v4::SigningParams::builder()
292            .identity(&identity)
293            .region(&self.config.region)
294            .name("lambda")
295            .time(SystemTime::now())
296            .settings(signing_settings)
297            .build()
298            .map_err(|e| DurableError::Checkpoint {
299                message: format!("Failed to build signing params: {}", e),
300                is_retriable: false,
301                aws_error: None,
302            })?;
303
304        let signable_request = SignableRequest::new(
305            method,
306            uri,
307            std::iter::empty::<(&str, &str)>(),
308            SignableBody::Bytes(body),
309        )
310        .map_err(|e| DurableError::Checkpoint {
311            message: format!("Failed to create signable request: {}", e),
312            is_retriable: false,
313            aws_error: None,
314        })?;
315
316        let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
317            .map_err(|e| DurableError::Checkpoint {
318                message: format!("Failed to sign request: {}", e),
319                is_retriable: false,
320                aws_error: None,
321            })?
322            .into_parts();
323
324        // Build a temporary HTTP request to apply signing instructions
325        let mut temp_request = http::Request::builder()
326            .method(method)
327            .uri(uri)
328            .body(())
329            .map_err(|e| DurableError::Checkpoint {
330                message: format!("Failed to build temp request: {}", e),
331                is_retriable: false,
332                aws_error: None,
333            })?;
334
335        signing_instructions.apply_to_request_http1x(&mut temp_request);
336
337        // Extract headers from the signed request
338        let headers: Vec<(String, String)> = temp_request
339            .headers()
340            .iter()
341            .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
342            .collect();
343
344        Ok(headers)
345    }
346}
347
348/// Request payload for checkpoint operations.
349#[derive(Debug, Clone, Serialize)]
350struct CheckpointRequestBody {
351    #[serde(rename = "CheckpointToken")]
352    checkpoint_token: String,
353    #[serde(rename = "Updates")]
354    updates: Vec<OperationUpdate>,
355}
356
357/// Request payload for get_operations.
358#[derive(Debug, Clone, Serialize)]
359struct GetOperationsRequestBody {
360    #[serde(rename = "NextMarker")]
361    next_marker: String,
362}
363
364#[async_trait]
365impl DurableServiceClient for LambdaDurableServiceClient {
366    async fn checkpoint(
367        &self,
368        durable_execution_arn: &str,
369        checkpoint_token: &str,
370        operations: Vec<OperationUpdate>,
371    ) -> Result<CheckpointResponse, DurableError> {
372        let request_body = CheckpointRequestBody {
373            checkpoint_token: checkpoint_token.to_string(),
374            updates: operations,
375        };
376
377        let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
378            message: format!("Failed to serialize checkpoint request: {}", e),
379        })?;
380
381        // URL-encode the durable execution ARN for the path
382        let encoded_arn = urlencoding::encode(durable_execution_arn);
383        let uri = format!(
384            "{}/2025-12-01/durable-executions/{}/checkpoint",
385            self.endpoint_url(),
386            encoded_arn
387        );
388
389        // Sign the request
390        let signed_headers = self.sign_request("POST", &uri, &body).await?;
391
392        // Build and send the request
393        let mut request = self
394            .http_client
395            .post(&uri)
396            .header("Content-Type", "application/json")
397            .body(body);
398
399        for (name, value) in signed_headers {
400            request = request.header(&name, &value);
401        }
402
403        let response = request.send().await.map_err(|e| DurableError::Checkpoint {
404            message: format!("HTTP request failed: {}", e),
405            is_retriable: e.is_timeout() || e.is_connect(),
406            aws_error: None,
407        })?;
408
409        let status = response.status();
410        let response_body = response
411            .bytes()
412            .await
413            .map_err(|e| DurableError::Checkpoint {
414                message: format!("Failed to read response body: {}", e),
415                is_retriable: true,
416                aws_error: None,
417            })?;
418
419        if !status.is_success() {
420            let error_message = String::from_utf8_lossy(&response_body);
421
422            // Check for specific error conditions
423
424            // Size limit exceeded (413 Request Entity Too Large or specific error message)
425            // Requirements: 25.6 - Handle size limit errors gracefully, return FAILED without retry
426            if status.as_u16() == 413
427                || error_message.contains("RequestEntityTooLarge")
428                || error_message.contains("Payload size exceeded")
429                || error_message.contains("Request too large")
430            {
431                return Err(DurableError::SizeLimit {
432                    message: format!("Checkpoint payload size exceeded: {}", error_message),
433                    actual_size: None,
434                    max_size: None,
435                });
436            }
437
438            // Resource not found (404)
439            // Requirements: 18.6 - Handle ResourceNotFoundException appropriately
440            if status.as_u16() == 404
441                || error_message.contains("ResourceNotFoundException")
442                || error_message.contains("not found")
443            {
444                return Err(DurableError::ResourceNotFound {
445                    message: format!("Durable execution not found: {}", error_message),
446                    resource_id: Some(durable_execution_arn.to_string()),
447                });
448            }
449
450            // Throttling (429 Too Many Requests)
451            // Requirements: 18.5 - Handle ThrottlingException with appropriate retry behavior
452            if status.as_u16() == 429
453                || error_message.contains("ThrottlingException")
454                || error_message.contains("TooManyRequestsException")
455                || error_message.contains("Rate exceeded")
456            {
457                return Err(DurableError::Throttling {
458                    message: format!("Rate limit exceeded: {}", error_message),
459                    retry_after_ms: None, // Could parse Retry-After header if available
460                });
461            }
462
463            // Check for InvalidParameterValueException with "Invalid checkpoint token" message
464            // This indicates the token was already consumed or is invalid, and Lambda should retry
465            // Requirements: 2.11
466            let is_invalid_token = error_message.contains("Invalid checkpoint token");
467
468            // Determine if the error is retriable:
469            // - Server errors (5xx) are retriable
470            // - Invalid checkpoint token errors are retriable (Lambda will provide a fresh token)
471            // Note: Throttling (429) is handled separately above
472            let is_retriable = status.is_server_error() || is_invalid_token;
473
474            return Err(DurableError::Checkpoint {
475                message: format!("Checkpoint API returned {}: {}", status, error_message),
476                is_retriable,
477                aws_error: Some(AwsError {
478                    code: if is_invalid_token {
479                        "InvalidParameterValueException".to_string()
480                    } else {
481                        status.to_string()
482                    },
483                    message: error_message.to_string(),
484                    request_id: None,
485                }),
486            });
487        }
488
489        let checkpoint_response: CheckpointResponse = serde_json::from_slice(&response_body)
490            .map_err(|e| DurableError::SerDes {
491                message: format!("Failed to deserialize checkpoint response: {}", e),
492            })?;
493
494        Ok(checkpoint_response)
495    }
496
497    async fn get_operations(
498        &self,
499        durable_execution_arn: &str,
500        next_marker: &str,
501    ) -> Result<GetOperationsResponse, DurableError> {
502        let request_body = GetOperationsRequestBody {
503            next_marker: next_marker.to_string(),
504        };
505
506        let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
507            message: format!("Failed to serialize get_operations request: {}", e),
508        })?;
509
510        // URL-encode the durable execution ARN for the path
511        let encoded_arn = urlencoding::encode(durable_execution_arn);
512        let uri = format!(
513            "{}/2025-12-01/durable-executions/{}/state",
514            self.endpoint_url(),
515            encoded_arn
516        );
517
518        // Sign the request
519        let signed_headers = self.sign_request("POST", &uri, &body).await?;
520
521        // Build and send the request
522        let mut request = self
523            .http_client
524            .post(&uri)
525            .header("Content-Type", "application/json")
526            .body(body);
527
528        for (name, value) in signed_headers {
529            request = request.header(&name, &value);
530        }
531
532        let response = request.send().await.map_err(|e| DurableError::Invocation {
533            message: format!("HTTP request failed: {}", e),
534            termination_reason: crate::error::TerminationReason::InvocationError,
535        })?;
536
537        let status = response.status();
538        let response_body = response
539            .bytes()
540            .await
541            .map_err(|e| DurableError::Invocation {
542                message: format!("Failed to read response body: {}", e),
543                termination_reason: crate::error::TerminationReason::InvocationError,
544            })?;
545
546        if !status.is_success() {
547            let error_message = String::from_utf8_lossy(&response_body);
548
549            // Resource not found (404)
550            // Requirements: 18.6 - Handle ResourceNotFoundException appropriately
551            if status.as_u16() == 404
552                || error_message.contains("ResourceNotFoundException")
553                || error_message.contains("not found")
554            {
555                return Err(DurableError::ResourceNotFound {
556                    message: format!("Durable execution not found: {}", error_message),
557                    resource_id: Some(durable_execution_arn.to_string()),
558                });
559            }
560
561            // Throttling (429 Too Many Requests)
562            // Requirements: 18.5 - Handle ThrottlingException with appropriate retry behavior
563            if status.as_u16() == 429
564                || error_message.contains("ThrottlingException")
565                || error_message.contains("TooManyRequestsException")
566                || error_message.contains("Rate exceeded")
567            {
568                return Err(DurableError::Throttling {
569                    message: format!("Rate limit exceeded: {}", error_message),
570                    retry_after_ms: None,
571                });
572            }
573
574            return Err(DurableError::Invocation {
575                message: format!(
576                    "GetDurableExecutionState API returned {}: {}",
577                    status, error_message
578                ),
579                termination_reason: crate::error::TerminationReason::InvocationError,
580            });
581        }
582
583        let operations_response: GetOperationsResponse = serde_json::from_slice(&response_body)
584            .map_err(|e| DurableError::SerDes {
585                message: format!("Failed to deserialize get_operations response: {}", e),
586            })?;
587
588        Ok(operations_response)
589    }
590}
591
592/// A mock implementation of DurableServiceClient for testing.
593#[cfg(test)]
594pub struct MockDurableServiceClient {
595    checkpoint_responses: std::sync::Mutex<Vec<Result<CheckpointResponse, DurableError>>>,
596    get_operations_responses: std::sync::Mutex<Vec<Result<GetOperationsResponse, DurableError>>>,
597}
598
599#[cfg(test)]
600impl Default for MockDurableServiceClient {
601    fn default() -> Self {
602        Self {
603            checkpoint_responses: std::sync::Mutex::new(Vec::new()),
604            get_operations_responses: std::sync::Mutex::new(Vec::new()),
605        }
606    }
607}
608
609#[cfg(test)]
610impl MockDurableServiceClient {
611    pub fn new() -> Self {
612        Self::default()
613    }
614
615    pub fn with_checkpoint_response(
616        self,
617        response: Result<CheckpointResponse, DurableError>,
618    ) -> Self {
619        self.checkpoint_responses.lock().unwrap().push(response);
620        self
621    }
622
623    pub fn with_get_operations_response(
624        self,
625        response: Result<GetOperationsResponse, DurableError>,
626    ) -> Self {
627        self.get_operations_responses.lock().unwrap().push(response);
628        self
629    }
630
631    /// Adds multiple checkpoint responses at once.
632    pub fn with_checkpoint_responses(self, count: usize) -> Self {
633        let mut responses = self.checkpoint_responses.lock().unwrap();
634        for i in 0..count {
635            responses.push(Ok(CheckpointResponse {
636                checkpoint_token: format!("token-{}", i),
637                new_execution_state: None,
638            }));
639        }
640        drop(responses);
641        self
642    }
643}
644
645#[cfg(test)]
646#[async_trait]
647impl DurableServiceClient for MockDurableServiceClient {
648    async fn checkpoint(
649        &self,
650        _durable_execution_arn: &str,
651        _checkpoint_token: &str,
652        _operations: Vec<OperationUpdate>,
653    ) -> Result<CheckpointResponse, DurableError> {
654        let mut responses = self.checkpoint_responses.lock().unwrap();
655        if responses.is_empty() {
656            Ok(CheckpointResponse {
657                checkpoint_token: "mock-token".to_string(),
658                new_execution_state: None,
659            })
660        } else {
661            responses.remove(0)
662        }
663    }
664
665    async fn get_operations(
666        &self,
667        _durable_execution_arn: &str,
668        _next_marker: &str,
669    ) -> Result<GetOperationsResponse, DurableError> {
670        let mut responses = self.get_operations_responses.lock().unwrap();
671        if responses.is_empty() {
672            Ok(GetOperationsResponse {
673                operations: Vec::new(),
674                next_marker: None,
675            })
676        } else {
677            responses.remove(0)
678        }
679    }
680}
681
682/// Type alias for a shared DurableServiceClient.
683pub type SharedDurableServiceClient = Arc<dyn DurableServiceClient>;
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688    use crate::operation::OperationType;
689
690    #[test]
691    fn test_checkpoint_response_serialization() {
692        let response = CheckpointResponse {
693            checkpoint_token: "token-123".to_string(),
694            new_execution_state: None,
695        };
696        let json = serde_json::to_string(&response).unwrap();
697        assert!(json.contains(r#""CheckpointToken":"token-123""#));
698    }
699
700    #[test]
701    fn test_checkpoint_response_deserialization() {
702        let json = r#"{"CheckpointToken": "token-456"}"#;
703        let response: CheckpointResponse = serde_json::from_str(json).unwrap();
704        assert_eq!(response.checkpoint_token, "token-456");
705    }
706
707    #[test]
708    fn test_get_operations_response_serialization() {
709        let response = GetOperationsResponse {
710            operations: vec![Operation::new("op-1", OperationType::Step)],
711            next_marker: Some("marker-123".to_string()),
712        };
713        let json = serde_json::to_string(&response).unwrap();
714        assert!(json.contains(r#""Operations""#));
715        assert!(json.contains(r#""NextMarker":"marker-123""#));
716    }
717
718    #[test]
719    fn test_get_operations_response_deserialization() {
720        let json = r#"{
721            "Operations": [
722                {
723                    "Id": "op-1",
724                    "Type": "STEP",
725                    "Status": "SUCCEEDED"
726                }
727            ],
728            "NextMarker": "marker-456"
729        }"#;
730        let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
731        assert_eq!(response.operations.len(), 1);
732        assert_eq!(response.operations[0].operation_id, "op-1");
733        assert_eq!(response.next_marker, Some("marker-456".to_string()));
734    }
735
736    #[test]
737    fn test_get_operations_response_without_marker() {
738        let json = r#"{
739            "Operations": []
740        }"#;
741        let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
742        assert!(response.operations.is_empty());
743        assert!(response.next_marker.is_none());
744    }
745
746    #[test]
747    fn test_lambda_client_config_default() {
748        let config = LambdaClientConfig::default();
749        assert_eq!(config.region, "us-east-1");
750        assert!(config.endpoint_url.is_none());
751    }
752
753    #[test]
754    fn test_lambda_client_config_with_region() {
755        let config = LambdaClientConfig::with_region("us-west-2");
756        assert_eq!(config.region, "us-west-2");
757    }
758
759    #[tokio::test]
760    async fn test_mock_client_checkpoint() {
761        let client = MockDurableServiceClient::new();
762        let result = client
763            .checkpoint(
764                "arn:aws:lambda:us-east-1:123456789012:function:test",
765                "token-123",
766                vec![],
767            )
768            .await;
769        assert!(result.is_ok());
770        assert_eq!(result.unwrap().checkpoint_token, "mock-token");
771    }
772
773    #[tokio::test]
774    async fn test_mock_client_checkpoint_with_custom_response() {
775        let client =
776            MockDurableServiceClient::new().with_checkpoint_response(Ok(CheckpointResponse {
777                checkpoint_token: "custom-token".to_string(),
778                new_execution_state: None,
779            }));
780        let result = client
781            .checkpoint(
782                "arn:aws:lambda:us-east-1:123456789012:function:test",
783                "token-123",
784                vec![],
785            )
786            .await;
787        assert!(result.is_ok());
788        assert_eq!(result.unwrap().checkpoint_token, "custom-token");
789    }
790
791    #[tokio::test]
792    async fn test_mock_client_checkpoint_with_error() {
793        let client = MockDurableServiceClient::new()
794            .with_checkpoint_response(Err(DurableError::checkpoint_retriable("Test error")));
795        let result = client
796            .checkpoint(
797                "arn:aws:lambda:us-east-1:123456789012:function:test",
798                "token-123",
799                vec![],
800            )
801            .await;
802        assert!(result.is_err());
803        assert!(result.unwrap_err().is_retriable());
804    }
805
806    #[tokio::test]
807    async fn test_mock_client_get_operations() {
808        let client = MockDurableServiceClient::new();
809        let result = client
810            .get_operations(
811                "arn:aws:lambda:us-east-1:123456789012:function:test",
812                "marker-123",
813            )
814            .await;
815        assert!(result.is_ok());
816        let response = result.unwrap();
817        assert!(response.operations.is_empty());
818        assert!(response.next_marker.is_none());
819    }
820
821    #[tokio::test]
822    async fn test_mock_client_get_operations_with_custom_response() {
823        let client = MockDurableServiceClient::new().with_get_operations_response(Ok(
824            GetOperationsResponse {
825                operations: vec![Operation::new("op-1", OperationType::Step)],
826                next_marker: Some("next-marker".to_string()),
827            },
828        ));
829        let result = client
830            .get_operations(
831                "arn:aws:lambda:us-east-1:123456789012:function:test",
832                "marker-123",
833            )
834            .await;
835        assert!(result.is_ok());
836        let response = result.unwrap();
837        assert_eq!(response.operations.len(), 1);
838        assert_eq!(response.next_marker, Some("next-marker".to_string()));
839    }
840
841    #[test]
842    fn test_checkpoint_request_body_serialization() {
843        let request = CheckpointRequestBody {
844            checkpoint_token: "token-123".to_string(),
845            updates: vec![OperationUpdate::start("op-1", OperationType::Step)],
846        };
847        let json = serde_json::to_string(&request).unwrap();
848        assert!(json.contains(r#""CheckpointToken":"token-123""#));
849        assert!(json.contains(r#""Updates""#));
850    }
851
852    #[test]
853    fn test_get_operations_request_body_serialization() {
854        let request = GetOperationsRequestBody {
855            next_marker: "marker-123".to_string(),
856        };
857        let json = serde_json::to_string(&request).unwrap();
858        assert!(json.contains(r#""NextMarker":"marker-123""#));
859    }
860
861    #[tokio::test]
862    async fn test_mock_client_checkpoint_with_invalid_token_error() {
863        // Test that invalid checkpoint token errors are properly marked as retriable
864        let error = DurableError::Checkpoint {
865            message: "Checkpoint API returned 400: Invalid checkpoint token".to_string(),
866            is_retriable: true,
867            aws_error: Some(AwsError {
868                code: "InvalidParameterValueException".to_string(),
869                message: "Invalid checkpoint token: token has been consumed".to_string(),
870                request_id: None,
871            }),
872        };
873
874        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
875        let result = client
876            .checkpoint(
877                "arn:aws:lambda:us-east-1:123456789012:function:test",
878                "consumed-token",
879                vec![],
880            )
881            .await;
882
883        assert!(result.is_err());
884        let err = result.unwrap_err();
885        assert!(err.is_retriable());
886        assert!(err.is_invalid_checkpoint_token());
887    }
888
889    #[tokio::test]
890    async fn test_mock_client_checkpoint_with_size_limit_error() {
891        // Test that size limit errors are properly returned as non-retriable
892        let error = DurableError::SizeLimit {
893            message: "Checkpoint payload size exceeded".to_string(),
894            actual_size: Some(7_000_000),
895            max_size: Some(6_000_000),
896        };
897
898        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
899        let result = client
900            .checkpoint(
901                "arn:aws:lambda:us-east-1:123456789012:function:test",
902                "token-123",
903                vec![],
904            )
905            .await;
906
907        assert!(result.is_err());
908        let err = result.unwrap_err();
909        assert!(err.is_size_limit());
910        assert!(!err.is_retriable());
911    }
912
913    #[tokio::test]
914    async fn test_mock_client_checkpoint_with_throttling_error() {
915        // Test that throttling errors are properly returned
916        let error = DurableError::Throttling {
917            message: "Rate limit exceeded".to_string(),
918            retry_after_ms: Some(5000),
919        };
920
921        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
922        let result = client
923            .checkpoint(
924                "arn:aws:lambda:us-east-1:123456789012:function:test",
925                "token-123",
926                vec![],
927            )
928            .await;
929
930        assert!(result.is_err());
931        let err = result.unwrap_err();
932        assert!(err.is_throttling());
933        assert_eq!(err.get_retry_after_ms(), Some(5000));
934    }
935
936    #[tokio::test]
937    async fn test_mock_client_checkpoint_with_resource_not_found_error() {
938        // Test that resource not found errors are properly returned
939        let error = DurableError::ResourceNotFound {
940            message: "Durable execution not found".to_string(),
941            resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
942        };
943
944        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
945        let result = client
946            .checkpoint(
947                "arn:aws:lambda:us-east-1:123456789012:function:test",
948                "token-123",
949                vec![],
950            )
951            .await;
952
953        assert!(result.is_err());
954        let err = result.unwrap_err();
955        assert!(err.is_resource_not_found());
956        assert!(!err.is_retriable());
957    }
958
959    #[tokio::test]
960    async fn test_mock_client_get_operations_with_throttling_error() {
961        // Test that throttling errors are properly returned for get_operations
962        let error = DurableError::Throttling {
963            message: "Rate limit exceeded".to_string(),
964            retry_after_ms: None,
965        };
966
967        let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
968        let result = client
969            .get_operations(
970                "arn:aws:lambda:us-east-1:123456789012:function:test",
971                "marker-123",
972            )
973            .await;
974
975        assert!(result.is_err());
976        let err = result.unwrap_err();
977        assert!(err.is_throttling());
978    }
979
980    #[tokio::test]
981    async fn test_mock_client_get_operations_with_resource_not_found_error() {
982        // Test that resource not found errors are properly returned for get_operations
983        let error = DurableError::ResourceNotFound {
984            message: "Durable execution not found".to_string(),
985            resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
986        };
987
988        let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
989        let result = client
990            .get_operations(
991                "arn:aws:lambda:us-east-1:123456789012:function:test",
992                "marker-123",
993            )
994            .await;
995
996        assert!(result.is_err());
997        let err = result.unwrap_err();
998        assert!(err.is_resource_not_found());
999    }
1000
1001    /// Helper to build a test SdkConfig with fake credentials.
1002    fn test_sdk_config() -> aws_config::SdkConfig {
1003        let creds =
1004            aws_credential_types::Credentials::new("test-key", "test-secret", None, None, "test");
1005        let provider = aws_credential_types::provider::SharedCredentialsProvider::new(creds);
1006        aws_config::SdkConfig::builder()
1007            .credentials_provider(provider)
1008            .region(aws_config::Region::new("us-east-1"))
1009            .build()
1010    }
1011
1012    #[test]
1013    fn test_from_aws_config_with_user_agent_constructs_client() {
1014        // Verify from_aws_config_with_user_agent creates a client successfully (Req 14.1)
1015        let config = test_sdk_config();
1016        let client =
1017            LambdaDurableServiceClient::from_aws_config_with_user_agent(&config, "my-sdk", "1.0.0")
1018                .unwrap();
1019        // Client should be constructed with the correct region config
1020        assert_eq!(client.config.region, "us-east-1");
1021    }
1022
1023    #[test]
1024    fn test_from_aws_config_unchanged() {
1025        // Verify existing from_aws_config still works without user-agent (Req 14.2)
1026        let config = test_sdk_config();
1027        let client = LambdaDurableServiceClient::from_aws_config(&config).unwrap();
1028        assert_eq!(client.config.region, "us-east-1");
1029    }
1030
1031    #[tokio::test]
1032    async fn test_user_agent_header_is_sent() {
1033        // Verify the user-agent header is actually included in HTTP requests (Req 14.1)
1034        use tokio::io::{AsyncReadExt, AsyncWriteExt};
1035        use tokio::net::TcpListener;
1036
1037        // Start a local TCP server that captures the User-Agent header
1038        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1039        let addr = listener.local_addr().unwrap();
1040
1041        let server_handle = tokio::spawn(async move {
1042            let (mut socket, _) = listener.accept().await.unwrap();
1043            let mut buf = vec![0u8; 4096];
1044            let n = socket.read(&mut buf).await.unwrap();
1045            let request = String::from_utf8_lossy(&buf[..n]).to_string();
1046
1047            // Send a minimal HTTP response
1048            let response = "HTTP/1.1 400 Bad Request\r\nContent-Length: 2\r\n\r\n{}";
1049            let _ = socket.write_all(response.as_bytes()).await;
1050            let _ = socket.shutdown().await;
1051
1052            request
1053        });
1054
1055        // Create client with custom user-agent pointing at our local server
1056        let sdk_name = "test-sdk";
1057        let sdk_version = "2.3.4";
1058        let user_agent = format!("{}/{}", sdk_name, sdk_version);
1059
1060        let config = test_sdk_config();
1061        let mut client = LambdaDurableServiceClient::from_aws_config_with_user_agent(
1062            &config,
1063            sdk_name,
1064            sdk_version,
1065        )
1066        .unwrap();
1067        client.config.endpoint_url = Some(format!("http://{}", addr));
1068
1069        // Make a request (it will fail, but we just need to capture the headers)
1070        let _ = client
1071            .checkpoint("arn:test:execution", "token", vec![])
1072            .await;
1073
1074        // Check the captured request contains our user-agent
1075        let captured_request = server_handle.await.unwrap();
1076        assert!(
1077            captured_request.contains(&user_agent),
1078            "Expected User-Agent '{}' in request headers, got:\n{}",
1079            user_agent,
1080            captured_request
1081        );
1082    }
1083}