Skip to main content

durable_execution_sdk/
client.rs

1//! Service client for the AWS Durable Execution SDK.
2//!
3//! This module defines the `DurableServiceClient` trait and provides
4//! a Lambda-based implementation for communicating with the AWS Lambda
5//! durable execution service using the CheckpointDurableExecution and
6//! GetDurableExecutionState REST APIs.
7
8use std::sync::Arc;
9use std::time::SystemTime;
10
11use async_trait::async_trait;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings};
14use aws_sigv4::sign::v4;
15use serde::{Deserialize, Serialize};
16
17use crate::error::{AwsError, DurableError};
18use crate::operation::{Operation, OperationUpdate};
19
20/// Trait for communicating with the durable execution service.
21///
22/// This trait abstracts the communication layer, allowing for different
23/// implementations (e.g., Lambda client, mock client for testing).
24#[async_trait]
25pub trait DurableServiceClient: Send + Sync {
26    /// Sends a batch of checkpoint operations to the service.
27    ///
28    /// # Arguments
29    ///
30    /// * `durable_execution_arn` - The ARN of the durable execution
31    /// * `checkpoint_token` - The token for this checkpoint batch
32    /// * `operations` - The operations to checkpoint
33    ///
34    /// # Returns
35    ///
36    /// A new checkpoint token on success, or an error on failure.
37    async fn checkpoint(
38        &self,
39        durable_execution_arn: &str,
40        checkpoint_token: &str,
41        operations: Vec<OperationUpdate>,
42    ) -> Result<CheckpointResponse, DurableError>;
43
44    /// Retrieves additional operations for pagination.
45    ///
46    /// # Arguments
47    ///
48    /// * `durable_execution_arn` - The ARN of the durable execution
49    /// * `next_marker` - The pagination marker from the previous response
50    ///
51    /// # Returns
52    ///
53    /// A list of operations and an optional next marker for further pagination.
54    async fn get_operations(
55        &self,
56        durable_execution_arn: &str,
57        next_marker: &str,
58    ) -> Result<GetOperationsResponse, DurableError>;
59}
60
61/// Response from a checkpoint operation.
62#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct CheckpointResponse {
64    /// The new checkpoint token to use for subsequent checkpoints
65    #[serde(rename = "CheckpointToken", default)]
66    pub checkpoint_token: String,
67
68    /// The new execution state containing updated operations
69    /// This includes service-generated values like CallbackDetails.CallbackId
70    #[serde(
71        rename = "NewExecutionState",
72        skip_serializing_if = "Option::is_none",
73        default
74    )]
75    pub new_execution_state: Option<NewExecutionState>,
76}
77
78impl CheckpointResponse {
79    /// Creates a new CheckpointResponse with just a checkpoint token.
80    pub fn new(checkpoint_token: impl Into<String>) -> Self {
81        Self {
82            checkpoint_token: checkpoint_token.into(),
83            new_execution_state: None,
84        }
85    }
86}
87
88/// New execution state returned from checkpoint operations.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NewExecutionState {
91    /// The updated operations with service-generated values
92    #[serde(rename = "Operations", default)]
93    pub operations: Vec<Operation>,
94
95    /// Marker for the next page of results, if any
96    #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
97    pub next_marker: Option<String>,
98}
99
100impl NewExecutionState {
101    /// Finds an operation by its ID.
102    ///
103    /// # Arguments
104    ///
105    /// * `operation_id` - The ID of the operation to find
106    ///
107    /// # Returns
108    ///
109    /// A reference to the operation if found, None otherwise.
110    pub fn find_operation(&self, operation_id: &str) -> Option<&Operation> {
111        self.operations
112            .iter()
113            .find(|op| op.operation_id == operation_id)
114    }
115}
116
117/// Response from a get_operations call.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct GetOperationsResponse {
120    /// The retrieved operations
121    #[serde(rename = "Operations")]
122    pub operations: Vec<Operation>,
123
124    /// Marker for the next page of results, if any
125    #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
126    pub next_marker: Option<String>,
127}
128
129/// Configuration for the Lambda durable service client.
130#[derive(Debug, Clone)]
131pub struct LambdaClientConfig {
132    /// AWS region for the Lambda service
133    pub region: String,
134    /// Optional custom endpoint URL (for testing)
135    pub endpoint_url: Option<String>,
136}
137
138impl Default for LambdaClientConfig {
139    fn default() -> Self {
140        Self {
141            region: "us-east-1".to_string(),
142            endpoint_url: None,
143        }
144    }
145}
146
147impl LambdaClientConfig {
148    /// Creates a new LambdaClientConfig with the specified region.
149    pub fn with_region(region: impl Into<String>) -> Self {
150        Self {
151            region: region.into(),
152            endpoint_url: None,
153        }
154    }
155
156    /// Creates a new LambdaClientConfig from AWS SDK config.
157    pub fn from_aws_config(config: &aws_config::SdkConfig) -> Self {
158        Self {
159            region: config
160                .region()
161                .map(|r| r.to_string())
162                .unwrap_or_else(|| "us-east-1".to_string()),
163            endpoint_url: None,
164        }
165    }
166}
167
168/// Lambda-based implementation of the DurableServiceClient.
169///
170/// This client uses the AWS Lambda REST APIs (CheckpointDurableExecution and
171/// GetDurableExecutionState) to communicate with the durable execution service.
172pub struct LambdaDurableServiceClient {
173    /// HTTP client for making requests
174    http_client: reqwest::Client,
175    /// AWS credentials provider
176    credentials_provider: Arc<dyn ProvideCredentials>,
177    /// Configuration for the client
178    config: LambdaClientConfig,
179}
180
181impl LambdaDurableServiceClient {
182    /// Creates a new LambdaDurableServiceClient from AWS config.
183    pub async fn from_env() -> Self {
184        let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
185            .load()
186            .await;
187        Self::from_aws_config(&aws_config)
188    }
189
190    /// Creates a new LambdaDurableServiceClient from AWS SDK config.
191    pub fn from_aws_config(aws_config: &aws_config::SdkConfig) -> Self {
192        let credentials_provider = aws_config
193            .credentials_provider()
194            .expect("No credentials provider configured")
195            .clone();
196
197        Self {
198            http_client: reqwest::Client::new(),
199            credentials_provider: Arc::from(credentials_provider),
200            config: LambdaClientConfig::from_aws_config(aws_config),
201        }
202    }
203
204    /// Creates a new LambdaDurableServiceClient from AWS SDK config with a custom
205    /// user-agent string appended to HTTP requests for SDK identification.
206    pub fn from_aws_config_with_user_agent(
207        aws_config: &aws_config::SdkConfig,
208        sdk_name: &str,
209        sdk_version: &str,
210    ) -> Self {
211        let credentials_provider = aws_config
212            .credentials_provider()
213            .expect("No credentials provider configured")
214            .clone();
215
216        let user_agent = format!("{}/{}", sdk_name, sdk_version);
217        let mut headers = reqwest::header::HeaderMap::new();
218        if let Ok(value) = reqwest::header::HeaderValue::from_str(&user_agent) {
219            headers.insert(reqwest::header::USER_AGENT, value);
220        }
221
222        let http_client = reqwest::Client::builder()
223            .default_headers(headers)
224            .build()
225            .unwrap_or_else(|_| reqwest::Client::new());
226
227        Self {
228            http_client,
229            credentials_provider: Arc::from(credentials_provider),
230            config: LambdaClientConfig::from_aws_config(aws_config),
231        }
232    }
233
234    /// Creates a new LambdaDurableServiceClient with custom configuration.
235    pub fn with_config(
236        credentials_provider: Arc<dyn ProvideCredentials>,
237        config: LambdaClientConfig,
238    ) -> Self {
239        Self {
240            http_client: reqwest::Client::new(),
241            credentials_provider,
242            config,
243        }
244    }
245
246    /// Creates a new LambdaDurableServiceClient with a Lambda client (for backward compatibility).
247    /// Note: This extracts the region from the Lambda client but uses direct HTTP calls.
248    ///
249    /// IMPORTANT: This method requires that the Lambda client was created with credentials.
250    /// If you're using this in a Lambda function, prefer using `from_env()` instead.
251    pub fn new(_lambda_client: aws_sdk_lambda::Client) -> Self {
252        // Note: We can't easily extract credentials from the Lambda client anymore
253        // as the credentials_provider() method is deprecated and returns None.
254        // This method is kept for backward compatibility but will panic if used.
255        // Users should use from_env() or from_aws_config() instead.
256        panic!(
257            "LambdaDurableServiceClient::new() is deprecated. \
258             Use LambdaDurableServiceClient::from_env() or from_aws_config() instead."
259        );
260    }
261
262    /// Returns the Lambda service endpoint URL.
263    fn endpoint_url(&self) -> String {
264        self.config
265            .endpoint_url
266            .clone()
267            .unwrap_or_else(|| format!("https://lambda.{}.amazonaws.com", self.config.region))
268    }
269
270    /// Signs an HTTP request using AWS SigV4 and returns the signed headers.
271    async fn sign_request(
272        &self,
273        method: &str,
274        uri: &str,
275        body: &[u8],
276    ) -> Result<Vec<(String, String)>, DurableError> {
277        let credentials = self
278            .credentials_provider
279            .provide_credentials()
280            .await
281            .map_err(|e| DurableError::Checkpoint {
282                message: format!("Failed to get AWS credentials: {}", e),
283                is_retriable: true,
284                aws_error: None,
285            })?;
286
287        let identity = credentials.into();
288        let signing_settings = SigningSettings::default();
289        let signing_params = v4::SigningParams::builder()
290            .identity(&identity)
291            .region(&self.config.region)
292            .name("lambda")
293            .time(SystemTime::now())
294            .settings(signing_settings)
295            .build()
296            .map_err(|e| DurableError::Checkpoint {
297                message: format!("Failed to build signing params: {}", e),
298                is_retriable: false,
299                aws_error: None,
300            })?;
301
302        let signable_request = SignableRequest::new(
303            method,
304            uri,
305            std::iter::empty::<(&str, &str)>(),
306            SignableBody::Bytes(body),
307        )
308        .map_err(|e| DurableError::Checkpoint {
309            message: format!("Failed to create signable request: {}", e),
310            is_retriable: false,
311            aws_error: None,
312        })?;
313
314        let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
315            .map_err(|e| DurableError::Checkpoint {
316                message: format!("Failed to sign request: {}", e),
317                is_retriable: false,
318                aws_error: None,
319            })?
320            .into_parts();
321
322        // Build a temporary HTTP request to apply signing instructions
323        let mut temp_request = http::Request::builder()
324            .method(method)
325            .uri(uri)
326            .body(())
327            .map_err(|e| DurableError::Checkpoint {
328                message: format!("Failed to build temp request: {}", e),
329                is_retriable: false,
330                aws_error: None,
331            })?;
332
333        signing_instructions.apply_to_request_http1x(&mut temp_request);
334
335        // Extract headers from the signed request
336        let headers: Vec<(String, String)> = temp_request
337            .headers()
338            .iter()
339            .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
340            .collect();
341
342        Ok(headers)
343    }
344}
345
346/// Request payload for checkpoint operations.
347#[derive(Debug, Clone, Serialize)]
348struct CheckpointRequestBody {
349    #[serde(rename = "CheckpointToken")]
350    checkpoint_token: String,
351    #[serde(rename = "Updates")]
352    updates: Vec<OperationUpdate>,
353}
354
355/// Request payload for get_operations.
356#[derive(Debug, Clone, Serialize)]
357struct GetOperationsRequestBody {
358    #[serde(rename = "NextMarker")]
359    next_marker: String,
360}
361
362#[async_trait]
363impl DurableServiceClient for LambdaDurableServiceClient {
364    async fn checkpoint(
365        &self,
366        durable_execution_arn: &str,
367        checkpoint_token: &str,
368        operations: Vec<OperationUpdate>,
369    ) -> Result<CheckpointResponse, DurableError> {
370        let request_body = CheckpointRequestBody {
371            checkpoint_token: checkpoint_token.to_string(),
372            updates: operations,
373        };
374
375        let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
376            message: format!("Failed to serialize checkpoint request: {}", e),
377        })?;
378
379        // URL-encode the durable execution ARN for the path
380        let encoded_arn = urlencoding::encode(durable_execution_arn);
381        let uri = format!(
382            "{}/2025-12-01/durable-executions/{}/checkpoint",
383            self.endpoint_url(),
384            encoded_arn
385        );
386
387        // Sign the request
388        let signed_headers = self.sign_request("POST", &uri, &body).await?;
389
390        // Build and send the request
391        let mut request = self
392            .http_client
393            .post(&uri)
394            .header("Content-Type", "application/json")
395            .body(body);
396
397        for (name, value) in signed_headers {
398            request = request.header(&name, &value);
399        }
400
401        let response = request.send().await.map_err(|e| DurableError::Checkpoint {
402            message: format!("HTTP request failed: {}", e),
403            is_retriable: e.is_timeout() || e.is_connect(),
404            aws_error: None,
405        })?;
406
407        let status = response.status();
408        let response_body = response
409            .bytes()
410            .await
411            .map_err(|e| DurableError::Checkpoint {
412                message: format!("Failed to read response body: {}", e),
413                is_retriable: true,
414                aws_error: None,
415            })?;
416
417        if !status.is_success() {
418            let error_message = String::from_utf8_lossy(&response_body);
419
420            // Check for specific error conditions
421
422            // Size limit exceeded (413 Request Entity Too Large or specific error message)
423            // Requirements: 25.6 - Handle size limit errors gracefully, return FAILED without retry
424            if status.as_u16() == 413
425                || error_message.contains("RequestEntityTooLarge")
426                || error_message.contains("Payload size exceeded")
427                || error_message.contains("Request too large")
428            {
429                return Err(DurableError::SizeLimit {
430                    message: format!("Checkpoint payload size exceeded: {}", error_message),
431                    actual_size: None,
432                    max_size: None,
433                });
434            }
435
436            // Resource not found (404)
437            // Requirements: 18.6 - Handle ResourceNotFoundException appropriately
438            if status.as_u16() == 404
439                || error_message.contains("ResourceNotFoundException")
440                || error_message.contains("not found")
441            {
442                return Err(DurableError::ResourceNotFound {
443                    message: format!("Durable execution not found: {}", error_message),
444                    resource_id: Some(durable_execution_arn.to_string()),
445                });
446            }
447
448            // Throttling (429 Too Many Requests)
449            // Requirements: 18.5 - Handle ThrottlingException with appropriate retry behavior
450            if status.as_u16() == 429
451                || error_message.contains("ThrottlingException")
452                || error_message.contains("TooManyRequestsException")
453                || error_message.contains("Rate exceeded")
454            {
455                return Err(DurableError::Throttling {
456                    message: format!("Rate limit exceeded: {}", error_message),
457                    retry_after_ms: None, // Could parse Retry-After header if available
458                });
459            }
460
461            // Check for InvalidParameterValueException with "Invalid checkpoint token" message
462            // This indicates the token was already consumed or is invalid, and Lambda should retry
463            // Requirements: 2.11
464            let is_invalid_token = error_message.contains("Invalid checkpoint token");
465
466            // Determine if the error is retriable:
467            // - Server errors (5xx) are retriable
468            // - Invalid checkpoint token errors are retriable (Lambda will provide a fresh token)
469            // Note: Throttling (429) is handled separately above
470            let is_retriable = status.is_server_error() || is_invalid_token;
471
472            return Err(DurableError::Checkpoint {
473                message: format!("Checkpoint API returned {}: {}", status, error_message),
474                is_retriable,
475                aws_error: Some(AwsError {
476                    code: if is_invalid_token {
477                        "InvalidParameterValueException".to_string()
478                    } else {
479                        status.to_string()
480                    },
481                    message: error_message.to_string(),
482                    request_id: None,
483                }),
484            });
485        }
486
487        let checkpoint_response: CheckpointResponse = serde_json::from_slice(&response_body)
488            .map_err(|e| DurableError::SerDes {
489                message: format!("Failed to deserialize checkpoint response: {}", e),
490            })?;
491
492        Ok(checkpoint_response)
493    }
494
495    async fn get_operations(
496        &self,
497        durable_execution_arn: &str,
498        next_marker: &str,
499    ) -> Result<GetOperationsResponse, DurableError> {
500        let request_body = GetOperationsRequestBody {
501            next_marker: next_marker.to_string(),
502        };
503
504        let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
505            message: format!("Failed to serialize get_operations request: {}", e),
506        })?;
507
508        // URL-encode the durable execution ARN for the path
509        let encoded_arn = urlencoding::encode(durable_execution_arn);
510        let uri = format!(
511            "{}/2025-12-01/durable-executions/{}/state",
512            self.endpoint_url(),
513            encoded_arn
514        );
515
516        // Sign the request
517        let signed_headers = self.sign_request("POST", &uri, &body).await?;
518
519        // Build and send the request
520        let mut request = self
521            .http_client
522            .post(&uri)
523            .header("Content-Type", "application/json")
524            .body(body);
525
526        for (name, value) in signed_headers {
527            request = request.header(&name, &value);
528        }
529
530        let response = request.send().await.map_err(|e| DurableError::Invocation {
531            message: format!("HTTP request failed: {}", e),
532            termination_reason: crate::error::TerminationReason::InvocationError,
533        })?;
534
535        let status = response.status();
536        let response_body = response
537            .bytes()
538            .await
539            .map_err(|e| DurableError::Invocation {
540                message: format!("Failed to read response body: {}", e),
541                termination_reason: crate::error::TerminationReason::InvocationError,
542            })?;
543
544        if !status.is_success() {
545            let error_message = String::from_utf8_lossy(&response_body);
546
547            // Resource not found (404)
548            // Requirements: 18.6 - Handle ResourceNotFoundException appropriately
549            if status.as_u16() == 404
550                || error_message.contains("ResourceNotFoundException")
551                || error_message.contains("not found")
552            {
553                return Err(DurableError::ResourceNotFound {
554                    message: format!("Durable execution not found: {}", error_message),
555                    resource_id: Some(durable_execution_arn.to_string()),
556                });
557            }
558
559            // Throttling (429 Too Many Requests)
560            // Requirements: 18.5 - Handle ThrottlingException with appropriate retry behavior
561            if status.as_u16() == 429
562                || error_message.contains("ThrottlingException")
563                || error_message.contains("TooManyRequestsException")
564                || error_message.contains("Rate exceeded")
565            {
566                return Err(DurableError::Throttling {
567                    message: format!("Rate limit exceeded: {}", error_message),
568                    retry_after_ms: None,
569                });
570            }
571
572            return Err(DurableError::Invocation {
573                message: format!(
574                    "GetDurableExecutionState API returned {}: {}",
575                    status, error_message
576                ),
577                termination_reason: crate::error::TerminationReason::InvocationError,
578            });
579        }
580
581        let operations_response: GetOperationsResponse = serde_json::from_slice(&response_body)
582            .map_err(|e| DurableError::SerDes {
583                message: format!("Failed to deserialize get_operations response: {}", e),
584            })?;
585
586        Ok(operations_response)
587    }
588}
589
590/// A mock implementation of DurableServiceClient for testing.
591#[cfg(test)]
592pub struct MockDurableServiceClient {
593    checkpoint_responses: std::sync::Mutex<Vec<Result<CheckpointResponse, DurableError>>>,
594    get_operations_responses: std::sync::Mutex<Vec<Result<GetOperationsResponse, DurableError>>>,
595}
596
597#[cfg(test)]
598impl MockDurableServiceClient {
599    pub fn new() -> Self {
600        Self {
601            checkpoint_responses: std::sync::Mutex::new(Vec::new()),
602            get_operations_responses: std::sync::Mutex::new(Vec::new()),
603        }
604    }
605
606    pub fn with_checkpoint_response(
607        self,
608        response: Result<CheckpointResponse, DurableError>,
609    ) -> Self {
610        self.checkpoint_responses.lock().unwrap().push(response);
611        self
612    }
613
614    pub fn with_get_operations_response(
615        self,
616        response: Result<GetOperationsResponse, DurableError>,
617    ) -> Self {
618        self.get_operations_responses.lock().unwrap().push(response);
619        self
620    }
621
622    /// Adds multiple checkpoint responses at once.
623    pub fn with_checkpoint_responses(self, count: usize) -> Self {
624        let mut responses = self.checkpoint_responses.lock().unwrap();
625        for i in 0..count {
626            responses.push(Ok(CheckpointResponse {
627                checkpoint_token: format!("token-{}", i),
628                new_execution_state: None,
629            }));
630        }
631        drop(responses);
632        self
633    }
634}
635
636#[cfg(test)]
637#[async_trait]
638impl DurableServiceClient for MockDurableServiceClient {
639    async fn checkpoint(
640        &self,
641        _durable_execution_arn: &str,
642        _checkpoint_token: &str,
643        _operations: Vec<OperationUpdate>,
644    ) -> Result<CheckpointResponse, DurableError> {
645        let mut responses = self.checkpoint_responses.lock().unwrap();
646        if responses.is_empty() {
647            Ok(CheckpointResponse {
648                checkpoint_token: "mock-token".to_string(),
649                new_execution_state: None,
650            })
651        } else {
652            responses.remove(0)
653        }
654    }
655
656    async fn get_operations(
657        &self,
658        _durable_execution_arn: &str,
659        _next_marker: &str,
660    ) -> Result<GetOperationsResponse, DurableError> {
661        let mut responses = self.get_operations_responses.lock().unwrap();
662        if responses.is_empty() {
663            Ok(GetOperationsResponse {
664                operations: Vec::new(),
665                next_marker: None,
666            })
667        } else {
668            responses.remove(0)
669        }
670    }
671}
672
673/// Type alias for a shared DurableServiceClient.
674pub type SharedDurableServiceClient = Arc<dyn DurableServiceClient>;
675
676#[cfg(test)]
677mod tests {
678    use super::*;
679    use crate::operation::OperationType;
680
681    #[test]
682    fn test_checkpoint_response_serialization() {
683        let response = CheckpointResponse {
684            checkpoint_token: "token-123".to_string(),
685            new_execution_state: None,
686        };
687        let json = serde_json::to_string(&response).unwrap();
688        assert!(json.contains(r#""CheckpointToken":"token-123""#));
689    }
690
691    #[test]
692    fn test_checkpoint_response_deserialization() {
693        let json = r#"{"CheckpointToken": "token-456"}"#;
694        let response: CheckpointResponse = serde_json::from_str(json).unwrap();
695        assert_eq!(response.checkpoint_token, "token-456");
696    }
697
698    #[test]
699    fn test_get_operations_response_serialization() {
700        let response = GetOperationsResponse {
701            operations: vec![Operation::new("op-1", OperationType::Step)],
702            next_marker: Some("marker-123".to_string()),
703        };
704        let json = serde_json::to_string(&response).unwrap();
705        assert!(json.contains(r#""Operations""#));
706        assert!(json.contains(r#""NextMarker":"marker-123""#));
707    }
708
709    #[test]
710    fn test_get_operations_response_deserialization() {
711        let json = r#"{
712            "Operations": [
713                {
714                    "Id": "op-1",
715                    "Type": "STEP",
716                    "Status": "SUCCEEDED"
717                }
718            ],
719            "NextMarker": "marker-456"
720        }"#;
721        let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
722        assert_eq!(response.operations.len(), 1);
723        assert_eq!(response.operations[0].operation_id, "op-1");
724        assert_eq!(response.next_marker, Some("marker-456".to_string()));
725    }
726
727    #[test]
728    fn test_get_operations_response_without_marker() {
729        let json = r#"{
730            "Operations": []
731        }"#;
732        let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
733        assert!(response.operations.is_empty());
734        assert!(response.next_marker.is_none());
735    }
736
737    #[test]
738    fn test_lambda_client_config_default() {
739        let config = LambdaClientConfig::default();
740        assert_eq!(config.region, "us-east-1");
741        assert!(config.endpoint_url.is_none());
742    }
743
744    #[test]
745    fn test_lambda_client_config_with_region() {
746        let config = LambdaClientConfig::with_region("us-west-2");
747        assert_eq!(config.region, "us-west-2");
748    }
749
750    #[tokio::test]
751    async fn test_mock_client_checkpoint() {
752        let client = MockDurableServiceClient::new();
753        let result = client
754            .checkpoint(
755                "arn:aws:lambda:us-east-1:123456789012:function:test",
756                "token-123",
757                vec![],
758            )
759            .await;
760        assert!(result.is_ok());
761        assert_eq!(result.unwrap().checkpoint_token, "mock-token");
762    }
763
764    #[tokio::test]
765    async fn test_mock_client_checkpoint_with_custom_response() {
766        let client =
767            MockDurableServiceClient::new().with_checkpoint_response(Ok(CheckpointResponse {
768                checkpoint_token: "custom-token".to_string(),
769                new_execution_state: None,
770            }));
771        let result = client
772            .checkpoint(
773                "arn:aws:lambda:us-east-1:123456789012:function:test",
774                "token-123",
775                vec![],
776            )
777            .await;
778        assert!(result.is_ok());
779        assert_eq!(result.unwrap().checkpoint_token, "custom-token");
780    }
781
782    #[tokio::test]
783    async fn test_mock_client_checkpoint_with_error() {
784        let client = MockDurableServiceClient::new()
785            .with_checkpoint_response(Err(DurableError::checkpoint_retriable("Test error")));
786        let result = client
787            .checkpoint(
788                "arn:aws:lambda:us-east-1:123456789012:function:test",
789                "token-123",
790                vec![],
791            )
792            .await;
793        assert!(result.is_err());
794        assert!(result.unwrap_err().is_retriable());
795    }
796
797    #[tokio::test]
798    async fn test_mock_client_get_operations() {
799        let client = MockDurableServiceClient::new();
800        let result = client
801            .get_operations(
802                "arn:aws:lambda:us-east-1:123456789012:function:test",
803                "marker-123",
804            )
805            .await;
806        assert!(result.is_ok());
807        let response = result.unwrap();
808        assert!(response.operations.is_empty());
809        assert!(response.next_marker.is_none());
810    }
811
812    #[tokio::test]
813    async fn test_mock_client_get_operations_with_custom_response() {
814        let client = MockDurableServiceClient::new().with_get_operations_response(Ok(
815            GetOperationsResponse {
816                operations: vec![Operation::new("op-1", OperationType::Step)],
817                next_marker: Some("next-marker".to_string()),
818            },
819        ));
820        let result = client
821            .get_operations(
822                "arn:aws:lambda:us-east-1:123456789012:function:test",
823                "marker-123",
824            )
825            .await;
826        assert!(result.is_ok());
827        let response = result.unwrap();
828        assert_eq!(response.operations.len(), 1);
829        assert_eq!(response.next_marker, Some("next-marker".to_string()));
830    }
831
832    #[test]
833    fn test_checkpoint_request_body_serialization() {
834        let request = CheckpointRequestBody {
835            checkpoint_token: "token-123".to_string(),
836            updates: vec![OperationUpdate::start("op-1", OperationType::Step)],
837        };
838        let json = serde_json::to_string(&request).unwrap();
839        assert!(json.contains(r#""CheckpointToken":"token-123""#));
840        assert!(json.contains(r#""Updates""#));
841    }
842
843    #[test]
844    fn test_get_operations_request_body_serialization() {
845        let request = GetOperationsRequestBody {
846            next_marker: "marker-123".to_string(),
847        };
848        let json = serde_json::to_string(&request).unwrap();
849        assert!(json.contains(r#""NextMarker":"marker-123""#));
850    }
851
852    #[tokio::test]
853    async fn test_mock_client_checkpoint_with_invalid_token_error() {
854        // Test that invalid checkpoint token errors are properly marked as retriable
855        let error = DurableError::Checkpoint {
856            message: "Checkpoint API returned 400: Invalid checkpoint token".to_string(),
857            is_retriable: true,
858            aws_error: Some(AwsError {
859                code: "InvalidParameterValueException".to_string(),
860                message: "Invalid checkpoint token: token has been consumed".to_string(),
861                request_id: None,
862            }),
863        };
864
865        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
866        let result = client
867            .checkpoint(
868                "arn:aws:lambda:us-east-1:123456789012:function:test",
869                "consumed-token",
870                vec![],
871            )
872            .await;
873
874        assert!(result.is_err());
875        let err = result.unwrap_err();
876        assert!(err.is_retriable());
877        assert!(err.is_invalid_checkpoint_token());
878    }
879
880    #[tokio::test]
881    async fn test_mock_client_checkpoint_with_size_limit_error() {
882        // Test that size limit errors are properly returned as non-retriable
883        let error = DurableError::SizeLimit {
884            message: "Checkpoint payload size exceeded".to_string(),
885            actual_size: Some(7_000_000),
886            max_size: Some(6_000_000),
887        };
888
889        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
890        let result = client
891            .checkpoint(
892                "arn:aws:lambda:us-east-1:123456789012:function:test",
893                "token-123",
894                vec![],
895            )
896            .await;
897
898        assert!(result.is_err());
899        let err = result.unwrap_err();
900        assert!(err.is_size_limit());
901        assert!(!err.is_retriable());
902    }
903
904    #[tokio::test]
905    async fn test_mock_client_checkpoint_with_throttling_error() {
906        // Test that throttling errors are properly returned
907        let error = DurableError::Throttling {
908            message: "Rate limit exceeded".to_string(),
909            retry_after_ms: Some(5000),
910        };
911
912        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
913        let result = client
914            .checkpoint(
915                "arn:aws:lambda:us-east-1:123456789012:function:test",
916                "token-123",
917                vec![],
918            )
919            .await;
920
921        assert!(result.is_err());
922        let err = result.unwrap_err();
923        assert!(err.is_throttling());
924        assert_eq!(err.get_retry_after_ms(), Some(5000));
925    }
926
927    #[tokio::test]
928    async fn test_mock_client_checkpoint_with_resource_not_found_error() {
929        // Test that resource not found errors are properly returned
930        let error = DurableError::ResourceNotFound {
931            message: "Durable execution not found".to_string(),
932            resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
933        };
934
935        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
936        let result = client
937            .checkpoint(
938                "arn:aws:lambda:us-east-1:123456789012:function:test",
939                "token-123",
940                vec![],
941            )
942            .await;
943
944        assert!(result.is_err());
945        let err = result.unwrap_err();
946        assert!(err.is_resource_not_found());
947        assert!(!err.is_retriable());
948    }
949
950    #[tokio::test]
951    async fn test_mock_client_get_operations_with_throttling_error() {
952        // Test that throttling errors are properly returned for get_operations
953        let error = DurableError::Throttling {
954            message: "Rate limit exceeded".to_string(),
955            retry_after_ms: None,
956        };
957
958        let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
959        let result = client
960            .get_operations(
961                "arn:aws:lambda:us-east-1:123456789012:function:test",
962                "marker-123",
963            )
964            .await;
965
966        assert!(result.is_err());
967        let err = result.unwrap_err();
968        assert!(err.is_throttling());
969    }
970
971    #[tokio::test]
972    async fn test_mock_client_get_operations_with_resource_not_found_error() {
973        // Test that resource not found errors are properly returned for get_operations
974        let error = DurableError::ResourceNotFound {
975            message: "Durable execution not found".to_string(),
976            resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
977        };
978
979        let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
980        let result = client
981            .get_operations(
982                "arn:aws:lambda:us-east-1:123456789012:function:test",
983                "marker-123",
984            )
985            .await;
986
987        assert!(result.is_err());
988        let err = result.unwrap_err();
989        assert!(err.is_resource_not_found());
990    }
991
992    /// Helper to build a test SdkConfig with fake credentials.
993    fn test_sdk_config() -> aws_config::SdkConfig {
994        let creds =
995            aws_credential_types::Credentials::new("test-key", "test-secret", None, None, "test");
996        let provider = aws_credential_types::provider::SharedCredentialsProvider::new(creds);
997        aws_config::SdkConfig::builder()
998            .credentials_provider(provider)
999            .region(aws_config::Region::new("us-east-1"))
1000            .build()
1001    }
1002
1003    #[test]
1004    fn test_from_aws_config_with_user_agent_constructs_client() {
1005        // Verify from_aws_config_with_user_agent creates a client successfully (Req 14.1)
1006        let config = test_sdk_config();
1007        let client =
1008            LambdaDurableServiceClient::from_aws_config_with_user_agent(&config, "my-sdk", "1.0.0");
1009        // Client should be constructed with the correct region config
1010        assert_eq!(client.config.region, "us-east-1");
1011    }
1012
1013    #[test]
1014    fn test_from_aws_config_unchanged() {
1015        // Verify existing from_aws_config still works without user-agent (Req 14.2)
1016        let config = test_sdk_config();
1017        let client = LambdaDurableServiceClient::from_aws_config(&config);
1018        assert_eq!(client.config.region, "us-east-1");
1019    }
1020
1021    #[tokio::test]
1022    async fn test_user_agent_header_is_sent() {
1023        // Verify the user-agent header is actually included in HTTP requests (Req 14.1)
1024        use tokio::io::{AsyncReadExt, AsyncWriteExt};
1025        use tokio::net::TcpListener;
1026
1027        // Start a local TCP server that captures the User-Agent header
1028        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1029        let addr = listener.local_addr().unwrap();
1030
1031        let server_handle = tokio::spawn(async move {
1032            let (mut socket, _) = listener.accept().await.unwrap();
1033            let mut buf = vec![0u8; 4096];
1034            let n = socket.read(&mut buf).await.unwrap();
1035            let request = String::from_utf8_lossy(&buf[..n]).to_string();
1036
1037            // Send a minimal HTTP response
1038            let response = "HTTP/1.1 400 Bad Request\r\nContent-Length: 2\r\n\r\n{}";
1039            let _ = socket.write_all(response.as_bytes()).await;
1040            let _ = socket.shutdown().await;
1041
1042            request
1043        });
1044
1045        // Create client with custom user-agent pointing at our local server
1046        let sdk_name = "test-sdk";
1047        let sdk_version = "2.3.4";
1048        let user_agent = format!("{}/{}", sdk_name, sdk_version);
1049
1050        let config = test_sdk_config();
1051        let mut client = LambdaDurableServiceClient::from_aws_config_with_user_agent(
1052            &config,
1053            sdk_name,
1054            sdk_version,
1055        );
1056        client.config.endpoint_url = Some(format!("http://{}", addr));
1057
1058        // Make a request (it will fail, but we just need to capture the headers)
1059        let _ = client
1060            .checkpoint("arn:test:execution", "token", vec![])
1061            .await;
1062
1063        // Check the captured request contains our user-agent
1064        let captured_request = server_handle.await.unwrap();
1065        assert!(
1066            captured_request.contains(&user_agent),
1067            "Expected User-Agent '{}' in request headers, got:\n{}",
1068            user_agent,
1069            captured_request
1070        );
1071    }
1072}