use aws_sdk_sqs::config::http::HttpResponse;
use aws_sdk_sqs::error::{ProvideErrorMetadata, SdkError};
use aws_sdk_sqs::types::{Message, MessageSystemAttributeName};
use std::collections::{HashMap, HashSet, VecDeque};
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::oneshot::error::TryRecvError;
use tokio::sync::{Notify, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use crate::backend::ConsumerOptionsInner as ConsumerOptions;
use crate::backends::sns::client::SnsClient;
use crate::backends::sns::router;
use crate::backends::sns::topology::QueueRegistry;
use crate::consumer_supervisor::{SupervisorOutcome, drive_fifo_until_timeout};
use crate::error::{Result, ShoveError};
use crate::handler::MessageHandler;
use crate::metadata::{DeadMessageMetadata, MessageMetadata};
use crate::outcome::Outcome;
use crate::retry::Backoff;
use crate::topic::{SequencedTopic, Topic};
use crate::topology::{QueueTopology, SequenceFailure};
use crate::{DEFAULT_HANDLER_TIMEOUT, metrics};
use crate::{DEFAULT_MAX_MESSAGE_SIZE, Sqs};
fn map_sqs_error<E>(context: &str, e: SdkError<E, HttpResponse>) -> ShoveError
where
E: std::fmt::Debug + std::fmt::Display + ProvideErrorMetadata,
{
match &e {
SdkError::TimeoutError(_) => ShoveError::Connection(format!("{context}: {e}")),
SdkError::DispatchFailure(_) => ShoveError::Connection(format!("{context}: {e}")),
SdkError::ResponseError(_) => ShoveError::Connection(format!("{context}: {e}")),
SdkError::ConstructionFailure(_) => ShoveError::Topology(format!("{context}: {e}")),
SdkError::ServiceError(se) => {
let code = ProvideErrorMetadata::code(se.err());
let is_transient = matches!(
code,
Some("RequestThrottled" | "Throttling" | "KMS.ThrottlingException" | "OverLimit")
);
if is_transient {
ShoveError::Connection(format!("{context}: {e}"))
} else {
ShoveError::Topology(format!("{context}: {e}"))
}
}
_ => ShoveError::Unknown(format!("unrecognized AWS SDK error in {context}: {e}")),
}
}
#[derive(Clone)]
pub struct SqsConsumer {
client: SnsClient,
queue_registry: Arc<QueueRegistry>,
}
impl SqsConsumer {
pub fn new(client: SnsClient, queue_registry: Arc<QueueRegistry>) -> Self {
Self {
client,
queue_registry,
}
}
async fn resolve_queue_url(&self, queue_name: &str) -> Result<String> {
self.queue_registry.get(queue_name).await.ok_or_else(|| {
ShoveError::Topology(format!(
"no SQS queue URL registered for '{queue_name}'. Declare topology first."
))
})
}
}
fn extract_metadata(msg: &Message) -> MessageMetadata {
let retry_count = router::get_retry_count(msg);
MessageMetadata {
retry_count,
delivery_id: msg.message_id().unwrap_or_default().to_string(),
redelivered: retry_count > 0,
headers: Arc::new(router::extract_message_attributes(msg)),
}
}
#[derive(serde::Deserialize)]
struct SnsEnvelope {
#[serde(rename = "Type")]
notification_type: String,
#[serde(rename = "Message")]
message: String,
}
fn extract_payload(body: &str) -> std::borrow::Cow<'_, str> {
if let Ok(envelope) = serde_json::from_str::<SnsEnvelope>(body)
&& envelope.notification_type == "Notification"
{
return std::borrow::Cow::Owned(envelope.message);
}
std::borrow::Cow::Borrowed(body)
}
fn extract_dead_metadata(msg: &Message, queue_name: &str) -> DeadMessageMetadata {
let metadata = extract_metadata(msg);
let death_count = metadata.retry_count;
DeadMessageMetadata {
message: metadata,
reason: Some("max_receives_exceeded".into()),
original_queue: Some(queue_name.to_string()),
death_count,
}
}
async fn run_with_reconnect<F, Fut>(
shutdown: &CancellationToken,
queue: &str,
max_reconnect_attempts: Option<u32>,
mut f: F,
) -> Result<()>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<()>>,
{
let mut backoff = Backoff::default();
let mut attempts = 0u32;
loop {
match f().await {
Ok(()) => return Ok(()),
Err(e) => {
if !e.is_retryable() {
return Err(e);
}
if shutdown.is_cancelled() {
return Ok(());
}
attempts += 1;
if let Some(max) = max_reconnect_attempts
&& attempts >= max
{
tracing::error!(
queue,
attempts,
error = %e,
"max reconnect attempts reached, giving up"
);
return Err(ShoveError::Connection(format!(
"consumer on '{queue}' exhausted {max} reconnect attempt(s): {e}"
)));
}
let delay = backoff.next().expect("backoff is infinite");
warn!(
queue,
attempt = attempts,
?max_reconnect_attempts,
"consumer error, reconnecting in {delay:?}: {e}"
);
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => return Ok(()),
}
}
}
}
}
async fn invoke_handler<F>(
fut: F,
timeout: Option<Duration>,
topic: &str,
group: Option<&str>,
) -> Outcome
where
F: std::future::Future<Output = Outcome> + Send + 'static,
{
let _inflight = metrics::InflightGuard::from_refs(topic, group);
let start = std::time::Instant::now();
let mut join = tokio::spawn(fut);
let outcome = match timeout {
Some(duration) => match tokio::time::timeout(duration, &mut join).await {
Ok(Ok(o)) => o,
Ok(Err(e)) => {
warn!(error = %e, "handler task panicked, retrying message");
Outcome::Retry
}
Err(_) => {
join.abort();
warn!("handler exceeded timeout ({duration:?}), retrying message");
metrics::record_failed(topic, group, metrics::FailReason::Timeout);
Outcome::Retry
}
},
None => match join.await {
Ok(o) => o,
Err(e) => {
warn!(error = %e, "handler task panicked, retrying message");
Outcome::Retry
}
},
};
let elapsed = start.elapsed().as_secs_f64();
metrics::record_consumed(topic, group, &outcome);
metrics::record_processing_duration(topic, group, &outcome, elapsed);
outcome
}
#[allow(clippy::too_many_arguments)]
fn spawn_handler<T, H>(
handler: &Arc<H>,
ctx: &Arc<H::Context>,
message: T::Message,
metadata: MessageMetadata,
timeout: Option<Duration>,
notify: &Arc<Notify>,
topic: Arc<str>,
group: Option<Arc<str>>,
) -> oneshot::Receiver<Outcome>
where
T: Topic,
H: MessageHandler<T>,
{
let (tx, rx) = oneshot::channel();
let h = handler.clone();
let c = ctx.clone();
let n = notify.clone();
tokio::spawn(async move {
let outcome = invoke_handler(
async move { h.handle(message, metadata, c.as_ref()).await },
timeout,
&topic,
group.as_deref(),
)
.await;
let _ = tx.send(outcome);
n.notify_one();
});
rx
}
struct PendingMessage {
receipt_handle: String,
msg: Arc<Message>,
retry_count: u32,
outcome_rx: oneshot::Receiver<Outcome>,
}
async fn consume_loop_concurrent<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
topology: &'static QueueTopology,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
options: &ConsumerOptions,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let notify = Arc::new(Notify::new());
let topic: Arc<str> = Arc::from(topology.queue());
let group: Option<Arc<str>> = options.consumer_group.as_deref().map(Arc::from);
let max_in_flight = options.prefetch_count as usize;
let receive_batch: usize = {
let configured = if options.receive_batch_size > 0 {
options.receive_batch_size as usize
} else {
10
};
configured.min(10)
};
let mut in_flight: VecDeque<PendingMessage> = VecDeque::with_capacity(max_in_flight);
let mut local_buffer: VecDeque<Message> = VecDeque::with_capacity(receive_batch);
let mut pending_acks: Vec<String> = Vec::with_capacity(10);
info!(
queue_url,
max_in_flight, receive_batch, "SQS consumer started"
);
loop {
while let Some(front) = in_flight.front_mut() {
match front.outcome_rx.try_recv() {
Ok(Outcome::Ack) => {
let msg = in_flight.pop_front().expect("in_flight was just peeked");
debug!(queue_url, receipt_handle = %msg.receipt_handle, "message acked (pending flush)");
pending_acks.push(msg.receipt_handle);
if pending_acks.len() >= 10 {
let batch_size = pending_acks.len();
debug!(queue_url, batch_size, "flushing full ack batch");
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks))
.await;
}
}
Ok(outcome) => {
if !pending_acks.is_empty() {
let batch_size = pending_acks.len();
debug!(
queue_url,
batch_size,
?outcome,
"flushing ack batch before non-ack outcome"
);
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks))
.await;
}
let msg = in_flight.pop_front().expect("in_flight was just peeked");
debug!(queue_url, ?outcome, "message handled");
route_outcome(
sqs,
queue_url,
&msg.receipt_handle,
&msg.msg,
outcome,
topology,
msg.retry_count,
)
.await;
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Closed) => {
if !pending_acks.is_empty() {
let batch_size = pending_acks.len();
debug!(
queue_url,
batch_size, "flushing ack batch after handler panic"
);
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks))
.await;
}
let msg = in_flight.pop_front().expect("in_flight was just peeked");
warn!(queue_url, "handler task panicked, retrying message");
route_outcome(
sqs,
queue_url,
&msg.receipt_handle,
&msg.msg,
Outcome::Retry,
topology,
msg.retry_count,
)
.await;
}
}
}
options.processing.store(
!in_flight.is_empty() || !local_buffer.is_empty(),
Ordering::Release,
);
if options.shutdown.is_cancelled() {
debug!(
"shutdown signal, requeueing {} buffered, draining {} in-flight on {queue_url}",
local_buffer.len(),
in_flight.len()
);
if !pending_acks.is_empty() {
let batch_size = pending_acks.len();
debug!(queue_url, batch_size, "flushing ack batch on shutdown");
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks)).await;
}
for msg in local_buffer.drain(..) {
if let Some(rh) = msg.receipt_handle() {
router::route_requeue(sqs, queue_url, rh).await;
}
}
let drain_timeout = options.handler_timeout.unwrap_or(DEFAULT_HANDLER_TIMEOUT);
let mut drain_acks: Vec<String> = Vec::with_capacity(in_flight.len());
for pending in in_flight {
let outcome = tokio::time::timeout(drain_timeout, pending.outcome_rx)
.await
.unwrap_or_else(|_| {
warn!(
queue_url,
"handler timed out during shutdown drain, retrying"
);
Ok(Outcome::Retry)
})
.unwrap_or(Outcome::Retry);
if matches!(outcome, Outcome::Ack) {
drain_acks.push(pending.receipt_handle);
} else {
route_outcome(
sqs,
queue_url,
&pending.receipt_handle,
&pending.msg,
outcome,
topology,
pending.retry_count,
)
.await;
}
}
if !drain_acks.is_empty() {
let batch_size = drain_acks.len();
debug!(
queue_url,
batch_size, "flushing ack batch on drain completion"
);
router::route_ack_batch(sqs, queue_url, drain_acks).await;
}
return Ok(());
}
while in_flight.len() < max_in_flight {
let Some(msg) = local_buffer.pop_front() else {
break;
};
let receipt_handle = msg.receipt_handle().unwrap_or_default().to_string();
let retry_count = router::get_retry_count(&msg);
if retry_count >= options.max_retries {
warn!(
queue_url,
retry_count,
max_retries = options.max_retries,
"message exceeded max retries, rejecting"
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let body = extract_payload(msg.body().unwrap_or_default());
metrics::record_message_size(&topic, group.as_deref(), body.len());
if let Err(e) = options.validate_payload_message_size(body.len()) {
warn!(error = %e, queue_url, "rejecting oversized message");
metrics::record_failed(&topic, group.as_deref(), metrics::FailReason::Oversize);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let message: T::Message = match <T::Codec as crate::Codec<T::Message>>::decode(
body.as_bytes(),
) {
Ok(m) => m,
Err(err) => {
error!(error = %err, queue_url, "failed to deserialize SQS message, rejecting");
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::Deserialize,
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
};
let metadata = extract_metadata(&msg);
debug!(
queue_url,
message_id = %metadata.delivery_id,
retry_count = metadata.retry_count,
"dispatching message to handler"
);
let rx = spawn_handler::<T, H>(
handler,
ctx,
message,
metadata,
options.handler_timeout,
¬ify,
Arc::clone(&topic),
group.clone(),
);
in_flight.push_back(PendingMessage {
receipt_handle,
msg: Arc::new(msg),
retry_count,
outcome_rx: rx,
});
options.processing.store(true, Ordering::Relaxed);
}
if local_buffer.is_empty() && in_flight.len() < max_in_flight {
if !pending_acks.is_empty() {
let batch_size = pending_acks.len();
debug!(queue_url, batch_size, "flushing ack batch before poll");
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks)).await;
}
let max_messages = receive_batch as i32;
let receive_result = sqs
.receive_message()
.queue_url(queue_url)
.wait_time_seconds(0)
.max_number_of_messages(max_messages)
.message_system_attribute_names(MessageSystemAttributeName::ApproximateReceiveCount)
.message_attribute_names("All")
.send()
.await
.map_err(|e| {
metrics::record_backend_error(
metrics::BackendLabel::SnsSqs,
metrics::BackendErrorKind::Consume,
);
map_sqs_error(&format!("SQS ReceiveMessage failed on {queue_url}"), e)
})?;
let msgs = receive_result.messages.unwrap_or_default();
if msgs.is_empty() {
debug!(queue_url, "queue empty, backing off 500ms");
tokio::select! {
biased;
_ = options.shutdown.cancelled() => {}
_ = tokio::time::sleep(Duration::from_millis(500)) => {}
}
} else {
debug!(
queue_url,
received = msgs.len(),
"received messages from SQS"
);
}
local_buffer.extend(msgs);
continue;
}
if in_flight.len() >= max_in_flight {
notify.notified().await;
}
}
}
async fn route_outcome(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
receipt_handle: &str,
msg: &Message,
outcome: Outcome,
topology: &'static QueueTopology,
retry_count: u32,
) {
match outcome {
Outcome::Ack => router::route_ack(sqs, queue_url, receipt_handle).await,
Outcome::Retry => {
let body = msg.body().unwrap_or_default();
let empty_attrs = HashMap::new();
let attrs = msg.message_attributes().unwrap_or(&empty_attrs);
router::route_retry(
sqs,
queue_url,
receipt_handle,
body,
attrs,
topology,
retry_count,
)
.await;
}
Outcome::Reject => router::route_reject(sqs, queue_url, receipt_handle, topology).await,
Outcome::Defer => {
let body = msg.body().unwrap_or_default();
let empty_attrs = HashMap::new();
let attrs = msg.message_attributes().unwrap_or(&empty_attrs);
router::route_defer(
sqs,
queue_url,
receipt_handle,
body,
attrs,
topology,
retry_count,
)
.await;
}
}
}
enum KeyState {
InFlight {
receipt_handle: String,
msg: Arc<Message>,
retry_count: u32,
outcome_rx: oneshot::Receiver<Outcome>,
},
AwaitingRetry,
}
fn extract_sequence_key(msg: &Message) -> Option<String> {
msg.attributes()
.and_then(|attrs| attrs.get(&MessageSystemAttributeName::MessageGroupId))
.map(|s| s.to_string())
}
#[allow(clippy::too_many_arguments)]
fn spawn_handler_keyed<T, H>(
handler: &Arc<H>,
ctx: &Arc<H::Context>,
message: T::Message,
metadata: MessageMetadata,
timeout: Option<Duration>,
completed_tx: &mpsc::UnboundedSender<String>,
key: String,
topic: Arc<str>,
group: Option<Arc<str>>,
) -> oneshot::Receiver<Outcome>
where
T: Topic,
H: MessageHandler<T>,
{
let (tx, rx) = oneshot::channel();
let h = handler.clone();
let c = ctx.clone();
let completed = completed_tx.clone();
tokio::spawn(async move {
let outcome = invoke_handler(
async move { h.handle(message, metadata, c.as_ref()).await },
timeout,
&topic,
group.as_deref(),
)
.await;
let _ = tx.send(outcome);
let _ = completed.send(key);
});
rx
}
#[allow(clippy::too_many_arguments)]
async fn run_sequenced_shard<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
queue_name: &str,
topology: &'static QueueTopology,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
options: &ConsumerOptions,
on_failure: SequenceFailure,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let mut poisoned_keys = HashSet::new();
let mut pending_deliveries: HashMap<String, VecDeque<Message>> = HashMap::new();
let mut backoff = Backoff::default();
let mut attempts = 0u32;
loop {
match consume_loop_sequenced::<T, H>(
sqs,
queue_url,
topology,
handler,
ctx,
options,
on_failure,
&mut poisoned_keys,
&mut pending_deliveries,
)
.await
{
Ok(()) => {
for (_key, msgs) in pending_deliveries.drain() {
for msg in msgs {
let rh = msg.receipt_handle().unwrap_or_default();
router::route_requeue(sqs, queue_url, rh).await;
}
}
return Ok(());
}
Err(e) => {
if options.shutdown.is_cancelled() {
pending_deliveries.clear();
return Ok(());
}
pending_deliveries.clear();
attempts += 1;
if let Some(max) = options.max_reconnect_attempts
&& attempts >= max
{
tracing::error!(
queue = queue_name,
attempts,
error = %e,
"max reconnect attempts reached, giving up"
);
return Err(ShoveError::Connection(format!(
"consumer on '{queue_name}' exhausted {max} reconnect attempt(s): {e}"
)));
}
let delay = backoff.next().expect("backoff is infinite");
warn!(
queue = queue_name,
attempt = attempts,
max_reconnect_attempts = ?options.max_reconnect_attempts,
"consumer error, reconnecting in {delay:?}: {e}"
);
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = options.shutdown.cancelled() => return Ok(()),
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
async fn consume_loop_sequenced<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
topology: &'static QueueTopology,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
options: &ConsumerOptions,
on_failure: SequenceFailure,
poisoned_keys: &mut HashSet<String>,
pending_deliveries: &mut HashMap<String, VecDeque<Message>>,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let prefetch = options.prefetch_count as usize;
let (completed_tx, mut completed_rx) = mpsc::unbounded_channel::<String>();
let topic: Arc<str> = Arc::from(topology.queue());
let group: Option<Arc<str>> = options.consumer_group.as_deref().map(Arc::from);
let mut key_states: HashMap<String, KeyState> = HashMap::new();
let mut in_flight_count: usize = 0;
info!(queue_url, prefetch, "sequenced SQS consumer started");
loop {
while let Ok(key) = completed_rx.try_recv() {
let Some(state) = key_states.remove(&key) else {
continue;
};
let KeyState::InFlight {
receipt_handle,
msg,
retry_count,
mut outcome_rx,
} = state
else {
key_states.insert(key, state);
continue;
};
let outcome = match outcome_rx.try_recv() {
Ok(o) => o,
Err(TryRecvError::Closed) => {
warn!(queue_url, sequence_key = %key, "handler task panicked, retrying");
Outcome::Retry
}
Err(TryRecvError::Empty) => {
key_states.insert(
key,
KeyState::InFlight {
receipt_handle,
msg,
retry_count,
outcome_rx,
},
);
continue;
}
};
debug!(queue_url, sequence_key = %key, ?outcome, "message handled (sequenced)");
match outcome {
Outcome::Ack => {
router::route_ack(sqs, queue_url, &receipt_handle).await;
in_flight_count -= 1;
drain_pending_for_key::<T, H>(
sqs,
queue_url,
&key,
handler,
ctx,
options,
on_failure,
topology,
poisoned_keys,
&completed_tx,
&mut key_states,
&mut in_flight_count,
pending_deliveries,
&topic,
&group,
)
.await;
}
Outcome::Reject => {
if on_failure == SequenceFailure::FailAll {
info!(
sequence_key = %key,
queue_url,
"poisoning sequence key (FailAll)"
);
poisoned_keys.insert(key.clone());
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
in_flight_count -= 1;
drain_pending_for_key::<T, H>(
sqs,
queue_url,
&key,
handler,
ctx,
options,
on_failure,
topology,
poisoned_keys,
&completed_tx,
&mut key_states,
&mut in_flight_count,
pending_deliveries,
&topic,
&group,
)
.await;
}
Outcome::Retry => {
router::route_retry_fifo(
sqs,
queue_url,
&receipt_handle,
topology,
retry_count,
)
.await;
in_flight_count -= 1;
key_states.insert(key, KeyState::AwaitingRetry);
}
Outcome::Defer => {
warn!(
queue_url,
sequence_key = %key,
"Defer is not supported on sequenced (FIFO) consumers, treating as Retry"
);
router::route_retry_fifo(
sqs,
queue_url,
&receipt_handle,
topology,
retry_count,
)
.await;
in_flight_count -= 1;
key_states.insert(key, KeyState::AwaitingRetry);
}
}
}
options
.processing
.store(in_flight_count > 0, Ordering::Relaxed);
let can_accept = in_flight_count < prefetch;
tokio::select! {
biased;
_ = options.shutdown.cancelled() => {
debug!(
"shutdown signal, draining {} in-flight messages on {queue_url}",
in_flight_count
);
for (key, state) in key_states.drain() {
if let KeyState::InFlight { receipt_handle, msg: _, retry_count, outcome_rx } = state {
let outcome = outcome_rx.await.unwrap_or(Outcome::Retry);
debug!(
queue_url,
sequence_key = %key,
?outcome,
"draining in-flight message on shutdown"
);
match outcome {
Outcome::Ack => {
router::route_ack(sqs, queue_url, &receipt_handle).await;
}
Outcome::Reject => {
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
}
Outcome::Retry => {
router::route_retry_fifo(
sqs,
queue_url,
&receipt_handle,
topology,
retry_count,
)
.await;
}
Outcome::Defer => {
warn!(
queue_url,
sequence_key = %key,
"Defer is not supported on sequenced (FIFO) consumers, treating as Retry"
);
router::route_retry_fifo(
sqs,
queue_url,
&receipt_handle,
topology,
retry_count,
)
.await;
}
}
}
}
for (_key, msgs) in pending_deliveries.drain() {
for msg in msgs {
let rh = msg.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
}
return Ok(());
}
Some(key) = completed_rx.recv() => {
let _ = completed_tx.send(key);
}
result = async {
sqs.receive_message()
.queue_url(queue_url)
.wait_time_seconds(5)
.max_number_of_messages(prefetch.saturating_sub(in_flight_count).min(10) as i32)
.message_system_attribute_names(MessageSystemAttributeName::ApproximateReceiveCount)
.message_system_attribute_names(MessageSystemAttributeName::MessageGroupId)
.message_attribute_names("All")
.send()
.await
}, if can_accept => {
let messages = result
.map_err(|e| {
metrics::record_backend_error(
metrics::BackendLabel::SnsSqs,
metrics::BackendErrorKind::Consume,
);
map_sqs_error(
&format!("SQS ReceiveMessage failed on {queue_url}"),
e,
)
})?
.messages
.unwrap_or_default();
debug!(queue_url, received = messages.len(), "received messages from SQS (sequenced)");
for msg in messages {
let receipt_handle = msg.receipt_handle().unwrap_or_default().to_string();
let retry_count = router::get_retry_count(&msg);
let seq_key = match extract_sequence_key(&msg) {
Some(k) => k,
None => {
warn!(
queue_url,
"message missing MessageGroupId, rejecting"
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
};
if on_failure == SequenceFailure::FailAll
&& poisoned_keys.contains(&seq_key)
{
warn!(
sequence_key = %seq_key,
queue_url,
"message with poisoned sequence key, rejecting"
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
if retry_count >= options.max_retries {
warn!(
queue_url,
retry_count,
max_retries = options.max_retries,
"message exceeded max retries, rejecting"
);
if on_failure == SequenceFailure::FailAll {
info!(
sequence_key = %seq_key,
queue_url,
"poisoning sequence key (FailAll)"
);
poisoned_keys.insert(seq_key.clone());
if let Some(pending) = pending_deliveries.remove(&seq_key) {
for pd in pending {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
}
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
match key_states.get(&seq_key) {
Some(KeyState::InFlight { .. }) => {
if let Some(limit) = options.max_pending_per_key {
let current_len = pending_deliveries
.get(&seq_key)
.map_or(0, |q| q.len());
if current_len >= limit {
warn!(
sequence_key = %seq_key,
queue_url,
limit,
"per-key pending buffer full, rejecting"
);
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::PendingFull,
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
}
debug!(
sequence_key = %seq_key,
queue_url,
"key in-flight, buffering delivery locally"
);
pending_deliveries
.entry(seq_key)
.or_insert_with(|| VecDeque::with_capacity(4))
.push_back(msg);
continue;
}
Some(KeyState::AwaitingRetry) => {
if retry_count > 0 {
debug!(
sequence_key = %seq_key,
queue_url,
retry_count,
"returning retry clears AwaitingRetry"
);
key_states.remove(&seq_key);
} else {
if let Some(limit) = options.max_pending_per_key {
let current_len = pending_deliveries
.get(&seq_key)
.map_or(0, |q| q.len());
if current_len >= limit {
warn!(
sequence_key = %seq_key,
queue_url,
limit,
"per-key pending buffer full, rejecting"
);
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::PendingFull,
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
}
debug!(
sequence_key = %seq_key,
queue_url,
"key awaiting retry, buffering new delivery locally"
);
pending_deliveries
.entry(seq_key)
.or_default()
.push_back(msg);
continue;
}
}
None => {}
}
let body = extract_payload(msg.body().unwrap_or_default());
metrics::record_message_size(&topic, group.as_deref(), body.len());
if let Err(e) = options.validate_payload_message_size(body.len()) {
warn!(
error = %e,
queue_url,
sequence_key = %seq_key,
"rejecting oversized message"
);
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::Oversize,
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(seq_key.clone());
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let message: T::Message = match <T::Codec as crate::Codec<T::Message>>::decode(
body.as_bytes(),
) {
Ok(m) => m,
Err(err) => {
error!(
error = %err,
queue_url,
sequence_key = %seq_key,
"failed to deserialize SQS message, rejecting"
);
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::Deserialize,
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(seq_key.clone());
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
};
let metadata = extract_metadata(&msg);
debug!(
queue_url,
sequence_key = %seq_key,
retry_count,
"dispatching sequenced message to handler"
);
let rx = spawn_handler_keyed::<T, H>(
handler,
ctx,
message,
metadata,
options.handler_timeout,
&completed_tx,
seq_key.clone(),
Arc::clone(&topic),
group.clone(),
);
key_states.insert(
seq_key,
KeyState::InFlight {
receipt_handle,
msg: Arc::new(msg),
retry_count,
outcome_rx: rx,
},
);
in_flight_count += 1;
options.processing.store(true, Ordering::Relaxed);
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
async fn drain_pending_for_key<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
key: &str,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
options: &ConsumerOptions,
on_failure: SequenceFailure,
topology: &'static QueueTopology,
poisoned_keys: &mut HashSet<String>,
completed_tx: &mpsc::UnboundedSender<String>,
key_states: &mut HashMap<String, KeyState>,
in_flight_count: &mut usize,
pending_deliveries: &mut HashMap<String, VecDeque<Message>>,
topic: &Arc<str>,
group: &Option<Arc<str>>,
) where
T: Topic,
H: MessageHandler<T>,
{
if on_failure == SequenceFailure::FailAll && poisoned_keys.contains(key) {
if let Some(pending) = pending_deliveries.remove(key) {
for pd in pending {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
}
return;
}
let Some(pending) = pending_deliveries.get_mut(key) else {
return;
};
while let Some(msg) = pending.pop_front() {
let receipt_handle = msg.receipt_handle().unwrap_or_default().to_string();
let retry_count = router::get_retry_count(&msg);
if retry_count >= options.max_retries {
warn!(
queue_url,
sequence_key = %key,
retry_count,
"buffered message exceeded max retries, rejecting"
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(key.to_string());
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
while let Some(pd) = pending.pop_front() {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
pending_deliveries.remove(key);
return;
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let body = extract_payload(msg.body().unwrap_or_default());
metrics::record_message_size(topic, group.as_deref(), body.len());
if let Err(e) = options.validate_payload_message_size(body.len()) {
warn!(
error = %e,
queue_url,
sequence_key = %key,
"rejecting oversized buffered message"
);
metrics::record_failed(topic, group.as_deref(), metrics::FailReason::Oversize);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(key.to_string());
while let Some(pd) = pending.pop_front() {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
pending_deliveries.remove(key);
return;
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let message: T::Message =
match <T::Codec as crate::Codec<T::Message>>::decode(body.as_bytes()) {
Ok(m) => m,
Err(err) => {
error!(
error = %err,
queue_url,
sequence_key = %key,
"failed to deserialize buffered SQS message, rejecting"
);
metrics::record_failed(
topic,
group.as_deref(),
metrics::FailReason::Deserialize,
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(key.to_string());
while let Some(pd) = pending.pop_front() {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
pending_deliveries.remove(key);
return;
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
};
let metadata = extract_metadata(&msg);
let rx = spawn_handler_keyed::<T, H>(
handler,
ctx,
message,
metadata,
options.handler_timeout,
completed_tx,
key.to_string(),
Arc::clone(topic),
group.clone(),
);
key_states.insert(
key.to_string(),
KeyState::InFlight {
receipt_handle,
msg: Arc::new(msg),
retry_count,
outcome_rx: rx,
},
);
*in_flight_count += 1;
if pending.is_empty() {
pending_deliveries.remove(key);
}
return;
}
pending_deliveries.remove(key);
}
async fn consume_dlq_loop<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
original_queue: &str,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
shutdown: &CancellationToken,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
info!(queue_url, "DLQ consumer started");
loop {
tokio::select! {
biased;
_ = shutdown.cancelled() => {
debug!("shutdown signal received, stopping DLQ consumer on {queue_url}");
return Ok(());
}
result = sqs
.receive_message()
.queue_url(queue_url)
.wait_time_seconds(5)
.max_number_of_messages(10)
.message_system_attribute_names(MessageSystemAttributeName::ApproximateReceiveCount)
.message_attribute_names("All")
.send() => {
let output = result.map_err(|e| {
metrics::record_backend_error(
metrics::BackendLabel::SnsSqs,
metrics::BackendErrorKind::Consume,
);
map_sqs_error(&format!("SQS ReceiveMessage failed on DLQ {queue_url}"), e)
})?;
let messages = output.messages.unwrap_or_default();
debug!(queue_url, received = messages.len(), "received messages from DLQ");
for msg in messages {
let receipt_handle = msg.receipt_handle().unwrap_or_default().to_string();
let body = extract_payload(msg.body().unwrap_or_default());
let metadata = extract_dead_metadata(&msg, original_queue);
if body.len() > DEFAULT_MAX_MESSAGE_SIZE {
warn!(
bytes = body.len(),
max = DEFAULT_MAX_MESSAGE_SIZE,
delivery_id = %metadata.message.delivery_id,
"oversized DLQ message — discarding"
);
} else {
match <T::Codec as crate::Codec<T::Message>>::decode(body.as_bytes()) {
Err(err) => {
error!(
error = %err,
delivery_id = %metadata.message.delivery_id,
"failed to deserialize message from DLQ — discarding"
);
}
Ok(message) => {
debug!(
queue_url,
delivery_id = %metadata.message.delivery_id,
death_count = metadata.death_count,
"dispatching DLQ message to handle_dead"
);
handler.handle_dead(message, metadata, ctx.as_ref()).await;
}
}
}
debug!(queue_url, "acking DLQ message");
router::route_ack(sqs, queue_url, &receipt_handle).await;
}
}
}
}
}
impl SqsConsumer {
pub fn run<T, H>(
&self,
handler: H,
ctx: H::Context,
options: crate::ConsumerOptions<Sqs>,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
self.run_with_inner::<T, H>(handler, ctx, options.into_inner())
}
pub(crate) fn run_with_inner<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
let client = self.client.clone();
let queue_registry = self.queue_registry.clone();
async move {
let topology = T::topology();
let consumer = SqsConsumer::new(client, queue_registry);
let queue_url = consumer.resolve_queue_url(topology.queue()).await?;
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let sqs = consumer.client.sqs().clone();
run_with_reconnect(
&options.shutdown,
topology.queue(),
options.max_reconnect_attempts,
|| {
consume_loop_concurrent::<T, H>(
&sqs, &queue_url, topology, &handler, &ctx, &options,
)
},
)
.await
}
}
pub fn run_fifo<T, H>(
&self,
handler: H,
ctx: H::Context,
options: crate::ConsumerOptions<Sqs>,
) -> impl Future<Output = Result<()>> + Send
where
T: SequencedTopic,
H: MessageHandler<T>,
{
self.run_fifo_with_inner::<T, H>(handler, ctx, options.into_inner())
}
pub async fn run_fifo_until_timeout<T, H, S>(
&self,
handler: H,
ctx: H::Context,
options: crate::ConsumerOptions<Sqs>,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
T: SequencedTopic,
H: MessageHandler<T>,
S: Future<Output = ()> + Send + 'static,
{
self.run_fifo_until_timeout_with_inner::<T, H, S>(
handler,
ctx,
options.into_inner(),
signal,
drain_timeout,
)
.await
}
pub(crate) async fn run_fifo_until_timeout_with_inner<T, H, S>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
T: SequencedTopic,
H: MessageHandler<T>,
S: Future<Output = ()> + Send + 'static,
{
let shutdown = options.shutdown.clone();
let handles = match self.spawn_fifo_shards::<T, H>(handler, ctx, options).await {
Ok(h) => h,
Err(e) => {
error!(error = %e, "run_fifo_until_timeout: shard spawn failed");
return SupervisorOutcome {
errors: 1,
panics: 0,
timed_out: false,
};
}
};
drive_fifo_until_timeout(handles, shutdown, signal, drain_timeout).await
}
pub(crate) async fn run_fifo_with_inner<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> Result<()>
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let handles = self
.spawn_fifo_shards::<T, H>(handler, ctx, options)
.await?;
for handle in handles {
match handle.await {
Ok(Ok(())) => {}
Ok(Err(e)) => error!("SQS sequenced shard task failed: {e}"),
Err(e) => error!("SQS sequenced shard task panicked: {e}"),
}
}
Ok(())
}
pub(crate) async fn spawn_fifo_shards<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> Result<Vec<tokio::task::JoinHandle<Result<()>>>>
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let topology = T::topology();
let seq = topology
.sequencing()
.ok_or_else(|| ShoveError::Topology("run_fifo requires a sequenced topic".into()))?;
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let consumer = SqsConsumer::new(self.client.clone(), self.queue_registry.clone());
let on_failure = seq.on_failure();
let mut handles = Vec::new();
for i in 0..seq.routing_shards() {
let shard_queue_name = format!("{}-seq-{i}", topology.queue());
let shard_queue_url = consumer.resolve_queue_url(&shard_queue_name).await?;
let sqs = consumer.client.sqs().clone();
let h = handler.clone();
let c = ctx.clone();
let opts = options.clone();
handles.push(tokio::spawn(async move {
run_sequenced_shard::<T, H>(
&sqs,
&shard_queue_url,
&shard_queue_name,
topology,
&h,
&c,
&opts,
on_failure,
)
.await
}));
}
Ok(handles)
}
pub fn run_dlq<T, H>(
&self,
handler: H,
ctx: H::Context,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
let client = self.client.clone();
let queue_registry = self.queue_registry.clone();
async move {
let topology = T::topology();
let dlq = topology.dlq().ok_or_else(|| {
ShoveError::Topology(format!(
"topic '{}' has no DLQ configured",
topology.queue()
))
})?;
let consumer = SqsConsumer::new(client, queue_registry);
let queue_url = consumer.resolve_queue_url(dlq).await?;
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let sqs = consumer.client.sqs().clone();
let shutdown = consumer.client.shutdown_token();
run_with_reconnect(&shutdown, dlq, None, || {
consume_dlq_loop::<T, H>(
&sqs,
&queue_url,
topology.queue(),
&handler,
&ctx,
&shutdown,
)
})
.await
}
}
}