1use crate::client::{QueueProvider, SessionProvider};
216use crate::error::{ConfigurationError, QueueError, SerializationError};
217use crate::message::{
218 Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
219};
220use crate::provider::{AwsSqsConfig, ProviderType, SessionSupport};
221use async_trait::async_trait;
222use chrono::{DateTime, Duration, Utc};
223use hmac::{Hmac, KeyInit, Mac};
224use reqwest::Client as HttpClient;
225use sha2::{Digest, Sha256};
226use std::collections::HashMap;
227use std::fmt;
228use std::sync::Arc;
229use tokio::sync::RwLock;
230
231#[cfg(test)]
232#[path = "aws_tests.rs"]
233mod tests;
234
235#[derive(Debug, thiserror::Error)]
241pub enum AwsError {
242 #[error("Authentication failed: {0}")]
243 Authentication(String),
244
245 #[error("Network error: {0}")]
246 NetworkError(String),
247
248 #[error("SQS service error: {0}")]
249 ServiceError(String),
250
251 #[error("Queue not found: {0}")]
252 QueueNotFound(String),
253
254 #[error("Invalid receipt handle: {0}")]
255 InvalidReceipt(String),
256
257 #[error("Message too large: {size} bytes (max: {max_size})")]
258 MessageTooLarge { size: usize, max_size: usize },
259
260 #[error("Invalid configuration: {0}")]
261 ConfigurationError(String),
262
263 #[error("Serialization error: {0}")]
264 SerializationError(String),
265
266 #[error("Sessions not supported on standard queues")]
267 SessionsNotSupported,
268}
269
270impl AwsError {
271 pub fn is_transient(&self) -> bool {
273 match self {
274 Self::Authentication(_) => false,
275 Self::NetworkError(_) => true,
276 Self::ServiceError(_) => true, Self::QueueNotFound(_) => false,
278 Self::InvalidReceipt(_) => false,
279 Self::MessageTooLarge { .. } => false,
280 Self::ConfigurationError(_) => false,
281 Self::SerializationError(_) => false,
282 Self::SessionsNotSupported => false,
283 }
284 }
285
286 pub fn to_queue_error(self) -> QueueError {
288 match self {
289 Self::Authentication(msg) => QueueError::AuthenticationFailed { message: msg },
290 Self::NetworkError(msg) => QueueError::ConnectionFailed { message: msg },
291 Self::ServiceError(msg) => QueueError::ProviderError {
292 provider: "AwsSqs".to_string(),
293 code: "ServiceError".to_string(),
294 message: msg,
295 },
296 Self::QueueNotFound(queue) => QueueError::QueueNotFound { queue_name: queue },
297 Self::InvalidReceipt(receipt) => QueueError::InvalidReceipt { receipt },
298 Self::MessageTooLarge { size, max_size } => {
299 QueueError::MessageTooLarge { size, max_size }
300 }
301 Self::ConfigurationError(msg) => {
302 QueueError::ConfigurationError(ConfigurationError::Invalid { message: msg })
303 }
304 Self::SerializationError(msg) => QueueError::SerializationError(
305 SerializationError::JsonError(serde_json::Error::io(std::io::Error::new(
306 std::io::ErrorKind::InvalidData,
307 msg,
308 ))),
309 ),
310 Self::SessionsNotSupported => QueueError::ProviderError {
311 provider: "AwsSqs".to_string(),
312 code: "SessionsNotSupported".to_string(),
313 message:
314 "Standard queues do not support session-based operations. Use FIFO queues."
315 .to_string(),
316 },
317 }
318 }
319}
320
321type HmacSha256 = Hmac<Sha256>;
326
327#[derive(Clone)]
340struct AwsV4Signer {
341 access_key: String,
342 secret_key: String,
343 region: String,
344 service: String,
345}
346
347impl AwsV4Signer {
348 fn new(access_key: String, secret_key: String, region: String) -> Self {
356 Self {
357 access_key,
358 secret_key,
359 region,
360 service: "sqs".to_string(),
361 }
362 }
363
364 fn sign_request(
380 &self,
381 method: &str,
382 host: &str,
383 path: &str,
384 query_params: &HashMap<String, String>,
385 body: &str,
386 timestamp: &DateTime<Utc>,
387 ) -> HashMap<String, String> {
388 let date_stamp = timestamp.format("%Y%m%d").to_string();
389 let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
390
391 let canonical_uri = path;
393
394 let mut canonical_query_string = query_params
396 .iter()
397 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
398 .collect::<Vec<_>>();
399 canonical_query_string.sort();
400 let canonical_query_string = canonical_query_string.join("&");
401
402 let canonical_headers = format!("host:{}\nx-amz-date:{}\n", host, amz_date);
404 let signed_headers = "host;x-amz-date";
405
406 let payload_hash = hex::encode(Sha256::digest(body.as_bytes()));
408
409 let canonical_request = format!(
411 "{}\n{}\n{}\n{}\n{}\n{}",
412 method,
413 canonical_uri,
414 canonical_query_string,
415 canonical_headers,
416 signed_headers,
417 payload_hash
418 );
419
420 let algorithm = "AWS4-HMAC-SHA256";
422 let credential_scope = format!(
423 "{}/{}/{}/aws4_request",
424 date_stamp, self.region, self.service
425 );
426 let canonical_request_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
427
428 let string_to_sign = format!(
429 "{}\n{}\n{}\n{}",
430 algorithm, amz_date, credential_scope, canonical_request_hash
431 );
432
433 let signature = self.calculate_signature(&string_to_sign, &date_stamp);
435
436 let authorization_header = format!(
438 "{} Credential={}/{}, SignedHeaders={}, Signature={}",
439 algorithm, self.access_key, credential_scope, signed_headers, signature
440 );
441
442 let mut headers = HashMap::new();
443 headers.insert("Authorization".to_string(), authorization_header);
444 headers.insert("x-amz-date".to_string(), amz_date);
445 headers.insert("host".to_string(), host.to_string());
446
447 headers
448 }
449
450 fn calculate_signature(&self, string_to_sign: &str, date_stamp: &str) -> String {
460 let k_secret = format!("AWS4{}", self.secret_key);
461 let k_date = self.hmac_sha256(k_secret.as_bytes(), date_stamp.as_bytes());
462 let k_region = self.hmac_sha256(&k_date, self.region.as_bytes());
463 let k_service = self.hmac_sha256(&k_region, self.service.as_bytes());
464 let k_signing = self.hmac_sha256(&k_service, b"aws4_request");
465 let signature = self.hmac_sha256(&k_signing, string_to_sign.as_bytes());
466
467 hex::encode(signature)
468 }
469
470 fn hmac_sha256(&self, key: &[u8], data: &[u8]) -> Vec<u8> {
472 let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
473 mac.update(data);
474 mac.finalize().into_bytes().to_vec()
475 }
476}
477
478#[derive(Debug, Clone)]
484struct AwsCredentials {
485 access_key_id: String,
486 secret_access_key: String,
487 session_token: Option<String>,
488 expiration: DateTime<Utc>,
489}
490
491impl AwsCredentials {
492 fn is_expired(&self) -> bool {
494 let now = Utc::now();
495 let buffer = Duration::minutes(5);
496 self.expiration - buffer <= now
497 }
498}
499
500struct AwsCredentialProvider {
510 http_client: HttpClient,
511 cached_credentials: Arc<RwLock<Option<AwsCredentials>>>,
512 explicit_config: Option<(String, String)>, }
514
515impl AwsCredentialProvider {
516 fn new(
518 http_client: HttpClient,
519 access_key_id: Option<String>,
520 secret_access_key: Option<String>,
521 ) -> Self {
522 let explicit_config = match (access_key_id, secret_access_key) {
523 (Some(key_id), Some(secret)) => Some((key_id, secret)),
524 _ => None,
525 };
526
527 Self {
528 http_client,
529 cached_credentials: Arc::new(RwLock::new(None)),
530 explicit_config,
531 }
532 }
533
534 async fn get_credentials(&self) -> Result<AwsCredentials, AwsError> {
536 {
538 let cache = self.cached_credentials.read().await;
539 if let Some(creds) = cache.as_ref() {
540 if !creds.is_expired() {
541 return Ok(creds.clone());
542 }
543 }
544 }
545
546 let creds = self.fetch_credentials().await?;
548
549 {
551 let mut cache = self.cached_credentials.write().await;
552 *cache = Some(creds.clone());
553 }
554
555 Ok(creds)
556 }
557
558 async fn fetch_credentials(&self) -> Result<AwsCredentials, AwsError> {
560 if let Some((key_id, secret)) = &self.explicit_config {
562 return Ok(AwsCredentials {
563 access_key_id: key_id.clone(),
564 secret_access_key: secret.clone(),
565 session_token: None,
566 expiration: Utc::now() + Duration::days(365), });
568 }
569
570 if let Ok(creds) = self.fetch_from_environment() {
572 return Ok(creds);
573 }
574
575 if let Ok(creds) = self.fetch_from_ecs_metadata().await {
577 return Ok(creds);
578 }
579
580 if let Ok(creds) = self.fetch_from_ec2_metadata().await {
582 return Ok(creds);
583 }
584
585 Err(AwsError::Authentication(
586 "No credentials found in credential chain".to_string(),
587 ))
588 }
589
590 fn fetch_from_environment(&self) -> Result<AwsCredentials, AwsError> {
592 let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
593 .map_err(|_| AwsError::Authentication("AWS_ACCESS_KEY_ID not set".to_string()))?;
594
595 let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
596 .map_err(|_| AwsError::Authentication("AWS_SECRET_ACCESS_KEY not set".to_string()))?;
597
598 let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
599
600 Ok(AwsCredentials {
601 access_key_id,
602 secret_access_key,
603 session_token,
604 expiration: Utc::now() + Duration::days(365), })
606 }
607
608 async fn fetch_from_ecs_metadata(&self) -> Result<AwsCredentials, AwsError> {
610 let relative_uri =
611 std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").map_err(|_| {
612 AwsError::Authentication(
613 "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI not set".to_string(),
614 )
615 })?;
616
617 let endpoint = format!("http://169.254.170.2{}", relative_uri);
618
619 let response = self
620 .http_client
621 .get(&endpoint)
622 .timeout(std::time::Duration::from_secs(2))
623 .send()
624 .await
625 .map_err(|e| {
626 AwsError::Authentication(format!("Failed to fetch ECS credentials: {}", e))
627 })?;
628
629 if !response.status().is_success() {
630 return Err(AwsError::Authentication(format!(
631 "ECS metadata returned error: {}",
632 response.status()
633 )));
634 }
635
636 let body = response
637 .text()
638 .await
639 .map_err(|e| AwsError::Authentication(format!("Failed to read ECS metadata: {}", e)))?;
640
641 self.parse_credentials_json(&body)
642 }
643
644 async fn fetch_from_ec2_metadata(&self) -> Result<AwsCredentials, AwsError> {
646 let token = self
648 .http_client
649 .put("http://169.254.169.254/latest/api/token")
650 .header("X-aws-ec2-metadata-token-ttl-seconds", "21600")
651 .timeout(std::time::Duration::from_secs(2))
652 .send()
653 .await
654 .map_err(|e| AwsError::Authentication(format!("Failed to get IMDSv2 token: {}", e)))?
655 .text()
656 .await
657 .map_err(|e| AwsError::Authentication(format!("Failed to read IMDSv2 token: {}", e)))?;
658
659 let role_name = self
661 .http_client
662 .get("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
663 .header("X-aws-ec2-metadata-token", &token)
664 .timeout(std::time::Duration::from_secs(2))
665 .send()
666 .await
667 .map_err(|e| AwsError::Authentication(format!("Failed to fetch IAM role name: {}", e)))?
668 .text()
669 .await
670 .map_err(|e| {
671 AwsError::Authentication(format!("Failed to read IAM role name: {}", e))
672 })?;
673
674 let credentials_url = format!(
676 "http://169.254.169.254/latest/meta-data/iam/security-credentials/{}",
677 role_name.trim()
678 );
679
680 let response = self
681 .http_client
682 .get(&credentials_url)
683 .header("X-aws-ec2-metadata-token", &token)
684 .timeout(std::time::Duration::from_secs(2))
685 .send()
686 .await
687 .map_err(|e| {
688 AwsError::Authentication(format!("Failed to fetch EC2 credentials: {}", e))
689 })?;
690
691 if !response.status().is_success() {
692 return Err(AwsError::Authentication(format!(
693 "EC2 metadata returned error: {}",
694 response.status()
695 )));
696 }
697
698 let body = response
699 .text()
700 .await
701 .map_err(|e| AwsError::Authentication(format!("Failed to read EC2 metadata: {}", e)))?;
702
703 self.parse_credentials_json(&body)
704 }
705
706 fn parse_credentials_json(&self, json: &str) -> Result<AwsCredentials, AwsError> {
708 let access_key_id = Self::extract_json_field(json, "AccessKeyId")?;
710 let secret_access_key = Self::extract_json_field(json, "SecretAccessKey")?;
711 let session_token = Self::extract_json_field(json, "Token").ok();
712 let expiration_str = Self::extract_json_field(json, "Expiration")?;
713
714 let expiration = DateTime::parse_from_rfc3339(&expiration_str)
716 .map_err(|e| AwsError::Authentication(format!("Invalid expiration timestamp: {}", e)))?
717 .with_timezone(&Utc);
718
719 Ok(AwsCredentials {
720 access_key_id,
721 secret_access_key,
722 session_token,
723 expiration,
724 })
725 }
726
727 fn extract_json_field(json: &str, field: &str) -> Result<String, AwsError> {
729 let pattern = format!("\"{}\": \"", field);
730 let start = json.find(&pattern).ok_or_else(|| {
731 AwsError::Authentication(format!("Field '{}' not found in JSON", field))
732 })?;
733
734 let value_start = start + pattern.len();
735 let value_end = json[value_start..].find('"').ok_or_else(|| {
736 AwsError::Authentication(format!("Malformed JSON for field '{}'", field))
737 })? + value_start;
738
739 Ok(json[value_start..value_end].to_string())
740 }
741}
742
743pub struct AwsSqsProvider {
763 http_client: HttpClient,
764 credential_provider: AwsCredentialProvider,
765 config: AwsSqsConfig,
766 endpoint: String,
767 queue_url_cache: Arc<RwLock<HashMap<QueueName, String>>>,
768}
769
770impl AwsSqsProvider {
771 pub async fn new(config: AwsSqsConfig) -> Result<Self, AwsError> {
803 if config.region.is_empty() {
805 return Err(AwsError::ConfigurationError(
806 "Region cannot be empty".to_string(),
807 ));
808 }
809
810 let endpoint = format!("https://sqs.{}.amazonaws.com", config.region);
812
813 let http_client = HttpClient::builder()
815 .timeout(std::time::Duration::from_secs(30))
816 .build()
817 .map_err(|e| AwsError::NetworkError(format!("Failed to create HTTP client: {}", e)))?;
818
819 let credential_provider = AwsCredentialProvider::new(
821 http_client.clone(),
822 config.access_key_id.clone(),
823 config.secret_access_key.clone(),
824 );
825
826 Ok(Self {
827 http_client,
828 credential_provider,
829 config,
830 endpoint,
831 queue_url_cache: Arc::new(RwLock::new(HashMap::new())),
832 })
833 }
834
835 async fn get_queue_url(&self, queue_name: &QueueName) -> Result<String, AwsError> {
845 {
847 let cache = self.queue_url_cache.read().await;
848 if let Some(url) = cache.get(queue_name) {
849 return Ok(url.clone());
850 }
851 }
852
853 let mut params = HashMap::new();
855 params.insert("Action".to_string(), "GetQueueUrl".to_string());
856 params.insert("QueueName".to_string(), queue_name.as_str().to_string());
857 params.insert("Version".to_string(), "2012-11-05".to_string());
858
859 let response = self.make_request("POST", "/", ¶ms, "").await?;
861
862 let queue_url = self.parse_queue_url_response(&response)?;
864
865 let mut cache = self.queue_url_cache.write().await;
867 cache.insert(queue_name.clone(), queue_url.clone());
868
869 Ok(queue_url)
870 }
871
872 async fn make_request(
874 &self,
875 method: &str,
876 path: &str,
877 query_params: &HashMap<String, String>,
878 body: &str,
879 ) -> Result<String, AwsError> {
880 let credentials = self.credential_provider.get_credentials().await?;
882
883 let signer = AwsV4Signer::new(
885 credentials.access_key_id.clone(),
886 credentials.secret_access_key.clone(),
887 self.config.region.clone(),
888 );
889
890 let host = self
892 .endpoint
893 .strip_prefix("https://")
894 .unwrap_or(&self.endpoint);
895
896 let timestamp = Utc::now();
898
899 let mut auth_headers =
901 signer.sign_request(method, host, path, query_params, body, ×tamp);
902
903 if let Some(session_token) = &credentials.session_token {
905 auth_headers.insert("X-Amz-Security-Token".to_string(), session_token.clone());
906 }
907
908 let mut url = format!("{}{}", self.endpoint, path);
910 if !query_params.is_empty() {
911 let query_string = query_params
912 .iter()
913 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
914 .collect::<Vec<_>>()
915 .join("&");
916 url = format!("{}?{}", url, query_string);
917 }
918
919 let mut request = self.http_client.request(
921 method
922 .parse()
923 .map_err(|e| AwsError::ConfigurationError(format!("Invalid HTTP method: {}", e)))?,
924 &url,
925 );
926
927 for (key, value) in auth_headers {
929 request = request.header(&key, value);
930 }
931
932 if !body.is_empty() {
934 request = request.body(body.to_string());
935 }
936
937 let response = request.send().await.map_err(|e| {
939 if e.is_timeout() {
940 AwsError::NetworkError(format!("Request timeout: {}", e))
941 } else if e.is_connect() {
942 AwsError::NetworkError(format!("Connection failed: {}", e))
943 } else {
944 AwsError::NetworkError(format!("HTTP request failed: {}", e))
945 }
946 })?;
947
948 let status = response.status();
950 let response_body = response
951 .text()
952 .await
953 .map_err(|e| AwsError::NetworkError(format!("Failed to read response body: {}", e)))?;
954
955 if !status.is_success() {
956 return Err(self.parse_error_response(&response_body, status.as_u16()));
958 }
959
960 Ok(response_body)
961 }
962
963 fn parse_queue_url_response(&self, xml: &str) -> Result<String, AwsError> {
965 use quick_xml::events::Event;
966 use quick_xml::Reader;
967
968 let mut reader = Reader::from_str(xml);
969 reader.config_mut().trim_text(true);
970
971 let mut in_queue_url = false;
972 let mut buf = Vec::new();
973
974 loop {
975 match reader.read_event_into(&mut buf) {
976 Ok(Event::Start(ref e)) if e.name().as_ref() == b"QueueUrl" => {
977 in_queue_url = true;
978 }
979 Ok(Event::Text(e)) if in_queue_url => {
980 return e
981 .decode()
982 .map_err(|e| {
983 AwsError::SerializationError(format!("Failed to parse XML: {}", e))
984 })
985 .and_then(|s| {
986 quick_xml::escape::unescape(&s)
987 .map(|u| u.into_owned())
988 .map_err(|e| {
989 AwsError::SerializationError(format!(
990 "Failed to unescape XML: {}",
991 e
992 ))
993 })
994 });
995 }
996 Ok(Event::Eof) => break,
997 Err(e) => {
998 return Err(AwsError::SerializationError(format!(
999 "XML parsing error: {}",
1000 e
1001 )))
1002 }
1003 _ => {}
1004 }
1005 buf.clear();
1006 }
1007
1008 Err(AwsError::SerializationError(
1009 "QueueUrl not found in response".to_string(),
1010 ))
1011 }
1012
1013 fn parse_error_response(&self, xml: &str, status_code: u16) -> AwsError {
1015 use quick_xml::events::Event;
1016 use quick_xml::Reader;
1017
1018 let mut reader = Reader::from_str(xml);
1019 reader.config_mut().trim_text(true);
1020
1021 let mut error_code = None;
1022 let mut error_message = None;
1023 let mut in_error = false;
1024 let mut in_code = false;
1025 let mut in_message = false;
1026 let mut buf = Vec::new();
1027
1028 loop {
1029 match reader.read_event_into(&mut buf) {
1030 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1031 b"Error" => in_error = true,
1032 b"Code" if in_error => in_code = true,
1033 b"Message" if in_error => in_message = true,
1034 _ => {}
1035 },
1036 Ok(Event::Text(e)) => {
1037 if in_code {
1038 error_code = e.decode().ok().and_then(|s| {
1039 quick_xml::escape::unescape(&s).ok().map(|u| u.into_owned())
1040 });
1041 in_code = false;
1042 } else if in_message {
1043 error_message = e.decode().ok().and_then(|s| {
1044 quick_xml::escape::unescape(&s).ok().map(|u| u.into_owned())
1045 });
1046 in_message = false;
1047 }
1048 }
1049 Ok(Event::End(ref e)) if e.name().as_ref() == b"Error" => {
1050 in_error = false;
1051 }
1052 Ok(Event::Eof) => break,
1053 Err(_) => break,
1054 _ => {}
1055 }
1056 buf.clear();
1057 }
1058
1059 let code = error_code.unwrap_or_else(|| "Unknown".to_string());
1060 let message = error_message.unwrap_or_else(|| "Unknown error".to_string());
1061
1062 match code.as_str() {
1064 "AWS.SimpleQueueService.NonExistentQueue" | "QueueDoesNotExist" => {
1065 AwsError::QueueNotFound(message)
1066 }
1067 "InvalidClientTokenId" | "UnrecognizedClientException" | "SignatureDoesNotMatch" => {
1068 AwsError::Authentication(format!("{}: {}", code, message))
1069 }
1070 "InvalidReceiptHandle" | "ReceiptHandleIsInvalid" => AwsError::InvalidReceipt(message),
1071 _ if status_code == 401 || status_code == 403 => {
1072 AwsError::Authentication(format!("{}: {}", code, message))
1073 }
1074 _ if status_code >= 500 => AwsError::ServiceError(format!("{}: {}", code, message)),
1075 _ => AwsError::ServiceError(format!("{}: {}", code, message)),
1076 }
1077 }
1078
1079 fn parse_send_message_response(&self, xml: &str) -> Result<MessageId, AwsError> {
1081 use quick_xml::events::Event;
1082 use quick_xml::Reader;
1083
1084 let mut reader = Reader::from_str(xml);
1085 reader.config_mut().trim_text(true);
1086
1087 let mut in_message_id = false;
1088 let mut buf = Vec::new();
1089
1090 loop {
1091 match reader.read_event_into(&mut buf) {
1092 Ok(Event::Start(ref e)) if e.name().as_ref() == b"MessageId" => {
1093 in_message_id = true;
1094 }
1095 Ok(Event::Text(e)) if in_message_id => {
1096 let msg_id = e.decode().map(|s| s.into_owned()).map_err(|e| {
1097 AwsError::SerializationError(format!("Failed to parse XML: {}", e))
1098 })?;
1099
1100 use std::str::FromStr;
1102 let message_id =
1103 MessageId::from_str(&msg_id).unwrap_or_else(|_| MessageId::new());
1104 return Ok(message_id);
1105 }
1106 Ok(Event::Eof) => break,
1107 Err(e) => {
1108 return Err(AwsError::SerializationError(format!(
1109 "XML parsing error: {}",
1110 e
1111 )))
1112 }
1113 _ => {}
1114 }
1115 buf.clear();
1116 }
1117
1118 Err(AwsError::SerializationError(
1119 "MessageId not found in response".to_string(),
1120 ))
1121 }
1122
1123 fn parse_receive_message_response(
1125 &self,
1126 xml: &str,
1127 queue: &QueueName,
1128 ) -> Result<Vec<ReceivedMessage>, AwsError> {
1129 use quick_xml::events::Event;
1130 use quick_xml::Reader;
1131
1132 let mut reader = Reader::from_str(xml);
1133 reader.config_mut().trim_text(true);
1134
1135 let mut messages = Vec::new();
1136 let mut in_message = false;
1137 let mut current_message_id: Option<String> = None;
1138 let mut current_receipt_handle: Option<String> = None;
1139 let mut current_body: Option<String> = None;
1140 let mut current_session_id: Option<String> = None;
1141 let mut current_delivery_count: u32 = 1;
1142
1143 let mut in_message_id = false;
1144 let mut in_receipt_handle = false;
1145 let mut in_body = false;
1146 let mut in_attribute_name = false;
1147 let mut in_attribute_value = false;
1148 let mut current_attribute_name: Option<String> = None;
1149
1150 let mut buf = Vec::new();
1151
1152 loop {
1153 match reader.read_event_into(&mut buf) {
1154 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1155 b"Message" => {
1156 in_message = true;
1157 current_message_id = None;
1159 current_receipt_handle = None;
1160 current_body = None;
1161 current_session_id = None;
1162 current_delivery_count = 1;
1163 }
1164 b"MessageId" if in_message => in_message_id = true,
1165 b"ReceiptHandle" if in_message => in_receipt_handle = true,
1166 b"Body" if in_message => in_body = true,
1167 b"Name" if in_message => in_attribute_name = true,
1168 b"Value" if in_message => in_attribute_value = true,
1169 _ => {}
1170 },
1171 Ok(Event::Text(e)) => {
1172 let text = e.decode().ok().map(|s| s.into_owned());
1173 if in_message_id {
1174 current_message_id = text;
1175 in_message_id = false;
1176 } else if in_receipt_handle {
1177 current_receipt_handle = text;
1178 in_receipt_handle = false;
1179 } else if in_body {
1180 current_body = text;
1181 in_body = false;
1182 } else if in_attribute_name {
1183 current_attribute_name = text;
1184 in_attribute_name = false;
1185 } else if in_attribute_value {
1186 if let Some(ref attr_name) = current_attribute_name {
1187 match attr_name.as_str() {
1188 "MessageGroupId" => current_session_id = text,
1189 "ApproximateReceiveCount" => {
1190 if let Some(count_str) = text {
1191 current_delivery_count =
1192 count_str.parse::<u32>().unwrap_or(1);
1193 }
1194 }
1195 _ => {}
1196 }
1197 }
1198 in_attribute_value = false;
1199 current_attribute_name = None;
1200 }
1201 }
1202 Ok(Event::End(ref e)) if e.name().as_ref() == b"Message" => {
1203 in_message = false;
1204
1205 if let (Some(body_base64), Some(receipt_handle)) =
1207 (current_body.as_ref(), current_receipt_handle.as_ref())
1208 {
1209 use base64::{engine::general_purpose::STANDARD, Engine};
1211 let body_bytes = STANDARD.decode(body_base64).map_err(|e| {
1212 AwsError::SerializationError(format!("Base64 decode failed: {}", e))
1213 })?;
1214 let body = bytes::Bytes::from(body_bytes);
1215
1216 use std::str::FromStr;
1218 let message_id = current_message_id
1219 .as_ref()
1220 .and_then(|id| MessageId::from_str(id).ok())
1221 .unwrap_or_default();
1222
1223 let session_id = current_session_id
1225 .as_ref()
1226 .and_then(|id| SessionId::new(id.clone()).ok());
1227
1228 let handle_with_queue = format!("{}|{}", queue.as_str(), receipt_handle);
1231 let expires_at = Timestamp::now();
1232 let receipt =
1233 ReceiptHandle::new(handle_with_queue, expires_at, ProviderType::AwsSqs);
1234
1235 let received_message = ReceivedMessage {
1237 message_id,
1238 body,
1239 attributes: HashMap::new(),
1240 session_id,
1241 correlation_id: None,
1242 receipt_handle: receipt,
1243 delivery_count: current_delivery_count,
1244 first_delivered_at: Timestamp::now(),
1245 delivered_at: Timestamp::now(),
1246 };
1247
1248 messages.push(received_message);
1249 }
1250 }
1251 Ok(Event::Eof) => break,
1252 Err(e) => {
1253 return Err(AwsError::SerializationError(format!(
1254 "XML parsing error: {}",
1255 e
1256 )))
1257 }
1258 _ => {}
1259 }
1260 buf.clear();
1261 }
1262
1263 Ok(messages)
1264 }
1265
1266 fn is_fifo_queue(queue_name: &QueueName) -> bool {
1268 queue_name.as_str().ends_with(".fifo")
1269 }
1270}
1271
1272impl fmt::Debug for AwsSqsProvider {
1273 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1274 f.debug_struct("AwsSqsProvider")
1275 .field("config", &self.config)
1276 .field("queue_url_cache_size", &"<redacted>")
1277 .finish()
1278 }
1279}
1280
1281#[async_trait]
1282impl QueueProvider for AwsSqsProvider {
1283 async fn send_message(
1284 &self,
1285 queue: &QueueName,
1286 message: &Message,
1287 ) -> Result<MessageId, QueueError> {
1288 let queue_url = self
1289 .get_queue_url(queue)
1290 .await
1291 .map_err(|e| e.to_queue_error())?;
1292
1293 use base64::{engine::general_purpose::STANDARD, Engine};
1295 let body_base64 = STANDARD.encode(&message.body);
1296
1297 if body_base64.len() > 256 * 1024 {
1299 return Err(AwsError::MessageTooLarge {
1300 size: body_base64.len(),
1301 max_size: 256 * 1024,
1302 }
1303 .to_queue_error());
1304 }
1305
1306 let mut params = HashMap::new();
1308 params.insert("Action".to_string(), "SendMessage".to_string());
1309 params.insert("Version".to_string(), "2012-11-05".to_string());
1310 params.insert("QueueUrl".to_string(), queue_url.clone());
1311 params.insert("MessageBody".to_string(), body_base64);
1312
1313 if Self::is_fifo_queue(queue) {
1315 if let Some(ref session_id) = message.session_id {
1316 params.insert(
1317 "MessageGroupId".to_string(),
1318 session_id.as_str().to_string(),
1319 );
1320 let dedup_id = uuid::Uuid::new_v4().to_string();
1322 params.insert("MessageDeduplicationId".to_string(), dedup_id);
1323 } else {
1324 return Err(QueueError::ValidationError(
1326 crate::error::ValidationError::Required {
1327 field: "session_id".to_string(),
1328 },
1329 ));
1330 }
1331 }
1332
1333 let response = self
1335 .make_request("POST", "/", ¶ms, "")
1336 .await
1337 .map_err(|e| e.to_queue_error())?;
1338
1339 let message_id = self
1341 .parse_send_message_response(&response)
1342 .map_err(|e| e.to_queue_error())?;
1343
1344 Ok(message_id)
1345 }
1346
1347 async fn send_messages(
1348 &self,
1349 queue: &QueueName,
1350 messages: &[Message],
1351 ) -> Result<Vec<MessageId>, QueueError> {
1352 if messages.is_empty() {
1353 return Ok(Vec::new());
1354 }
1355
1356 let max_batch = self.max_batch_size() as usize;
1358 let mut all_message_ids = Vec::new();
1359
1360 for chunk in messages.chunks(max_batch) {
1362 let message_ids = self.send_messages_batch(queue, chunk).await?;
1363 all_message_ids.extend(message_ids);
1364 }
1365
1366 Ok(all_message_ids)
1367 }
1368
1369 async fn receive_message(
1370 &self,
1371 queue: &QueueName,
1372 timeout: Duration,
1373 ) -> Result<Option<ReceivedMessage>, QueueError> {
1374 let messages = self.receive_messages(queue, 1, timeout).await?;
1375 Ok(messages.into_iter().next())
1376 }
1377
1378 async fn receive_messages(
1379 &self,
1380 queue: &QueueName,
1381 max_messages: u32,
1382 timeout: Duration,
1383 ) -> Result<Vec<ReceivedMessage>, QueueError> {
1384 let queue_url = self
1385 .get_queue_url(queue)
1386 .await
1387 .map_err(|e| e.to_queue_error())?;
1388
1389 let wait_time_seconds = timeout.num_seconds().clamp(0, 20); let mut params = HashMap::new();
1394 params.insert("Action".to_string(), "ReceiveMessage".to_string());
1395 params.insert("Version".to_string(), "2012-11-05".to_string());
1396 params.insert("QueueUrl".to_string(), queue_url);
1397 params.insert(
1398 "MaxNumberOfMessages".to_string(),
1399 max_messages.min(10).to_string(), );
1401 params.insert("WaitTimeSeconds".to_string(), wait_time_seconds.to_string());
1402 params.insert("AttributeName.1".to_string(), "All".to_string()); let response = self
1406 .make_request("POST", "/", ¶ms, "")
1407 .await
1408 .map_err(|e| e.to_queue_error())?;
1409
1410 let messages = self
1412 .parse_receive_message_response(&response, queue)
1413 .map_err(|e| e.to_queue_error())?;
1414
1415 Ok(messages)
1416 }
1417
1418 async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1419 let handle_str = receipt.handle();
1427 let parts: Vec<&str> = handle_str.split('|').collect();
1428
1429 if parts.len() != 2 {
1430 return Err(QueueError::InvalidReceipt {
1431 receipt: handle_str.to_string(),
1432 });
1433 }
1434
1435 let queue_name =
1436 QueueName::new(parts[0].to_string()).map_err(QueueError::ValidationError)?;
1437 let receipt_token = parts[1];
1438
1439 let queue_url = self
1441 .get_queue_url(&queue_name)
1442 .await
1443 .map_err(|e| e.to_queue_error())?;
1444
1445 let mut params = HashMap::new();
1447 params.insert("Action".to_string(), "DeleteMessage".to_string());
1448 params.insert("Version".to_string(), "2012-11-05".to_string());
1449 params.insert("QueueUrl".to_string(), queue_url);
1450 params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
1451
1452 let _response = self
1454 .make_request("POST", "/", ¶ms, "")
1455 .await
1456 .map_err(|e| e.to_queue_error())?;
1457
1458 Ok(())
1460 }
1461
1462 async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
1463 let handle_str = receipt.handle();
1465 let parts: Vec<&str> = handle_str.split('|').collect();
1466
1467 if parts.len() != 2 {
1468 return Err(QueueError::InvalidReceipt {
1469 receipt: handle_str.to_string(),
1470 });
1471 }
1472
1473 let queue_name =
1474 QueueName::new(parts[0].to_string()).map_err(QueueError::ValidationError)?;
1475 let receipt_token = parts[1];
1476
1477 let queue_url = self
1479 .get_queue_url(&queue_name)
1480 .await
1481 .map_err(|e| e.to_queue_error())?;
1482
1483 let mut params = HashMap::new();
1486 params.insert("Action".to_string(), "ChangeMessageVisibility".to_string());
1487 params.insert("Version".to_string(), "2012-11-05".to_string());
1488 params.insert("QueueUrl".to_string(), queue_url);
1489 params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
1490 params.insert("VisibilityTimeout".to_string(), "0".to_string());
1491
1492 let _response = self
1494 .make_request("POST", "/", ¶ms, "")
1495 .await
1496 .map_err(|e| e.to_queue_error())?;
1497
1498 Ok(())
1500 }
1501
1502 async fn dead_letter_message(
1503 &self,
1504 receipt: &ReceiptHandle,
1505 _reason: &str,
1506 ) -> Result<(), QueueError> {
1507 self.complete_message(receipt).await
1510 }
1511
1512 async fn create_session_client(
1513 &self,
1514 queue: &QueueName,
1515 session_id: Option<SessionId>,
1516 ) -> Result<Box<dyn SessionProvider>, QueueError> {
1517 if !Self::is_fifo_queue(queue) {
1519 return Err(AwsError::SessionsNotSupported.to_queue_error());
1520 }
1521
1522 let queue_url = self
1524 .get_queue_url(queue)
1525 .await
1526 .map_err(|e| e.to_queue_error())?;
1527
1528 let session_id = session_id.ok_or_else(|| {
1530 QueueError::ValidationError(crate::error::ValidationError::Required {
1531 field: "session_id".to_string(),
1532 })
1533 })?;
1534
1535 Ok(Box::new(AwsSessionProvider::new(
1536 self.http_client.clone(),
1537 AwsCredentialProvider::new(
1538 self.http_client.clone(),
1539 self.config.access_key_id.clone(),
1540 self.config.secret_access_key.clone(),
1541 ),
1542 self.config.region.clone(),
1543 self.endpoint.clone(),
1544 queue_url,
1545 queue.clone(),
1546 session_id,
1547 )))
1548 }
1549
1550 fn provider_type(&self) -> ProviderType {
1551 ProviderType::AwsSqs
1552 }
1553
1554 fn supports_sessions(&self) -> SessionSupport {
1555 SessionSupport::Emulated
1556 }
1557
1558 fn supports_batching(&self) -> bool {
1559 true
1560 }
1561
1562 fn max_batch_size(&self) -> u32 {
1563 10 }
1565}
1566
1567impl AwsSqsProvider {
1569 async fn send_messages_batch(
1571 &self,
1572 queue: &QueueName,
1573 messages: &[Message],
1574 ) -> Result<Vec<MessageId>, QueueError> {
1575 if messages.is_empty() {
1576 return Ok(Vec::new());
1577 }
1578
1579 if messages.len() > 10 {
1581 return Err(QueueError::ValidationError(
1582 crate::error::ValidationError::OutOfRange {
1583 field: "messages".to_string(),
1584 message: format!("Batch size {} exceeds AWS SQS limit of 10", messages.len()),
1585 },
1586 ));
1587 }
1588
1589 let queue_url = self
1590 .get_queue_url(queue)
1591 .await
1592 .map_err(|e| e.to_queue_error())?;
1593
1594 let mut params = HashMap::new();
1596 params.insert("Action".to_string(), "SendMessageBatch".to_string());
1597 params.insert("Version".to_string(), "2012-11-05".to_string());
1598 params.insert("QueueUrl".to_string(), queue_url.clone());
1599
1600 use base64::{engine::general_purpose::STANDARD, Engine};
1602
1603 for (idx, message) in messages.iter().enumerate() {
1605 let entry_id = format!("msg-{}", idx);
1606 let body_base64 = STANDARD.encode(&message.body);
1607
1608 if body_base64.len() > 256 * 1024 {
1610 return Err(AwsError::MessageTooLarge {
1611 size: body_base64.len(),
1612 max_size: 256 * 1024,
1613 }
1614 .to_queue_error());
1615 }
1616
1617 params.insert(
1618 format!("SendMessageBatchRequestEntry.{}.Id", idx + 1),
1619 entry_id,
1620 );
1621 params.insert(
1622 format!("SendMessageBatchRequestEntry.{}.MessageBody", idx + 1),
1623 body_base64,
1624 );
1625
1626 if Self::is_fifo_queue(queue) {
1628 if let Some(ref session_id) = message.session_id {
1630 params.insert(
1631 format!("SendMessageBatchRequestEntry.{}.MessageGroupId", idx + 1),
1632 session_id.as_str().to_string(),
1633 );
1634 }
1635
1636 use sha2::{Digest, Sha256};
1639 let mut hasher = Sha256::new();
1640 hasher.update(&message.body);
1641 if let Some(ref session_id) = message.session_id {
1642 hasher.update(session_id.as_str().as_bytes());
1643 }
1644 let hash = hex::encode(hasher.finalize());
1645 params.insert(
1646 format!(
1647 "SendMessageBatchRequestEntry.{}.MessageDeduplicationId",
1648 idx + 1
1649 ),
1650 hash,
1651 );
1652 }
1653 }
1654
1655 let response = self
1657 .make_request("POST", "/", ¶ms, "")
1658 .await
1659 .map_err(|e| e.to_queue_error())?;
1660
1661 self.parse_send_message_batch_response(&response)
1663 .map_err(|e| e.to_queue_error())
1664 }
1665
1666 fn parse_send_message_batch_response(&self, xml: &str) -> Result<Vec<MessageId>, AwsError> {
1668 use quick_xml::events::Event;
1669 use quick_xml::Reader;
1670
1671 let mut reader = Reader::from_str(xml);
1672 reader.config_mut().trim_text(true);
1673
1674 let mut message_ids = Vec::new();
1675 let mut in_successful = false;
1676 let mut in_message_id = false;
1677 let mut buf = Vec::new();
1678
1679 loop {
1680 match reader.read_event_into(&mut buf) {
1681 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1682 b"SendMessageBatchResultEntry" => in_successful = true,
1683 b"MessageId" if in_successful => in_message_id = true,
1684 _ => {}
1685 },
1686 Ok(Event::Text(e)) if in_message_id => {
1687 let msg_id = e.decode().map(|s| s.into_owned()).map_err(|e| {
1688 AwsError::SerializationError(format!("Failed to parse XML: {}", e))
1689 })?;
1690
1691 use std::str::FromStr;
1693 let message_id =
1694 MessageId::from_str(&msg_id).unwrap_or_else(|_| MessageId::new());
1695 message_ids.push(message_id);
1696 in_message_id = false;
1697 }
1698 Ok(Event::End(ref e)) if e.name().as_ref() == b"SendMessageBatchResultEntry" => {
1699 in_successful = false;
1700 }
1701 Ok(Event::Eof) => break,
1702 Err(e) => {
1703 return Err(AwsError::SerializationError(format!(
1704 "XML parsing error: {}",
1705 e
1706 )))
1707 }
1708 _ => {}
1709 }
1710 buf.clear();
1711 }
1712
1713 Ok(message_ids)
1714 }
1715}
1716
1717pub struct AwsSessionProvider {
1726 http_client: HttpClient,
1727 credential_provider: AwsCredentialProvider,
1728 region: String,
1729 endpoint: String,
1730 queue_url: String,
1731 queue_name: QueueName,
1732 session_id: SessionId,
1733}
1734
1735impl AwsSessionProvider {
1736 fn new(
1738 http_client: HttpClient,
1739 credential_provider: AwsCredentialProvider,
1740 region: String,
1741 endpoint: String,
1742 queue_url: String,
1743 queue_name: QueueName,
1744 session_id: SessionId,
1745 ) -> Self {
1746 Self {
1747 http_client,
1748 credential_provider,
1749 region,
1750 endpoint,
1751 queue_url,
1752 queue_name,
1753 session_id,
1754 }
1755 }
1756
1757 async fn get_queue_url(&self) -> Result<String, AwsError> {
1759 Ok(self.queue_url.clone())
1760 }
1761
1762 async fn make_request(
1764 &self,
1765 method: &str,
1766 path: &str,
1767 params: &HashMap<String, String>,
1768 body: &str,
1769 ) -> Result<String, AwsError> {
1770 use reqwest::header;
1771
1772 let credentials = self.credential_provider.get_credentials().await?;
1774
1775 let signer = AwsV4Signer::new(
1777 credentials.access_key_id.clone(),
1778 credentials.secret_access_key.clone(),
1779 self.region.clone(),
1780 );
1781
1782 let query_string = if params.is_empty() {
1784 String::new()
1785 } else {
1786 let mut pairs: Vec<String> = params
1787 .iter()
1788 .map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
1789 .collect();
1790 pairs.sort();
1791 pairs.join("&")
1792 };
1793
1794 let url = if query_string.is_empty() {
1795 format!("{}{}", self.endpoint, path)
1796 } else {
1797 format!("{}{}?{}", self.endpoint, path, query_string)
1798 };
1799
1800 let mut request_builder = self.http_client.request(
1802 method
1803 .parse()
1804 .map_err(|e| AwsError::NetworkError(format!("Invalid HTTP method: {}", e)))?,
1805 &url,
1806 );
1807
1808 let timestamp = Utc::now();
1810 let host = self
1811 .endpoint
1812 .trim_start_matches("https://")
1813 .trim_start_matches("http://");
1814 let mut signed_headers = signer.sign_request(method, host, path, params, body, ×tamp);
1815
1816 if let Some(session_token) = &credentials.session_token {
1818 signed_headers.insert("X-Amz-Security-Token".to_string(), session_token.clone());
1819 }
1820
1821 for (key, value) in signed_headers {
1822 request_builder = request_builder.header(key, value);
1823 }
1824
1825 if !body.is_empty() {
1827 request_builder = request_builder
1828 .header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
1829 .body(body.to_string());
1830 }
1831
1832 let response = request_builder
1834 .send()
1835 .await
1836 .map_err(|e| AwsError::NetworkError(format!("HTTP request failed: {}", e)))?;
1837
1838 let status = response.status();
1839 let response_text = response
1840 .text()
1841 .await
1842 .map_err(|e| AwsError::NetworkError(format!("Failed to read response: {}", e)))?;
1843
1844 if !status.is_success() {
1846 return Err(self.parse_error_response(&response_text, status.as_u16()));
1847 }
1848
1849 Ok(response_text)
1850 }
1851
1852 fn parse_error_response(&self, xml: &str, status_code: u16) -> AwsError {
1854 use quick_xml::events::Event;
1855 use quick_xml::Reader;
1856
1857 let mut reader = Reader::from_str(xml);
1858 reader.config_mut().trim_text(true);
1859
1860 let mut error_code = None;
1861 let mut error_message = None;
1862 let mut in_error = false;
1863 let mut in_code = false;
1864 let mut in_message = false;
1865 let mut buf = Vec::new();
1866
1867 loop {
1868 match reader.read_event_into(&mut buf) {
1869 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1870 b"Error" => in_error = true,
1871 b"Code" if in_error => in_code = true,
1872 b"Message" if in_error => in_message = true,
1873 _ => {}
1874 },
1875 Ok(Event::Text(e)) => {
1876 if in_code {
1877 error_code = e.decode().ok().and_then(|s| {
1878 quick_xml::escape::unescape(&s).ok().map(|u| u.into_owned())
1879 });
1880 in_code = false;
1881 } else if in_message {
1882 error_message = e.decode().ok().and_then(|s| {
1883 quick_xml::escape::unescape(&s).ok().map(|u| u.into_owned())
1884 });
1885 in_message = false;
1886 }
1887 }
1888 Ok(Event::Eof) => break,
1889 Err(_) => break,
1890 _ => {}
1891 }
1892 buf.clear();
1893 }
1894
1895 match error_code.as_deref() {
1897 Some("InvalidParameterValue") | Some("MissingParameter") => AwsError::ServiceError(
1898 error_message.unwrap_or_else(|| "Invalid parameter".to_string()),
1899 ),
1900 Some("AccessDenied") | Some("InvalidClientTokenId") | Some("SignatureDoesNotMatch") => {
1901 AwsError::Authentication(
1902 error_message.unwrap_or_else(|| "Authentication failed".to_string()),
1903 )
1904 }
1905 Some("AWS.SimpleQueueService.NonExistentQueue") | Some("QueueDoesNotExist") => {
1906 AwsError::QueueNotFound(
1907 error_message.unwrap_or_else(|| "Queue not found".to_string()),
1908 )
1909 }
1910 _ => {
1911 if status_code >= 500 {
1912 AwsError::ServiceError(
1913 error_message.unwrap_or_else(|| "Service error".to_string()),
1914 )
1915 } else {
1916 AwsError::ServiceError(
1917 error_message.unwrap_or_else(|| format!("HTTP {}", status_code)),
1918 )
1919 }
1920 }
1921 }
1922 }
1923
1924 fn parse_receive_message_response(
1926 &self,
1927 xml: &str,
1928 queue: &QueueName,
1929 ) -> Result<Vec<ReceivedMessage>, AwsError> {
1930 use quick_xml::events::Event;
1931 use quick_xml::Reader;
1932
1933 let mut reader = Reader::from_str(xml);
1934 reader.config_mut().trim_text(true);
1935
1936 let mut messages = Vec::new();
1937 let mut in_message = false;
1938 let mut current_message_id: Option<String> = None;
1939 let mut current_receipt_handle: Option<String> = None;
1940 let mut current_body: Option<String> = None;
1941 let mut current_session_id: Option<String> = None;
1942 let mut current_delivery_count: u32 = 1;
1943
1944 let mut in_message_id = false;
1945 let mut in_receipt_handle = false;
1946 let mut in_body = false;
1947 let mut in_attribute_name = false;
1948 let mut in_attribute_value = false;
1949 let mut current_attribute_name: Option<String> = None;
1950
1951 let mut buf = Vec::new();
1952
1953 loop {
1954 match reader.read_event_into(&mut buf) {
1955 Ok(Event::Start(ref e)) => match e.name().as_ref() {
1956 b"Message" => {
1957 in_message = true;
1958 current_message_id = None;
1959 current_receipt_handle = None;
1960 current_body = None;
1961 current_session_id = None;
1962 current_delivery_count = 1;
1963 }
1964 b"MessageId" if in_message => in_message_id = true,
1965 b"ReceiptHandle" if in_message => in_receipt_handle = true,
1966 b"Body" if in_message => in_body = true,
1967 b"Name" if in_message => in_attribute_name = true,
1968 b"Value" if in_message => in_attribute_value = true,
1969 _ => {}
1970 },
1971 Ok(Event::Text(e)) => {
1972 let text = e.decode().ok().map(|s| s.into_owned());
1973 if in_message_id {
1974 current_message_id = text;
1975 in_message_id = false;
1976 } else if in_receipt_handle {
1977 current_receipt_handle = text;
1978 in_receipt_handle = false;
1979 } else if in_body {
1980 current_body = text;
1981 in_body = false;
1982 } else if in_attribute_name {
1983 current_attribute_name = text;
1984 in_attribute_name = false;
1985 } else if in_attribute_value {
1986 if let Some(ref attr_name) = current_attribute_name {
1987 match attr_name.as_str() {
1988 "MessageGroupId" => current_session_id = text,
1989 "ApproximateReceiveCount" => {
1990 if let Some(count_str) = text {
1991 current_delivery_count =
1992 count_str.parse::<u32>().unwrap_or(1);
1993 }
1994 }
1995 _ => {}
1996 }
1997 }
1998 in_attribute_value = false;
1999 current_attribute_name = None;
2000 }
2001 }
2002 Ok(Event::End(ref e)) if e.name().as_ref() == b"Message" => {
2003 in_message = false;
2004
2005 if let (Some(body_base64), Some(receipt_handle)) =
2006 (current_body.as_ref(), current_receipt_handle.as_ref())
2007 {
2008 use base64::{engine::general_purpose::STANDARD, Engine};
2009 let body_bytes = STANDARD.decode(body_base64).map_err(|e| {
2010 AwsError::SerializationError(format!("Base64 decode failed: {}", e))
2011 })?;
2012 let body = bytes::Bytes::from(body_bytes);
2013
2014 use std::str::FromStr;
2015 let message_id = current_message_id
2016 .as_ref()
2017 .and_then(|id| MessageId::from_str(id).ok())
2018 .unwrap_or_default();
2019
2020 let session_id = current_session_id
2021 .as_ref()
2022 .and_then(|id| SessionId::new(id.clone()).ok());
2023
2024 let handle_with_queue = format!("{}|{}", queue.as_str(), receipt_handle);
2025 let expires_at = Timestamp::now();
2026 let receipt =
2027 ReceiptHandle::new(handle_with_queue, expires_at, ProviderType::AwsSqs);
2028
2029 let received_message = ReceivedMessage {
2030 message_id,
2031 body,
2032 attributes: HashMap::new(),
2033 session_id,
2034 correlation_id: None,
2035 receipt_handle: receipt,
2036 delivery_count: current_delivery_count,
2037 first_delivered_at: Timestamp::now(),
2038 delivered_at: Timestamp::now(),
2039 };
2040
2041 messages.push(received_message);
2042 }
2043 }
2044 Ok(Event::Eof) => break,
2045 Err(e) => {
2046 return Err(AwsError::SerializationError(format!(
2047 "XML parsing error: {}",
2048 e
2049 )))
2050 }
2051 _ => {}
2052 }
2053 buf.clear();
2054 }
2055
2056 Ok(messages)
2057 }
2058}
2059
2060impl fmt::Debug for AwsSessionProvider {
2061 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
2062 f.debug_struct("AwsSessionProvider")
2063 .field("queue_name", &self.queue_name)
2064 .field("session_id", &self.session_id)
2065 .finish()
2066 }
2067}
2068
2069#[async_trait]
2070impl SessionProvider for AwsSessionProvider {
2071 async fn receive_message(
2072 &self,
2073 timeout: Duration,
2074 ) -> Result<Option<ReceivedMessage>, QueueError> {
2075 let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
2078
2079 let mut params = HashMap::new();
2081 params.insert("Action".to_string(), "ReceiveMessage".to_string());
2082 params.insert("Version".to_string(), "2012-11-05".to_string());
2083 params.insert("QueueUrl".to_string(), queue_url);
2084 params.insert("MaxNumberOfMessages".to_string(), "1".to_string());
2085 params.insert(
2086 "WaitTimeSeconds".to_string(),
2087 timeout.num_seconds().clamp(0, 20).to_string(),
2088 );
2089 params.insert("AttributeName.1".to_string(), "All".to_string());
2090
2091 let response = self
2093 .make_request("POST", "/", ¶ms, "")
2094 .await
2095 .map_err(|e| e.to_queue_error())?;
2096
2097 let messages = self
2099 .parse_receive_message_response(&response, &self.queue_name)
2100 .map_err(|e| e.to_queue_error())?;
2101
2102 Ok(messages
2104 .into_iter()
2105 .find(|msg| msg.session_id.as_ref() == Some(&self.session_id)))
2106 }
2107
2108 async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
2109 let handle_str = receipt.handle();
2111 let parts: Vec<&str> = handle_str.split('|').collect();
2112
2113 if parts.len() != 2 {
2114 return Err(QueueError::InvalidReceipt {
2115 receipt: handle_str.to_string(),
2116 });
2117 }
2118
2119 let receipt_token = parts[1];
2120 let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
2121
2122 let mut params = HashMap::new();
2124 params.insert("Action".to_string(), "DeleteMessage".to_string());
2125 params.insert("Version".to_string(), "2012-11-05".to_string());
2126 params.insert("QueueUrl".to_string(), queue_url);
2127 params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
2128
2129 self.make_request("POST", "/", ¶ms, "")
2131 .await
2132 .map_err(|e| e.to_queue_error())?;
2133
2134 Ok(())
2135 }
2136
2137 async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
2138 let handle_str = receipt.handle();
2140 let parts: Vec<&str> = handle_str.split('|').collect();
2141
2142 if parts.len() != 2 {
2143 return Err(QueueError::InvalidReceipt {
2144 receipt: handle_str.to_string(),
2145 });
2146 }
2147
2148 let receipt_token = parts[1];
2149 let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
2150
2151 let mut params = HashMap::new();
2153 params.insert("Action".to_string(), "ChangeMessageVisibility".to_string());
2154 params.insert("Version".to_string(), "2012-11-05".to_string());
2155 params.insert("QueueUrl".to_string(), queue_url);
2156 params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
2157 params.insert("VisibilityTimeout".to_string(), "0".to_string());
2158
2159 self.make_request("POST", "/", ¶ms, "")
2161 .await
2162 .map_err(|e| e.to_queue_error())?;
2163
2164 Ok(())
2165 }
2166
2167 async fn dead_letter_message(
2168 &self,
2169 receipt: &ReceiptHandle,
2170 _reason: &str,
2171 ) -> Result<(), QueueError> {
2172 self.complete_message(receipt).await
2174 }
2175
2176 async fn renew_session_lock(&self) -> Result<(), QueueError> {
2177 Ok(())
2179 }
2180
2181 async fn close_session(&self) -> Result<(), QueueError> {
2182 Ok(())
2184 }
2185
2186 fn session_id(&self) -> &SessionId {
2187 &self.session_id
2188 }
2189
2190 fn session_expires_at(&self) -> Timestamp {
2191 Timestamp::from_datetime(chrono::Utc::now() + chrono::Duration::days(365))
2193 }
2194}