1use crate::client::{QueueProvider, SessionProvider};
216use crate::error::{ConfigurationError, QueueError, SerializationError};
217use crate::message::{
218 Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
219};
220use crate::provider::{AwsSqsConfig, ProviderType, SessionSupport};
221use async_trait::async_trait;
222use chrono::{DateTime, Duration, Utc};
223use hmac::{Hmac, Mac};
224use reqwest::Client as HttpClient;
225use sha2::{Digest, Sha256};
226use std::collections::HashMap;
227use std::fmt;
228use std::sync::Arc;
229use tokio::sync::RwLock;
230
231#[cfg(test)]
232#[path = "aws_tests.rs"]
233mod tests;
234
235#[derive(Debug, thiserror::Error)]
241pub enum AwsError {
242 #[error("Authentication failed: {0}")]
243 Authentication(String),
244
245 #[error("Network error: {0}")]
246 NetworkError(String),
247
248 #[error("SQS service error: {0}")]
249 ServiceError(String),
250
251 #[error("Queue not found: {0}")]
252 QueueNotFound(String),
253
254 #[error("Invalid receipt handle: {0}")]
255 InvalidReceipt(String),
256
257 #[error("Message too large: {size} bytes (max: {max_size})")]
258 MessageTooLarge { size: usize, max_size: usize },
259
260 #[error("Invalid configuration: {0}")]
261 ConfigurationError(String),
262
263 #[error("Serialization error: {0}")]
264 SerializationError(String),
265
266 #[error("Sessions not supported on standard queues")]
267 SessionsNotSupported,
268}
269
270impl AwsError {
271 pub fn is_transient(&self) -> bool {
273 match self {
274 Self::Authentication(_) => false,
275 Self::NetworkError(_) => true,
276 Self::ServiceError(_) => true, Self::QueueNotFound(_) => false,
278 Self::InvalidReceipt(_) => false,
279 Self::MessageTooLarge { .. } => false,
280 Self::ConfigurationError(_) => false,
281 Self::SerializationError(_) => false,
282 Self::SessionsNotSupported => false,
283 }
284 }
285
286 pub fn to_queue_error(self) -> QueueError {
288 match self {
289 Self::Authentication(msg) => QueueError::AuthenticationFailed { message: msg },
290 Self::NetworkError(msg) => QueueError::ConnectionFailed { message: msg },
291 Self::ServiceError(msg) => QueueError::ProviderError {
292 provider: "AwsSqs".to_string(),
293 code: "ServiceError".to_string(),
294 message: msg,
295 },
296 Self::QueueNotFound(queue) => QueueError::QueueNotFound { queue_name: queue },
297 Self::InvalidReceipt(receipt) => QueueError::MessageNotFound { receipt },
298 Self::MessageTooLarge { size, max_size } => {
299 QueueError::MessageTooLarge { size, max_size }
300 }
301 Self::ConfigurationError(msg) => {
302 QueueError::ConfigurationError(ConfigurationError::Invalid { message: msg })
303 }
304 Self::SerializationError(msg) => QueueError::SerializationError(
305 SerializationError::JsonError(serde_json::Error::io(std::io::Error::new(
306 std::io::ErrorKind::InvalidData,
307 msg,
308 ))),
309 ),
310 Self::SessionsNotSupported => QueueError::ProviderError {
311 provider: "AwsSqs".to_string(),
312 code: "SessionsNotSupported".to_string(),
313 message:
314 "Standard queues do not support session-based operations. Use FIFO queues."
315 .to_string(),
316 },
317 }
318 }
319}
320
321type HmacSha256 = Hmac<Sha256>;
326
327#[derive(Clone)]
340struct AwsV4Signer {
341 access_key: String,
342 secret_key: String,
343 region: String,
344 service: String,
345}
346
347impl AwsV4Signer {
348 fn new(access_key: String, secret_key: String, region: String) -> Self {
356 Self {
357 access_key,
358 secret_key,
359 region,
360 service: "sqs".to_string(),
361 }
362 }
363
364 fn sign_request(
380 &self,
381 method: &str,
382 host: &str,
383 path: &str,
384 query_params: &HashMap<String, String>,
385 body: &str,
386 timestamp: &DateTime<Utc>,
387 ) -> HashMap<String, String> {
388 let date_stamp = timestamp.format("%Y%m%d").to_string();
389 let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
390
391 let canonical_uri = path;
393
394 let mut canonical_query_string = query_params
396 .iter()
397 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
398 .collect::<Vec<_>>();
399 canonical_query_string.sort();
400 let canonical_query_string = canonical_query_string.join("&");
401
402 let canonical_headers = format!("host:{}\nx-amz-date:{}\n", host, amz_date);
404 let signed_headers = "host;x-amz-date";
405
406 let payload_hash = format!("{:x}", Sha256::digest(body.as_bytes()));
408
409 let canonical_request = format!(
411 "{}\n{}\n{}\n{}\n{}\n{}",
412 method,
413 canonical_uri,
414 canonical_query_string,
415 canonical_headers,
416 signed_headers,
417 payload_hash
418 );
419
420 let algorithm = "AWS4-HMAC-SHA256";
422 let credential_scope = format!(
423 "{}/{}/{}/aws4_request",
424 date_stamp, self.region, self.service
425 );
426 let canonical_request_hash = format!("{:x}", Sha256::digest(canonical_request.as_bytes()));
427
428 let string_to_sign = format!(
429 "{}\n{}\n{}\n{}",
430 algorithm, amz_date, credential_scope, canonical_request_hash
431 );
432
433 let signature = self.calculate_signature(&string_to_sign, &date_stamp);
435
436 let authorization_header = format!(
438 "{} Credential={}/{}, SignedHeaders={}, Signature={}",
439 algorithm, self.access_key, credential_scope, signed_headers, signature
440 );
441
442 let mut headers = HashMap::new();
443 headers.insert("Authorization".to_string(), authorization_header);
444 headers.insert("x-amz-date".to_string(), amz_date);
445 headers.insert("host".to_string(), host.to_string());
446
447 headers
448 }
449
450 fn calculate_signature(&self, string_to_sign: &str, date_stamp: &str) -> String {
460 let k_secret = format!("AWS4{}", self.secret_key);
461 let k_date = self.hmac_sha256(k_secret.as_bytes(), date_stamp.as_bytes());
462 let k_region = self.hmac_sha256(&k_date, self.region.as_bytes());
463 let k_service = self.hmac_sha256(&k_region, self.service.as_bytes());
464 let k_signing = self.hmac_sha256(&k_service, b"aws4_request");
465 let signature = self.hmac_sha256(&k_signing, string_to_sign.as_bytes());
466
467 hex::encode(signature)
468 }
469
470 fn hmac_sha256(&self, key: &[u8], data: &[u8]) -> Vec<u8> {
472 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
473 mac.update(data);
474 mac.finalize().into_bytes().to_vec()
475 }
476}
477
478#[derive(Debug, Clone)]
484struct AwsCredentials {
485 access_key_id: String,
486 secret_access_key: String,
487 session_token: Option<String>,
488 expiration: DateTime<Utc>,
489}
490
491impl AwsCredentials {
492 fn is_expired(&self) -> bool {
494 let now = Utc::now();
495 let buffer = Duration::minutes(5);
496 self.expiration - buffer <= now
497 }
498}
499
500struct AwsCredentialProvider {
510 http_client: HttpClient,
511 cached_credentials: Arc<RwLock<Option<AwsCredentials>>>,
512 explicit_config: Option<(String, String)>, }
514
515impl AwsCredentialProvider {
516 fn new(
518 http_client: HttpClient,
519 access_key_id: Option<String>,
520 secret_access_key: Option<String>,
521 ) -> Self {
522 let explicit_config = match (access_key_id, secret_access_key) {
523 (Some(key_id), Some(secret)) => Some((key_id, secret)),
524 _ => None,
525 };
526
527 Self {
528 http_client,
529 cached_credentials: Arc::new(RwLock::new(None)),
530 explicit_config,
531 }
532 }
533
534 async fn get_credentials(&self) -> Result<AwsCredentials, AwsError> {
536 {
538 let cache = self.cached_credentials.read().await;
539 if let Some(creds) = cache.as_ref() {
540 if !creds.is_expired() {
541 return Ok(creds.clone());
542 }
543 }
544 }
545
546 let creds = self.fetch_credentials().await?;
548
549 {
551 let mut cache = self.cached_credentials.write().await;
552 *cache = Some(creds.clone());
553 }
554
555 Ok(creds)
556 }
557
558 async fn fetch_credentials(&self) -> Result<AwsCredentials, AwsError> {
560 if let Some((key_id, secret)) = &self.explicit_config {
562 return Ok(AwsCredentials {
563 access_key_id: key_id.clone(),
564 secret_access_key: secret.clone(),
565 session_token: None,
566 expiration: Utc::now() + Duration::days(365), });
568 }
569
570 if let Ok(creds) = self.fetch_from_environment() {
572 return Ok(creds);
573 }
574
575 if let Ok(creds) = self.fetch_from_ecs_metadata().await {
577 return Ok(creds);
578 }
579
580 if let Ok(creds) = self.fetch_from_ec2_metadata().await {
582 return Ok(creds);
583 }
584
585 Err(AwsError::Authentication(
586 "No credentials found in credential chain".to_string(),
587 ))
588 }
589
590 fn fetch_from_environment(&self) -> Result<AwsCredentials, AwsError> {
592 let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
593 .map_err(|_| AwsError::Authentication("AWS_ACCESS_KEY_ID not set".to_string()))?;
594
595 let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
596 .map_err(|_| AwsError::Authentication("AWS_SECRET_ACCESS_KEY not set".to_string()))?;
597
598 let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
599
600 Ok(AwsCredentials {
601 access_key_id,
602 secret_access_key,
603 session_token,
604 expiration: Utc::now() + Duration::days(365), })
606 }
607
608 async fn fetch_from_ecs_metadata(&self) -> Result<AwsCredentials, AwsError> {
610 let relative_uri =
611 std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").map_err(|_| {
612 AwsError::Authentication(
613 "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI not set".to_string(),
614 )
615 })?;
616
617 let endpoint = format!("http://169.254.170.2{}", relative_uri);
618
619 let response = self
620 .http_client
621 .get(&endpoint)
622 .timeout(std::time::Duration::from_secs(2))
623 .send()
624 .await
625 .map_err(|e| {
626 AwsError::Authentication(format!("Failed to fetch ECS credentials: {}", e))
627 })?;
628
629 if !response.status().is_success() {
630 return Err(AwsError::Authentication(format!(
631 "ECS metadata returned error: {}",
632 response.status()
633 )));
634 }
635
636 let body = response
637 .text()
638 .await
639 .map_err(|e| AwsError::Authentication(format!("Failed to read ECS metadata: {}", e)))?;
640
641 self.parse_credentials_json(&body)
642 }
643
644 async fn fetch_from_ec2_metadata(&self) -> Result<AwsCredentials, AwsError> {
646 let token = self
648 .http_client
649 .put("http://169.254.169.254/latest/api/token")
650 .header("X-aws-ec2-metadata-token-ttl-seconds", "21600")
651 .timeout(std::time::Duration::from_secs(2))
652 .send()
653 .await
654 .map_err(|e| AwsError::Authentication(format!("Failed to get IMDSv2 token: {}", e)))?
655 .text()
656 .await
657 .map_err(|e| AwsError::Authentication(format!("Failed to read IMDSv2 token: {}", e)))?;
658
659 let role_name = self
661 .http_client
662 .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
663 .header("X-aws-ec2-metadata-token", &token)
664 .timeout(std::time::Duration::from_secs(2))
665 .send()
666 .await
667 .map_err(|e| AwsError::Authentication(format!("Failed to fetch IAM role name: {}", e)))?
668 .text()
669 .await
670 .map_err(|e| {
671 AwsError::Authentication(format!("Failed to read IAM role name: {}", e))
672 })?;
673
674 let credentials_url = format!(
676 "http://169.254.169.254/latest/meta-data/iam/security-credentials/{}",
677 role_name.trim()
678 );
679
680 let response = self
681 .http_client
682 .get(&credentials_url)
683 .header("X-aws-ec2-metadata-token", &token)
684 .timeout(std::time::Duration::from_secs(2))
685 .send()
686 .await
687 .map_err(|e| {
688 AwsError::Authentication(format!("Failed to fetch EC2 credentials: {}", e))
689 })?;
690
691 if !response.status().is_success() {
692 return Err(AwsError::Authentication(format!(
693 "EC2 metadata returned error: {}",
694 response.status()
695 )));
696 }
697
698 let body = response
699 .text()
700 .await
701 .map_err(|e| AwsError::Authentication(format!("Failed to read EC2 metadata: {}", e)))?;
702
703 self.parse_credentials_json(&body)
704 }
705
706 fn parse_credentials_json(&self, json: &str) -> Result<AwsCredentials, AwsError> {
708 let access_key_id = Self::extract_json_field(json, "AccessKeyId")?;
710 let secret_access_key = Self::extract_json_field(json, "SecretAccessKey")?;
711 let session_token = Self::extract_json_field(json, "Token").ok();
712 let expiration_str = Self::extract_json_field(json, "Expiration")?;
713
714 let expiration = DateTime::parse_from_rfc3339(&expiration_str)
716 .map_err(|e| AwsError::Authentication(format!("Invalid expiration timestamp: {}", e)))?
717 .with_timezone(&Utc);
718
719 Ok(AwsCredentials {
720 access_key_id,
721 secret_access_key,
722 session_token,
723 expiration,
724 })
725 }
726
727 fn extract_json_field(json: &str, field: &str) -> Result<String, AwsError> {
729 let pattern = format!("\"{}\": \"", field);
730 let start = json.find(&pattern).ok_or_else(|| {
731 AwsError::Authentication(format!("Field '{}' not found in JSON", field))
732 })?;
733
734 let value_start = start + pattern.len();
735 let value_end = json[value_start..].find('"').ok_or_else(|| {
736 AwsError::Authentication(format!("Malformed JSON for field '{}'", field))
737 })? + value_start;
738
739 Ok(json[value_start..value_end].to_string())
740 }
741}
742
743pub struct AwsSqsProvider {
763 http_client: HttpClient,
764 credential_provider: AwsCredentialProvider,
765 config: AwsSqsConfig,
766 endpoint: String,
767 queue_url_cache: Arc<RwLock<HashMap<QueueName, String>>>,
768}
769
770impl AwsSqsProvider {
771 pub async fn new(config: AwsSqsConfig) -> Result<Self, AwsError> {
803 if config.region.is_empty() {
805 return Err(AwsError::ConfigurationError(
806 "Region cannot be empty".to_string(),
807 ));
808 }
809
810 let endpoint = format!("https://sqs.{}.amazonaws.com", config.region);
812
813 let http_client = HttpClient::builder()
815 .timeout(std::time::Duration::from_secs(30))
816 .build()
817 .map_err(|e| AwsError::NetworkError(format!("Failed to create HTTP client: {}", e)))?;
818
819 let credential_provider = AwsCredentialProvider::new(
821 http_client.clone(),
822 config.access_key_id.clone(),
823 config.secret_access_key.clone(),
824 );
825
826 Ok(Self {
827 http_client,
828 credential_provider,
829 config,
830 endpoint,
831 queue_url_cache: Arc::new(RwLock::new(HashMap::new())),
832 })
833 }
834
835 async fn get_queue_url(&self, queue_name: &QueueName) -> Result<String, AwsError> {
845 {
847 let cache = self.queue_url_cache.read().await;
848 if let Some(url) = cache.get(queue_name) {
849 return Ok(url.clone());
850 }
851 }
852
853 let mut params = HashMap::new();
855 params.insert("Action".to_string(), "GetQueueUrl".to_string());
856 params.insert("QueueName".to_string(), queue_name.as_str().to_string());
857 params.insert("Version".to_string(), "2012-11-05".to_string());
858
859 let response = self.make_request("POST", "/", ¶ms, "").await?;
861
862 let queue_url = self.parse_queue_url_response(&response)?;
864
865 let mut cache = self.queue_url_cache.write().await;
867 cache.insert(queue_name.clone(), queue_url.clone());
868
869 Ok(queue_url)
870 }
871
872 async fn make_request(
874 &self,
875 method: &str,
876 path: &str,
877 query_params: &HashMap<String, String>,
878 body: &str,
879 ) -> Result<String, AwsError> {
880 let credentials = self.credential_provider.get_credentials().await?;
882
883 let signer = AwsV4Signer::new(
885 credentials.access_key_id.clone(),
886 credentials.secret_access_key.clone(),
887 self.config.region.clone(),
888 );
889
890 let host = self
892 .endpoint
893 .strip_prefix("https://")
894 .unwrap_or(&self.endpoint);
895
896 let timestamp = Utc::now();
898
899 let mut auth_headers =
901 signer.sign_request(method, host, path, query_params, body, ×tamp);
902
903 if let Some(session_token) = &credentials.session_token {
905 auth_headers.insert("X-Amz-Security-Token".to_string(), session_token.clone());
906 }
907
908 let mut url = format!("{}{}", self.endpoint, path);
910 if !query_params.is_empty() {
911 let query_string = query_params
912 .iter()
913 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
914 .collect::<Vec<_>>()
915 .join("&");
916 url = format!("{}?{}", url, query_string);
917 }
918
919 let mut request = self.http_client.request(
921 method
922 .parse()
923 .map_err(|e| AwsError::ConfigurationError(format!("Invalid HTTP method: {}", e)))?,
924 &url,
925 );
926
927 for (key, value) in auth_headers {
929 request = request.header(&key, value);
930 }
931
932 if !body.is_empty() {
934 request = request.body(body.to_string());
935 }
936
937 let response = request.send().await.map_err(|e| {
939 if e.is_timeout() {
940 AwsError::NetworkError(format!("Request timeout: {}", e))
941 } else if e.is_connect() {
942 AwsError::NetworkError(format!("Connection failed: {}", e))
943 } else {
944 AwsError::NetworkError(format!("HTTP request failed: {}", e))
945 }
946 })?;
947
948 let status = response.status();
950 let response_body = response
951 .text()
952 .await
953 .map_err(|e| AwsError::NetworkError(format!("Failed to read response body: {}", e)))?;
954
955 if !status.is_success() {
956 return Err(self.parse_error_response(&response_body, status.as_u16()));
958 }
959
960 Ok(response_body)
961 }
962
963 fn parse_queue_url_response(&self, xml: &str) -> Result<String, AwsError> {
965 use quick_xml::events::Event;
966 use quick_xml::Reader;
967
968 let mut reader = Reader::from_str(xml);
969 reader.trim_text(true);
970
971 let mut in_queue_url = false;
972 let mut buf = Vec::new();
973
974 loop {
975 match reader.read_event_into(&mut buf) {
976 Ok(Event::Start(ref e)) if e.name().as_ref() == b"QueueUrl" => {
977 in_queue_url = true;
978 }
979 Ok(Event::Text(e)) if in_queue_url => {
980 return e.unescape().map(|s| s.into_owned()).map_err(|e| {
981 AwsError::SerializationError(format!("Failed to parse XML: {}", e))
982 });
983 }
984 Ok(Event::Eof) => break,
985 Err(e) => {
986 return Err(AwsError::SerializationError(format!(
987 "XML parsing error: {}",
988 e
989 )))
990 }
991 _ => {}
992 }
993 buf.clear();
994 }
995
996 Err(AwsError::SerializationError(
997 "QueueUrl not found in response".to_string(),
998 ))
999 }
1000
1001 fn parse_error_response(&self, xml: &str, status_code: u16) -> AwsError {
1003 use quick_xml::events::Event;
1004 use quick_xml::Reader;
1005
1006 let mut reader = Reader::from_str(xml);
1007 reader.trim_text(true);
1008
1009 let mut error_code = None;
1010 let mut error_message = None;
1011 let mut in_error = false;
1012 let mut in_code = false;
1013 let mut in_message = false;
1014 let mut buf = Vec::new();
1015
1016 loop {
1017 match reader.read_event_into(&mut buf) {
1018 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1019 b"Error" => in_error = true,
1020 b"Code" if in_error => in_code = true,
1021 b"Message" if in_error => in_message = true,
1022 _ => {}
1023 },
1024 Ok(Event::Text(e)) => {
1025 if in_code {
1026 error_code = e.unescape().ok().map(|s| s.into_owned());
1027 in_code = false;
1028 } else if in_message {
1029 error_message = e.unescape().ok().map(|s| s.into_owned());
1030 in_message = false;
1031 }
1032 }
1033 Ok(Event::End(ref e)) if e.name().as_ref() == b"Error" => {
1034 in_error = false;
1035 }
1036 Ok(Event::Eof) => break,
1037 Err(_) => break,
1038 _ => {}
1039 }
1040 buf.clear();
1041 }
1042
1043 let code = error_code.unwrap_or_else(|| "Unknown".to_string());
1044 let message = error_message.unwrap_or_else(|| "Unknown error".to_string());
1045
1046 match code.as_str() {
1048 "AWS.SimpleQueueService.NonExistentQueue" | "QueueDoesNotExist" => {
1049 AwsError::QueueNotFound(message)
1050 }
1051 "InvalidClientTokenId" | "UnrecognizedClientException" | "SignatureDoesNotMatch" => {
1052 AwsError::Authentication(format!("{}: {}", code, message))
1053 }
1054 "InvalidReceiptHandle" | "ReceiptHandleIsInvalid" => AwsError::InvalidReceipt(message),
1055 _ if status_code == 401 || status_code == 403 => {
1056 AwsError::Authentication(format!("{}: {}", code, message))
1057 }
1058 _ if status_code >= 500 => AwsError::ServiceError(format!("{}: {}", code, message)),
1059 _ => AwsError::ServiceError(format!("{}: {}", code, message)),
1060 }
1061 }
1062
1063 fn parse_send_message_response(&self, xml: &str) -> Result<MessageId, AwsError> {
1065 use quick_xml::events::Event;
1066 use quick_xml::Reader;
1067
1068 let mut reader = Reader::from_str(xml);
1069 reader.trim_text(true);
1070
1071 let mut in_message_id = false;
1072 let mut buf = Vec::new();
1073
1074 loop {
1075 match reader.read_event_into(&mut buf) {
1076 Ok(Event::Start(ref e)) if e.name().as_ref() == b"MessageId" => {
1077 in_message_id = true;
1078 }
1079 Ok(Event::Text(e)) if in_message_id => {
1080 let msg_id = e.unescape().map(|s| s.into_owned()).map_err(|e| {
1081 AwsError::SerializationError(format!("Failed to parse XML: {}", e))
1082 })?;
1083
1084 use std::str::FromStr;
1086 let message_id =
1087 MessageId::from_str(&msg_id).unwrap_or_else(|_| MessageId::new());
1088 return Ok(message_id);
1089 }
1090 Ok(Event::Eof) => break,
1091 Err(e) => {
1092 return Err(AwsError::SerializationError(format!(
1093 "XML parsing error: {}",
1094 e
1095 )))
1096 }
1097 _ => {}
1098 }
1099 buf.clear();
1100 }
1101
1102 Err(AwsError::SerializationError(
1103 "MessageId not found in response".to_string(),
1104 ))
1105 }
1106
1107 fn parse_receive_message_response(
1109 &self,
1110 xml: &str,
1111 queue: &QueueName,
1112 ) -> Result<Vec<ReceivedMessage>, AwsError> {
1113 use quick_xml::events::Event;
1114 use quick_xml::Reader;
1115
1116 let mut reader = Reader::from_str(xml);
1117 reader.trim_text(true);
1118
1119 let mut messages = Vec::new();
1120 let mut in_message = false;
1121 let mut current_message_id: Option<String> = None;
1122 let mut current_receipt_handle: Option<String> = None;
1123 let mut current_body: Option<String> = None;
1124 let mut current_session_id: Option<String> = None;
1125 let mut current_delivery_count: u32 = 1;
1126
1127 let mut in_message_id = false;
1128 let mut in_receipt_handle = false;
1129 let mut in_body = false;
1130 let mut in_attribute_name = false;
1131 let mut in_attribute_value = false;
1132 let mut current_attribute_name: Option<String> = None;
1133
1134 let mut buf = Vec::new();
1135
1136 loop {
1137 match reader.read_event_into(&mut buf) {
1138 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1139 b"Message" => {
1140 in_message = true;
1141 current_message_id = None;
1143 current_receipt_handle = None;
1144 current_body = None;
1145 current_session_id = None;
1146 current_delivery_count = 1;
1147 }
1148 b"MessageId" if in_message => in_message_id = true,
1149 b"ReceiptHandle" if in_message => in_receipt_handle = true,
1150 b"Body" if in_message => in_body = true,
1151 b"Name" if in_message => in_attribute_name = true,
1152 b"Value" if in_message => in_attribute_value = true,
1153 _ => {}
1154 },
1155 Ok(Event::Text(e)) => {
1156 let text = e.unescape().ok().map(|s| s.into_owned());
1157 if in_message_id {
1158 current_message_id = text;
1159 in_message_id = false;
1160 } else if in_receipt_handle {
1161 current_receipt_handle = text;
1162 in_receipt_handle = false;
1163 } else if in_body {
1164 current_body = text;
1165 in_body = false;
1166 } else if in_attribute_name {
1167 current_attribute_name = text;
1168 in_attribute_name = false;
1169 } else if in_attribute_value {
1170 if let Some(ref attr_name) = current_attribute_name {
1171 match attr_name.as_str() {
1172 "MessageGroupId" => current_session_id = text,
1173 "ApproximateReceiveCount" => {
1174 if let Some(count_str) = text {
1175 current_delivery_count = count_str.parse().unwrap_or(1);
1176 }
1177 }
1178 _ => {}
1179 }
1180 }
1181 in_attribute_value = false;
1182 current_attribute_name = None;
1183 }
1184 }
1185 Ok(Event::End(ref e)) if e.name().as_ref() == b"Message" => {
1186 in_message = false;
1187
1188 if let (Some(body_base64), Some(receipt_handle)) =
1190 (current_body.as_ref(), current_receipt_handle.as_ref())
1191 {
1192 use base64::{engine::general_purpose::STANDARD, Engine};
1194 let body_bytes = STANDARD.decode(body_base64).map_err(|e| {
1195 AwsError::SerializationError(format!("Base64 decode failed: {}", e))
1196 })?;
1197 let body = bytes::Bytes::from(body_bytes);
1198
1199 use std::str::FromStr;
1201 let message_id = current_message_id
1202 .as_ref()
1203 .and_then(|id| MessageId::from_str(id).ok())
1204 .unwrap_or_default();
1205
1206 let session_id = current_session_id
1208 .as_ref()
1209 .and_then(|id| SessionId::new(id.clone()).ok());
1210
1211 let handle_with_queue = format!("{}|{}", queue.as_str(), receipt_handle);
1214 let expires_at = Timestamp::now();
1215 let receipt =
1216 ReceiptHandle::new(handle_with_queue, expires_at, ProviderType::AwsSqs);
1217
1218 let received_message = ReceivedMessage {
1220 message_id,
1221 body,
1222 attributes: HashMap::new(),
1223 session_id,
1224 correlation_id: None,
1225 receipt_handle: receipt,
1226 delivery_count: current_delivery_count,
1227 first_delivered_at: Timestamp::now(),
1228 delivered_at: Timestamp::now(),
1229 };
1230
1231 messages.push(received_message);
1232 }
1233 }
1234 Ok(Event::Eof) => break,
1235 Err(e) => {
1236 return Err(AwsError::SerializationError(format!(
1237 "XML parsing error: {}",
1238 e
1239 )))
1240 }
1241 _ => {}
1242 }
1243 buf.clear();
1244 }
1245
1246 Ok(messages)
1247 }
1248
1249 fn is_fifo_queue(queue_name: &QueueName) -> bool {
1251 queue_name.as_str().ends_with(".fifo")
1252 }
1253}
1254
1255impl fmt::Debug for AwsSqsProvider {
1256 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1257 f.debug_struct("AwsSqsProvider")
1258 .field("config", &self.config)
1259 .field("queue_url_cache_size", &"<redacted>")
1260 .finish()
1261 }
1262}
1263
1264#[async_trait]
1265impl QueueProvider for AwsSqsProvider {
1266 async fn send_message(
1267 &self,
1268 queue: &QueueName,
1269 message: &Message,
1270 ) -> Result<MessageId, QueueError> {
1271 let queue_url = self
1272 .get_queue_url(queue)
1273 .await
1274 .map_err(|e| e.to_queue_error())?;
1275
1276 use base64::{engine::general_purpose::STANDARD, Engine};
1278 let body_base64 = STANDARD.encode(&message.body);
1279
1280 if body_base64.len() > 256 * 1024 {
1282 return Err(AwsError::MessageTooLarge {
1283 size: body_base64.len(),
1284 max_size: 256 * 1024,
1285 }
1286 .to_queue_error());
1287 }
1288
1289 let mut params = HashMap::new();
1291 params.insert("Action".to_string(), "SendMessage".to_string());
1292 params.insert("Version".to_string(), "2012-11-05".to_string());
1293 params.insert("QueueUrl".to_string(), queue_url.clone());
1294 params.insert("MessageBody".to_string(), body_base64);
1295
1296 if Self::is_fifo_queue(queue) {
1298 if let Some(ref session_id) = message.session_id {
1299 params.insert(
1300 "MessageGroupId".to_string(),
1301 session_id.as_str().to_string(),
1302 );
1303 let dedup_id = uuid::Uuid::new_v4().to_string();
1305 params.insert("MessageDeduplicationId".to_string(), dedup_id);
1306 } else {
1307 return Err(QueueError::ValidationError(
1309 crate::error::ValidationError::Required {
1310 field: "session_id".to_string(),
1311 },
1312 ));
1313 }
1314 }
1315
1316 let response = self
1318 .make_request("POST", "/", ¶ms, "")
1319 .await
1320 .map_err(|e| e.to_queue_error())?;
1321
1322 let message_id = self
1324 .parse_send_message_response(&response)
1325 .map_err(|e| e.to_queue_error())?;
1326
1327 Ok(message_id)
1328 }
1329
1330 async fn send_messages(
1331 &self,
1332 queue: &QueueName,
1333 messages: &[Message],
1334 ) -> Result<Vec<MessageId>, QueueError> {
1335 if messages.is_empty() {
1336 return Ok(Vec::new());
1337 }
1338
1339 let max_batch = self.max_batch_size() as usize;
1341 let mut all_message_ids = Vec::new();
1342
1343 for chunk in messages.chunks(max_batch) {
1345 let message_ids = self.send_messages_batch(queue, chunk).await?;
1346 all_message_ids.extend(message_ids);
1347 }
1348
1349 Ok(all_message_ids)
1350 }
1351
1352 async fn receive_message(
1353 &self,
1354 queue: &QueueName,
1355 timeout: Duration,
1356 ) -> Result<Option<ReceivedMessage>, QueueError> {
1357 let messages = self.receive_messages(queue, 1, timeout).await?;
1358 Ok(messages.into_iter().next())
1359 }
1360
1361 async fn receive_messages(
1362 &self,
1363 queue: &QueueName,
1364 max_messages: u32,
1365 timeout: Duration,
1366 ) -> Result<Vec<ReceivedMessage>, QueueError> {
1367 let queue_url = self
1368 .get_queue_url(queue)
1369 .await
1370 .map_err(|e| e.to_queue_error())?;
1371
1372 let wait_time_seconds = timeout.num_seconds().clamp(0, 20); let mut params = HashMap::new();
1377 params.insert("Action".to_string(), "ReceiveMessage".to_string());
1378 params.insert("Version".to_string(), "2012-11-05".to_string());
1379 params.insert("QueueUrl".to_string(), queue_url);
1380 params.insert(
1381 "MaxNumberOfMessages".to_string(),
1382 max_messages.min(10).to_string(), );
1384 params.insert("WaitTimeSeconds".to_string(), wait_time_seconds.to_string());
1385 params.insert("AttributeName.1".to_string(), "All".to_string()); let response = self
1389 .make_request("POST", "/", ¶ms, "")
1390 .await
1391 .map_err(|e| e.to_queue_error())?;
1392
1393 let messages = self
1395 .parse_receive_message_response(&response, queue)
1396 .map_err(|e| e.to_queue_error())?;
1397
1398 Ok(messages)
1399 }
1400
1401 async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1402 let handle_str = receipt.handle();
1410 let parts: Vec<&str> = handle_str.split('|').collect();
1411
1412 if parts.len() != 2 {
1413 return Err(QueueError::MessageNotFound {
1414 receipt: handle_str.to_string(),
1415 });
1416 }
1417
1418 let queue_name =
1419 QueueName::new(parts[0].to_string()).map_err(QueueError::ValidationError)?;
1420 let receipt_token = parts[1];
1421
1422 let queue_url = self
1424 .get_queue_url(&queue_name)
1425 .await
1426 .map_err(|e| e.to_queue_error())?;
1427
1428 let mut params = HashMap::new();
1430 params.insert("Action".to_string(), "DeleteMessage".to_string());
1431 params.insert("Version".to_string(), "2012-11-05".to_string());
1432 params.insert("QueueUrl".to_string(), queue_url);
1433 params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
1434
1435 let _response = self
1437 .make_request("POST", "/", ¶ms, "")
1438 .await
1439 .map_err(|e| e.to_queue_error())?;
1440
1441 Ok(())
1443 }
1444
1445 async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1446 let handle_str = receipt.handle();
1448 let parts: Vec<&str> = handle_str.split('|').collect();
1449
1450 if parts.len() != 2 {
1451 return Err(QueueError::MessageNotFound {
1452 receipt: handle_str.to_string(),
1453 });
1454 }
1455
1456 let queue_name =
1457 QueueName::new(parts[0].to_string()).map_err(QueueError::ValidationError)?;
1458 let receipt_token = parts[1];
1459
1460 let queue_url = self
1462 .get_queue_url(&queue_name)
1463 .await
1464 .map_err(|e| e.to_queue_error())?;
1465
1466 let mut params = HashMap::new();
1469 params.insert("Action".to_string(), "ChangeMessageVisibility".to_string());
1470 params.insert("Version".to_string(), "2012-11-05".to_string());
1471 params.insert("QueueUrl".to_string(), queue_url);
1472 params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
1473 params.insert("VisibilityTimeout".to_string(), "0".to_string());
1474
1475 let _response = self
1477 .make_request("POST", "/", ¶ms, "")
1478 .await
1479 .map_err(|e| e.to_queue_error())?;
1480
1481 Ok(())
1483 }
1484
1485 async fn dead_letter_message(
1486 &self,
1487 receipt: &ReceiptHandle,
1488 _reason: &str,
1489 ) -> Result<(), QueueError> {
1490 self.complete_message(receipt).await
1493 }
1494
1495 async fn create_session_client(
1496 &self,
1497 queue: &QueueName,
1498 session_id: Option<SessionId>,
1499 ) -> Result<Box<dyn SessionProvider>, QueueError> {
1500 if !Self::is_fifo_queue(queue) {
1502 return Err(AwsError::SessionsNotSupported.to_queue_error());
1503 }
1504
1505 let queue_url = self
1507 .get_queue_url(queue)
1508 .await
1509 .map_err(|e| e.to_queue_error())?;
1510
1511 let session_id = session_id.ok_or_else(|| {
1513 QueueError::ValidationError(crate::error::ValidationError::Required {
1514 field: "session_id".to_string(),
1515 })
1516 })?;
1517
1518 Ok(Box::new(AwsSessionProvider::new(
1519 self.http_client.clone(),
1520 AwsCredentialProvider::new(
1521 self.http_client.clone(),
1522 self.config.access_key_id.clone(),
1523 self.config.secret_access_key.clone(),
1524 ),
1525 self.config.region.clone(),
1526 self.endpoint.clone(),
1527 queue_url,
1528 queue.clone(),
1529 session_id,
1530 )))
1531 }
1532
1533 fn provider_type(&self) -> ProviderType {
1534 ProviderType::AwsSqs
1535 }
1536
1537 fn supports_sessions(&self) -> SessionSupport {
1538 SessionSupport::Emulated
1539 }
1540
1541 fn supports_batching(&self) -> bool {
1542 true
1543 }
1544
1545 fn max_batch_size(&self) -> u32 {
1546 10 }
1548}
1549
1550impl AwsSqsProvider {
1552 async fn send_messages_batch(
1554 &self,
1555 queue: &QueueName,
1556 messages: &[Message],
1557 ) -> Result<Vec<MessageId>, QueueError> {
1558 if messages.is_empty() {
1559 return Ok(Vec::new());
1560 }
1561
1562 if messages.len() > 10 {
1564 return Err(QueueError::ValidationError(
1565 crate::error::ValidationError::OutOfRange {
1566 field: "messages".to_string(),
1567 message: format!("Batch size {} exceeds AWS SQS limit of 10", messages.len()),
1568 },
1569 ));
1570 }
1571
1572 let queue_url = self
1573 .get_queue_url(queue)
1574 .await
1575 .map_err(|e| e.to_queue_error())?;
1576
1577 let mut params = HashMap::new();
1579 params.insert("Action".to_string(), "SendMessageBatch".to_string());
1580 params.insert("Version".to_string(), "2012-11-05".to_string());
1581 params.insert("QueueUrl".to_string(), queue_url.clone());
1582
1583 use base64::{engine::general_purpose::STANDARD, Engine};
1585
1586 for (idx, message) in messages.iter().enumerate() {
1588 let entry_id = format!("msg-{}", idx);
1589 let body_base64 = STANDARD.encode(&message.body);
1590
1591 if body_base64.len() > 256 * 1024 {
1593 return Err(AwsError::MessageTooLarge {
1594 size: body_base64.len(),
1595 max_size: 256 * 1024,
1596 }
1597 .to_queue_error());
1598 }
1599
1600 params.insert(
1601 format!("SendMessageBatchRequestEntry.{}.Id", idx + 1),
1602 entry_id,
1603 );
1604 params.insert(
1605 format!("SendMessageBatchRequestEntry.{}.MessageBody", idx + 1),
1606 body_base64,
1607 );
1608
1609 if Self::is_fifo_queue(queue) {
1611 if let Some(ref session_id) = message.session_id {
1613 params.insert(
1614 format!("SendMessageBatchRequestEntry.{}.MessageGroupId", idx + 1),
1615 session_id.as_str().to_string(),
1616 );
1617 }
1618
1619 use sha2::{Digest, Sha256};
1622 let mut hasher = Sha256::new();
1623 hasher.update(&message.body);
1624 if let Some(ref session_id) = message.session_id {
1625 hasher.update(session_id.as_str().as_bytes());
1626 }
1627 let hash = format!("{:x}", hasher.finalize());
1628 params.insert(
1629 format!(
1630 "SendMessageBatchRequestEntry.{}.MessageDeduplicationId",
1631 idx + 1
1632 ),
1633 hash,
1634 );
1635 }
1636 }
1637
1638 let response = self
1640 .make_request("POST", "/", ¶ms, "")
1641 .await
1642 .map_err(|e| e.to_queue_error())?;
1643
1644 self.parse_send_message_batch_response(&response)
1646 .map_err(|e| e.to_queue_error())
1647 }
1648
1649 fn parse_send_message_batch_response(&self, xml: &str) -> Result<Vec<MessageId>, AwsError> {
1651 use quick_xml::events::Event;
1652 use quick_xml::Reader;
1653
1654 let mut reader = Reader::from_str(xml);
1655 reader.trim_text(true);
1656
1657 let mut message_ids = Vec::new();
1658 let mut in_successful = false;
1659 let mut in_message_id = false;
1660 let mut buf = Vec::new();
1661
1662 loop {
1663 match reader.read_event_into(&mut buf) {
1664 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1665 b"SendMessageBatchResultEntry" => in_successful = true,
1666 b"MessageId" if in_successful => in_message_id = true,
1667 _ => {}
1668 },
1669 Ok(Event::Text(e)) if in_message_id => {
1670 let msg_id = e.unescape().map(|s| s.into_owned()).map_err(|e| {
1671 AwsError::SerializationError(format!("Failed to parse XML: {}", e))
1672 })?;
1673
1674 use std::str::FromStr;
1676 let message_id =
1677 MessageId::from_str(&msg_id).unwrap_or_else(|_| MessageId::new());
1678 message_ids.push(message_id);
1679 in_message_id = false;
1680 }
1681 Ok(Event::End(ref e)) if e.name().as_ref() == b"SendMessageBatchResultEntry" => {
1682 in_successful = false;
1683 }
1684 Ok(Event::Eof) => break,
1685 Err(e) => {
1686 return Err(AwsError::SerializationError(format!(
1687 "XML parsing error: {}",
1688 e
1689 )))
1690 }
1691 _ => {}
1692 }
1693 buf.clear();
1694 }
1695
1696 Ok(message_ids)
1697 }
1698}
1699
1700pub struct AwsSessionProvider {
1709 http_client: HttpClient,
1710 credential_provider: AwsCredentialProvider,
1711 region: String,
1712 endpoint: String,
1713 queue_url: String,
1714 queue_name: QueueName,
1715 session_id: SessionId,
1716}
1717
1718impl AwsSessionProvider {
1719 fn new(
1721 http_client: HttpClient,
1722 credential_provider: AwsCredentialProvider,
1723 region: String,
1724 endpoint: String,
1725 queue_url: String,
1726 queue_name: QueueName,
1727 session_id: SessionId,
1728 ) -> Self {
1729 Self {
1730 http_client,
1731 credential_provider,
1732 region,
1733 endpoint,
1734 queue_url,
1735 queue_name,
1736 session_id,
1737 }
1738 }
1739
1740 async fn get_queue_url(&self) -> Result<String, AwsError> {
1742 Ok(self.queue_url.clone())
1743 }
1744
1745 async fn make_request(
1747 &self,
1748 method: &str,
1749 path: &str,
1750 params: &HashMap<String, String>,
1751 body: &str,
1752 ) -> Result<String, AwsError> {
1753 use reqwest::header;
1754
1755 let credentials = self.credential_provider.get_credentials().await?;
1757
1758 let signer = AwsV4Signer::new(
1760 credentials.access_key_id.clone(),
1761 credentials.secret_access_key.clone(),
1762 self.region.clone(),
1763 );
1764
1765 let query_string = if params.is_empty() {
1767 String::new()
1768 } else {
1769 let mut pairs: Vec<String> = params
1770 .iter()
1771 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
1772 .collect();
1773 pairs.sort();
1774 pairs.join("&")
1775 };
1776
1777 let url = if query_string.is_empty() {
1778 format!("{}{}", self.endpoint, path)
1779 } else {
1780 format!("{}{}?{}", self.endpoint, path, query_string)
1781 };
1782
1783 let mut request_builder = self.http_client.request(
1785 method
1786 .parse()
1787 .map_err(|e| AwsError::NetworkError(format!("Invalid HTTP method: {}", e)))?,
1788 &url,
1789 );
1790
1791 let timestamp = Utc::now();
1793 let host = self
1794 .endpoint
1795 .trim_start_matches("https://")
1796 .trim_start_matches("http://");
1797 let mut signed_headers = signer.sign_request(method, host, path, params, body, ×tamp);
1798
1799 if let Some(session_token) = &credentials.session_token {
1801 signed_headers.insert("X-Amz-Security-Token".to_string(), session_token.clone());
1802 }
1803
1804 for (key, value) in signed_headers {
1805 request_builder = request_builder.header(key, value);
1806 }
1807
1808 if !body.is_empty() {
1810 request_builder = request_builder
1811 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1812 .body(body.to_string());
1813 }
1814
1815 let response = request_builder
1817 .send()
1818 .await
1819 .map_err(|e| AwsError::NetworkError(format!("HTTP request failed: {}", e)))?;
1820
1821 let status = response.status();
1822 let response_text = response
1823 .text()
1824 .await
1825 .map_err(|e| AwsError::NetworkError(format!("Failed to read response: {}", e)))?;
1826
1827 if !status.is_success() {
1829 return Err(self.parse_error_response(&response_text, status.as_u16()));
1830 }
1831
1832 Ok(response_text)
1833 }
1834
1835 fn parse_error_response(&self, xml: &str, status_code: u16) -> AwsError {
1837 use quick_xml::events::Event;
1838 use quick_xml::Reader;
1839
1840 let mut reader = Reader::from_str(xml);
1841 reader.trim_text(true);
1842
1843 let mut error_code = None;
1844 let mut error_message = None;
1845 let mut in_error = false;
1846 let mut in_code = false;
1847 let mut in_message = false;
1848 let mut buf = Vec::new();
1849
1850 loop {
1851 match reader.read_event_into(&mut buf) {
1852 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1853 b"Error" => in_error = true,
1854 b"Code" if in_error => in_code = true,
1855 b"Message" if in_error => in_message = true,
1856 _ => {}
1857 },
1858 Ok(Event::Text(e)) => {
1859 if in_code {
1860 error_code = e.unescape().ok().map(|s| s.into_owned());
1861 in_code = false;
1862 } else if in_message {
1863 error_message = e.unescape().ok().map(|s| s.into_owned());
1864 in_message = false;
1865 }
1866 }
1867 Ok(Event::Eof) => break,
1868 Err(_) => break,
1869 _ => {}
1870 }
1871 buf.clear();
1872 }
1873
1874 match error_code.as_deref() {
1876 Some("InvalidParameterValue") | Some("MissingParameter") => AwsError::ServiceError(
1877 error_message.unwrap_or_else(|| "Invalid parameter".to_string()),
1878 ),
1879 Some("AccessDenied") | Some("InvalidClientTokenId") | Some("SignatureDoesNotMatch") => {
1880 AwsError::Authentication(
1881 error_message.unwrap_or_else(|| "Authentication failed".to_string()),
1882 )
1883 }
1884 Some("AWS.SimpleQueueService.NonExistentQueue") | Some("QueueDoesNotExist") => {
1885 AwsError::QueueNotFound(
1886 error_message.unwrap_or_else(|| "Queue not found".to_string()),
1887 )
1888 }
1889 _ => {
1890 if status_code >= 500 {
1891 AwsError::ServiceError(
1892 error_message.unwrap_or_else(|| "Service error".to_string()),
1893 )
1894 } else {
1895 AwsError::ServiceError(
1896 error_message.unwrap_or_else(|| format!("HTTP {}", status_code)),
1897 )
1898 }
1899 }
1900 }
1901 }
1902
1903 fn parse_receive_message_response(
1905 &self,
1906 xml: &str,
1907 queue: &QueueName,
1908 ) -> Result<Vec<ReceivedMessage>, AwsError> {
1909 use quick_xml::events::Event;
1910 use quick_xml::Reader;
1911
1912 let mut reader = Reader::from_str(xml);
1913 reader.trim_text(true);
1914
1915 let mut messages = Vec::new();
1916 let mut in_message = false;
1917 let mut current_message_id: Option<String> = None;
1918 let mut current_receipt_handle: Option<String> = None;
1919 let mut current_body: Option<String> = None;
1920 let mut current_session_id: Option<String> = None;
1921 let mut current_delivery_count: u32 = 1;
1922
1923 let mut in_message_id = false;
1924 let mut in_receipt_handle = false;
1925 let mut in_body = false;
1926 let mut in_attribute_name = false;
1927 let mut in_attribute_value = false;
1928 let mut current_attribute_name: Option<String> = None;
1929
1930 let mut buf = Vec::new();
1931
1932 loop {
1933 match reader.read_event_into(&mut buf) {
1934 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1935 b"Message" => {
1936 in_message = true;
1937 current_message_id = None;
1938 current_receipt_handle = None;
1939 current_body = None;
1940 current_session_id = None;
1941 current_delivery_count = 1;
1942 }
1943 b"MessageId" if in_message => in_message_id = true,
1944 b"ReceiptHandle" if in_message => in_receipt_handle = true,
1945 b"Body" if in_message => in_body = true,
1946 b"Name" if in_message => in_attribute_name = true,
1947 b"Value" if in_message => in_attribute_value = true,
1948 _ => {}
1949 },
1950 Ok(Event::Text(e)) => {
1951 let text = e.unescape().ok().map(|s| s.into_owned());
1952 if in_message_id {
1953 current_message_id = text;
1954 in_message_id = false;
1955 } else if in_receipt_handle {
1956 current_receipt_handle = text;
1957 in_receipt_handle = false;
1958 } else if in_body {
1959 current_body = text;
1960 in_body = false;
1961 } else if in_attribute_name {
1962 current_attribute_name = text;
1963 in_attribute_name = false;
1964 } else if in_attribute_value {
1965 if let Some(ref attr_name) = current_attribute_name {
1966 match attr_name.as_str() {
1967 "MessageGroupId" => current_session_id = text,
1968 "ApproximateReceiveCount" => {
1969 if let Some(count_str) = text {
1970 current_delivery_count = count_str.parse().unwrap_or(1);
1971 }
1972 }
1973 _ => {}
1974 }
1975 }
1976 in_attribute_value = false;
1977 current_attribute_name = None;
1978 }
1979 }
1980 Ok(Event::End(ref e)) if e.name().as_ref() == b"Message" => {
1981 in_message = false;
1982
1983 if let (Some(body_base64), Some(receipt_handle)) =
1984 (current_body.as_ref(), current_receipt_handle.as_ref())
1985 {
1986 use base64::{engine::general_purpose::STANDARD, Engine};
1987 let body_bytes = STANDARD.decode(body_base64).map_err(|e| {
1988 AwsError::SerializationError(format!("Base64 decode failed: {}", e))
1989 })?;
1990 let body = bytes::Bytes::from(body_bytes);
1991
1992 use std::str::FromStr;
1993 let message_id = current_message_id
1994 .as_ref()
1995 .and_then(|id| MessageId::from_str(id).ok())
1996 .unwrap_or_default();
1997
1998 let session_id = current_session_id
1999 .as_ref()
2000 .and_then(|id| SessionId::new(id.clone()).ok());
2001
2002 let handle_with_queue = format!("{}|{}", queue.as_str(), receipt_handle);
2003 let expires_at = Timestamp::now();
2004 let receipt =
2005 ReceiptHandle::new(handle_with_queue, expires_at, ProviderType::AwsSqs);
2006
2007 let received_message = ReceivedMessage {
2008 message_id,
2009 body,
2010 attributes: HashMap::new(),
2011 session_id,
2012 correlation_id: None,
2013 receipt_handle: receipt,
2014 delivery_count: current_delivery_count,
2015 first_delivered_at: Timestamp::now(),
2016 delivered_at: Timestamp::now(),
2017 };
2018
2019 messages.push(received_message);
2020 }
2021 }
2022 Ok(Event::Eof) => break,
2023 Err(e) => {
2024 return Err(AwsError::SerializationError(format!(
2025 "XML parsing error: {}",
2026 e
2027 )))
2028 }
2029 _ => {}
2030 }
2031 buf.clear();
2032 }
2033
2034 Ok(messages)
2035 }
2036}
2037
2038impl fmt::Debug for AwsSessionProvider {
2039 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2040 f.debug_struct("AwsSessionProvider")
2041 .field("queue_name", &self.queue_name)
2042 .field("session_id", &self.session_id)
2043 .finish()
2044 }
2045}
2046
2047#[async_trait]
2048impl SessionProvider for AwsSessionProvider {
2049 async fn receive_message(
2050 &self,
2051 timeout: Duration,
2052 ) -> Result<Option<ReceivedMessage>, QueueError> {
2053 let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
2056
2057 let mut params = HashMap::new();
2059 params.insert("Action".to_string(), "ReceiveMessage".to_string());
2060 params.insert("Version".to_string(), "2012-11-05".to_string());
2061 params.insert("QueueUrl".to_string(), queue_url);
2062 params.insert("MaxNumberOfMessages".to_string(), "1".to_string());
2063 params.insert(
2064 "WaitTimeSeconds".to_string(),
2065 timeout.num_seconds().clamp(0, 20).to_string(),
2066 );
2067 params.insert("AttributeName.1".to_string(), "All".to_string());
2068
2069 let response = self
2071 .make_request("POST", "/", ¶ms, "")
2072 .await
2073 .map_err(|e| e.to_queue_error())?;
2074
2075 let messages = self
2077 .parse_receive_message_response(&response, &self.queue_name)
2078 .map_err(|e| e.to_queue_error())?;
2079
2080 Ok(messages
2082 .into_iter()
2083 .find(|msg| msg.session_id.as_ref() == Some(&self.session_id)))
2084 }
2085
2086 async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
2087 let handle_str = receipt.handle();
2089 let parts: Vec<&str> = handle_str.split('|').collect();
2090
2091 if parts.len() != 2 {
2092 return Err(QueueError::MessageNotFound {
2093 receipt: handle_str.to_string(),
2094 });
2095 }
2096
2097 let receipt_token = parts[1];
2098 let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
2099
2100 let mut params = HashMap::new();
2102 params.insert("Action".to_string(), "DeleteMessage".to_string());
2103 params.insert("Version".to_string(), "2012-11-05".to_string());
2104 params.insert("QueueUrl".to_string(), queue_url);
2105 params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
2106
2107 self.make_request("POST", "/", ¶ms, "")
2109 .await
2110 .map_err(|e| e.to_queue_error())?;
2111
2112 Ok(())
2113 }
2114
2115 async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
2116 let handle_str = receipt.handle();
2118 let parts: Vec<&str> = handle_str.split('|').collect();
2119
2120 if parts.len() != 2 {
2121 return Err(QueueError::MessageNotFound {
2122 receipt: handle_str.to_string(),
2123 });
2124 }
2125
2126 let receipt_token = parts[1];
2127 let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
2128
2129 let mut params = HashMap::new();
2131 params.insert("Action".to_string(), "ChangeMessageVisibility".to_string());
2132 params.insert("Version".to_string(), "2012-11-05".to_string());
2133 params.insert("QueueUrl".to_string(), queue_url);
2134 params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
2135 params.insert("VisibilityTimeout".to_string(), "0".to_string());
2136
2137 self.make_request("POST", "/", ¶ms, "")
2139 .await
2140 .map_err(|e| e.to_queue_error())?;
2141
2142 Ok(())
2143 }
2144
2145 async fn dead_letter_message(
2146 &self,
2147 receipt: &ReceiptHandle,
2148 _reason: &str,
2149 ) -> Result<(), QueueError> {
2150 self.complete_message(receipt).await
2152 }
2153
2154 async fn renew_session_lock(&self) -> Result<(), QueueError> {
2155 Ok(())
2157 }
2158
2159 async fn close_session(&self) -> Result<(), QueueError> {
2160 Ok(())
2162 }
2163
2164 fn session_id(&self) -> &SessionId {
2165 &self.session_id
2166 }
2167
2168 fn session_expires_at(&self) -> Timestamp {
2169 Timestamp::from_datetime(chrono::Utc::now() + chrono::Duration::days(365))
2171 }
2172}