1use std::sync::Arc;
9use std::time::SystemTime;
10
11use async_trait::async_trait;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings};
14use aws_sigv4::sign::v4;
15use serde::{Deserialize, Serialize};
16
17use crate::error::{AwsError, DurableError};
18use crate::operation::{Operation, OperationUpdate};
19
20#[async_trait]
25pub trait DurableServiceClient: Send + Sync {
26 async fn checkpoint(
38 &self,
39 durable_execution_arn: &str,
40 checkpoint_token: &str,
41 operations: Vec<OperationUpdate>,
42 ) -> Result<CheckpointResponse, DurableError>;
43
44 async fn get_operations(
55 &self,
56 durable_execution_arn: &str,
57 next_marker: &str,
58 ) -> Result<GetOperationsResponse, DurableError>;
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct CheckpointResponse {
64 #[serde(rename = "CheckpointToken", default)]
66 pub checkpoint_token: String,
67
68 #[serde(
71 rename = "NewExecutionState",
72 skip_serializing_if = "Option::is_none",
73 default
74 )]
75 pub new_execution_state: Option<NewExecutionState>,
76}
77
78impl CheckpointResponse {
79 pub fn new(checkpoint_token: impl Into<String>) -> Self {
81 Self {
82 checkpoint_token: checkpoint_token.into(),
83 new_execution_state: None,
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NewExecutionState {
91 #[serde(rename = "Operations", default)]
93 pub operations: Vec<Operation>,
94
95 #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
97 pub next_marker: Option<String>,
98}
99
100impl NewExecutionState {
101 pub fn find_operation(&self, operation_id: &str) -> Option<&Operation> {
111 self.operations
112 .iter()
113 .find(|op| op.operation_id == operation_id)
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct GetOperationsResponse {
120 #[serde(rename = "Operations")]
122 pub operations: Vec<Operation>,
123
124 #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
126 pub next_marker: Option<String>,
127}
128
129#[derive(Debug, Clone)]
131pub struct LambdaClientConfig {
132 pub region: String,
134 pub endpoint_url: Option<String>,
136}
137
138impl Default for LambdaClientConfig {
139 fn default() -> Self {
140 Self {
141 region: "us-east-1".to_string(),
142 endpoint_url: None,
143 }
144 }
145}
146
147impl LambdaClientConfig {
148 pub fn with_region(region: impl Into<String>) -> Self {
150 Self {
151 region: region.into(),
152 endpoint_url: None,
153 }
154 }
155
156 pub fn from_aws_config(config: &aws_config::SdkConfig) -> Self {
158 Self {
159 region: config
160 .region()
161 .map(|r| r.to_string())
162 .unwrap_or_else(|| "us-east-1".to_string()),
163 endpoint_url: None,
164 }
165 }
166}
167
168pub struct LambdaDurableServiceClient {
173 http_client: reqwest::Client,
175 credentials_provider: Arc<dyn ProvideCredentials>,
177 config: LambdaClientConfig,
179}
180
181impl LambdaDurableServiceClient {
182 pub async fn from_env() -> Result<Self, DurableError> {
188 let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
189 .load()
190 .await;
191 Self::from_aws_config(&aws_config)
192 }
193
194 pub fn from_aws_config(aws_config: &aws_config::SdkConfig) -> Result<Self, DurableError> {
200 let credentials_provider = aws_config
201 .credentials_provider()
202 .ok_or_else(|| DurableError::Configuration {
203 message: "No credentials provider configured in AWS SDK config".to_string(),
204 })?
205 .clone();
206
207 Ok(Self {
208 http_client: reqwest::Client::new(),
209 credentials_provider: Arc::from(credentials_provider),
210 config: LambdaClientConfig::from_aws_config(aws_config),
211 })
212 }
213
214 pub fn from_aws_config_with_user_agent(
221 aws_config: &aws_config::SdkConfig,
222 sdk_name: &str,
223 sdk_version: &str,
224 ) -> Result<Self, DurableError> {
225 let credentials_provider = aws_config
226 .credentials_provider()
227 .ok_or_else(|| DurableError::Configuration {
228 message: "No credentials provider configured in AWS SDK config".to_string(),
229 })?
230 .clone();
231
232 let user_agent = format!("{}/{}", sdk_name, sdk_version);
233 let mut headers = reqwest::header::HeaderMap::new();
234 if let Ok(value) = reqwest::header::HeaderValue::from_str(&user_agent) {
235 headers.insert(reqwest::header::USER_AGENT, value);
236 }
237
238 let http_client = reqwest::Client::builder()
239 .default_headers(headers)
240 .build()
241 .map_err(|e| DurableError::Configuration {
242 message: format!("Failed to build HTTP client: {}", e),
243 })?;
244
245 Ok(Self {
246 http_client,
247 credentials_provider: Arc::from(credentials_provider),
248 config: LambdaClientConfig::from_aws_config(aws_config),
249 })
250 }
251
252 pub fn with_config(
254 credentials_provider: Arc<dyn ProvideCredentials>,
255 config: LambdaClientConfig,
256 ) -> Self {
257 Self {
258 http_client: reqwest::Client::new(),
259 credentials_provider,
260 config,
261 }
262 }
263
264 fn endpoint_url(&self) -> String {
266 self.config
267 .endpoint_url
268 .clone()
269 .unwrap_or_else(|| format!("https://lambda.{}.amazonaws.com", self.config.region))
270 }
271
272 async fn sign_request(
274 &self,
275 method: &str,
276 uri: &str,
277 body: &[u8],
278 ) -> Result<Vec<(String, String)>, DurableError> {
279 let credentials = self
280 .credentials_provider
281 .provide_credentials()
282 .await
283 .map_err(|e| DurableError::Checkpoint {
284 message: format!("Failed to get AWS credentials: {}", e),
285 is_retriable: true,
286 aws_error: None,
287 })?;
288
289 let identity = credentials.into();
290 let signing_settings = SigningSettings::default();
291 let signing_params = v4::SigningParams::builder()
292 .identity(&identity)
293 .region(&self.config.region)
294 .name("lambda")
295 .time(SystemTime::now())
296 .settings(signing_settings)
297 .build()
298 .map_err(|e| DurableError::Checkpoint {
299 message: format!("Failed to build signing params: {}", e),
300 is_retriable: false,
301 aws_error: None,
302 })?;
303
304 let signable_request = SignableRequest::new(
305 method,
306 uri,
307 std::iter::empty::<(&str, &str)>(),
308 SignableBody::Bytes(body),
309 )
310 .map_err(|e| DurableError::Checkpoint {
311 message: format!("Failed to create signable request: {}", e),
312 is_retriable: false,
313 aws_error: None,
314 })?;
315
316 let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
317 .map_err(|e| DurableError::Checkpoint {
318 message: format!("Failed to sign request: {}", e),
319 is_retriable: false,
320 aws_error: None,
321 })?
322 .into_parts();
323
324 let mut temp_request = http::Request::builder()
326 .method(method)
327 .uri(uri)
328 .body(())
329 .map_err(|e| DurableError::Checkpoint {
330 message: format!("Failed to build temp request: {}", e),
331 is_retriable: false,
332 aws_error: None,
333 })?;
334
335 signing_instructions.apply_to_request_http1x(&mut temp_request);
336
337 let headers: Vec<(String, String)> = temp_request
339 .headers()
340 .iter()
341 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
342 .collect();
343
344 Ok(headers)
345 }
346}
347
348#[derive(Debug, Clone, Serialize)]
350struct CheckpointRequestBody {
351 #[serde(rename = "CheckpointToken")]
352 checkpoint_token: String,
353 #[serde(rename = "Updates")]
354 updates: Vec<OperationUpdate>,
355}
356
357#[derive(Debug, Clone, Serialize)]
359struct GetOperationsRequestBody {
360 #[serde(rename = "NextMarker")]
361 next_marker: String,
362}
363
364#[async_trait]
365impl DurableServiceClient for LambdaDurableServiceClient {
366 async fn checkpoint(
367 &self,
368 durable_execution_arn: &str,
369 checkpoint_token: &str,
370 operations: Vec<OperationUpdate>,
371 ) -> Result<CheckpointResponse, DurableError> {
372 let request_body = CheckpointRequestBody {
373 checkpoint_token: checkpoint_token.to_string(),
374 updates: operations,
375 };
376
377 let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
378 message: format!("Failed to serialize checkpoint request: {}", e),
379 })?;
380
381 let encoded_arn = urlencoding::encode(durable_execution_arn);
383 let uri = format!(
384 "{}/2025-12-01/durable-executions/{}/checkpoint",
385 self.endpoint_url(),
386 encoded_arn
387 );
388
389 let signed_headers = self.sign_request("POST", &uri, &body).await?;
391
392 let mut request = self
394 .http_client
395 .post(&uri)
396 .header("Content-Type", "application/json")
397 .body(body);
398
399 for (name, value) in signed_headers {
400 request = request.header(&name, &value);
401 }
402
403 let response = request.send().await.map_err(|e| DurableError::Checkpoint {
404 message: format!("HTTP request failed: {}", e),
405 is_retriable: e.is_timeout() || e.is_connect(),
406 aws_error: None,
407 })?;
408
409 let status = response.status();
410 let response_body = response
411 .bytes()
412 .await
413 .map_err(|e| DurableError::Checkpoint {
414 message: format!("Failed to read response body: {}", e),
415 is_retriable: true,
416 aws_error: None,
417 })?;
418
419 if !status.is_success() {
420 let error_message = String::from_utf8_lossy(&response_body);
421
422 if status.as_u16() == 413
427 || error_message.contains("RequestEntityTooLarge")
428 || error_message.contains("Payload size exceeded")
429 || error_message.contains("Request too large")
430 {
431 return Err(DurableError::SizeLimit {
432 message: format!("Checkpoint payload size exceeded: {}", error_message),
433 actual_size: None,
434 max_size: None,
435 });
436 }
437
438 if status.as_u16() == 404
441 || error_message.contains("ResourceNotFoundException")
442 || error_message.contains("not found")
443 {
444 return Err(DurableError::ResourceNotFound {
445 message: format!("Durable execution not found: {}", error_message),
446 resource_id: Some(durable_execution_arn.to_string()),
447 });
448 }
449
450 if status.as_u16() == 429
453 || error_message.contains("ThrottlingException")
454 || error_message.contains("TooManyRequestsException")
455 || error_message.contains("Rate exceeded")
456 {
457 return Err(DurableError::Throttling {
458 message: format!("Rate limit exceeded: {}", error_message),
459 retry_after_ms: None, });
461 }
462
463 let is_invalid_token = error_message.contains("Invalid checkpoint token");
467
468 let is_retriable = status.is_server_error() || is_invalid_token;
473
474 return Err(DurableError::Checkpoint {
475 message: format!("Checkpoint API returned {}: {}", status, error_message),
476 is_retriable,
477 aws_error: Some(AwsError {
478 code: if is_invalid_token {
479 "InvalidParameterValueException".to_string()
480 } else {
481 status.to_string()
482 },
483 message: error_message.to_string(),
484 request_id: None,
485 }),
486 });
487 }
488
489 let checkpoint_response: CheckpointResponse = serde_json::from_slice(&response_body)
490 .map_err(|e| DurableError::SerDes {
491 message: format!("Failed to deserialize checkpoint response: {}", e),
492 })?;
493
494 Ok(checkpoint_response)
495 }
496
497 async fn get_operations(
498 &self,
499 durable_execution_arn: &str,
500 next_marker: &str,
501 ) -> Result<GetOperationsResponse, DurableError> {
502 let request_body = GetOperationsRequestBody {
503 next_marker: next_marker.to_string(),
504 };
505
506 let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
507 message: format!("Failed to serialize get_operations request: {}", e),
508 })?;
509
510 let encoded_arn = urlencoding::encode(durable_execution_arn);
512 let uri = format!(
513 "{}/2025-12-01/durable-executions/{}/state",
514 self.endpoint_url(),
515 encoded_arn
516 );
517
518 let signed_headers = self.sign_request("POST", &uri, &body).await?;
520
521 let mut request = self
523 .http_client
524 .post(&uri)
525 .header("Content-Type", "application/json")
526 .body(body);
527
528 for (name, value) in signed_headers {
529 request = request.header(&name, &value);
530 }
531
532 let response = request.send().await.map_err(|e| DurableError::Invocation {
533 message: format!("HTTP request failed: {}", e),
534 termination_reason: crate::error::TerminationReason::InvocationError,
535 })?;
536
537 let status = response.status();
538 let response_body = response
539 .bytes()
540 .await
541 .map_err(|e| DurableError::Invocation {
542 message: format!("Failed to read response body: {}", e),
543 termination_reason: crate::error::TerminationReason::InvocationError,
544 })?;
545
546 if !status.is_success() {
547 let error_message = String::from_utf8_lossy(&response_body);
548
549 if status.as_u16() == 404
552 || error_message.contains("ResourceNotFoundException")
553 || error_message.contains("not found")
554 {
555 return Err(DurableError::ResourceNotFound {
556 message: format!("Durable execution not found: {}", error_message),
557 resource_id: Some(durable_execution_arn.to_string()),
558 });
559 }
560
561 if status.as_u16() == 429
564 || error_message.contains("ThrottlingException")
565 || error_message.contains("TooManyRequestsException")
566 || error_message.contains("Rate exceeded")
567 {
568 return Err(DurableError::Throttling {
569 message: format!("Rate limit exceeded: {}", error_message),
570 retry_after_ms: None,
571 });
572 }
573
574 return Err(DurableError::Invocation {
575 message: format!(
576 "GetDurableExecutionState API returned {}: {}",
577 status, error_message
578 ),
579 termination_reason: crate::error::TerminationReason::InvocationError,
580 });
581 }
582
583 let operations_response: GetOperationsResponse = serde_json::from_slice(&response_body)
584 .map_err(|e| DurableError::SerDes {
585 message: format!("Failed to deserialize get_operations response: {}", e),
586 })?;
587
588 Ok(operations_response)
589 }
590}
591
592#[cfg(test)]
594pub struct MockDurableServiceClient {
595 checkpoint_responses: std::sync::Mutex<Vec<Result<CheckpointResponse, DurableError>>>,
596 get_operations_responses: std::sync::Mutex<Vec<Result<GetOperationsResponse, DurableError>>>,
597}
598
599#[cfg(test)]
600impl MockDurableServiceClient {
601 pub fn new() -> Self {
602 Self {
603 checkpoint_responses: std::sync::Mutex::new(Vec::new()),
604 get_operations_responses: std::sync::Mutex::new(Vec::new()),
605 }
606 }
607
608 pub fn with_checkpoint_response(
609 self,
610 response: Result<CheckpointResponse, DurableError>,
611 ) -> Self {
612 self.checkpoint_responses.lock().unwrap().push(response);
613 self
614 }
615
616 pub fn with_get_operations_response(
617 self,
618 response: Result<GetOperationsResponse, DurableError>,
619 ) -> Self {
620 self.get_operations_responses.lock().unwrap().push(response);
621 self
622 }
623
624 pub fn with_checkpoint_responses(self, count: usize) -> Self {
626 let mut responses = self.checkpoint_responses.lock().unwrap();
627 for i in 0..count {
628 responses.push(Ok(CheckpointResponse {
629 checkpoint_token: format!("token-{}", i),
630 new_execution_state: None,
631 }));
632 }
633 drop(responses);
634 self
635 }
636}
637
638#[cfg(test)]
639#[async_trait]
640impl DurableServiceClient for MockDurableServiceClient {
641 async fn checkpoint(
642 &self,
643 _durable_execution_arn: &str,
644 _checkpoint_token: &str,
645 _operations: Vec<OperationUpdate>,
646 ) -> Result<CheckpointResponse, DurableError> {
647 let mut responses = self.checkpoint_responses.lock().unwrap();
648 if responses.is_empty() {
649 Ok(CheckpointResponse {
650 checkpoint_token: "mock-token".to_string(),
651 new_execution_state: None,
652 })
653 } else {
654 responses.remove(0)
655 }
656 }
657
658 async fn get_operations(
659 &self,
660 _durable_execution_arn: &str,
661 _next_marker: &str,
662 ) -> Result<GetOperationsResponse, DurableError> {
663 let mut responses = self.get_operations_responses.lock().unwrap();
664 if responses.is_empty() {
665 Ok(GetOperationsResponse {
666 operations: Vec::new(),
667 next_marker: None,
668 })
669 } else {
670 responses.remove(0)
671 }
672 }
673}
674
675pub type SharedDurableServiceClient = Arc<dyn DurableServiceClient>;
677
678#[cfg(test)]
679mod tests {
680 use super::*;
681 use crate::operation::OperationType;
682
683 #[test]
684 fn test_checkpoint_response_serialization() {
685 let response = CheckpointResponse {
686 checkpoint_token: "token-123".to_string(),
687 new_execution_state: None,
688 };
689 let json = serde_json::to_string(&response).unwrap();
690 assert!(json.contains(r#""CheckpointToken":"token-123""#));
691 }
692
693 #[test]
694 fn test_checkpoint_response_deserialization() {
695 let json = r#"{"CheckpointToken": "token-456"}"#;
696 let response: CheckpointResponse = serde_json::from_str(json).unwrap();
697 assert_eq!(response.checkpoint_token, "token-456");
698 }
699
700 #[test]
701 fn test_get_operations_response_serialization() {
702 let response = GetOperationsResponse {
703 operations: vec![Operation::new("op-1", OperationType::Step)],
704 next_marker: Some("marker-123".to_string()),
705 };
706 let json = serde_json::to_string(&response).unwrap();
707 assert!(json.contains(r#""Operations""#));
708 assert!(json.contains(r#""NextMarker":"marker-123""#));
709 }
710
711 #[test]
712 fn test_get_operations_response_deserialization() {
713 let json = r#"{
714 "Operations": [
715 {
716 "Id": "op-1",
717 "Type": "STEP",
718 "Status": "SUCCEEDED"
719 }
720 ],
721 "NextMarker": "marker-456"
722 }"#;
723 let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
724 assert_eq!(response.operations.len(), 1);
725 assert_eq!(response.operations[0].operation_id, "op-1");
726 assert_eq!(response.next_marker, Some("marker-456".to_string()));
727 }
728
729 #[test]
730 fn test_get_operations_response_without_marker() {
731 let json = r#"{
732 "Operations": []
733 }"#;
734 let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
735 assert!(response.operations.is_empty());
736 assert!(response.next_marker.is_none());
737 }
738
739 #[test]
740 fn test_lambda_client_config_default() {
741 let config = LambdaClientConfig::default();
742 assert_eq!(config.region, "us-east-1");
743 assert!(config.endpoint_url.is_none());
744 }
745
746 #[test]
747 fn test_lambda_client_config_with_region() {
748 let config = LambdaClientConfig::with_region("us-west-2");
749 assert_eq!(config.region, "us-west-2");
750 }
751
752 #[tokio::test]
753 async fn test_mock_client_checkpoint() {
754 let client = MockDurableServiceClient::new();
755 let result = client
756 .checkpoint(
757 "arn:aws:lambda:us-east-1:123456789012:function:test",
758 "token-123",
759 vec![],
760 )
761 .await;
762 assert!(result.is_ok());
763 assert_eq!(result.unwrap().checkpoint_token, "mock-token");
764 }
765
766 #[tokio::test]
767 async fn test_mock_client_checkpoint_with_custom_response() {
768 let client =
769 MockDurableServiceClient::new().with_checkpoint_response(Ok(CheckpointResponse {
770 checkpoint_token: "custom-token".to_string(),
771 new_execution_state: None,
772 }));
773 let result = client
774 .checkpoint(
775 "arn:aws:lambda:us-east-1:123456789012:function:test",
776 "token-123",
777 vec![],
778 )
779 .await;
780 assert!(result.is_ok());
781 assert_eq!(result.unwrap().checkpoint_token, "custom-token");
782 }
783
784 #[tokio::test]
785 async fn test_mock_client_checkpoint_with_error() {
786 let client = MockDurableServiceClient::new()
787 .with_checkpoint_response(Err(DurableError::checkpoint_retriable("Test error")));
788 let result = client
789 .checkpoint(
790 "arn:aws:lambda:us-east-1:123456789012:function:test",
791 "token-123",
792 vec![],
793 )
794 .await;
795 assert!(result.is_err());
796 assert!(result.unwrap_err().is_retriable());
797 }
798
799 #[tokio::test]
800 async fn test_mock_client_get_operations() {
801 let client = MockDurableServiceClient::new();
802 let result = client
803 .get_operations(
804 "arn:aws:lambda:us-east-1:123456789012:function:test",
805 "marker-123",
806 )
807 .await;
808 assert!(result.is_ok());
809 let response = result.unwrap();
810 assert!(response.operations.is_empty());
811 assert!(response.next_marker.is_none());
812 }
813
814 #[tokio::test]
815 async fn test_mock_client_get_operations_with_custom_response() {
816 let client = MockDurableServiceClient::new().with_get_operations_response(Ok(
817 GetOperationsResponse {
818 operations: vec![Operation::new("op-1", OperationType::Step)],
819 next_marker: Some("next-marker".to_string()),
820 },
821 ));
822 let result = client
823 .get_operations(
824 "arn:aws:lambda:us-east-1:123456789012:function:test",
825 "marker-123",
826 )
827 .await;
828 assert!(result.is_ok());
829 let response = result.unwrap();
830 assert_eq!(response.operations.len(), 1);
831 assert_eq!(response.next_marker, Some("next-marker".to_string()));
832 }
833
834 #[test]
835 fn test_checkpoint_request_body_serialization() {
836 let request = CheckpointRequestBody {
837 checkpoint_token: "token-123".to_string(),
838 updates: vec![OperationUpdate::start("op-1", OperationType::Step)],
839 };
840 let json = serde_json::to_string(&request).unwrap();
841 assert!(json.contains(r#""CheckpointToken":"token-123""#));
842 assert!(json.contains(r#""Updates""#));
843 }
844
845 #[test]
846 fn test_get_operations_request_body_serialization() {
847 let request = GetOperationsRequestBody {
848 next_marker: "marker-123".to_string(),
849 };
850 let json = serde_json::to_string(&request).unwrap();
851 assert!(json.contains(r#""NextMarker":"marker-123""#));
852 }
853
854 #[tokio::test]
855 async fn test_mock_client_checkpoint_with_invalid_token_error() {
856 let error = DurableError::Checkpoint {
858 message: "Checkpoint API returned 400: Invalid checkpoint token".to_string(),
859 is_retriable: true,
860 aws_error: Some(AwsError {
861 code: "InvalidParameterValueException".to_string(),
862 message: "Invalid checkpoint token: token has been consumed".to_string(),
863 request_id: None,
864 }),
865 };
866
867 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
868 let result = client
869 .checkpoint(
870 "arn:aws:lambda:us-east-1:123456789012:function:test",
871 "consumed-token",
872 vec![],
873 )
874 .await;
875
876 assert!(result.is_err());
877 let err = result.unwrap_err();
878 assert!(err.is_retriable());
879 assert!(err.is_invalid_checkpoint_token());
880 }
881
882 #[tokio::test]
883 async fn test_mock_client_checkpoint_with_size_limit_error() {
884 let error = DurableError::SizeLimit {
886 message: "Checkpoint payload size exceeded".to_string(),
887 actual_size: Some(7_000_000),
888 max_size: Some(6_000_000),
889 };
890
891 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
892 let result = client
893 .checkpoint(
894 "arn:aws:lambda:us-east-1:123456789012:function:test",
895 "token-123",
896 vec![],
897 )
898 .await;
899
900 assert!(result.is_err());
901 let err = result.unwrap_err();
902 assert!(err.is_size_limit());
903 assert!(!err.is_retriable());
904 }
905
906 #[tokio::test]
907 async fn test_mock_client_checkpoint_with_throttling_error() {
908 let error = DurableError::Throttling {
910 message: "Rate limit exceeded".to_string(),
911 retry_after_ms: Some(5000),
912 };
913
914 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
915 let result = client
916 .checkpoint(
917 "arn:aws:lambda:us-east-1:123456789012:function:test",
918 "token-123",
919 vec![],
920 )
921 .await;
922
923 assert!(result.is_err());
924 let err = result.unwrap_err();
925 assert!(err.is_throttling());
926 assert_eq!(err.get_retry_after_ms(), Some(5000));
927 }
928
929 #[tokio::test]
930 async fn test_mock_client_checkpoint_with_resource_not_found_error() {
931 let error = DurableError::ResourceNotFound {
933 message: "Durable execution not found".to_string(),
934 resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
935 };
936
937 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
938 let result = client
939 .checkpoint(
940 "arn:aws:lambda:us-east-1:123456789012:function:test",
941 "token-123",
942 vec![],
943 )
944 .await;
945
946 assert!(result.is_err());
947 let err = result.unwrap_err();
948 assert!(err.is_resource_not_found());
949 assert!(!err.is_retriable());
950 }
951
952 #[tokio::test]
953 async fn test_mock_client_get_operations_with_throttling_error() {
954 let error = DurableError::Throttling {
956 message: "Rate limit exceeded".to_string(),
957 retry_after_ms: None,
958 };
959
960 let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
961 let result = client
962 .get_operations(
963 "arn:aws:lambda:us-east-1:123456789012:function:test",
964 "marker-123",
965 )
966 .await;
967
968 assert!(result.is_err());
969 let err = result.unwrap_err();
970 assert!(err.is_throttling());
971 }
972
973 #[tokio::test]
974 async fn test_mock_client_get_operations_with_resource_not_found_error() {
975 let error = DurableError::ResourceNotFound {
977 message: "Durable execution not found".to_string(),
978 resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
979 };
980
981 let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
982 let result = client
983 .get_operations(
984 "arn:aws:lambda:us-east-1:123456789012:function:test",
985 "marker-123",
986 )
987 .await;
988
989 assert!(result.is_err());
990 let err = result.unwrap_err();
991 assert!(err.is_resource_not_found());
992 }
993
994 fn test_sdk_config() -> aws_config::SdkConfig {
996 let creds =
997 aws_credential_types::Credentials::new("test-key", "test-secret", None, None, "test");
998 let provider = aws_credential_types::provider::SharedCredentialsProvider::new(creds);
999 aws_config::SdkConfig::builder()
1000 .credentials_provider(provider)
1001 .region(aws_config::Region::new("us-east-1"))
1002 .build()
1003 }
1004
1005 #[test]
1006 fn test_from_aws_config_with_user_agent_constructs_client() {
1007 let config = test_sdk_config();
1009 let client =
1010 LambdaDurableServiceClient::from_aws_config_with_user_agent(&config, "my-sdk", "1.0.0")
1011 .unwrap();
1012 assert_eq!(client.config.region, "us-east-1");
1014 }
1015
1016 #[test]
1017 fn test_from_aws_config_unchanged() {
1018 let config = test_sdk_config();
1020 let client = LambdaDurableServiceClient::from_aws_config(&config).unwrap();
1021 assert_eq!(client.config.region, "us-east-1");
1022 }
1023
1024 #[tokio::test]
1025 async fn test_user_agent_header_is_sent() {
1026 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1028 use tokio::net::TcpListener;
1029
1030 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1032 let addr = listener.local_addr().unwrap();
1033
1034 let server_handle = tokio::spawn(async move {
1035 let (mut socket, _) = listener.accept().await.unwrap();
1036 let mut buf = vec![0u8; 4096];
1037 let n = socket.read(&mut buf).await.unwrap();
1038 let request = String::from_utf8_lossy(&buf[..n]).to_string();
1039
1040 let response = "HTTP/1.1 400 Bad Request\r\nContent-Length: 2\r\n\r\n{}";
1042 let _ = socket.write_all(response.as_bytes()).await;
1043 let _ = socket.shutdown().await;
1044
1045 request
1046 });
1047
1048 let sdk_name = "test-sdk";
1050 let sdk_version = "2.3.4";
1051 let user_agent = format!("{}/{}", sdk_name, sdk_version);
1052
1053 let config = test_sdk_config();
1054 let mut client = LambdaDurableServiceClient::from_aws_config_with_user_agent(
1055 &config,
1056 sdk_name,
1057 sdk_version,
1058 )
1059 .unwrap();
1060 client.config.endpoint_url = Some(format!("http://{}", addr));
1061
1062 let _ = client
1064 .checkpoint("arn:test:execution", "token", vec![])
1065 .await;
1066
1067 let captured_request = server_handle.await.unwrap();
1069 assert!(
1070 captured_request.contains(&user_agent),
1071 "Expected User-Agent '{}' in request headers, got:\n{}",
1072 user_agent,
1073 captured_request
1074 );
1075 }
1076}