use crate::client::{QueueProvider, SessionProvider};
use crate::error::{ConfigurationError, QueueError, SerializationError};
use crate::message::{
Message, MessageId, QueueName, ReceiptHandle, ReceivedMessage, SessionId, Timestamp,
};
use crate::provider::{AwsSqsConfig, ProviderType, SessionSupport};
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use hmac::{Hmac, KeyInit, Mac};
use reqwest::Client as HttpClient;
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
use tokio::sync::RwLock;
#[cfg(test)]
#[path = "aws_tests.rs"]
mod tests;
#[derive(Debug, thiserror::Error)]
pub enum AwsError {
#[error("Authentication failed: {0}")]
Authentication(String),
#[error("Network error: {0}")]
NetworkError(String),
#[error("SQS service error: {0}")]
ServiceError(String),
#[error("Queue not found: {0}")]
QueueNotFound(String),
#[error("Invalid receipt handle: {0}")]
InvalidReceipt(String),
#[error("Message too large: {size} bytes (max: {max_size})")]
MessageTooLarge { size: usize, max_size: usize },
#[error("Invalid configuration: {0}")]
ConfigurationError(String),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("Sessions not supported on standard queues")]
SessionsNotSupported,
}
impl AwsError {
pub fn is_transient(&self) -> bool {
match self {
Self::Authentication(_) => false,
Self::NetworkError(_) => true,
Self::ServiceError(_) => true, Self::QueueNotFound(_) => false,
Self::InvalidReceipt(_) => false,
Self::MessageTooLarge { .. } => false,
Self::ConfigurationError(_) => false,
Self::SerializationError(_) => false,
Self::SessionsNotSupported => false,
}
}
pub fn to_queue_error(self) -> QueueError {
match self {
Self::Authentication(msg) => QueueError::AuthenticationFailed { message: msg },
Self::NetworkError(msg) => QueueError::ConnectionFailed { message: msg },
Self::ServiceError(msg) => QueueError::ProviderError {
provider: "AwsSqs".to_string(),
code: "ServiceError".to_string(),
message: msg,
},
Self::QueueNotFound(queue) => QueueError::QueueNotFound { queue_name: queue },
Self::InvalidReceipt(receipt) => QueueError::InvalidReceipt { receipt },
Self::MessageTooLarge { size, max_size } => {
QueueError::MessageTooLarge { size, max_size }
}
Self::ConfigurationError(msg) => {
QueueError::ConfigurationError(ConfigurationError::Invalid { message: msg })
}
Self::SerializationError(msg) => QueueError::SerializationError(
SerializationError::JsonError(serde_json::Error::io(std::io::Error::new(
std::io::ErrorKind::InvalidData,
msg,
))),
),
Self::SessionsNotSupported => QueueError::ProviderError {
provider: "AwsSqs".to_string(),
code: "SessionsNotSupported".to_string(),
message:
"Standard queues do not support session-based operations. Use FIFO queues."
.to_string(),
},
}
}
}
type HmacSha256 = Hmac<Sha256>;
#[derive(Clone)]
struct AwsV4Signer {
access_key: String,
secret_key: String,
region: String,
service: String,
}
impl AwsV4Signer {
fn new(access_key: String, secret_key: String, region: String) -> Self {
Self {
access_key,
secret_key,
region,
service: "sqs".to_string(),
}
}
fn sign_request(
&self,
method: &str,
host: &str,
path: &str,
query_params: &HashMap<String, String>,
body: &str,
timestamp: &DateTime<Utc>,
) -> HashMap<String, String> {
let date_stamp = timestamp.format("%Y%m%d").to_string();
let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string();
let canonical_uri = path;
let mut canonical_query_string = query_params
.iter()
.map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
.collect::<Vec<_>>();
canonical_query_string.sort();
let canonical_query_string = canonical_query_string.join("&");
let canonical_headers = format!("host:{}\nx-amz-date:{}\n", host, amz_date);
let signed_headers = "host;x-amz-date";
let payload_hash = hex::encode(Sha256::digest(body.as_bytes()));
let canonical_request = format!(
"{}\n{}\n{}\n{}\n{}\n{}",
method,
canonical_uri,
canonical_query_string,
canonical_headers,
signed_headers,
payload_hash
);
let algorithm = "AWS4-HMAC-SHA256";
let credential_scope = format!(
"{}/{}/{}/aws4_request",
date_stamp, self.region, self.service
);
let canonical_request_hash = hex::encode(Sha256::digest(canonical_request.as_bytes()));
let string_to_sign = format!(
"{}\n{}\n{}\n{}",
algorithm, amz_date, credential_scope, canonical_request_hash
);
let signature = self.calculate_signature(&string_to_sign, &date_stamp);
let authorization_header = format!(
"{} Credential={}/{}, SignedHeaders={}, Signature={}",
algorithm, self.access_key, credential_scope, signed_headers, signature
);
let mut headers = HashMap::new();
headers.insert("Authorization".to_string(), authorization_header);
headers.insert("x-amz-date".to_string(), amz_date);
headers.insert("host".to_string(), host.to_string());
headers
}
fn calculate_signature(&self, string_to_sign: &str, date_stamp: &str) -> String {
let k_secret = format!("AWS4{}", self.secret_key);
let k_date = self.hmac_sha256(k_secret.as_bytes(), date_stamp.as_bytes());
let k_region = self.hmac_sha256(&k_date, self.region.as_bytes());
let k_service = self.hmac_sha256(&k_region, self.service.as_bytes());
let k_signing = self.hmac_sha256(&k_service, b"aws4_request");
let signature = self.hmac_sha256(&k_signing, string_to_sign.as_bytes());
hex::encode(signature)
}
fn hmac_sha256(&self, key: &[u8], data: &[u8]) -> Vec<u8> {
let mut mac = HmacSha256::new_from_slice(key).expect("HMAC can take key of any size");
mac.update(data);
mac.finalize().into_bytes().to_vec()
}
}
#[derive(Debug, Clone)]
struct AwsCredentials {
access_key_id: String,
secret_access_key: String,
session_token: Option<String>,
expiration: DateTime<Utc>,
}
impl AwsCredentials {
fn is_expired(&self) -> bool {
let now = Utc::now();
let buffer = Duration::minutes(5);
self.expiration - buffer <= now
}
}
struct AwsCredentialProvider {
http_client: HttpClient,
cached_credentials: Arc<RwLock<Option<AwsCredentials>>>,
explicit_config: Option<(String, String)>, }
impl AwsCredentialProvider {
fn new(
http_client: HttpClient,
access_key_id: Option<String>,
secret_access_key: Option<String>,
) -> Self {
let explicit_config = match (access_key_id, secret_access_key) {
(Some(key_id), Some(secret)) => Some((key_id, secret)),
_ => None,
};
Self {
http_client,
cached_credentials: Arc::new(RwLock::new(None)),
explicit_config,
}
}
async fn get_credentials(&self) -> Result<AwsCredentials, AwsError> {
{
let cache = self.cached_credentials.read().await;
if let Some(creds) = cache.as_ref() {
if !creds.is_expired() {
return Ok(creds.clone());
}
}
}
let creds = self.fetch_credentials().await?;
{
let mut cache = self.cached_credentials.write().await;
*cache = Some(creds.clone());
}
Ok(creds)
}
async fn fetch_credentials(&self) -> Result<AwsCredentials, AwsError> {
if let Some((key_id, secret)) = &self.explicit_config {
return Ok(AwsCredentials {
access_key_id: key_id.clone(),
secret_access_key: secret.clone(),
session_token: None,
expiration: Utc::now() + Duration::days(365), });
}
if let Ok(creds) = self.fetch_from_environment() {
return Ok(creds);
}
if let Ok(creds) = self.fetch_from_ecs_metadata().await {
return Ok(creds);
}
if let Ok(creds) = self.fetch_from_ec2_metadata().await {
return Ok(creds);
}
Err(AwsError::Authentication(
"No credentials found in credential chain".to_string(),
))
}
fn fetch_from_environment(&self) -> Result<AwsCredentials, AwsError> {
let access_key_id = std::env::var("AWS_ACCESS_KEY_ID")
.map_err(|_| AwsError::Authentication("AWS_ACCESS_KEY_ID not set".to_string()))?;
let secret_access_key = std::env::var("AWS_SECRET_ACCESS_KEY")
.map_err(|_| AwsError::Authentication("AWS_SECRET_ACCESS_KEY not set".to_string()))?;
let session_token = std::env::var("AWS_SESSION_TOKEN").ok();
Ok(AwsCredentials {
access_key_id,
secret_access_key,
session_token,
expiration: Utc::now() + Duration::days(365), })
}
async fn fetch_from_ecs_metadata(&self) -> Result<AwsCredentials, AwsError> {
let relative_uri =
std::env::var("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").map_err(|_| {
AwsError::Authentication(
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI not set".to_string(),
)
})?;
let endpoint = format!("http://169.254.170.2{}", relative_uri);
let response = self
.http_client
.get(&endpoint)
.timeout(std::time::Duration::from_secs(2))
.send()
.await
.map_err(|e| {
AwsError::Authentication(format!("Failed to fetch ECS credentials: {}", e))
})?;
if !response.status().is_success() {
return Err(AwsError::Authentication(format!(
"ECS metadata returned error: {}",
response.status()
)));
}
let body = response
.text()
.await
.map_err(|e| AwsError::Authentication(format!("Failed to read ECS metadata: {}", e)))?;
self.parse_credentials_json(&body)
}
async fn fetch_from_ec2_metadata(&self) -> Result<AwsCredentials, AwsError> {
let token = self
.http_client
.put("http://169.254.169.254/latest/api/token")
.header("X-aws-ec2-metadata-token-ttl-seconds", "21600")
.timeout(std::time::Duration::from_secs(2))
.send()
.await
.map_err(|e| AwsError::Authentication(format!("Failed to get IMDSv2 token: {}", e)))?
.text()
.await
.map_err(|e| AwsError::Authentication(format!("Failed to read IMDSv2 token: {}", e)))?;
let role_name = self
.http_client
.get("http://169.254.169.254/latest/meta-data/iam/security-credentials/")
.header("X-aws-ec2-metadata-token", &token)
.timeout(std::time::Duration::from_secs(2))
.send()
.await
.map_err(|e| AwsError::Authentication(format!("Failed to fetch IAM role name: {}", e)))?
.text()
.await
.map_err(|e| {
AwsError::Authentication(format!("Failed to read IAM role name: {}", e))
})?;
let credentials_url = format!(
"http://169.254.169.254/latest/meta-data/iam/security-credentials/{}",
role_name.trim()
);
let response = self
.http_client
.get(&credentials_url)
.header("X-aws-ec2-metadata-token", &token)
.timeout(std::time::Duration::from_secs(2))
.send()
.await
.map_err(|e| {
AwsError::Authentication(format!("Failed to fetch EC2 credentials: {}", e))
})?;
if !response.status().is_success() {
return Err(AwsError::Authentication(format!(
"EC2 metadata returned error: {}",
response.status()
)));
}
let body = response
.text()
.await
.map_err(|e| AwsError::Authentication(format!("Failed to read EC2 metadata: {}", e)))?;
self.parse_credentials_json(&body)
}
fn parse_credentials_json(&self, json: &str) -> Result<AwsCredentials, AwsError> {
let access_key_id = Self::extract_json_field(json, "AccessKeyId")?;
let secret_access_key = Self::extract_json_field(json, "SecretAccessKey")?;
let session_token = Self::extract_json_field(json, "Token").ok();
let expiration_str = Self::extract_json_field(json, "Expiration")?;
let expiration = DateTime::parse_from_rfc3339(&expiration_str)
.map_err(|e| AwsError::Authentication(format!("Invalid expiration timestamp: {}", e)))?
.with_timezone(&Utc);
Ok(AwsCredentials {
access_key_id,
secret_access_key,
session_token,
expiration,
})
}
fn extract_json_field(json: &str, field: &str) -> Result<String, AwsError> {
let pattern = format!("\"{}\": \"", field);
let start = json.find(&pattern).ok_or_else(|| {
AwsError::Authentication(format!("Field '{}' not found in JSON", field))
})?;
let value_start = start + pattern.len();
let value_end = json[value_start..].find('"').ok_or_else(|| {
AwsError::Authentication(format!("Malformed JSON for field '{}'", field))
})? + value_start;
Ok(json[value_start..value_end].to_string())
}
}
pub struct AwsSqsProvider {
http_client: HttpClient,
credential_provider: AwsCredentialProvider,
config: AwsSqsConfig,
endpoint: String,
queue_url_cache: Arc<RwLock<HashMap<QueueName, String>>>,
}
impl AwsSqsProvider {
pub async fn new(config: AwsSqsConfig) -> Result<Self, AwsError> {
if config.region.is_empty() {
return Err(AwsError::ConfigurationError(
"Region cannot be empty".to_string(),
));
}
let endpoint = format!("https://sqs.{}.amazonaws.com", config.region);
let http_client = HttpClient::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.map_err(|e| AwsError::NetworkError(format!("Failed to create HTTP client: {}", e)))?;
let credential_provider = AwsCredentialProvider::new(
http_client.clone(),
config.access_key_id.clone(),
config.secret_access_key.clone(),
);
Ok(Self {
http_client,
credential_provider,
config,
endpoint,
queue_url_cache: Arc::new(RwLock::new(HashMap::new())),
})
}
async fn get_queue_url(&self, queue_name: &QueueName) -> Result<String, AwsError> {
{
let cache = self.queue_url_cache.read().await;
if let Some(url) = cache.get(queue_name) {
return Ok(url.clone());
}
}
let mut params = HashMap::new();
params.insert("Action".to_string(), "GetQueueUrl".to_string());
params.insert("QueueName".to_string(), queue_name.as_str().to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
let response = self.make_request("POST", "/", ¶ms, "").await?;
let queue_url = self.parse_queue_url_response(&response)?;
let mut cache = self.queue_url_cache.write().await;
cache.insert(queue_name.clone(), queue_url.clone());
Ok(queue_url)
}
async fn make_request(
&self,
method: &str,
path: &str,
query_params: &HashMap<String, String>,
body: &str,
) -> Result<String, AwsError> {
let credentials = self.credential_provider.get_credentials().await?;
let signer = AwsV4Signer::new(
credentials.access_key_id.clone(),
credentials.secret_access_key.clone(),
self.config.region.clone(),
);
let host = self
.endpoint
.strip_prefix("https://")
.unwrap_or(&self.endpoint);
let timestamp = Utc::now();
let mut auth_headers =
signer.sign_request(method, host, path, query_params, body, ×tamp);
if let Some(session_token) = &credentials.session_token {
auth_headers.insert("X-Amz-Security-Token".to_string(), session_token.clone());
}
let mut url = format!("{}{}", self.endpoint, path);
if !query_params.is_empty() {
let query_string = query_params
.iter()
.map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
.collect::<Vec<_>>()
.join("&");
url = format!("{}?{}", url, query_string);
}
let mut request = self.http_client.request(
method
.parse()
.map_err(|e| AwsError::ConfigurationError(format!("Invalid HTTP method: {}", e)))?,
&url,
);
for (key, value) in auth_headers {
request = request.header(&key, value);
}
if !body.is_empty() {
request = request.body(body.to_string());
}
let response = request.send().await.map_err(|e| {
if e.is_timeout() {
AwsError::NetworkError(format!("Request timeout: {}", e))
} else if e.is_connect() {
AwsError::NetworkError(format!("Connection failed: {}", e))
} else {
AwsError::NetworkError(format!("HTTP request failed: {}", e))
}
})?;
let status = response.status();
let response_body = response
.text()
.await
.map_err(|e| AwsError::NetworkError(format!("Failed to read response body: {}", e)))?;
if !status.is_success() {
return Err(self.parse_error_response(&response_body, status.as_u16()));
}
Ok(response_body)
}
fn parse_queue_url_response(&self, xml: &str) -> Result<String, AwsError> {
use quick_xml::events::Event;
use quick_xml::Reader;
let mut reader = Reader::from_str(xml);
reader.config_mut().trim_text(true);
let mut in_queue_url = false;
let mut buf = Vec::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(ref e)) if e.name().as_ref() == b"QueueUrl" => {
in_queue_url = true;
}
Ok(Event::Text(e)) if in_queue_url => {
return e
.decode()
.map_err(|e| {
AwsError::SerializationError(format!("Failed to parse XML: {}", e))
})
.and_then(|s| {
quick_xml::escape::unescape(&s)
.map(|u| u.into_owned())
.map_err(|e| {
AwsError::SerializationError(format!(
"Failed to unescape XML: {}",
e
))
})
});
}
Ok(Event::Eof) => break,
Err(e) => {
return Err(AwsError::SerializationError(format!(
"XML parsing error: {}",
e
)))
}
_ => {}
}
buf.clear();
}
Err(AwsError::SerializationError(
"QueueUrl not found in response".to_string(),
))
}
fn parse_error_response(&self, xml: &str, status_code: u16) -> AwsError {
use quick_xml::events::Event;
use quick_xml::Reader;
let mut reader = Reader::from_str(xml);
reader.config_mut().trim_text(true);
let mut error_code = None;
let mut error_message = None;
let mut in_error = false;
let mut in_code = false;
let mut in_message = false;
let mut buf = Vec::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(ref e)) => match e.name().as_ref() {
b"Error" => in_error = true,
b"Code" if in_error => in_code = true,
b"Message" if in_error => in_message = true,
_ => {}
},
Ok(Event::Text(e)) => {
if in_code {
error_code = e.decode().ok().and_then(|s| {
quick_xml::escape::unescape(&s).ok().map(|u| u.into_owned())
});
in_code = false;
} else if in_message {
error_message = e.decode().ok().and_then(|s| {
quick_xml::escape::unescape(&s).ok().map(|u| u.into_owned())
});
in_message = false;
}
}
Ok(Event::End(ref e)) if e.name().as_ref() == b"Error" => {
in_error = false;
}
Ok(Event::Eof) => break,
Err(_) => break,
_ => {}
}
buf.clear();
}
let code = error_code.unwrap_or_else(|| "Unknown".to_string());
let message = error_message.unwrap_or_else(|| "Unknown error".to_string());
match code.as_str() {
"AWS.SimpleQueueService.NonExistentQueue" | "QueueDoesNotExist" => {
AwsError::QueueNotFound(message)
}
"InvalidClientTokenId" | "UnrecognizedClientException" | "SignatureDoesNotMatch" => {
AwsError::Authentication(format!("{}: {}", code, message))
}
"InvalidReceiptHandle" | "ReceiptHandleIsInvalid" => AwsError::InvalidReceipt(message),
_ if status_code == 401 || status_code == 403 => {
AwsError::Authentication(format!("{}: {}", code, message))
}
_ if status_code >= 500 => AwsError::ServiceError(format!("{}: {}", code, message)),
_ => AwsError::ServiceError(format!("{}: {}", code, message)),
}
}
fn parse_send_message_response(&self, xml: &str) -> Result<MessageId, AwsError> {
use quick_xml::events::Event;
use quick_xml::Reader;
let mut reader = Reader::from_str(xml);
reader.config_mut().trim_text(true);
let mut in_message_id = false;
let mut buf = Vec::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(ref e)) if e.name().as_ref() == b"MessageId" => {
in_message_id = true;
}
Ok(Event::Text(e)) if in_message_id => {
let msg_id = e.decode().map(|s| s.into_owned()).map_err(|e| {
AwsError::SerializationError(format!("Failed to parse XML: {}", e))
})?;
use std::str::FromStr;
let message_id =
MessageId::from_str(&msg_id).unwrap_or_else(|_| MessageId::new());
return Ok(message_id);
}
Ok(Event::Eof) => break,
Err(e) => {
return Err(AwsError::SerializationError(format!(
"XML parsing error: {}",
e
)))
}
_ => {}
}
buf.clear();
}
Err(AwsError::SerializationError(
"MessageId not found in response".to_string(),
))
}
fn parse_receive_message_response(
&self,
xml: &str,
queue: &QueueName,
) -> Result<Vec<ReceivedMessage>, AwsError> {
use quick_xml::events::Event;
use quick_xml::Reader;
let mut reader = Reader::from_str(xml);
reader.config_mut().trim_text(true);
let mut messages = Vec::new();
let mut in_message = false;
let mut current_message_id: Option<String> = None;
let mut current_receipt_handle: Option<String> = None;
let mut current_body: Option<String> = None;
let mut current_session_id: Option<String> = None;
let mut current_delivery_count: u32 = 1;
let mut in_message_id = false;
let mut in_receipt_handle = false;
let mut in_body = false;
let mut in_attribute_name = false;
let mut in_attribute_value = false;
let mut current_attribute_name: Option<String> = None;
let mut buf = Vec::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(ref e)) => match e.name().as_ref() {
b"Message" => {
in_message = true;
current_message_id = None;
current_receipt_handle = None;
current_body = None;
current_session_id = None;
current_delivery_count = 1;
}
b"MessageId" if in_message => in_message_id = true,
b"ReceiptHandle" if in_message => in_receipt_handle = true,
b"Body" if in_message => in_body = true,
b"Name" if in_message => in_attribute_name = true,
b"Value" if in_message => in_attribute_value = true,
_ => {}
},
Ok(Event::Text(e)) => {
let text = e.decode().ok().map(|s| s.into_owned());
if in_message_id {
current_message_id = text;
in_message_id = false;
} else if in_receipt_handle {
current_receipt_handle = text;
in_receipt_handle = false;
} else if in_body {
current_body = text;
in_body = false;
} else if in_attribute_name {
current_attribute_name = text;
in_attribute_name = false;
} else if in_attribute_value {
if let Some(ref attr_name) = current_attribute_name {
match attr_name.as_str() {
"MessageGroupId" => current_session_id = text,
"ApproximateReceiveCount" => {
if let Some(count_str) = text {
current_delivery_count =
count_str.parse::<u32>().unwrap_or(1);
}
}
_ => {}
}
}
in_attribute_value = false;
current_attribute_name = None;
}
}
Ok(Event::End(ref e)) if e.name().as_ref() == b"Message" => {
in_message = false;
if let (Some(body_base64), Some(receipt_handle)) =
(current_body.as_ref(), current_receipt_handle.as_ref())
{
use base64::{engine::general_purpose::STANDARD, Engine};
let body_bytes = STANDARD.decode(body_base64).map_err(|e| {
AwsError::SerializationError(format!("Base64 decode failed: {}", e))
})?;
let body = bytes::Bytes::from(body_bytes);
use std::str::FromStr;
let message_id = current_message_id
.as_ref()
.and_then(|id| MessageId::from_str(id).ok())
.unwrap_or_default();
let session_id = current_session_id
.as_ref()
.and_then(|id| SessionId::new(id.clone()).ok());
let handle_with_queue = format!("{}|{}", queue.as_str(), receipt_handle);
let expires_at = Timestamp::now();
let receipt =
ReceiptHandle::new(handle_with_queue, expires_at, ProviderType::AwsSqs);
let received_message = ReceivedMessage {
message_id,
body,
attributes: HashMap::new(),
session_id,
correlation_id: None,
receipt_handle: receipt,
delivery_count: current_delivery_count,
first_delivered_at: Timestamp::now(),
delivered_at: Timestamp::now(),
};
messages.push(received_message);
}
}
Ok(Event::Eof) => break,
Err(e) => {
return Err(AwsError::SerializationError(format!(
"XML parsing error: {}",
e
)))
}
_ => {}
}
buf.clear();
}
Ok(messages)
}
fn is_fifo_queue(queue_name: &QueueName) -> bool {
queue_name.as_str().ends_with(".fifo")
}
}
impl fmt::Debug for AwsSqsProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AwsSqsProvider")
.field("config", &self.config)
.field("queue_url_cache_size", &"<redacted>")
.finish()
}
}
#[async_trait]
impl QueueProvider for AwsSqsProvider {
async fn send_message(
&self,
queue: &QueueName,
message: &Message,
) -> Result<MessageId, QueueError> {
let queue_url = self
.get_queue_url(queue)
.await
.map_err(|e| e.to_queue_error())?;
use base64::{engine::general_purpose::STANDARD, Engine};
let body_base64 = STANDARD.encode(&message.body);
if body_base64.len() > 256 * 1024 {
return Err(AwsError::MessageTooLarge {
size: body_base64.len(),
max_size: 256 * 1024,
}
.to_queue_error());
}
let mut params = HashMap::new();
params.insert("Action".to_string(), "SendMessage".to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
params.insert("QueueUrl".to_string(), queue_url.clone());
params.insert("MessageBody".to_string(), body_base64);
if Self::is_fifo_queue(queue) {
if let Some(ref session_id) = message.session_id {
params.insert(
"MessageGroupId".to_string(),
session_id.as_str().to_string(),
);
let dedup_id = uuid::Uuid::new_v4().to_string();
params.insert("MessageDeduplicationId".to_string(), dedup_id);
} else {
return Err(QueueError::ValidationError(
crate::error::ValidationError::Required {
field: "session_id".to_string(),
},
));
}
}
let response = self
.make_request("POST", "/", ¶ms, "")
.await
.map_err(|e| e.to_queue_error())?;
let message_id = self
.parse_send_message_response(&response)
.map_err(|e| e.to_queue_error())?;
Ok(message_id)
}
async fn send_messages(
&self,
queue: &QueueName,
messages: &[Message],
) -> Result<Vec<MessageId>, QueueError> {
if messages.is_empty() {
return Ok(Vec::new());
}
let max_batch = self.max_batch_size() as usize;
let mut all_message_ids = Vec::new();
for chunk in messages.chunks(max_batch) {
let message_ids = self.send_messages_batch(queue, chunk).await?;
all_message_ids.extend(message_ids);
}
Ok(all_message_ids)
}
async fn receive_message(
&self,
queue: &QueueName,
timeout: Duration,
) -> Result<Option<ReceivedMessage>, QueueError> {
let messages = self.receive_messages(queue, 1, timeout).await?;
Ok(messages.into_iter().next())
}
async fn receive_messages(
&self,
queue: &QueueName,
max_messages: u32,
timeout: Duration,
) -> Result<Vec<ReceivedMessage>, QueueError> {
let queue_url = self
.get_queue_url(queue)
.await
.map_err(|e| e.to_queue_error())?;
let wait_time_seconds = timeout.num_seconds().clamp(0, 20);
let mut params = HashMap::new();
params.insert("Action".to_string(), "ReceiveMessage".to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
params.insert("QueueUrl".to_string(), queue_url);
params.insert(
"MaxNumberOfMessages".to_string(),
max_messages.min(10).to_string(), );
params.insert("WaitTimeSeconds".to_string(), wait_time_seconds.to_string());
params.insert("AttributeName.1".to_string(), "All".to_string());
let response = self
.make_request("POST", "/", ¶ms, "")
.await
.map_err(|e| e.to_queue_error())?;
let messages = self
.parse_receive_message_response(&response, queue)
.map_err(|e| e.to_queue_error())?;
Ok(messages)
}
async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
let handle_str = receipt.handle();
let parts: Vec<&str> = handle_str.split('|').collect();
if parts.len() != 2 {
return Err(QueueError::InvalidReceipt {
receipt: handle_str.to_string(),
});
}
let queue_name =
QueueName::new(parts[0].to_string()).map_err(QueueError::ValidationError)?;
let receipt_token = parts[1];
let queue_url = self
.get_queue_url(&queue_name)
.await
.map_err(|e| e.to_queue_error())?;
let mut params = HashMap::new();
params.insert("Action".to_string(), "DeleteMessage".to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
params.insert("QueueUrl".to_string(), queue_url);
params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
let _response = self
.make_request("POST", "/", ¶ms, "")
.await
.map_err(|e| e.to_queue_error())?;
Ok(())
}
async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
let handle_str = receipt.handle();
let parts: Vec<&str> = handle_str.split('|').collect();
if parts.len() != 2 {
return Err(QueueError::InvalidReceipt {
receipt: handle_str.to_string(),
});
}
let queue_name =
QueueName::new(parts[0].to_string()).map_err(QueueError::ValidationError)?;
let receipt_token = parts[1];
let queue_url = self
.get_queue_url(&queue_name)
.await
.map_err(|e| e.to_queue_error())?;
let mut params = HashMap::new();
params.insert("Action".to_string(), "ChangeMessageVisibility".to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
params.insert("QueueUrl".to_string(), queue_url);
params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
params.insert("VisibilityTimeout".to_string(), "0".to_string());
let _response = self
.make_request("POST", "/", ¶ms, "")
.await
.map_err(|e| e.to_queue_error())?;
Ok(())
}
async fn dead_letter_message(
&self,
receipt: &ReceiptHandle,
_reason: &str,
) -> Result<(), QueueError> {
self.complete_message(receipt).await
}
async fn create_session_client(
&self,
queue: &QueueName,
session_id: Option<SessionId>,
) -> Result<Box<dyn SessionProvider>, QueueError> {
if !Self::is_fifo_queue(queue) {
return Err(AwsError::SessionsNotSupported.to_queue_error());
}
let queue_url = self
.get_queue_url(queue)
.await
.map_err(|e| e.to_queue_error())?;
let session_id = session_id.ok_or_else(|| {
QueueError::ValidationError(crate::error::ValidationError::Required {
field: "session_id".to_string(),
})
})?;
Ok(Box::new(AwsSessionProvider::new(
self.http_client.clone(),
AwsCredentialProvider::new(
self.http_client.clone(),
self.config.access_key_id.clone(),
self.config.secret_access_key.clone(),
),
self.config.region.clone(),
self.endpoint.clone(),
queue_url,
queue.clone(),
session_id,
)))
}
fn provider_type(&self) -> ProviderType {
ProviderType::AwsSqs
}
fn supports_sessions(&self) -> SessionSupport {
SessionSupport::Emulated
}
fn supports_batching(&self) -> bool {
true
}
fn max_batch_size(&self) -> u32 {
10 }
}
impl AwsSqsProvider {
async fn send_messages_batch(
&self,
queue: &QueueName,
messages: &[Message],
) -> Result<Vec<MessageId>, QueueError> {
if messages.is_empty() {
return Ok(Vec::new());
}
if messages.len() > 10 {
return Err(QueueError::ValidationError(
crate::error::ValidationError::OutOfRange {
field: "messages".to_string(),
message: format!("Batch size {} exceeds AWS SQS limit of 10", messages.len()),
},
));
}
let queue_url = self
.get_queue_url(queue)
.await
.map_err(|e| e.to_queue_error())?;
let mut params = HashMap::new();
params.insert("Action".to_string(), "SendMessageBatch".to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
params.insert("QueueUrl".to_string(), queue_url.clone());
use base64::{engine::general_purpose::STANDARD, Engine};
for (idx, message) in messages.iter().enumerate() {
let entry_id = format!("msg-{}", idx);
let body_base64 = STANDARD.encode(&message.body);
if body_base64.len() > 256 * 1024 {
return Err(AwsError::MessageTooLarge {
size: body_base64.len(),
max_size: 256 * 1024,
}
.to_queue_error());
}
params.insert(
format!("SendMessageBatchRequestEntry.{}.Id", idx + 1),
entry_id,
);
params.insert(
format!("SendMessageBatchRequestEntry.{}.MessageBody", idx + 1),
body_base64,
);
if Self::is_fifo_queue(queue) {
if let Some(ref session_id) = message.session_id {
params.insert(
format!("SendMessageBatchRequestEntry.{}.MessageGroupId", idx + 1),
session_id.as_str().to_string(),
);
}
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(&message.body);
if let Some(ref session_id) = message.session_id {
hasher.update(session_id.as_str().as_bytes());
}
let hash = hex::encode(hasher.finalize());
params.insert(
format!(
"SendMessageBatchRequestEntry.{}.MessageDeduplicationId",
idx + 1
),
hash,
);
}
}
let response = self
.make_request("POST", "/", ¶ms, "")
.await
.map_err(|e| e.to_queue_error())?;
self.parse_send_message_batch_response(&response)
.map_err(|e| e.to_queue_error())
}
fn parse_send_message_batch_response(&self, xml: &str) -> Result<Vec<MessageId>, AwsError> {
use quick_xml::events::Event;
use quick_xml::Reader;
let mut reader = Reader::from_str(xml);
reader.config_mut().trim_text(true);
let mut message_ids = Vec::new();
let mut in_successful = false;
let mut in_message_id = false;
let mut buf = Vec::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(ref e)) => match e.name().as_ref() {
b"SendMessageBatchResultEntry" => in_successful = true,
b"MessageId" if in_successful => in_message_id = true,
_ => {}
},
Ok(Event::Text(e)) if in_message_id => {
let msg_id = e.decode().map(|s| s.into_owned()).map_err(|e| {
AwsError::SerializationError(format!("Failed to parse XML: {}", e))
})?;
use std::str::FromStr;
let message_id =
MessageId::from_str(&msg_id).unwrap_or_else(|_| MessageId::new());
message_ids.push(message_id);
in_message_id = false;
}
Ok(Event::End(ref e)) if e.name().as_ref() == b"SendMessageBatchResultEntry" => {
in_successful = false;
}
Ok(Event::Eof) => break,
Err(e) => {
return Err(AwsError::SerializationError(format!(
"XML parsing error: {}",
e
)))
}
_ => {}
}
buf.clear();
}
Ok(message_ids)
}
}
pub struct AwsSessionProvider {
http_client: HttpClient,
credential_provider: AwsCredentialProvider,
region: String,
endpoint: String,
queue_url: String,
queue_name: QueueName,
session_id: SessionId,
}
impl AwsSessionProvider {
fn new(
http_client: HttpClient,
credential_provider: AwsCredentialProvider,
region: String,
endpoint: String,
queue_url: String,
queue_name: QueueName,
session_id: SessionId,
) -> Self {
Self {
http_client,
credential_provider,
region,
endpoint,
queue_url,
queue_name,
session_id,
}
}
async fn get_queue_url(&self) -> Result<String, AwsError> {
Ok(self.queue_url.clone())
}
async fn make_request(
&self,
method: &str,
path: &str,
params: &HashMap<String, String>,
body: &str,
) -> Result<String, AwsError> {
use reqwest::header;
let credentials = self.credential_provider.get_credentials().await?;
let signer = AwsV4Signer::new(
credentials.access_key_id.clone(),
credentials.secret_access_key.clone(),
self.region.clone(),
);
let query_string = if params.is_empty() {
String::new()
} else {
let mut pairs: Vec<String> = params
.iter()
.map(|(k, v)| format!("{}={}", urlencoding::encode(k), urlencoding::encode(v)))
.collect();
pairs.sort();
pairs.join("&")
};
let url = if query_string.is_empty() {
format!("{}{}", self.endpoint, path)
} else {
format!("{}{}?{}", self.endpoint, path, query_string)
};
let mut request_builder = self.http_client.request(
method
.parse()
.map_err(|e| AwsError::NetworkError(format!("Invalid HTTP method: {}", e)))?,
&url,
);
let timestamp = Utc::now();
let host = self
.endpoint
.trim_start_matches("https://")
.trim_start_matches("http://");
let mut signed_headers = signer.sign_request(method, host, path, params, body, ×tamp);
if let Some(session_token) = &credentials.session_token {
signed_headers.insert("X-Amz-Security-Token".to_string(), session_token.clone());
}
for (key, value) in signed_headers {
request_builder = request_builder.header(key, value);
}
if !body.is_empty() {
request_builder = request_builder
.header(header::CONTENT_TYPE, "application/x-www-form-urlencoded")
.body(body.to_string());
}
let response = request_builder
.send()
.await
.map_err(|e| AwsError::NetworkError(format!("HTTP request failed: {}", e)))?;
let status = response.status();
let response_text = response
.text()
.await
.map_err(|e| AwsError::NetworkError(format!("Failed to read response: {}", e)))?;
if !status.is_success() {
return Err(self.parse_error_response(&response_text, status.as_u16()));
}
Ok(response_text)
}
fn parse_error_response(&self, xml: &str, status_code: u16) -> AwsError {
use quick_xml::events::Event;
use quick_xml::Reader;
let mut reader = Reader::from_str(xml);
reader.config_mut().trim_text(true);
let mut error_code = None;
let mut error_message = None;
let mut in_error = false;
let mut in_code = false;
let mut in_message = false;
let mut buf = Vec::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(ref e)) => match e.name().as_ref() {
b"Error" => in_error = true,
b"Code" if in_error => in_code = true,
b"Message" if in_error => in_message = true,
_ => {}
},
Ok(Event::Text(e)) => {
if in_code {
error_code = e.decode().ok().and_then(|s| {
quick_xml::escape::unescape(&s).ok().map(|u| u.into_owned())
});
in_code = false;
} else if in_message {
error_message = e.decode().ok().and_then(|s| {
quick_xml::escape::unescape(&s).ok().map(|u| u.into_owned())
});
in_message = false;
}
}
Ok(Event::Eof) => break,
Err(_) => break,
_ => {}
}
buf.clear();
}
match error_code.as_deref() {
Some("InvalidParameterValue") | Some("MissingParameter") => AwsError::ServiceError(
error_message.unwrap_or_else(|| "Invalid parameter".to_string()),
),
Some("AccessDenied") | Some("InvalidClientTokenId") | Some("SignatureDoesNotMatch") => {
AwsError::Authentication(
error_message.unwrap_or_else(|| "Authentication failed".to_string()),
)
}
Some("AWS.SimpleQueueService.NonExistentQueue") | Some("QueueDoesNotExist") => {
AwsError::QueueNotFound(
error_message.unwrap_or_else(|| "Queue not found".to_string()),
)
}
_ => {
if status_code >= 500 {
AwsError::ServiceError(
error_message.unwrap_or_else(|| "Service error".to_string()),
)
} else {
AwsError::ServiceError(
error_message.unwrap_or_else(|| format!("HTTP {}", status_code)),
)
}
}
}
}
fn parse_receive_message_response(
&self,
xml: &str,
queue: &QueueName,
) -> Result<Vec<ReceivedMessage>, AwsError> {
use quick_xml::events::Event;
use quick_xml::Reader;
let mut reader = Reader::from_str(xml);
reader.config_mut().trim_text(true);
let mut messages = Vec::new();
let mut in_message = false;
let mut current_message_id: Option<String> = None;
let mut current_receipt_handle: Option<String> = None;
let mut current_body: Option<String> = None;
let mut current_session_id: Option<String> = None;
let mut current_delivery_count: u32 = 1;
let mut in_message_id = false;
let mut in_receipt_handle = false;
let mut in_body = false;
let mut in_attribute_name = false;
let mut in_attribute_value = false;
let mut current_attribute_name: Option<String> = None;
let mut buf = Vec::new();
loop {
match reader.read_event_into(&mut buf) {
Ok(Event::Start(ref e)) => match e.name().as_ref() {
b"Message" => {
in_message = true;
current_message_id = None;
current_receipt_handle = None;
current_body = None;
current_session_id = None;
current_delivery_count = 1;
}
b"MessageId" if in_message => in_message_id = true,
b"ReceiptHandle" if in_message => in_receipt_handle = true,
b"Body" if in_message => in_body = true,
b"Name" if in_message => in_attribute_name = true,
b"Value" if in_message => in_attribute_value = true,
_ => {}
},
Ok(Event::Text(e)) => {
let text = e.decode().ok().map(|s| s.into_owned());
if in_message_id {
current_message_id = text;
in_message_id = false;
} else if in_receipt_handle {
current_receipt_handle = text;
in_receipt_handle = false;
} else if in_body {
current_body = text;
in_body = false;
} else if in_attribute_name {
current_attribute_name = text;
in_attribute_name = false;
} else if in_attribute_value {
if let Some(ref attr_name) = current_attribute_name {
match attr_name.as_str() {
"MessageGroupId" => current_session_id = text,
"ApproximateReceiveCount" => {
if let Some(count_str) = text {
current_delivery_count =
count_str.parse::<u32>().unwrap_or(1);
}
}
_ => {}
}
}
in_attribute_value = false;
current_attribute_name = None;
}
}
Ok(Event::End(ref e)) if e.name().as_ref() == b"Message" => {
in_message = false;
if let (Some(body_base64), Some(receipt_handle)) =
(current_body.as_ref(), current_receipt_handle.as_ref())
{
use base64::{engine::general_purpose::STANDARD, Engine};
let body_bytes = STANDARD.decode(body_base64).map_err(|e| {
AwsError::SerializationError(format!("Base64 decode failed: {}", e))
})?;
let body = bytes::Bytes::from(body_bytes);
use std::str::FromStr;
let message_id = current_message_id
.as_ref()
.and_then(|id| MessageId::from_str(id).ok())
.unwrap_or_default();
let session_id = current_session_id
.as_ref()
.and_then(|id| SessionId::new(id.clone()).ok());
let handle_with_queue = format!("{}|{}", queue.as_str(), receipt_handle);
let expires_at = Timestamp::now();
let receipt =
ReceiptHandle::new(handle_with_queue, expires_at, ProviderType::AwsSqs);
let received_message = ReceivedMessage {
message_id,
body,
attributes: HashMap::new(),
session_id,
correlation_id: None,
receipt_handle: receipt,
delivery_count: current_delivery_count,
first_delivered_at: Timestamp::now(),
delivered_at: Timestamp::now(),
};
messages.push(received_message);
}
}
Ok(Event::Eof) => break,
Err(e) => {
return Err(AwsError::SerializationError(format!(
"XML parsing error: {}",
e
)))
}
_ => {}
}
buf.clear();
}
Ok(messages)
}
}
impl fmt::Debug for AwsSessionProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AwsSessionProvider")
.field("queue_name", &self.queue_name)
.field("session_id", &self.session_id)
.finish()
}
}
#[async_trait]
impl SessionProvider for AwsSessionProvider {
async fn receive_message(
&self,
timeout: Duration,
) -> Result<Option<ReceivedMessage>, QueueError> {
let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
let mut params = HashMap::new();
params.insert("Action".to_string(), "ReceiveMessage".to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
params.insert("QueueUrl".to_string(), queue_url);
params.insert("MaxNumberOfMessages".to_string(), "1".to_string());
params.insert(
"WaitTimeSeconds".to_string(),
timeout.num_seconds().clamp(0, 20).to_string(),
);
params.insert("AttributeName.1".to_string(), "All".to_string());
let response = self
.make_request("POST", "/", ¶ms, "")
.await
.map_err(|e| e.to_queue_error())?;
let messages = self
.parse_receive_message_response(&response, &self.queue_name)
.map_err(|e| e.to_queue_error())?;
Ok(messages
.into_iter()
.find(|msg| msg.session_id.as_ref() == Some(&self.session_id)))
}
async fn complete_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
let handle_str = receipt.handle();
let parts: Vec<&str> = handle_str.split('|').collect();
if parts.len() != 2 {
return Err(QueueError::InvalidReceipt {
receipt: handle_str.to_string(),
});
}
let receipt_token = parts[1];
let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
let mut params = HashMap::new();
params.insert("Action".to_string(), "DeleteMessage".to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
params.insert("QueueUrl".to_string(), queue_url);
params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
self.make_request("POST", "/", ¶ms, "")
.await
.map_err(|e| e.to_queue_error())?;
Ok(())
}
async fn abandon_message(&self, receipt: &ReceiptHandle) -> Result<(), QueueError> {
let handle_str = receipt.handle();
let parts: Vec<&str> = handle_str.split('|').collect();
if parts.len() != 2 {
return Err(QueueError::InvalidReceipt {
receipt: handle_str.to_string(),
});
}
let receipt_token = parts[1];
let queue_url = self.get_queue_url().await.map_err(|e| e.to_queue_error())?;
let mut params = HashMap::new();
params.insert("Action".to_string(), "ChangeMessageVisibility".to_string());
params.insert("Version".to_string(), "2012-11-05".to_string());
params.insert("QueueUrl".to_string(), queue_url);
params.insert("ReceiptHandle".to_string(), receipt_token.to_string());
params.insert("VisibilityTimeout".to_string(), "0".to_string());
self.make_request("POST", "/", ¶ms, "")
.await
.map_err(|e| e.to_queue_error())?;
Ok(())
}
async fn dead_letter_message(
&self,
receipt: &ReceiptHandle,
_reason: &str,
) -> Result<(), QueueError> {
self.complete_message(receipt).await
}
async fn renew_session_lock(&self) -> Result<(), QueueError> {
Ok(())
}
async fn close_session(&self) -> Result<(), QueueError> {
Ok(())
}
fn session_id(&self) -> &SessionId {
&self.session_id
}
fn session_expires_at(&self) -> Timestamp {
Timestamp::from_datetime(chrono::Utc::now() + chrono::Duration::days(365))
}
}