1use std::sync::Arc;
9use std::time::SystemTime;
10
11use async_trait::async_trait;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings};
14use aws_sigv4::sign::v4;
15use serde::{Deserialize, Serialize};
16
17use crate::error::{AwsError, DurableError};
18use crate::operation::{Operation, OperationUpdate};
19
20#[async_trait]
25pub trait DurableServiceClient: Send + Sync {
26 async fn checkpoint(
38 &self,
39 durable_execution_arn: &str,
40 checkpoint_token: &str,
41 operations: Vec<OperationUpdate>,
42 ) -> Result<CheckpointResponse, DurableError>;
43
44 async fn get_operations(
55 &self,
56 durable_execution_arn: &str,
57 next_marker: &str,
58 ) -> Result<GetOperationsResponse, DurableError>;
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct CheckpointResponse {
64 #[serde(rename = "CheckpointToken", default)]
66 pub checkpoint_token: String,
67
68 #[serde(
71 rename = "NewExecutionState",
72 skip_serializing_if = "Option::is_none",
73 default
74 )]
75 pub new_execution_state: Option<NewExecutionState>,
76}
77
78impl CheckpointResponse {
79 pub fn new(checkpoint_token: impl Into<String>) -> Self {
81 Self {
82 checkpoint_token: checkpoint_token.into(),
83 new_execution_state: None,
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NewExecutionState {
91 #[serde(rename = "Operations", default)]
93 pub operations: Vec<Operation>,
94
95 #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
97 pub next_marker: Option<String>,
98}
99
100impl NewExecutionState {
101 pub fn find_operation(&self, operation_id: &str) -> Option<&Operation> {
111 self.operations
112 .iter()
113 .find(|op| op.operation_id == operation_id)
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct GetOperationsResponse {
120 #[serde(rename = "Operations")]
122 pub operations: Vec<Operation>,
123
124 #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
126 pub next_marker: Option<String>,
127}
128
129#[derive(Debug, Clone)]
131pub struct LambdaClientConfig {
132 pub region: String,
134 pub endpoint_url: Option<String>,
136}
137
138impl Default for LambdaClientConfig {
139 fn default() -> Self {
140 Self {
141 region: "us-east-1".to_string(),
142 endpoint_url: None,
143 }
144 }
145}
146
147impl LambdaClientConfig {
148 pub fn with_region(region: impl Into<String>) -> Self {
150 Self {
151 region: region.into(),
152 endpoint_url: None,
153 }
154 }
155
156 pub fn from_aws_config(config: &aws_config::SdkConfig) -> Self {
158 Self {
159 region: config
160 .region()
161 .map(|r| r.to_string())
162 .unwrap_or_else(|| "us-east-1".to_string()),
163 endpoint_url: None,
164 }
165 }
166}
167
168pub struct LambdaDurableServiceClient {
173 http_client: reqwest::Client,
175 credentials_provider: Arc<dyn ProvideCredentials>,
177 config: LambdaClientConfig,
179}
180
181impl LambdaDurableServiceClient {
182 pub async fn from_env() -> Result<Self, DurableError> {
188 let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
189 .load()
190 .await;
191 Self::from_aws_config(&aws_config)
192 }
193
194 pub fn from_aws_config(aws_config: &aws_config::SdkConfig) -> Result<Self, DurableError> {
200 let credentials_provider = aws_config
201 .credentials_provider()
202 .ok_or_else(|| DurableError::Configuration {
203 message: "No credentials provider configured in AWS SDK config".to_string(),
204 })?
205 .clone();
206
207 Ok(Self {
208 http_client: reqwest::Client::new(),
209 credentials_provider: Arc::from(credentials_provider),
210 config: LambdaClientConfig::from_aws_config(aws_config),
211 })
212 }
213
214 pub fn from_aws_config_with_user_agent(
221 aws_config: &aws_config::SdkConfig,
222 sdk_name: &str,
223 sdk_version: &str,
224 ) -> Result<Self, DurableError> {
225 let credentials_provider = aws_config
226 .credentials_provider()
227 .ok_or_else(|| DurableError::Configuration {
228 message: "No credentials provider configured in AWS SDK config".to_string(),
229 })?
230 .clone();
231
232 let user_agent = format!("{}/{}", sdk_name, sdk_version);
233 let mut headers = reqwest::header::HeaderMap::new();
234 if let Ok(value) = reqwest::header::HeaderValue::from_str(&user_agent) {
235 headers.insert(reqwest::header::USER_AGENT, value);
236 }
237
238 let http_client = reqwest::Client::builder()
239 .default_headers(headers)
240 .build()
241 .map_err(|e| DurableError::Configuration {
242 message: format!("Failed to build HTTP client: {}", e),
243 })?;
244
245 Ok(Self {
246 http_client,
247 credentials_provider: Arc::from(credentials_provider),
248 config: LambdaClientConfig::from_aws_config(aws_config),
249 })
250 }
251
252 pub fn with_config(
254 credentials_provider: Arc<dyn ProvideCredentials>,
255 config: LambdaClientConfig,
256 ) -> Self {
257 Self {
258 http_client: reqwest::Client::new(),
259 credentials_provider,
260 config,
261 }
262 }
263
264 fn endpoint_url(&self) -> String {
266 self.config
267 .endpoint_url
268 .clone()
269 .unwrap_or_else(|| format!("https://lambda.{}.amazonaws.com", self.config.region))
270 }
271
272 async fn sign_request(
274 &self,
275 method: &str,
276 uri: &str,
277 body: &[u8],
278 ) -> Result<Vec<(String, String)>, DurableError> {
279 let credentials = self
280 .credentials_provider
281 .provide_credentials()
282 .await
283 .map_err(|e| DurableError::Checkpoint {
284 message: format!("Failed to get AWS credentials: {}", e),
285 is_retriable: true,
286 aws_error: None,
287 })?;
288
289 let identity = credentials.into();
290 let signing_settings = SigningSettings::default();
291 let signing_params = v4::SigningParams::builder()
292 .identity(&identity)
293 .region(&self.config.region)
294 .name("lambda")
295 .time(SystemTime::now())
296 .settings(signing_settings)
297 .build()
298 .map_err(|e| DurableError::Checkpoint {
299 message: format!("Failed to build signing params: {}", e),
300 is_retriable: false,
301 aws_error: None,
302 })?;
303
304 let signable_request = SignableRequest::new(
305 method,
306 uri,
307 std::iter::empty::<(&str, &str)>(),
308 SignableBody::Bytes(body),
309 )
310 .map_err(|e| DurableError::Checkpoint {
311 message: format!("Failed to create signable request: {}", e),
312 is_retriable: false,
313 aws_error: None,
314 })?;
315
316 let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
317 .map_err(|e| DurableError::Checkpoint {
318 message: format!("Failed to sign request: {}", e),
319 is_retriable: false,
320 aws_error: None,
321 })?
322 .into_parts();
323
324 let mut temp_request = http::Request::builder()
326 .method(method)
327 .uri(uri)
328 .body(())
329 .map_err(|e| DurableError::Checkpoint {
330 message: format!("Failed to build temp request: {}", e),
331 is_retriable: false,
332 aws_error: None,
333 })?;
334
335 signing_instructions.apply_to_request_http1x(&mut temp_request);
336
337 let headers: Vec<(String, String)> = temp_request
339 .headers()
340 .iter()
341 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
342 .collect();
343
344 Ok(headers)
345 }
346}
347
348#[derive(Debug, Clone, Serialize)]
350struct CheckpointRequestBody {
351 #[serde(rename = "CheckpointToken")]
352 checkpoint_token: String,
353 #[serde(rename = "Updates")]
354 updates: Vec<OperationUpdate>,
355}
356
357#[derive(Debug, Clone, Serialize)]
359struct GetOperationsRequestBody {
360 #[serde(rename = "NextMarker")]
361 next_marker: String,
362}
363
364#[async_trait]
365impl DurableServiceClient for LambdaDurableServiceClient {
366 async fn checkpoint(
367 &self,
368 durable_execution_arn: &str,
369 checkpoint_token: &str,
370 operations: Vec<OperationUpdate>,
371 ) -> Result<CheckpointResponse, DurableError> {
372 let request_body = CheckpointRequestBody {
373 checkpoint_token: checkpoint_token.to_string(),
374 updates: operations,
375 };
376
377 let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
378 message: format!("Failed to serialize checkpoint request: {}", e),
379 })?;
380
381 let encoded_arn = urlencoding::encode(durable_execution_arn);
383 let uri = format!(
384 "{}/2025-12-01/durable-executions/{}/checkpoint",
385 self.endpoint_url(),
386 encoded_arn
387 );
388
389 let signed_headers = self.sign_request("POST", &uri, &body).await?;
391
392 let mut request = self
394 .http_client
395 .post(&uri)
396 .header("Content-Type", "application/json")
397 .body(body);
398
399 for (name, value) in signed_headers {
400 request = request.header(&name, &value);
401 }
402
403 let response = request.send().await.map_err(|e| DurableError::Checkpoint {
404 message: format!("HTTP request failed: {}", e),
405 is_retriable: e.is_timeout() || e.is_connect(),
406 aws_error: None,
407 })?;
408
409 let status = response.status();
410 let response_body = response
411 .bytes()
412 .await
413 .map_err(|e| DurableError::Checkpoint {
414 message: format!("Failed to read response body: {}", e),
415 is_retriable: true,
416 aws_error: None,
417 })?;
418
419 if !status.is_success() {
420 let error_message = String::from_utf8_lossy(&response_body);
421
422 if status.as_u16() == 413
427 || error_message.contains("RequestEntityTooLarge")
428 || error_message.contains("Payload size exceeded")
429 || error_message.contains("Request too large")
430 {
431 return Err(DurableError::SizeLimit {
432 message: format!("Checkpoint payload size exceeded: {}", error_message),
433 actual_size: None,
434 max_size: None,
435 });
436 }
437
438 if status.as_u16() == 404
441 || error_message.contains("ResourceNotFoundException")
442 || error_message.contains("not found")
443 {
444 return Err(DurableError::ResourceNotFound {
445 message: format!("Durable execution not found: {}", error_message),
446 resource_id: Some(durable_execution_arn.to_string()),
447 });
448 }
449
450 if status.as_u16() == 429
453 || error_message.contains("ThrottlingException")
454 || error_message.contains("TooManyRequestsException")
455 || error_message.contains("Rate exceeded")
456 {
457 return Err(DurableError::Throttling {
458 message: format!("Rate limit exceeded: {}", error_message),
459 retry_after_ms: None, });
461 }
462
463 let is_invalid_token = error_message.contains("Invalid checkpoint token");
467
468 let is_retriable = status.is_server_error() || is_invalid_token;
473
474 return Err(DurableError::Checkpoint {
475 message: format!("Checkpoint API returned {}: {}", status, error_message),
476 is_retriable,
477 aws_error: Some(AwsError {
478 code: if is_invalid_token {
479 "InvalidParameterValueException".to_string()
480 } else {
481 status.to_string()
482 },
483 message: error_message.to_string(),
484 request_id: None,
485 }),
486 });
487 }
488
489 let checkpoint_response: CheckpointResponse = serde_json::from_slice(&response_body)
490 .map_err(|e| DurableError::SerDes {
491 message: format!("Failed to deserialize checkpoint response: {}", e),
492 })?;
493
494 Ok(checkpoint_response)
495 }
496
497 async fn get_operations(
498 &self,
499 durable_execution_arn: &str,
500 next_marker: &str,
501 ) -> Result<GetOperationsResponse, DurableError> {
502 let request_body = GetOperationsRequestBody {
503 next_marker: next_marker.to_string(),
504 };
505
506 let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
507 message: format!("Failed to serialize get_operations request: {}", e),
508 })?;
509
510 let encoded_arn = urlencoding::encode(durable_execution_arn);
512 let uri = format!(
513 "{}/2025-12-01/durable-executions/{}/state",
514 self.endpoint_url(),
515 encoded_arn
516 );
517
518 let signed_headers = self.sign_request("POST", &uri, &body).await?;
520
521 let mut request = self
523 .http_client
524 .post(&uri)
525 .header("Content-Type", "application/json")
526 .body(body);
527
528 for (name, value) in signed_headers {
529 request = request.header(&name, &value);
530 }
531
532 let response = request.send().await.map_err(|e| DurableError::Invocation {
533 message: format!("HTTP request failed: {}", e),
534 termination_reason: crate::error::TerminationReason::InvocationError,
535 })?;
536
537 let status = response.status();
538 let response_body = response
539 .bytes()
540 .await
541 .map_err(|e| DurableError::Invocation {
542 message: format!("Failed to read response body: {}", e),
543 termination_reason: crate::error::TerminationReason::InvocationError,
544 })?;
545
546 if !status.is_success() {
547 let error_message = String::from_utf8_lossy(&response_body);
548
549 if status.as_u16() == 404
552 || error_message.contains("ResourceNotFoundException")
553 || error_message.contains("not found")
554 {
555 return Err(DurableError::ResourceNotFound {
556 message: format!("Durable execution not found: {}", error_message),
557 resource_id: Some(durable_execution_arn.to_string()),
558 });
559 }
560
561 if status.as_u16() == 429
564 || error_message.contains("ThrottlingException")
565 || error_message.contains("TooManyRequestsException")
566 || error_message.contains("Rate exceeded")
567 {
568 return Err(DurableError::Throttling {
569 message: format!("Rate limit exceeded: {}", error_message),
570 retry_after_ms: None,
571 });
572 }
573
574 return Err(DurableError::Invocation {
575 message: format!(
576 "GetDurableExecutionState API returned {}: {}",
577 status, error_message
578 ),
579 termination_reason: crate::error::TerminationReason::InvocationError,
580 });
581 }
582
583 let operations_response: GetOperationsResponse = serde_json::from_slice(&response_body)
584 .map_err(|e| DurableError::SerDes {
585 message: format!("Failed to deserialize get_operations response: {}", e),
586 })?;
587
588 Ok(operations_response)
589 }
590}
591
592#[cfg(test)]
594pub struct MockDurableServiceClient {
595 checkpoint_responses: std::sync::Mutex<Vec<Result<CheckpointResponse, DurableError>>>,
596 get_operations_responses: std::sync::Mutex<Vec<Result<GetOperationsResponse, DurableError>>>,
597}
598
599#[cfg(test)]
600impl Default for MockDurableServiceClient {
601 fn default() -> Self {
602 Self {
603 checkpoint_responses: std::sync::Mutex::new(Vec::new()),
604 get_operations_responses: std::sync::Mutex::new(Vec::new()),
605 }
606 }
607}
608
609#[cfg(test)]
610impl MockDurableServiceClient {
611 pub fn new() -> Self {
612 Self::default()
613 }
614
615 pub fn with_checkpoint_response(
616 self,
617 response: Result<CheckpointResponse, DurableError>,
618 ) -> Self {
619 self.checkpoint_responses.lock().unwrap().push(response);
620 self
621 }
622
623 pub fn with_get_operations_response(
624 self,
625 response: Result<GetOperationsResponse, DurableError>,
626 ) -> Self {
627 self.get_operations_responses.lock().unwrap().push(response);
628 self
629 }
630
631 pub fn with_checkpoint_responses(self, count: usize) -> Self {
633 let mut responses = self.checkpoint_responses.lock().unwrap();
634 for i in 0..count {
635 responses.push(Ok(CheckpointResponse {
636 checkpoint_token: format!("token-{}", i),
637 new_execution_state: None,
638 }));
639 }
640 drop(responses);
641 self
642 }
643}
644
645#[cfg(test)]
646#[async_trait]
647impl DurableServiceClient for MockDurableServiceClient {
648 async fn checkpoint(
649 &self,
650 _durable_execution_arn: &str,
651 _checkpoint_token: &str,
652 _operations: Vec<OperationUpdate>,
653 ) -> Result<CheckpointResponse, DurableError> {
654 let mut responses = self.checkpoint_responses.lock().unwrap();
655 if responses.is_empty() {
656 Ok(CheckpointResponse {
657 checkpoint_token: "mock-token".to_string(),
658 new_execution_state: None,
659 })
660 } else {
661 responses.remove(0)
662 }
663 }
664
665 async fn get_operations(
666 &self,
667 _durable_execution_arn: &str,
668 _next_marker: &str,
669 ) -> Result<GetOperationsResponse, DurableError> {
670 let mut responses = self.get_operations_responses.lock().unwrap();
671 if responses.is_empty() {
672 Ok(GetOperationsResponse {
673 operations: Vec::new(),
674 next_marker: None,
675 })
676 } else {
677 responses.remove(0)
678 }
679 }
680}
681
682pub type SharedDurableServiceClient = Arc<dyn DurableServiceClient>;
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688 use crate::operation::OperationType;
689
690 #[test]
691 fn test_checkpoint_response_serialization() {
692 let response = CheckpointResponse {
693 checkpoint_token: "token-123".to_string(),
694 new_execution_state: None,
695 };
696 let json = serde_json::to_string(&response).unwrap();
697 assert!(json.contains(r#""CheckpointToken":"token-123""#));
698 }
699
700 #[test]
701 fn test_checkpoint_response_deserialization() {
702 let json = r#"{"CheckpointToken": "token-456"}"#;
703 let response: CheckpointResponse = serde_json::from_str(json).unwrap();
704 assert_eq!(response.checkpoint_token, "token-456");
705 }
706
707 #[test]
708 fn test_get_operations_response_serialization() {
709 let response = GetOperationsResponse {
710 operations: vec![Operation::new("op-1", OperationType::Step)],
711 next_marker: Some("marker-123".to_string()),
712 };
713 let json = serde_json::to_string(&response).unwrap();
714 assert!(json.contains(r#""Operations""#));
715 assert!(json.contains(r#""NextMarker":"marker-123""#));
716 }
717
718 #[test]
719 fn test_get_operations_response_deserialization() {
720 let json = r#"{
721 "Operations": [
722 {
723 "Id": "op-1",
724 "Type": "STEP",
725 "Status": "SUCCEEDED"
726 }
727 ],
728 "NextMarker": "marker-456"
729 }"#;
730 let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
731 assert_eq!(response.operations.len(), 1);
732 assert_eq!(response.operations[0].operation_id, "op-1");
733 assert_eq!(response.next_marker, Some("marker-456".to_string()));
734 }
735
736 #[test]
737 fn test_get_operations_response_without_marker() {
738 let json = r#"{
739 "Operations": []
740 }"#;
741 let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
742 assert!(response.operations.is_empty());
743 assert!(response.next_marker.is_none());
744 }
745
746 #[test]
747 fn test_lambda_client_config_default() {
748 let config = LambdaClientConfig::default();
749 assert_eq!(config.region, "us-east-1");
750 assert!(config.endpoint_url.is_none());
751 }
752
753 #[test]
754 fn test_lambda_client_config_with_region() {
755 let config = LambdaClientConfig::with_region("us-west-2");
756 assert_eq!(config.region, "us-west-2");
757 }
758
759 #[tokio::test]
760 async fn test_mock_client_checkpoint() {
761 let client = MockDurableServiceClient::new();
762 let result = client
763 .checkpoint(
764 "arn:aws:lambda:us-east-1:123456789012:function:test",
765 "token-123",
766 vec![],
767 )
768 .await;
769 assert!(result.is_ok());
770 assert_eq!(result.unwrap().checkpoint_token, "mock-token");
771 }
772
773 #[tokio::test]
774 async fn test_mock_client_checkpoint_with_custom_response() {
775 let client =
776 MockDurableServiceClient::new().with_checkpoint_response(Ok(CheckpointResponse {
777 checkpoint_token: "custom-token".to_string(),
778 new_execution_state: None,
779 }));
780 let result = client
781 .checkpoint(
782 "arn:aws:lambda:us-east-1:123456789012:function:test",
783 "token-123",
784 vec![],
785 )
786 .await;
787 assert!(result.is_ok());
788 assert_eq!(result.unwrap().checkpoint_token, "custom-token");
789 }
790
791 #[tokio::test]
792 async fn test_mock_client_checkpoint_with_error() {
793 let client = MockDurableServiceClient::new()
794 .with_checkpoint_response(Err(DurableError::checkpoint_retriable("Test error")));
795 let result = client
796 .checkpoint(
797 "arn:aws:lambda:us-east-1:123456789012:function:test",
798 "token-123",
799 vec![],
800 )
801 .await;
802 assert!(result.is_err());
803 assert!(result.unwrap_err().is_retriable());
804 }
805
806 #[tokio::test]
807 async fn test_mock_client_get_operations() {
808 let client = MockDurableServiceClient::new();
809 let result = client
810 .get_operations(
811 "arn:aws:lambda:us-east-1:123456789012:function:test",
812 "marker-123",
813 )
814 .await;
815 assert!(result.is_ok());
816 let response = result.unwrap();
817 assert!(response.operations.is_empty());
818 assert!(response.next_marker.is_none());
819 }
820
821 #[tokio::test]
822 async fn test_mock_client_get_operations_with_custom_response() {
823 let client = MockDurableServiceClient::new().with_get_operations_response(Ok(
824 GetOperationsResponse {
825 operations: vec![Operation::new("op-1", OperationType::Step)],
826 next_marker: Some("next-marker".to_string()),
827 },
828 ));
829 let result = client
830 .get_operations(
831 "arn:aws:lambda:us-east-1:123456789012:function:test",
832 "marker-123",
833 )
834 .await;
835 assert!(result.is_ok());
836 let response = result.unwrap();
837 assert_eq!(response.operations.len(), 1);
838 assert_eq!(response.next_marker, Some("next-marker".to_string()));
839 }
840
841 #[test]
842 fn test_checkpoint_request_body_serialization() {
843 let request = CheckpointRequestBody {
844 checkpoint_token: "token-123".to_string(),
845 updates: vec![OperationUpdate::start("op-1", OperationType::Step)],
846 };
847 let json = serde_json::to_string(&request).unwrap();
848 assert!(json.contains(r#""CheckpointToken":"token-123""#));
849 assert!(json.contains(r#""Updates""#));
850 }
851
852 #[test]
853 fn test_get_operations_request_body_serialization() {
854 let request = GetOperationsRequestBody {
855 next_marker: "marker-123".to_string(),
856 };
857 let json = serde_json::to_string(&request).unwrap();
858 assert!(json.contains(r#""NextMarker":"marker-123""#));
859 }
860
861 #[tokio::test]
862 async fn test_mock_client_checkpoint_with_invalid_token_error() {
863 let error = DurableError::Checkpoint {
865 message: "Checkpoint API returned 400: Invalid checkpoint token".to_string(),
866 is_retriable: true,
867 aws_error: Some(AwsError {
868 code: "InvalidParameterValueException".to_string(),
869 message: "Invalid checkpoint token: token has been consumed".to_string(),
870 request_id: None,
871 }),
872 };
873
874 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
875 let result = client
876 .checkpoint(
877 "arn:aws:lambda:us-east-1:123456789012:function:test",
878 "consumed-token",
879 vec![],
880 )
881 .await;
882
883 assert!(result.is_err());
884 let err = result.unwrap_err();
885 assert!(err.is_retriable());
886 assert!(err.is_invalid_checkpoint_token());
887 }
888
889 #[tokio::test]
890 async fn test_mock_client_checkpoint_with_size_limit_error() {
891 let error = DurableError::SizeLimit {
893 message: "Checkpoint payload size exceeded".to_string(),
894 actual_size: Some(7_000_000),
895 max_size: Some(6_000_000),
896 };
897
898 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
899 let result = client
900 .checkpoint(
901 "arn:aws:lambda:us-east-1:123456789012:function:test",
902 "token-123",
903 vec![],
904 )
905 .await;
906
907 assert!(result.is_err());
908 let err = result.unwrap_err();
909 assert!(err.is_size_limit());
910 assert!(!err.is_retriable());
911 }
912
913 #[tokio::test]
914 async fn test_mock_client_checkpoint_with_throttling_error() {
915 let error = DurableError::Throttling {
917 message: "Rate limit exceeded".to_string(),
918 retry_after_ms: Some(5000),
919 };
920
921 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
922 let result = client
923 .checkpoint(
924 "arn:aws:lambda:us-east-1:123456789012:function:test",
925 "token-123",
926 vec![],
927 )
928 .await;
929
930 assert!(result.is_err());
931 let err = result.unwrap_err();
932 assert!(err.is_throttling());
933 assert_eq!(err.get_retry_after_ms(), Some(5000));
934 }
935
936 #[tokio::test]
937 async fn test_mock_client_checkpoint_with_resource_not_found_error() {
938 let error = DurableError::ResourceNotFound {
940 message: "Durable execution not found".to_string(),
941 resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
942 };
943
944 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
945 let result = client
946 .checkpoint(
947 "arn:aws:lambda:us-east-1:123456789012:function:test",
948 "token-123",
949 vec![],
950 )
951 .await;
952
953 assert!(result.is_err());
954 let err = result.unwrap_err();
955 assert!(err.is_resource_not_found());
956 assert!(!err.is_retriable());
957 }
958
959 #[tokio::test]
960 async fn test_mock_client_get_operations_with_throttling_error() {
961 let error = DurableError::Throttling {
963 message: "Rate limit exceeded".to_string(),
964 retry_after_ms: None,
965 };
966
967 let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
968 let result = client
969 .get_operations(
970 "arn:aws:lambda:us-east-1:123456789012:function:test",
971 "marker-123",
972 )
973 .await;
974
975 assert!(result.is_err());
976 let err = result.unwrap_err();
977 assert!(err.is_throttling());
978 }
979
980 #[tokio::test]
981 async fn test_mock_client_get_operations_with_resource_not_found_error() {
982 let error = DurableError::ResourceNotFound {
984 message: "Durable execution not found".to_string(),
985 resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
986 };
987
988 let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
989 let result = client
990 .get_operations(
991 "arn:aws:lambda:us-east-1:123456789012:function:test",
992 "marker-123",
993 )
994 .await;
995
996 assert!(result.is_err());
997 let err = result.unwrap_err();
998 assert!(err.is_resource_not_found());
999 }
1000
1001 fn test_sdk_config() -> aws_config::SdkConfig {
1003 let creds =
1004 aws_credential_types::Credentials::new("test-key", "test-secret", None, None, "test");
1005 let provider = aws_credential_types::provider::SharedCredentialsProvider::new(creds);
1006 aws_config::SdkConfig::builder()
1007 .credentials_provider(provider)
1008 .region(aws_config::Region::new("us-east-1"))
1009 .build()
1010 }
1011
1012 #[test]
1013 fn test_from_aws_config_with_user_agent_constructs_client() {
1014 let config = test_sdk_config();
1016 let client =
1017 LambdaDurableServiceClient::from_aws_config_with_user_agent(&config, "my-sdk", "1.0.0")
1018 .unwrap();
1019 assert_eq!(client.config.region, "us-east-1");
1021 }
1022
1023 #[test]
1024 fn test_from_aws_config_unchanged() {
1025 let config = test_sdk_config();
1027 let client = LambdaDurableServiceClient::from_aws_config(&config).unwrap();
1028 assert_eq!(client.config.region, "us-east-1");
1029 }
1030
1031 #[tokio::test]
1032 async fn test_user_agent_header_is_sent() {
1033 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1035 use tokio::net::TcpListener;
1036
1037 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1039 let addr = listener.local_addr().unwrap();
1040
1041 let server_handle = tokio::spawn(async move {
1042 let (mut socket, _) = listener.accept().await.unwrap();
1043 let mut buf = vec![0u8; 4096];
1044 let n = socket.read(&mut buf).await.unwrap();
1045 let request = String::from_utf8_lossy(&buf[..n]).to_string();
1046
1047 let response = "HTTP/1.1 400 Bad Request\r\nContent-Length: 2\r\n\r\n{}";
1049 let _ = socket.write_all(response.as_bytes()).await;
1050 let _ = socket.shutdown().await;
1051
1052 request
1053 });
1054
1055 let sdk_name = "test-sdk";
1057 let sdk_version = "2.3.4";
1058 let user_agent = format!("{}/{}", sdk_name, sdk_version);
1059
1060 let config = test_sdk_config();
1061 let mut client = LambdaDurableServiceClient::from_aws_config_with_user_agent(
1062 &config,
1063 sdk_name,
1064 sdk_version,
1065 )
1066 .unwrap();
1067 client.config.endpoint_url = Some(format!("http://{}", addr));
1068
1069 let _ = client
1071 .checkpoint("arn:test:execution", "token", vec![])
1072 .await;
1073
1074 let captured_request = server_handle.await.unwrap();
1076 assert!(
1077 captured_request.contains(&user_agent),
1078 "Expected User-Agent '{}' in request headers, got:\n{}",
1079 user_agent,
1080 captured_request
1081 );
1082 }
1083}