Skip to main content

durable_execution_sdk/
client.rs

1//! Service client for the AWS Durable Execution SDK.
2//!
3//! This module defines the `DurableServiceClient` trait and provides
4//! a Lambda-based implementation for communicating with the AWS Lambda
5//! durable execution service using the CheckpointDurableExecution and
6//! GetDurableExecutionState REST APIs.
7
8use std::sync::Arc;
9use std::time::SystemTime;
10
11use async_trait::async_trait;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings};
14use aws_sigv4::sign::v4;
15use serde::{Deserialize, Serialize};
16
17use crate::error::{AwsError, DurableError};
18use crate::operation::{Operation, OperationUpdate};
19
20/// Trait for communicating with the durable execution service.
21///
22/// This trait abstracts the communication layer, allowing for different
23/// implementations (e.g., Lambda client, mock client for testing).
24#[async_trait]
25pub trait DurableServiceClient: Send + Sync {
26    /// Sends a batch of checkpoint operations to the service.
27    ///
28    /// # Arguments
29    ///
30    /// * `durable_execution_arn` - The ARN of the durable execution
31    /// * `checkpoint_token` - The token for this checkpoint batch
32    /// * `operations` - The operations to checkpoint
33    ///
34    /// # Returns
35    ///
36    /// A new checkpoint token on success, or an error on failure.
37    async fn checkpoint(
38        &self,
39        durable_execution_arn: &str,
40        checkpoint_token: &str,
41        operations: Vec<OperationUpdate>,
42    ) -> Result<CheckpointResponse, DurableError>;
43
44    /// Retrieves additional operations for pagination.
45    ///
46    /// # Arguments
47    ///
48    /// * `durable_execution_arn` - The ARN of the durable execution
49    /// * `next_marker` - The pagination marker from the previous response
50    ///
51    /// # Returns
52    ///
53    /// A list of operations and an optional next marker for further pagination.
54    async fn get_operations(
55        &self,
56        durable_execution_arn: &str,
57        next_marker: &str,
58    ) -> Result<GetOperationsResponse, DurableError>;
59}
60
61/// Response from a checkpoint operation.
62#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct CheckpointResponse {
64    /// The new checkpoint token to use for subsequent checkpoints
65    #[serde(rename = "CheckpointToken", default)]
66    pub checkpoint_token: String,
67
68    /// The new execution state containing updated operations
69    /// This includes service-generated values like CallbackDetails.CallbackId
70    #[serde(
71        rename = "NewExecutionState",
72        skip_serializing_if = "Option::is_none",
73        default
74    )]
75    pub new_execution_state: Option<NewExecutionState>,
76}
77
78impl CheckpointResponse {
79    /// Creates a new CheckpointResponse with just a checkpoint token.
80    pub fn new(checkpoint_token: impl Into<String>) -> Self {
81        Self {
82            checkpoint_token: checkpoint_token.into(),
83            new_execution_state: None,
84        }
85    }
86}
87
88/// New execution state returned from checkpoint operations.
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NewExecutionState {
91    /// The updated operations with service-generated values
92    #[serde(rename = "Operations", default)]
93    pub operations: Vec<Operation>,
94
95    /// Marker for the next page of results, if any
96    #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
97    pub next_marker: Option<String>,
98}
99
100impl NewExecutionState {
101    /// Finds an operation by its ID.
102    ///
103    /// # Arguments
104    ///
105    /// * `operation_id` - The ID of the operation to find
106    ///
107    /// # Returns
108    ///
109    /// A reference to the operation if found, None otherwise.
110    pub fn find_operation(&self, operation_id: &str) -> Option<&Operation> {
111        self.operations
112            .iter()
113            .find(|op| op.operation_id == operation_id)
114    }
115}
116
117/// Response from a get_operations call.
118#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct GetOperationsResponse {
120    /// The retrieved operations
121    #[serde(rename = "Operations")]
122    pub operations: Vec<Operation>,
123
124    /// Marker for the next page of results, if any
125    #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
126    pub next_marker: Option<String>,
127}
128
129/// Configuration for the Lambda durable service client.
130#[derive(Debug, Clone)]
131pub struct LambdaClientConfig {
132    /// AWS region for the Lambda service
133    pub region: String,
134    /// Optional custom endpoint URL (for testing)
135    pub endpoint_url: Option<String>,
136}
137
138impl Default for LambdaClientConfig {
139    fn default() -> Self {
140        Self {
141            region: "us-east-1".to_string(),
142            endpoint_url: None,
143        }
144    }
145}
146
147impl LambdaClientConfig {
148    /// Creates a new LambdaClientConfig with the specified region.
149    pub fn with_region(region: impl Into<String>) -> Self {
150        Self {
151            region: region.into(),
152            endpoint_url: None,
153        }
154    }
155
156    /// Creates a new LambdaClientConfig from AWS SDK config.
157    pub fn from_aws_config(config: &aws_config::SdkConfig) -> Self {
158        Self {
159            region: config
160                .region()
161                .map(|r| r.to_string())
162                .unwrap_or_else(|| "us-east-1".to_string()),
163            endpoint_url: None,
164        }
165    }
166}
167
168/// Lambda-based implementation of the DurableServiceClient.
169///
170/// This client uses the AWS Lambda REST APIs (CheckpointDurableExecution and
171/// GetDurableExecutionState) to communicate with the durable execution service.
172pub struct LambdaDurableServiceClient {
173    /// HTTP client for making requests
174    http_client: reqwest::Client,
175    /// AWS credentials provider
176    credentials_provider: Arc<dyn ProvideCredentials>,
177    /// Configuration for the client
178    config: LambdaClientConfig,
179}
180
181impl LambdaDurableServiceClient {
182    /// Creates a new LambdaDurableServiceClient from AWS config.
183    ///
184    /// # Errors
185    ///
186    /// Returns `DurableError::Configuration` if no credentials provider is configured.
187    pub async fn from_env() -> Result<Self, DurableError> {
188        let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
189            .load()
190            .await;
191        Self::from_aws_config(&aws_config)
192    }
193
194    /// Creates a new LambdaDurableServiceClient from AWS SDK config.
195    ///
196    /// # Errors
197    ///
198    /// Returns `DurableError::Configuration` if no credentials provider is configured.
199    pub fn from_aws_config(aws_config: &aws_config::SdkConfig) -> Result<Self, DurableError> {
200        let credentials_provider = aws_config
201            .credentials_provider()
202            .ok_or_else(|| DurableError::Configuration {
203                message: "No credentials provider configured in AWS SDK config".to_string(),
204            })?
205            .clone();
206
207        Ok(Self {
208            http_client: reqwest::Client::new(),
209            credentials_provider: Arc::from(credentials_provider),
210            config: LambdaClientConfig::from_aws_config(aws_config),
211        })
212    }
213
214    /// Creates a new LambdaDurableServiceClient from AWS SDK config with a custom
215    /// user-agent string appended to HTTP requests for SDK identification.
216    ///
217    /// # Errors
218    ///
219    /// Returns `DurableError::Configuration` if no credentials provider is configured.
220    pub fn from_aws_config_with_user_agent(
221        aws_config: &aws_config::SdkConfig,
222        sdk_name: &str,
223        sdk_version: &str,
224    ) -> Result<Self, DurableError> {
225        let credentials_provider = aws_config
226            .credentials_provider()
227            .ok_or_else(|| DurableError::Configuration {
228                message: "No credentials provider configured in AWS SDK config".to_string(),
229            })?
230            .clone();
231
232        let user_agent = format!("{}/{}", sdk_name, sdk_version);
233        let mut headers = reqwest::header::HeaderMap::new();
234        if let Ok(value) = reqwest::header::HeaderValue::from_str(&user_agent) {
235            headers.insert(reqwest::header::USER_AGENT, value);
236        }
237
238        let http_client = reqwest::Client::builder()
239            .default_headers(headers)
240            .build()
241            .map_err(|e| DurableError::Configuration {
242                message: format!("Failed to build HTTP client: {}", e),
243            })?;
244
245        Ok(Self {
246            http_client,
247            credentials_provider: Arc::from(credentials_provider),
248            config: LambdaClientConfig::from_aws_config(aws_config),
249        })
250    }
251
252    /// Creates a new LambdaDurableServiceClient with custom configuration.
253    pub fn with_config(
254        credentials_provider: Arc<dyn ProvideCredentials>,
255        config: LambdaClientConfig,
256    ) -> Self {
257        Self {
258            http_client: reqwest::Client::new(),
259            credentials_provider,
260            config,
261        }
262    }
263
264    /// Returns the Lambda service endpoint URL.
265    fn endpoint_url(&self) -> String {
266        self.config
267            .endpoint_url
268            .clone()
269            .unwrap_or_else(|| format!("https://lambda.{}.amazonaws.com", self.config.region))
270    }
271
272    /// Signs an HTTP request using AWS SigV4 and returns the signed headers.
273    async fn sign_request(
274        &self,
275        method: &str,
276        uri: &str,
277        body: &[u8],
278    ) -> Result<Vec<(String, String)>, DurableError> {
279        let credentials = self
280            .credentials_provider
281            .provide_credentials()
282            .await
283            .map_err(|e| DurableError::Checkpoint {
284                message: format!("Failed to get AWS credentials: {}", e),
285                is_retriable: true,
286                aws_error: None,
287            })?;
288
289        let identity = credentials.into();
290        let signing_settings = SigningSettings::default();
291        let signing_params = v4::SigningParams::builder()
292            .identity(&identity)
293            .region(&self.config.region)
294            .name("lambda")
295            .time(SystemTime::now())
296            .settings(signing_settings)
297            .build()
298            .map_err(|e| DurableError::Checkpoint {
299                message: format!("Failed to build signing params: {}", e),
300                is_retriable: false,
301                aws_error: None,
302            })?;
303
304        let signable_request = SignableRequest::new(
305            method,
306            uri,
307            std::iter::empty::<(&str, &str)>(),
308            SignableBody::Bytes(body),
309        )
310        .map_err(|e| DurableError::Checkpoint {
311            message: format!("Failed to create signable request: {}", e),
312            is_retriable: false,
313            aws_error: None,
314        })?;
315
316        let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
317            .map_err(|e| DurableError::Checkpoint {
318                message: format!("Failed to sign request: {}", e),
319                is_retriable: false,
320                aws_error: None,
321            })?
322            .into_parts();
323
324        // Build a temporary HTTP request to apply signing instructions
325        let mut temp_request = http::Request::builder()
326            .method(method)
327            .uri(uri)
328            .body(())
329            .map_err(|e| DurableError::Checkpoint {
330                message: format!("Failed to build temp request: {}", e),
331                is_retriable: false,
332                aws_error: None,
333            })?;
334
335        signing_instructions.apply_to_request_http1x(&mut temp_request);
336
337        // Extract headers from the signed request
338        let headers: Vec<(String, String)> = temp_request
339            .headers()
340            .iter()
341            .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
342            .collect();
343
344        Ok(headers)
345    }
346}
347
348/// Request payload for checkpoint operations.
349#[derive(Debug, Clone, Serialize)]
350struct CheckpointRequestBody {
351    #[serde(rename = "CheckpointToken")]
352    checkpoint_token: String,
353    #[serde(rename = "Updates")]
354    updates: Vec<OperationUpdate>,
355}
356
357/// Request payload for get_operations.
358#[derive(Debug, Clone, Serialize)]
359struct GetOperationsRequestBody {
360    #[serde(rename = "NextMarker")]
361    next_marker: String,
362}
363
364#[async_trait]
365impl DurableServiceClient for LambdaDurableServiceClient {
366    async fn checkpoint(
367        &self,
368        durable_execution_arn: &str,
369        checkpoint_token: &str,
370        operations: Vec<OperationUpdate>,
371    ) -> Result<CheckpointResponse, DurableError> {
372        let request_body = CheckpointRequestBody {
373            checkpoint_token: checkpoint_token.to_string(),
374            updates: operations,
375        };
376
377        let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
378            message: format!("Failed to serialize checkpoint request: {}", e),
379        })?;
380
381        // URL-encode the durable execution ARN for the path
382        let encoded_arn = urlencoding::encode(durable_execution_arn);
383        let uri = format!(
384            "{}/2025-12-01/durable-executions/{}/checkpoint",
385            self.endpoint_url(),
386            encoded_arn
387        );
388
389        // Sign the request
390        let signed_headers = self.sign_request("POST", &uri, &body).await?;
391
392        // Build and send the request
393        let mut request = self
394            .http_client
395            .post(&uri)
396            .header("Content-Type", "application/json")
397            .body(body);
398
399        for (name, value) in signed_headers {
400            request = request.header(&name, &value);
401        }
402
403        let response = request.send().await.map_err(|e| DurableError::Checkpoint {
404            message: format!("HTTP request failed: {}", e),
405            is_retriable: e.is_timeout() || e.is_connect(),
406            aws_error: None,
407        })?;
408
409        let status = response.status();
410        let response_body = response
411            .bytes()
412            .await
413            .map_err(|e| DurableError::Checkpoint {
414                message: format!("Failed to read response body: {}", e),
415                is_retriable: true,
416                aws_error: None,
417            })?;
418
419        if !status.is_success() {
420            let error_message = String::from_utf8_lossy(&response_body);
421
422            // Check for specific error conditions
423
424            // Size limit exceeded (413 Request Entity Too Large or specific error message)
425            // Requirements: 25.6 - Handle size limit errors gracefully, return FAILED without retry
426            if status.as_u16() == 413
427                || error_message.contains("RequestEntityTooLarge")
428                || error_message.contains("Payload size exceeded")
429                || error_message.contains("Request too large")
430            {
431                return Err(DurableError::SizeLimit {
432                    message: format!("Checkpoint payload size exceeded: {}", error_message),
433                    actual_size: None,
434                    max_size: None,
435                });
436            }
437
438            // Resource not found (404)
439            // Requirements: 18.6 - Handle ResourceNotFoundException appropriately
440            if status.as_u16() == 404
441                || error_message.contains("ResourceNotFoundException")
442                || error_message.contains("not found")
443            {
444                return Err(DurableError::ResourceNotFound {
445                    message: format!("Durable execution not found: {}", error_message),
446                    resource_id: Some(durable_execution_arn.to_string()),
447                });
448            }
449
450            // Throttling (429 Too Many Requests)
451            // Requirements: 18.5 - Handle ThrottlingException with appropriate retry behavior
452            if status.as_u16() == 429
453                || error_message.contains("ThrottlingException")
454                || error_message.contains("TooManyRequestsException")
455                || error_message.contains("Rate exceeded")
456            {
457                return Err(DurableError::Throttling {
458                    message: format!("Rate limit exceeded: {}", error_message),
459                    retry_after_ms: None, // Could parse Retry-After header if available
460                });
461            }
462
463            // Check for InvalidParameterValueException with "Invalid checkpoint token" message
464            // This indicates the token was already consumed or is invalid, and Lambda should retry
465            // Requirements: 2.11
466            let is_invalid_token = error_message.contains("Invalid checkpoint token");
467
468            // Determine if the error is retriable:
469            // - Server errors (5xx) are retriable
470            // - Invalid checkpoint token errors are retriable (Lambda will provide a fresh token)
471            // Note: Throttling (429) is handled separately above
472            let is_retriable = status.is_server_error() || is_invalid_token;
473
474            return Err(DurableError::Checkpoint {
475                message: format!("Checkpoint API returned {}: {}", status, error_message),
476                is_retriable,
477                aws_error: Some(AwsError {
478                    code: if is_invalid_token {
479                        "InvalidParameterValueException".to_string()
480                    } else {
481                        status.to_string()
482                    },
483                    message: error_message.to_string(),
484                    request_id: None,
485                }),
486            });
487        }
488
489        let checkpoint_response: CheckpointResponse = serde_json::from_slice(&response_body)
490            .map_err(|e| DurableError::SerDes {
491                message: format!("Failed to deserialize checkpoint response: {}", e),
492            })?;
493
494        Ok(checkpoint_response)
495    }
496
497    async fn get_operations(
498        &self,
499        durable_execution_arn: &str,
500        next_marker: &str,
501    ) -> Result<GetOperationsResponse, DurableError> {
502        let request_body = GetOperationsRequestBody {
503            next_marker: next_marker.to_string(),
504        };
505
506        let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
507            message: format!("Failed to serialize get_operations request: {}", e),
508        })?;
509
510        // URL-encode the durable execution ARN for the path
511        let encoded_arn = urlencoding::encode(durable_execution_arn);
512        let uri = format!(
513            "{}/2025-12-01/durable-executions/{}/state",
514            self.endpoint_url(),
515            encoded_arn
516        );
517
518        // Sign the request
519        let signed_headers = self.sign_request("POST", &uri, &body).await?;
520
521        // Build and send the request
522        let mut request = self
523            .http_client
524            .post(&uri)
525            .header("Content-Type", "application/json")
526            .body(body);
527
528        for (name, value) in signed_headers {
529            request = request.header(&name, &value);
530        }
531
532        let response = request.send().await.map_err(|e| DurableError::Invocation {
533            message: format!("HTTP request failed: {}", e),
534            termination_reason: crate::error::TerminationReason::InvocationError,
535        })?;
536
537        let status = response.status();
538        let response_body = response
539            .bytes()
540            .await
541            .map_err(|e| DurableError::Invocation {
542                message: format!("Failed to read response body: {}", e),
543                termination_reason: crate::error::TerminationReason::InvocationError,
544            })?;
545
546        if !status.is_success() {
547            let error_message = String::from_utf8_lossy(&response_body);
548
549            // Resource not found (404)
550            // Requirements: 18.6 - Handle ResourceNotFoundException appropriately
551            if status.as_u16() == 404
552                || error_message.contains("ResourceNotFoundException")
553                || error_message.contains("not found")
554            {
555                return Err(DurableError::ResourceNotFound {
556                    message: format!("Durable execution not found: {}", error_message),
557                    resource_id: Some(durable_execution_arn.to_string()),
558                });
559            }
560
561            // Throttling (429 Too Many Requests)
562            // Requirements: 18.5 - Handle ThrottlingException with appropriate retry behavior
563            if status.as_u16() == 429
564                || error_message.contains("ThrottlingException")
565                || error_message.contains("TooManyRequestsException")
566                || error_message.contains("Rate exceeded")
567            {
568                return Err(DurableError::Throttling {
569                    message: format!("Rate limit exceeded: {}", error_message),
570                    retry_after_ms: None,
571                });
572            }
573
574            return Err(DurableError::Invocation {
575                message: format!(
576                    "GetDurableExecutionState API returned {}: {}",
577                    status, error_message
578                ),
579                termination_reason: crate::error::TerminationReason::InvocationError,
580            });
581        }
582
583        let operations_response: GetOperationsResponse = serde_json::from_slice(&response_body)
584            .map_err(|e| DurableError::SerDes {
585                message: format!("Failed to deserialize get_operations response: {}", e),
586            })?;
587
588        Ok(operations_response)
589    }
590}
591
592/// A mock implementation of DurableServiceClient for testing.
593#[cfg(test)]
594pub struct MockDurableServiceClient {
595    checkpoint_responses: std::sync::Mutex<Vec<Result<CheckpointResponse, DurableError>>>,
596    get_operations_responses: std::sync::Mutex<Vec<Result<GetOperationsResponse, DurableError>>>,
597}
598
599#[cfg(test)]
600impl MockDurableServiceClient {
601    pub fn new() -> Self {
602        Self {
603            checkpoint_responses: std::sync::Mutex::new(Vec::new()),
604            get_operations_responses: std::sync::Mutex::new(Vec::new()),
605        }
606    }
607
608    pub fn with_checkpoint_response(
609        self,
610        response: Result<CheckpointResponse, DurableError>,
611    ) -> Self {
612        self.checkpoint_responses.lock().unwrap().push(response);
613        self
614    }
615
616    pub fn with_get_operations_response(
617        self,
618        response: Result<GetOperationsResponse, DurableError>,
619    ) -> Self {
620        self.get_operations_responses.lock().unwrap().push(response);
621        self
622    }
623
624    /// Adds multiple checkpoint responses at once.
625    pub fn with_checkpoint_responses(self, count: usize) -> Self {
626        let mut responses = self.checkpoint_responses.lock().unwrap();
627        for i in 0..count {
628            responses.push(Ok(CheckpointResponse {
629                checkpoint_token: format!("token-{}", i),
630                new_execution_state: None,
631            }));
632        }
633        drop(responses);
634        self
635    }
636}
637
638#[cfg(test)]
639#[async_trait]
640impl DurableServiceClient for MockDurableServiceClient {
641    async fn checkpoint(
642        &self,
643        _durable_execution_arn: &str,
644        _checkpoint_token: &str,
645        _operations: Vec<OperationUpdate>,
646    ) -> Result<CheckpointResponse, DurableError> {
647        let mut responses = self.checkpoint_responses.lock().unwrap();
648        if responses.is_empty() {
649            Ok(CheckpointResponse {
650                checkpoint_token: "mock-token".to_string(),
651                new_execution_state: None,
652            })
653        } else {
654            responses.remove(0)
655        }
656    }
657
658    async fn get_operations(
659        &self,
660        _durable_execution_arn: &str,
661        _next_marker: &str,
662    ) -> Result<GetOperationsResponse, DurableError> {
663        let mut responses = self.get_operations_responses.lock().unwrap();
664        if responses.is_empty() {
665            Ok(GetOperationsResponse {
666                operations: Vec::new(),
667                next_marker: None,
668            })
669        } else {
670            responses.remove(0)
671        }
672    }
673}
674
675/// Type alias for a shared DurableServiceClient.
676pub type SharedDurableServiceClient = Arc<dyn DurableServiceClient>;
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681    use crate::operation::OperationType;
682
683    #[test]
684    fn test_checkpoint_response_serialization() {
685        let response = CheckpointResponse {
686            checkpoint_token: "token-123".to_string(),
687            new_execution_state: None,
688        };
689        let json = serde_json::to_string(&response).unwrap();
690        assert!(json.contains(r#""CheckpointToken":"token-123""#));
691    }
692
693    #[test]
694    fn test_checkpoint_response_deserialization() {
695        let json = r#"{"CheckpointToken": "token-456"}"#;
696        let response: CheckpointResponse = serde_json::from_str(json).unwrap();
697        assert_eq!(response.checkpoint_token, "token-456");
698    }
699
700    #[test]
701    fn test_get_operations_response_serialization() {
702        let response = GetOperationsResponse {
703            operations: vec![Operation::new("op-1", OperationType::Step)],
704            next_marker: Some("marker-123".to_string()),
705        };
706        let json = serde_json::to_string(&response).unwrap();
707        assert!(json.contains(r#""Operations""#));
708        assert!(json.contains(r#""NextMarker":"marker-123""#));
709    }
710
711    #[test]
712    fn test_get_operations_response_deserialization() {
713        let json = r#"{
714            "Operations": [
715                {
716                    "Id": "op-1",
717                    "Type": "STEP",
718                    "Status": "SUCCEEDED"
719                }
720            ],
721            "NextMarker": "marker-456"
722        }"#;
723        let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
724        assert_eq!(response.operations.len(), 1);
725        assert_eq!(response.operations[0].operation_id, "op-1");
726        assert_eq!(response.next_marker, Some("marker-456".to_string()));
727    }
728
729    #[test]
730    fn test_get_operations_response_without_marker() {
731        let json = r#"{
732            "Operations": []
733        }"#;
734        let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
735        assert!(response.operations.is_empty());
736        assert!(response.next_marker.is_none());
737    }
738
739    #[test]
740    fn test_lambda_client_config_default() {
741        let config = LambdaClientConfig::default();
742        assert_eq!(config.region, "us-east-1");
743        assert!(config.endpoint_url.is_none());
744    }
745
746    #[test]
747    fn test_lambda_client_config_with_region() {
748        let config = LambdaClientConfig::with_region("us-west-2");
749        assert_eq!(config.region, "us-west-2");
750    }
751
752    #[tokio::test]
753    async fn test_mock_client_checkpoint() {
754        let client = MockDurableServiceClient::new();
755        let result = client
756            .checkpoint(
757                "arn:aws:lambda:us-east-1:123456789012:function:test",
758                "token-123",
759                vec![],
760            )
761            .await;
762        assert!(result.is_ok());
763        assert_eq!(result.unwrap().checkpoint_token, "mock-token");
764    }
765
766    #[tokio::test]
767    async fn test_mock_client_checkpoint_with_custom_response() {
768        let client =
769            MockDurableServiceClient::new().with_checkpoint_response(Ok(CheckpointResponse {
770                checkpoint_token: "custom-token".to_string(),
771                new_execution_state: None,
772            }));
773        let result = client
774            .checkpoint(
775                "arn:aws:lambda:us-east-1:123456789012:function:test",
776                "token-123",
777                vec![],
778            )
779            .await;
780        assert!(result.is_ok());
781        assert_eq!(result.unwrap().checkpoint_token, "custom-token");
782    }
783
784    #[tokio::test]
785    async fn test_mock_client_checkpoint_with_error() {
786        let client = MockDurableServiceClient::new()
787            .with_checkpoint_response(Err(DurableError::checkpoint_retriable("Test error")));
788        let result = client
789            .checkpoint(
790                "arn:aws:lambda:us-east-1:123456789012:function:test",
791                "token-123",
792                vec![],
793            )
794            .await;
795        assert!(result.is_err());
796        assert!(result.unwrap_err().is_retriable());
797    }
798
799    #[tokio::test]
800    async fn test_mock_client_get_operations() {
801        let client = MockDurableServiceClient::new();
802        let result = client
803            .get_operations(
804                "arn:aws:lambda:us-east-1:123456789012:function:test",
805                "marker-123",
806            )
807            .await;
808        assert!(result.is_ok());
809        let response = result.unwrap();
810        assert!(response.operations.is_empty());
811        assert!(response.next_marker.is_none());
812    }
813
814    #[tokio::test]
815    async fn test_mock_client_get_operations_with_custom_response() {
816        let client = MockDurableServiceClient::new().with_get_operations_response(Ok(
817            GetOperationsResponse {
818                operations: vec![Operation::new("op-1", OperationType::Step)],
819                next_marker: Some("next-marker".to_string()),
820            },
821        ));
822        let result = client
823            .get_operations(
824                "arn:aws:lambda:us-east-1:123456789012:function:test",
825                "marker-123",
826            )
827            .await;
828        assert!(result.is_ok());
829        let response = result.unwrap();
830        assert_eq!(response.operations.len(), 1);
831        assert_eq!(response.next_marker, Some("next-marker".to_string()));
832    }
833
834    #[test]
835    fn test_checkpoint_request_body_serialization() {
836        let request = CheckpointRequestBody {
837            checkpoint_token: "token-123".to_string(),
838            updates: vec![OperationUpdate::start("op-1", OperationType::Step)],
839        };
840        let json = serde_json::to_string(&request).unwrap();
841        assert!(json.contains(r#""CheckpointToken":"token-123""#));
842        assert!(json.contains(r#""Updates""#));
843    }
844
845    #[test]
846    fn test_get_operations_request_body_serialization() {
847        let request = GetOperationsRequestBody {
848            next_marker: "marker-123".to_string(),
849        };
850        let json = serde_json::to_string(&request).unwrap();
851        assert!(json.contains(r#""NextMarker":"marker-123""#));
852    }
853
854    #[tokio::test]
855    async fn test_mock_client_checkpoint_with_invalid_token_error() {
856        // Test that invalid checkpoint token errors are properly marked as retriable
857        let error = DurableError::Checkpoint {
858            message: "Checkpoint API returned 400: Invalid checkpoint token".to_string(),
859            is_retriable: true,
860            aws_error: Some(AwsError {
861                code: "InvalidParameterValueException".to_string(),
862                message: "Invalid checkpoint token: token has been consumed".to_string(),
863                request_id: None,
864            }),
865        };
866
867        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
868        let result = client
869            .checkpoint(
870                "arn:aws:lambda:us-east-1:123456789012:function:test",
871                "consumed-token",
872                vec![],
873            )
874            .await;
875
876        assert!(result.is_err());
877        let err = result.unwrap_err();
878        assert!(err.is_retriable());
879        assert!(err.is_invalid_checkpoint_token());
880    }
881
882    #[tokio::test]
883    async fn test_mock_client_checkpoint_with_size_limit_error() {
884        // Test that size limit errors are properly returned as non-retriable
885        let error = DurableError::SizeLimit {
886            message: "Checkpoint payload size exceeded".to_string(),
887            actual_size: Some(7_000_000),
888            max_size: Some(6_000_000),
889        };
890
891        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
892        let result = client
893            .checkpoint(
894                "arn:aws:lambda:us-east-1:123456789012:function:test",
895                "token-123",
896                vec![],
897            )
898            .await;
899
900        assert!(result.is_err());
901        let err = result.unwrap_err();
902        assert!(err.is_size_limit());
903        assert!(!err.is_retriable());
904    }
905
906    #[tokio::test]
907    async fn test_mock_client_checkpoint_with_throttling_error() {
908        // Test that throttling errors are properly returned
909        let error = DurableError::Throttling {
910            message: "Rate limit exceeded".to_string(),
911            retry_after_ms: Some(5000),
912        };
913
914        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
915        let result = client
916            .checkpoint(
917                "arn:aws:lambda:us-east-1:123456789012:function:test",
918                "token-123",
919                vec![],
920            )
921            .await;
922
923        assert!(result.is_err());
924        let err = result.unwrap_err();
925        assert!(err.is_throttling());
926        assert_eq!(err.get_retry_after_ms(), Some(5000));
927    }
928
929    #[tokio::test]
930    async fn test_mock_client_checkpoint_with_resource_not_found_error() {
931        // Test that resource not found errors are properly returned
932        let error = DurableError::ResourceNotFound {
933            message: "Durable execution not found".to_string(),
934            resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
935        };
936
937        let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
938        let result = client
939            .checkpoint(
940                "arn:aws:lambda:us-east-1:123456789012:function:test",
941                "token-123",
942                vec![],
943            )
944            .await;
945
946        assert!(result.is_err());
947        let err = result.unwrap_err();
948        assert!(err.is_resource_not_found());
949        assert!(!err.is_retriable());
950    }
951
952    #[tokio::test]
953    async fn test_mock_client_get_operations_with_throttling_error() {
954        // Test that throttling errors are properly returned for get_operations
955        let error = DurableError::Throttling {
956            message: "Rate limit exceeded".to_string(),
957            retry_after_ms: None,
958        };
959
960        let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
961        let result = client
962            .get_operations(
963                "arn:aws:lambda:us-east-1:123456789012:function:test",
964                "marker-123",
965            )
966            .await;
967
968        assert!(result.is_err());
969        let err = result.unwrap_err();
970        assert!(err.is_throttling());
971    }
972
973    #[tokio::test]
974    async fn test_mock_client_get_operations_with_resource_not_found_error() {
975        // Test that resource not found errors are properly returned for get_operations
976        let error = DurableError::ResourceNotFound {
977            message: "Durable execution not found".to_string(),
978            resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
979        };
980
981        let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
982        let result = client
983            .get_operations(
984                "arn:aws:lambda:us-east-1:123456789012:function:test",
985                "marker-123",
986            )
987            .await;
988
989        assert!(result.is_err());
990        let err = result.unwrap_err();
991        assert!(err.is_resource_not_found());
992    }
993
994    /// Helper to build a test SdkConfig with fake credentials.
995    fn test_sdk_config() -> aws_config::SdkConfig {
996        let creds =
997            aws_credential_types::Credentials::new("test-key", "test-secret", None, None, "test");
998        let provider = aws_credential_types::provider::SharedCredentialsProvider::new(creds);
999        aws_config::SdkConfig::builder()
1000            .credentials_provider(provider)
1001            .region(aws_config::Region::new("us-east-1"))
1002            .build()
1003    }
1004
1005    #[test]
1006    fn test_from_aws_config_with_user_agent_constructs_client() {
1007        // Verify from_aws_config_with_user_agent creates a client successfully (Req 14.1)
1008        let config = test_sdk_config();
1009        let client =
1010            LambdaDurableServiceClient::from_aws_config_with_user_agent(&config, "my-sdk", "1.0.0")
1011                .unwrap();
1012        // Client should be constructed with the correct region config
1013        assert_eq!(client.config.region, "us-east-1");
1014    }
1015
1016    #[test]
1017    fn test_from_aws_config_unchanged() {
1018        // Verify existing from_aws_config still works without user-agent (Req 14.2)
1019        let config = test_sdk_config();
1020        let client = LambdaDurableServiceClient::from_aws_config(&config).unwrap();
1021        assert_eq!(client.config.region, "us-east-1");
1022    }
1023
1024    #[tokio::test]
1025    async fn test_user_agent_header_is_sent() {
1026        // Verify the user-agent header is actually included in HTTP requests (Req 14.1)
1027        use tokio::io::{AsyncReadExt, AsyncWriteExt};
1028        use tokio::net::TcpListener;
1029
1030        // Start a local TCP server that captures the User-Agent header
1031        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1032        let addr = listener.local_addr().unwrap();
1033
1034        let server_handle = tokio::spawn(async move {
1035            let (mut socket, _) = listener.accept().await.unwrap();
1036            let mut buf = vec![0u8; 4096];
1037            let n = socket.read(&mut buf).await.unwrap();
1038            let request = String::from_utf8_lossy(&buf[..n]).to_string();
1039
1040            // Send a minimal HTTP response
1041            let response = "HTTP/1.1 400 Bad Request\r\nContent-Length: 2\r\n\r\n{}";
1042            let _ = socket.write_all(response.as_bytes()).await;
1043            let _ = socket.shutdown().await;
1044
1045            request
1046        });
1047
1048        // Create client with custom user-agent pointing at our local server
1049        let sdk_name = "test-sdk";
1050        let sdk_version = "2.3.4";
1051        let user_agent = format!("{}/{}", sdk_name, sdk_version);
1052
1053        let config = test_sdk_config();
1054        let mut client = LambdaDurableServiceClient::from_aws_config_with_user_agent(
1055            &config,
1056            sdk_name,
1057            sdk_version,
1058        )
1059        .unwrap();
1060        client.config.endpoint_url = Some(format!("http://{}", addr));
1061
1062        // Make a request (it will fail, but we just need to capture the headers)
1063        let _ = client
1064            .checkpoint("arn:test:execution", "token", vec![])
1065            .await;
1066
1067        // Check the captured request contains our user-agent
1068        let captured_request = server_handle.await.unwrap();
1069        assert!(
1070            captured_request.contains(&user_agent),
1071            "Expected User-Agent '{}' in request headers, got:\n{}",
1072            user_agent,
1073            captured_request
1074        );
1075    }
1076}