1use std::sync::Arc;
9use std::time::SystemTime;
10
11use async_trait::async_trait;
12use aws_credential_types::provider::ProvideCredentials;
13use aws_sigv4::http_request::{sign, SignableBody, SignableRequest, SigningSettings};
14use aws_sigv4::sign::v4;
15use serde::{Deserialize, Serialize};
16
17use crate::error::{AwsError, DurableError};
18use crate::operation::{Operation, OperationUpdate};
19
20#[async_trait]
25pub trait DurableServiceClient: Send + Sync {
26 async fn checkpoint(
38 &self,
39 durable_execution_arn: &str,
40 checkpoint_token: &str,
41 operations: Vec<OperationUpdate>,
42 ) -> Result<CheckpointResponse, DurableError>;
43
44 async fn get_operations(
55 &self,
56 durable_execution_arn: &str,
57 next_marker: &str,
58 ) -> Result<GetOperationsResponse, DurableError>;
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize, Default)]
63pub struct CheckpointResponse {
64 #[serde(rename = "CheckpointToken", default)]
66 pub checkpoint_token: String,
67
68 #[serde(
71 rename = "NewExecutionState",
72 skip_serializing_if = "Option::is_none",
73 default
74 )]
75 pub new_execution_state: Option<NewExecutionState>,
76}
77
78impl CheckpointResponse {
79 pub fn new(checkpoint_token: impl Into<String>) -> Self {
81 Self {
82 checkpoint_token: checkpoint_token.into(),
83 new_execution_state: None,
84 }
85 }
86}
87
88#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct NewExecutionState {
91 #[serde(rename = "Operations", default)]
93 pub operations: Vec<Operation>,
94
95 #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
97 pub next_marker: Option<String>,
98}
99
100impl NewExecutionState {
101 pub fn find_operation(&self, operation_id: &str) -> Option<&Operation> {
111 self.operations
112 .iter()
113 .find(|op| op.operation_id == operation_id)
114 }
115}
116
117#[derive(Debug, Clone, Serialize, Deserialize)]
119pub struct GetOperationsResponse {
120 #[serde(rename = "Operations")]
122 pub operations: Vec<Operation>,
123
124 #[serde(rename = "NextMarker", skip_serializing_if = "Option::is_none")]
126 pub next_marker: Option<String>,
127}
128
129#[derive(Debug, Clone)]
131pub struct LambdaClientConfig {
132 pub region: String,
134 pub endpoint_url: Option<String>,
136}
137
138impl Default for LambdaClientConfig {
139 fn default() -> Self {
140 Self {
141 region: "us-east-1".to_string(),
142 endpoint_url: None,
143 }
144 }
145}
146
147impl LambdaClientConfig {
148 pub fn with_region(region: impl Into<String>) -> Self {
150 Self {
151 region: region.into(),
152 endpoint_url: None,
153 }
154 }
155
156 pub fn from_aws_config(config: &aws_config::SdkConfig) -> Self {
158 Self {
159 region: config
160 .region()
161 .map(|r| r.to_string())
162 .unwrap_or_else(|| "us-east-1".to_string()),
163 endpoint_url: None,
164 }
165 }
166}
167
168pub struct LambdaDurableServiceClient {
173 http_client: reqwest::Client,
175 credentials_provider: Arc<dyn ProvideCredentials>,
177 config: LambdaClientConfig,
179}
180
181impl LambdaDurableServiceClient {
182 pub async fn from_env() -> Self {
184 let aws_config = aws_config::defaults(aws_config::BehaviorVersion::latest())
185 .load()
186 .await;
187 Self::from_aws_config(&aws_config)
188 }
189
190 pub fn from_aws_config(aws_config: &aws_config::SdkConfig) -> Self {
192 let credentials_provider = aws_config
193 .credentials_provider()
194 .expect("No credentials provider configured")
195 .clone();
196
197 Self {
198 http_client: reqwest::Client::new(),
199 credentials_provider: Arc::from(credentials_provider),
200 config: LambdaClientConfig::from_aws_config(aws_config),
201 }
202 }
203
204 pub fn from_aws_config_with_user_agent(
207 aws_config: &aws_config::SdkConfig,
208 sdk_name: &str,
209 sdk_version: &str,
210 ) -> Self {
211 let credentials_provider = aws_config
212 .credentials_provider()
213 .expect("No credentials provider configured")
214 .clone();
215
216 let user_agent = format!("{}/{}", sdk_name, sdk_version);
217 let mut headers = reqwest::header::HeaderMap::new();
218 if let Ok(value) = reqwest::header::HeaderValue::from_str(&user_agent) {
219 headers.insert(reqwest::header::USER_AGENT, value);
220 }
221
222 let http_client = reqwest::Client::builder()
223 .default_headers(headers)
224 .build()
225 .unwrap_or_else(|_| reqwest::Client::new());
226
227 Self {
228 http_client,
229 credentials_provider: Arc::from(credentials_provider),
230 config: LambdaClientConfig::from_aws_config(aws_config),
231 }
232 }
233
234 pub fn with_config(
236 credentials_provider: Arc<dyn ProvideCredentials>,
237 config: LambdaClientConfig,
238 ) -> Self {
239 Self {
240 http_client: reqwest::Client::new(),
241 credentials_provider,
242 config,
243 }
244 }
245
246 pub fn new(_lambda_client: aws_sdk_lambda::Client) -> Self {
252 panic!(
257 "LambdaDurableServiceClient::new() is deprecated. \
258 Use LambdaDurableServiceClient::from_env() or from_aws_config() instead."
259 );
260 }
261
262 fn endpoint_url(&self) -> String {
264 self.config
265 .endpoint_url
266 .clone()
267 .unwrap_or_else(|| format!("https://lambda.{}.amazonaws.com", self.config.region))
268 }
269
270 async fn sign_request(
272 &self,
273 method: &str,
274 uri: &str,
275 body: &[u8],
276 ) -> Result<Vec<(String, String)>, DurableError> {
277 let credentials = self
278 .credentials_provider
279 .provide_credentials()
280 .await
281 .map_err(|e| DurableError::Checkpoint {
282 message: format!("Failed to get AWS credentials: {}", e),
283 is_retriable: true,
284 aws_error: None,
285 })?;
286
287 let identity = credentials.into();
288 let signing_settings = SigningSettings::default();
289 let signing_params = v4::SigningParams::builder()
290 .identity(&identity)
291 .region(&self.config.region)
292 .name("lambda")
293 .time(SystemTime::now())
294 .settings(signing_settings)
295 .build()
296 .map_err(|e| DurableError::Checkpoint {
297 message: format!("Failed to build signing params: {}", e),
298 is_retriable: false,
299 aws_error: None,
300 })?;
301
302 let signable_request = SignableRequest::new(
303 method,
304 uri,
305 std::iter::empty::<(&str, &str)>(),
306 SignableBody::Bytes(body),
307 )
308 .map_err(|e| DurableError::Checkpoint {
309 message: format!("Failed to create signable request: {}", e),
310 is_retriable: false,
311 aws_error: None,
312 })?;
313
314 let (signing_instructions, _signature) = sign(signable_request, &signing_params.into())
315 .map_err(|e| DurableError::Checkpoint {
316 message: format!("Failed to sign request: {}", e),
317 is_retriable: false,
318 aws_error: None,
319 })?
320 .into_parts();
321
322 let mut temp_request = http::Request::builder()
324 .method(method)
325 .uri(uri)
326 .body(())
327 .map_err(|e| DurableError::Checkpoint {
328 message: format!("Failed to build temp request: {}", e),
329 is_retriable: false,
330 aws_error: None,
331 })?;
332
333 signing_instructions.apply_to_request_http1x(&mut temp_request);
334
335 let headers: Vec<(String, String)> = temp_request
337 .headers()
338 .iter()
339 .map(|(name, value)| (name.to_string(), value.to_str().unwrap_or("").to_string()))
340 .collect();
341
342 Ok(headers)
343 }
344}
345
346#[derive(Debug, Clone, Serialize)]
348struct CheckpointRequestBody {
349 #[serde(rename = "CheckpointToken")]
350 checkpoint_token: String,
351 #[serde(rename = "Updates")]
352 updates: Vec<OperationUpdate>,
353}
354
355#[derive(Debug, Clone, Serialize)]
357struct GetOperationsRequestBody {
358 #[serde(rename = "NextMarker")]
359 next_marker: String,
360}
361
362#[async_trait]
363impl DurableServiceClient for LambdaDurableServiceClient {
364 async fn checkpoint(
365 &self,
366 durable_execution_arn: &str,
367 checkpoint_token: &str,
368 operations: Vec<OperationUpdate>,
369 ) -> Result<CheckpointResponse, DurableError> {
370 let request_body = CheckpointRequestBody {
371 checkpoint_token: checkpoint_token.to_string(),
372 updates: operations,
373 };
374
375 let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
376 message: format!("Failed to serialize checkpoint request: {}", e),
377 })?;
378
379 let encoded_arn = urlencoding::encode(durable_execution_arn);
381 let uri = format!(
382 "{}/2025-12-01/durable-executions/{}/checkpoint",
383 self.endpoint_url(),
384 encoded_arn
385 );
386
387 let signed_headers = self.sign_request("POST", &uri, &body).await?;
389
390 let mut request = self
392 .http_client
393 .post(&uri)
394 .header("Content-Type", "application/json")
395 .body(body);
396
397 for (name, value) in signed_headers {
398 request = request.header(&name, &value);
399 }
400
401 let response = request.send().await.map_err(|e| DurableError::Checkpoint {
402 message: format!("HTTP request failed: {}", e),
403 is_retriable: e.is_timeout() || e.is_connect(),
404 aws_error: None,
405 })?;
406
407 let status = response.status();
408 let response_body = response
409 .bytes()
410 .await
411 .map_err(|e| DurableError::Checkpoint {
412 message: format!("Failed to read response body: {}", e),
413 is_retriable: true,
414 aws_error: None,
415 })?;
416
417 if !status.is_success() {
418 let error_message = String::from_utf8_lossy(&response_body);
419
420 if status.as_u16() == 413
425 || error_message.contains("RequestEntityTooLarge")
426 || error_message.contains("Payload size exceeded")
427 || error_message.contains("Request too large")
428 {
429 return Err(DurableError::SizeLimit {
430 message: format!("Checkpoint payload size exceeded: {}", error_message),
431 actual_size: None,
432 max_size: None,
433 });
434 }
435
436 if status.as_u16() == 404
439 || error_message.contains("ResourceNotFoundException")
440 || error_message.contains("not found")
441 {
442 return Err(DurableError::ResourceNotFound {
443 message: format!("Durable execution not found: {}", error_message),
444 resource_id: Some(durable_execution_arn.to_string()),
445 });
446 }
447
448 if status.as_u16() == 429
451 || error_message.contains("ThrottlingException")
452 || error_message.contains("TooManyRequestsException")
453 || error_message.contains("Rate exceeded")
454 {
455 return Err(DurableError::Throttling {
456 message: format!("Rate limit exceeded: {}", error_message),
457 retry_after_ms: None, });
459 }
460
461 let is_invalid_token = error_message.contains("Invalid checkpoint token");
465
466 let is_retriable = status.is_server_error() || is_invalid_token;
471
472 return Err(DurableError::Checkpoint {
473 message: format!("Checkpoint API returned {}: {}", status, error_message),
474 is_retriable,
475 aws_error: Some(AwsError {
476 code: if is_invalid_token {
477 "InvalidParameterValueException".to_string()
478 } else {
479 status.to_string()
480 },
481 message: error_message.to_string(),
482 request_id: None,
483 }),
484 });
485 }
486
487 let checkpoint_response: CheckpointResponse = serde_json::from_slice(&response_body)
488 .map_err(|e| DurableError::SerDes {
489 message: format!("Failed to deserialize checkpoint response: {}", e),
490 })?;
491
492 Ok(checkpoint_response)
493 }
494
495 async fn get_operations(
496 &self,
497 durable_execution_arn: &str,
498 next_marker: &str,
499 ) -> Result<GetOperationsResponse, DurableError> {
500 let request_body = GetOperationsRequestBody {
501 next_marker: next_marker.to_string(),
502 };
503
504 let body = serde_json::to_vec(&request_body).map_err(|e| DurableError::SerDes {
505 message: format!("Failed to serialize get_operations request: {}", e),
506 })?;
507
508 let encoded_arn = urlencoding::encode(durable_execution_arn);
510 let uri = format!(
511 "{}/2025-12-01/durable-executions/{}/state",
512 self.endpoint_url(),
513 encoded_arn
514 );
515
516 let signed_headers = self.sign_request("POST", &uri, &body).await?;
518
519 let mut request = self
521 .http_client
522 .post(&uri)
523 .header("Content-Type", "application/json")
524 .body(body);
525
526 for (name, value) in signed_headers {
527 request = request.header(&name, &value);
528 }
529
530 let response = request.send().await.map_err(|e| DurableError::Invocation {
531 message: format!("HTTP request failed: {}", e),
532 termination_reason: crate::error::TerminationReason::InvocationError,
533 })?;
534
535 let status = response.status();
536 let response_body = response
537 .bytes()
538 .await
539 .map_err(|e| DurableError::Invocation {
540 message: format!("Failed to read response body: {}", e),
541 termination_reason: crate::error::TerminationReason::InvocationError,
542 })?;
543
544 if !status.is_success() {
545 let error_message = String::from_utf8_lossy(&response_body);
546
547 if status.as_u16() == 404
550 || error_message.contains("ResourceNotFoundException")
551 || error_message.contains("not found")
552 {
553 return Err(DurableError::ResourceNotFound {
554 message: format!("Durable execution not found: {}", error_message),
555 resource_id: Some(durable_execution_arn.to_string()),
556 });
557 }
558
559 if status.as_u16() == 429
562 || error_message.contains("ThrottlingException")
563 || error_message.contains("TooManyRequestsException")
564 || error_message.contains("Rate exceeded")
565 {
566 return Err(DurableError::Throttling {
567 message: format!("Rate limit exceeded: {}", error_message),
568 retry_after_ms: None,
569 });
570 }
571
572 return Err(DurableError::Invocation {
573 message: format!(
574 "GetDurableExecutionState API returned {}: {}",
575 status, error_message
576 ),
577 termination_reason: crate::error::TerminationReason::InvocationError,
578 });
579 }
580
581 let operations_response: GetOperationsResponse = serde_json::from_slice(&response_body)
582 .map_err(|e| DurableError::SerDes {
583 message: format!("Failed to deserialize get_operations response: {}", e),
584 })?;
585
586 Ok(operations_response)
587 }
588}
589
590#[cfg(test)]
592pub struct MockDurableServiceClient {
593 checkpoint_responses: std::sync::Mutex<Vec<Result<CheckpointResponse, DurableError>>>,
594 get_operations_responses: std::sync::Mutex<Vec<Result<GetOperationsResponse, DurableError>>>,
595}
596
597#[cfg(test)]
598impl MockDurableServiceClient {
599 pub fn new() -> Self {
600 Self {
601 checkpoint_responses: std::sync::Mutex::new(Vec::new()),
602 get_operations_responses: std::sync::Mutex::new(Vec::new()),
603 }
604 }
605
606 pub fn with_checkpoint_response(
607 self,
608 response: Result<CheckpointResponse, DurableError>,
609 ) -> Self {
610 self.checkpoint_responses.lock().unwrap().push(response);
611 self
612 }
613
614 pub fn with_get_operations_response(
615 self,
616 response: Result<GetOperationsResponse, DurableError>,
617 ) -> Self {
618 self.get_operations_responses.lock().unwrap().push(response);
619 self
620 }
621
622 pub fn with_checkpoint_responses(self, count: usize) -> Self {
624 let mut responses = self.checkpoint_responses.lock().unwrap();
625 for i in 0..count {
626 responses.push(Ok(CheckpointResponse {
627 checkpoint_token: format!("token-{}", i),
628 new_execution_state: None,
629 }));
630 }
631 drop(responses);
632 self
633 }
634}
635
636#[cfg(test)]
637#[async_trait]
638impl DurableServiceClient for MockDurableServiceClient {
639 async fn checkpoint(
640 &self,
641 _durable_execution_arn: &str,
642 _checkpoint_token: &str,
643 _operations: Vec<OperationUpdate>,
644 ) -> Result<CheckpointResponse, DurableError> {
645 let mut responses = self.checkpoint_responses.lock().unwrap();
646 if responses.is_empty() {
647 Ok(CheckpointResponse {
648 checkpoint_token: "mock-token".to_string(),
649 new_execution_state: None,
650 })
651 } else {
652 responses.remove(0)
653 }
654 }
655
656 async fn get_operations(
657 &self,
658 _durable_execution_arn: &str,
659 _next_marker: &str,
660 ) -> Result<GetOperationsResponse, DurableError> {
661 let mut responses = self.get_operations_responses.lock().unwrap();
662 if responses.is_empty() {
663 Ok(GetOperationsResponse {
664 operations: Vec::new(),
665 next_marker: None,
666 })
667 } else {
668 responses.remove(0)
669 }
670 }
671}
672
673pub type SharedDurableServiceClient = Arc<dyn DurableServiceClient>;
675
676#[cfg(test)]
677mod tests {
678 use super::*;
679 use crate::operation::OperationType;
680
681 #[test]
682 fn test_checkpoint_response_serialization() {
683 let response = CheckpointResponse {
684 checkpoint_token: "token-123".to_string(),
685 new_execution_state: None,
686 };
687 let json = serde_json::to_string(&response).unwrap();
688 assert!(json.contains(r#""CheckpointToken":"token-123""#));
689 }
690
691 #[test]
692 fn test_checkpoint_response_deserialization() {
693 let json = r#"{"CheckpointToken": "token-456"}"#;
694 let response: CheckpointResponse = serde_json::from_str(json).unwrap();
695 assert_eq!(response.checkpoint_token, "token-456");
696 }
697
698 #[test]
699 fn test_get_operations_response_serialization() {
700 let response = GetOperationsResponse {
701 operations: vec![Operation::new("op-1", OperationType::Step)],
702 next_marker: Some("marker-123".to_string()),
703 };
704 let json = serde_json::to_string(&response).unwrap();
705 assert!(json.contains(r#""Operations""#));
706 assert!(json.contains(r#""NextMarker":"marker-123""#));
707 }
708
709 #[test]
710 fn test_get_operations_response_deserialization() {
711 let json = r#"{
712 "Operations": [
713 {
714 "Id": "op-1",
715 "Type": "STEP",
716 "Status": "SUCCEEDED"
717 }
718 ],
719 "NextMarker": "marker-456"
720 }"#;
721 let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
722 assert_eq!(response.operations.len(), 1);
723 assert_eq!(response.operations[0].operation_id, "op-1");
724 assert_eq!(response.next_marker, Some("marker-456".to_string()));
725 }
726
727 #[test]
728 fn test_get_operations_response_without_marker() {
729 let json = r#"{
730 "Operations": []
731 }"#;
732 let response: GetOperationsResponse = serde_json::from_str(json).unwrap();
733 assert!(response.operations.is_empty());
734 assert!(response.next_marker.is_none());
735 }
736
737 #[test]
738 fn test_lambda_client_config_default() {
739 let config = LambdaClientConfig::default();
740 assert_eq!(config.region, "us-east-1");
741 assert!(config.endpoint_url.is_none());
742 }
743
744 #[test]
745 fn test_lambda_client_config_with_region() {
746 let config = LambdaClientConfig::with_region("us-west-2");
747 assert_eq!(config.region, "us-west-2");
748 }
749
750 #[tokio::test]
751 async fn test_mock_client_checkpoint() {
752 let client = MockDurableServiceClient::new();
753 let result = client
754 .checkpoint(
755 "arn:aws:lambda:us-east-1:123456789012:function:test",
756 "token-123",
757 vec![],
758 )
759 .await;
760 assert!(result.is_ok());
761 assert_eq!(result.unwrap().checkpoint_token, "mock-token");
762 }
763
764 #[tokio::test]
765 async fn test_mock_client_checkpoint_with_custom_response() {
766 let client =
767 MockDurableServiceClient::new().with_checkpoint_response(Ok(CheckpointResponse {
768 checkpoint_token: "custom-token".to_string(),
769 new_execution_state: None,
770 }));
771 let result = client
772 .checkpoint(
773 "arn:aws:lambda:us-east-1:123456789012:function:test",
774 "token-123",
775 vec![],
776 )
777 .await;
778 assert!(result.is_ok());
779 assert_eq!(result.unwrap().checkpoint_token, "custom-token");
780 }
781
782 #[tokio::test]
783 async fn test_mock_client_checkpoint_with_error() {
784 let client = MockDurableServiceClient::new()
785 .with_checkpoint_response(Err(DurableError::checkpoint_retriable("Test error")));
786 let result = client
787 .checkpoint(
788 "arn:aws:lambda:us-east-1:123456789012:function:test",
789 "token-123",
790 vec![],
791 )
792 .await;
793 assert!(result.is_err());
794 assert!(result.unwrap_err().is_retriable());
795 }
796
797 #[tokio::test]
798 async fn test_mock_client_get_operations() {
799 let client = MockDurableServiceClient::new();
800 let result = client
801 .get_operations(
802 "arn:aws:lambda:us-east-1:123456789012:function:test",
803 "marker-123",
804 )
805 .await;
806 assert!(result.is_ok());
807 let response = result.unwrap();
808 assert!(response.operations.is_empty());
809 assert!(response.next_marker.is_none());
810 }
811
812 #[tokio::test]
813 async fn test_mock_client_get_operations_with_custom_response() {
814 let client = MockDurableServiceClient::new().with_get_operations_response(Ok(
815 GetOperationsResponse {
816 operations: vec![Operation::new("op-1", OperationType::Step)],
817 next_marker: Some("next-marker".to_string()),
818 },
819 ));
820 let result = client
821 .get_operations(
822 "arn:aws:lambda:us-east-1:123456789012:function:test",
823 "marker-123",
824 )
825 .await;
826 assert!(result.is_ok());
827 let response = result.unwrap();
828 assert_eq!(response.operations.len(), 1);
829 assert_eq!(response.next_marker, Some("next-marker".to_string()));
830 }
831
832 #[test]
833 fn test_checkpoint_request_body_serialization() {
834 let request = CheckpointRequestBody {
835 checkpoint_token: "token-123".to_string(),
836 updates: vec![OperationUpdate::start("op-1", OperationType::Step)],
837 };
838 let json = serde_json::to_string(&request).unwrap();
839 assert!(json.contains(r#""CheckpointToken":"token-123""#));
840 assert!(json.contains(r#""Updates""#));
841 }
842
843 #[test]
844 fn test_get_operations_request_body_serialization() {
845 let request = GetOperationsRequestBody {
846 next_marker: "marker-123".to_string(),
847 };
848 let json = serde_json::to_string(&request).unwrap();
849 assert!(json.contains(r#""NextMarker":"marker-123""#));
850 }
851
852 #[tokio::test]
853 async fn test_mock_client_checkpoint_with_invalid_token_error() {
854 let error = DurableError::Checkpoint {
856 message: "Checkpoint API returned 400: Invalid checkpoint token".to_string(),
857 is_retriable: true,
858 aws_error: Some(AwsError {
859 code: "InvalidParameterValueException".to_string(),
860 message: "Invalid checkpoint token: token has been consumed".to_string(),
861 request_id: None,
862 }),
863 };
864
865 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
866 let result = client
867 .checkpoint(
868 "arn:aws:lambda:us-east-1:123456789012:function:test",
869 "consumed-token",
870 vec![],
871 )
872 .await;
873
874 assert!(result.is_err());
875 let err = result.unwrap_err();
876 assert!(err.is_retriable());
877 assert!(err.is_invalid_checkpoint_token());
878 }
879
880 #[tokio::test]
881 async fn test_mock_client_checkpoint_with_size_limit_error() {
882 let error = DurableError::SizeLimit {
884 message: "Checkpoint payload size exceeded".to_string(),
885 actual_size: Some(7_000_000),
886 max_size: Some(6_000_000),
887 };
888
889 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
890 let result = client
891 .checkpoint(
892 "arn:aws:lambda:us-east-1:123456789012:function:test",
893 "token-123",
894 vec![],
895 )
896 .await;
897
898 assert!(result.is_err());
899 let err = result.unwrap_err();
900 assert!(err.is_size_limit());
901 assert!(!err.is_retriable());
902 }
903
904 #[tokio::test]
905 async fn test_mock_client_checkpoint_with_throttling_error() {
906 let error = DurableError::Throttling {
908 message: "Rate limit exceeded".to_string(),
909 retry_after_ms: Some(5000),
910 };
911
912 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
913 let result = client
914 .checkpoint(
915 "arn:aws:lambda:us-east-1:123456789012:function:test",
916 "token-123",
917 vec![],
918 )
919 .await;
920
921 assert!(result.is_err());
922 let err = result.unwrap_err();
923 assert!(err.is_throttling());
924 assert_eq!(err.get_retry_after_ms(), Some(5000));
925 }
926
927 #[tokio::test]
928 async fn test_mock_client_checkpoint_with_resource_not_found_error() {
929 let error = DurableError::ResourceNotFound {
931 message: "Durable execution not found".to_string(),
932 resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
933 };
934
935 let client = MockDurableServiceClient::new().with_checkpoint_response(Err(error));
936 let result = client
937 .checkpoint(
938 "arn:aws:lambda:us-east-1:123456789012:function:test",
939 "token-123",
940 vec![],
941 )
942 .await;
943
944 assert!(result.is_err());
945 let err = result.unwrap_err();
946 assert!(err.is_resource_not_found());
947 assert!(!err.is_retriable());
948 }
949
950 #[tokio::test]
951 async fn test_mock_client_get_operations_with_throttling_error() {
952 let error = DurableError::Throttling {
954 message: "Rate limit exceeded".to_string(),
955 retry_after_ms: None,
956 };
957
958 let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
959 let result = client
960 .get_operations(
961 "arn:aws:lambda:us-east-1:123456789012:function:test",
962 "marker-123",
963 )
964 .await;
965
966 assert!(result.is_err());
967 let err = result.unwrap_err();
968 assert!(err.is_throttling());
969 }
970
971 #[tokio::test]
972 async fn test_mock_client_get_operations_with_resource_not_found_error() {
973 let error = DurableError::ResourceNotFound {
975 message: "Durable execution not found".to_string(),
976 resource_id: Some("arn:aws:lambda:us-east-1:123456789012:function:test".to_string()),
977 };
978
979 let client = MockDurableServiceClient::new().with_get_operations_response(Err(error));
980 let result = client
981 .get_operations(
982 "arn:aws:lambda:us-east-1:123456789012:function:test",
983 "marker-123",
984 )
985 .await;
986
987 assert!(result.is_err());
988 let err = result.unwrap_err();
989 assert!(err.is_resource_not_found());
990 }
991
992 fn test_sdk_config() -> aws_config::SdkConfig {
994 let creds =
995 aws_credential_types::Credentials::new("test-key", "test-secret", None, None, "test");
996 let provider = aws_credential_types::provider::SharedCredentialsProvider::new(creds);
997 aws_config::SdkConfig::builder()
998 .credentials_provider(provider)
999 .region(aws_config::Region::new("us-east-1"))
1000 .build()
1001 }
1002
1003 #[test]
1004 fn test_from_aws_config_with_user_agent_constructs_client() {
1005 let config = test_sdk_config();
1007 let client =
1008 LambdaDurableServiceClient::from_aws_config_with_user_agent(&config, "my-sdk", "1.0.0");
1009 assert_eq!(client.config.region, "us-east-1");
1011 }
1012
1013 #[test]
1014 fn test_from_aws_config_unchanged() {
1015 let config = test_sdk_config();
1017 let client = LambdaDurableServiceClient::from_aws_config(&config);
1018 assert_eq!(client.config.region, "us-east-1");
1019 }
1020
1021 #[tokio::test]
1022 async fn test_user_agent_header_is_sent() {
1023 use tokio::io::{AsyncReadExt, AsyncWriteExt};
1025 use tokio::net::TcpListener;
1026
1027 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1029 let addr = listener.local_addr().unwrap();
1030
1031 let server_handle = tokio::spawn(async move {
1032 let (mut socket, _) = listener.accept().await.unwrap();
1033 let mut buf = vec![0u8; 4096];
1034 let n = socket.read(&mut buf).await.unwrap();
1035 let request = String::from_utf8_lossy(&buf[..n]).to_string();
1036
1037 let response = "HTTP/1.1 400 Bad Request\r\nContent-Length: 2\r\n\r\n{}";
1039 let _ = socket.write_all(response.as_bytes()).await;
1040 let _ = socket.shutdown().await;
1041
1042 request
1043 });
1044
1045 let sdk_name = "test-sdk";
1047 let sdk_version = "2.3.4";
1048 let user_agent = format!("{}/{}", sdk_name, sdk_version);
1049
1050 let config = test_sdk_config();
1051 let mut client = LambdaDurableServiceClient::from_aws_config_with_user_agent(
1052 &config,
1053 sdk_name,
1054 sdk_version,
1055 );
1056 client.config.endpoint_url = Some(format!("http://{}", addr));
1057
1058 let _ = client
1060 .checkpoint("arn:test:execution", "token", vec![])
1061 .await;
1062
1063 let captured_request = server_handle.await.unwrap();
1065 assert!(
1066 captured_request.contains(&user_agent),
1067 "Expected User-Agent '{}' in request headers, got:\n{}",
1068 user_agent,
1069 captured_request
1070 );
1071 }
1072}