use aws_sdk_sqs::config::http::HttpResponse;
use aws_sdk_sqs::error::{ProvideErrorMetadata, SdkError};
use aws_sdk_sqs::types::{Message, MessageAttributeValue, MessageSystemAttributeName};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use tokio::sync::oneshot::error::TryRecvError;
use tokio::sync::{Notify, mpsc, oneshot};
use tokio_util::sync::CancellationToken;
use tracing::{debug, error, info, warn};
use crate::backend::ConsumerOptionsInner as ConsumerOptions;
use crate::backends::sns::client::SnsClient;
use crate::backends::sns::router;
use crate::backends::sns::topology::QueueRegistry;
use crate::error::{Result, ShoveError};
use crate::handler::MessageHandler;
use crate::metadata::{DeadMessageMetadata, MessageMetadata};
use crate::outcome::Outcome;
use crate::retry::Backoff;
use crate::topic::{SequencedTopic, Topic};
use crate::topology::{QueueTopology, SequenceFailure};
use crate::{DEFAULT_MAX_MESSAGE_SIZE, Sqs};
fn map_sqs_error<E>(context: &str, e: SdkError<E, HttpResponse>) -> ShoveError
where
E: std::fmt::Debug + std::fmt::Display + ProvideErrorMetadata,
{
match &e {
SdkError::TimeoutError(_) => ShoveError::Connection(format!("{context}: {e}")),
SdkError::DispatchFailure(_) => ShoveError::Connection(format!("{context}: {e}")),
SdkError::ResponseError(_) => ShoveError::Connection(format!("{context}: {e}")),
SdkError::ConstructionFailure(_) => ShoveError::Topology(format!("{context}: {e}")),
SdkError::ServiceError(se) => {
let code = ProvideErrorMetadata::code(se.err());
let is_transient = matches!(
code,
Some("RequestThrottled" | "Throttling" | "KMS.ThrottlingException" | "OverLimit")
);
if is_transient {
ShoveError::Connection(format!("{context}: {e}"))
} else {
ShoveError::Topology(format!("{context}: {e}"))
}
}
_ => ShoveError::Connection(format!("{context}: {e}")),
}
}
#[derive(Clone)]
pub struct SqsConsumer {
client: SnsClient,
queue_registry: Arc<QueueRegistry>,
}
impl SqsConsumer {
pub fn new(client: SnsClient, queue_registry: Arc<QueueRegistry>) -> Self {
Self {
client,
queue_registry,
}
}
async fn resolve_queue_url(&self, queue_name: &str) -> Result<String> {
self.queue_registry.get(queue_name).await.ok_or_else(|| {
ShoveError::Topology(format!(
"no SQS queue URL registered for '{queue_name}'. Declare topology first."
))
})
}
}
fn extract_metadata(msg: &Message) -> MessageMetadata {
let retry_count = router::get_retry_count(msg);
MessageMetadata {
retry_count,
delivery_id: msg.message_id().unwrap_or_default().to_string(),
redelivered: retry_count > 0,
headers: router::extract_message_attributes(msg),
}
}
#[derive(serde::Deserialize)]
struct SnsEnvelope {
#[serde(rename = "Type")]
notification_type: String,
#[serde(rename = "Message")]
message: String,
}
fn extract_payload(body: &str) -> std::borrow::Cow<'_, str> {
if let Ok(envelope) = serde_json::from_str::<SnsEnvelope>(body)
&& envelope.notification_type == "Notification"
{
return std::borrow::Cow::Owned(envelope.message);
}
std::borrow::Cow::Borrowed(body)
}
fn extract_dead_metadata(msg: &Message, queue_name: &str) -> DeadMessageMetadata {
let metadata = extract_metadata(msg);
let death_count = metadata.retry_count;
DeadMessageMetadata {
message: metadata,
reason: Some("max_receives_exceeded".into()),
original_queue: Some(queue_name.to_string()),
death_count,
}
}
async fn run_with_reconnect<F, Fut>(
shutdown: &CancellationToken,
queue: &str,
mut f: F,
) -> Result<()>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<()>>,
{
let mut backoff = Backoff::default();
loop {
match f().await {
Ok(()) => return Ok(()),
Err(e) => {
if !e.is_retryable() {
return Err(e);
}
if shutdown.is_cancelled() {
return Ok(());
}
let delay = backoff.next().expect("backoff is infinite");
warn!("consumer error on {queue}: {e}. Reconnecting in {delay:?}");
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => return Ok(()),
}
}
}
}
}
async fn invoke_handler<T: Topic, H: MessageHandler<T>>(
handler: &H,
ctx: &H::Context,
message: T::Message,
metadata: MessageMetadata,
timeout: Option<Duration>,
) -> Outcome {
match timeout {
Some(duration) => tokio::time::timeout(duration, handler.handle(message, metadata, ctx))
.await
.unwrap_or_else(|_| {
warn!("handler exceeded timeout ({duration:?}), retrying message");
Outcome::Retry
}),
None => handler.handle(message, metadata, ctx).await,
}
}
fn spawn_handler<T, H>(
handler: &Arc<H>,
ctx: &Arc<H::Context>,
message: T::Message,
metadata: MessageMetadata,
timeout: Option<Duration>,
notify: &Arc<Notify>,
) -> oneshot::Receiver<Outcome>
where
T: Topic,
H: MessageHandler<T>,
{
let (tx, rx) = oneshot::channel();
let h = handler.clone();
let c = ctx.clone();
let n = notify.clone();
tokio::spawn(async move {
let outcome = invoke_handler::<T, H>(&h, c.as_ref(), message, metadata, timeout).await;
let _ = tx.send(outcome);
n.notify_one();
});
rx
}
struct PendingMessage {
receipt_handle: String,
body: String,
message_attributes: HashMap<String, MessageAttributeValue>,
retry_count: u32,
outcome_rx: oneshot::Receiver<Outcome>,
}
async fn consume_loop_concurrent<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
topology: &'static QueueTopology,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
options: &ConsumerOptions,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let notify = Arc::new(Notify::new());
let max_in_flight = options.prefetch_count as usize;
let receive_batch: usize = {
let configured = if options.receive_batch_size > 0 {
options.receive_batch_size as usize
} else {
10
};
configured.min(10)
};
let mut in_flight: VecDeque<PendingMessage> = VecDeque::with_capacity(max_in_flight);
let mut local_buffer: VecDeque<Message> = VecDeque::with_capacity(receive_batch);
let mut pending_acks: Vec<String> = Vec::with_capacity(10);
info!(
queue_url,
max_in_flight, receive_batch, "SQS consumer started"
);
loop {
while let Some(front) = in_flight.front_mut() {
match front.outcome_rx.try_recv() {
Ok(Outcome::Ack) => {
let msg = in_flight.pop_front().unwrap();
debug!(queue_url, receipt_handle = %msg.receipt_handle, "message acked (pending flush)");
pending_acks.push(msg.receipt_handle);
if pending_acks.len() >= 10 {
let batch_size = pending_acks.len();
debug!(queue_url, batch_size, "flushing full ack batch");
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks))
.await;
}
}
Ok(outcome) => {
if !pending_acks.is_empty() {
let batch_size = pending_acks.len();
debug!(
queue_url,
batch_size,
?outcome,
"flushing ack batch before non-ack outcome"
);
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks))
.await;
}
let msg = in_flight.pop_front().unwrap();
debug!(queue_url, ?outcome, "message handled");
route_outcome(
sqs,
queue_url,
&msg.receipt_handle,
&msg.body,
&msg.message_attributes,
outcome,
topology,
msg.retry_count,
)
.await;
}
Err(TryRecvError::Empty) => break,
Err(TryRecvError::Closed) => {
if !pending_acks.is_empty() {
let batch_size = pending_acks.len();
debug!(
queue_url,
batch_size, "flushing ack batch after handler panic"
);
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks))
.await;
}
let msg = in_flight.pop_front().unwrap();
warn!(queue_url, "handler task panicked, retrying message");
route_outcome(
sqs,
queue_url,
&msg.receipt_handle,
&msg.body,
&msg.message_attributes,
Outcome::Retry,
topology,
msg.retry_count,
)
.await;
}
}
}
options.processing.store(
!in_flight.is_empty() || !local_buffer.is_empty(),
Ordering::Release,
);
if options.shutdown.is_cancelled() {
debug!(
"shutdown signal, requeueing {} buffered, draining {} in-flight on {queue_url}",
local_buffer.len(),
in_flight.len()
);
if !pending_acks.is_empty() {
let batch_size = pending_acks.len();
debug!(queue_url, batch_size, "flushing ack batch on shutdown");
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks)).await;
}
for msg in local_buffer.drain(..) {
if let Some(rh) = msg.receipt_handle() {
router::route_requeue(sqs, queue_url, rh).await;
}
}
for pending in in_flight {
let outcome = pending.outcome_rx.await.unwrap_or(Outcome::Retry);
route_outcome(
sqs,
queue_url,
&pending.receipt_handle,
&pending.body,
&pending.message_attributes,
outcome,
topology,
pending.retry_count,
)
.await;
}
return Ok(());
}
while in_flight.len() < max_in_flight {
let Some(msg) = local_buffer.pop_front() else {
break;
};
let receipt_handle = msg.receipt_handle().unwrap_or_default().to_string();
let retry_count = router::get_retry_count(&msg);
if retry_count >= options.max_retries {
warn!(
queue_url,
retry_count,
max_retries = options.max_retries,
"message exceeded max retries, rejecting"
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let body = extract_payload(msg.body().unwrap_or_default());
if let Err(e) = options.validate_payload_message_size(body.len()) {
warn!(error = %e, queue_url, "rejecting oversized message");
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let message: T::Message = match serde_json::from_str(&body) {
Ok(m) => m,
Err(err) => {
error!(error = %err, queue_url, "failed to deserialize SQS message, rejecting");
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
};
let metadata = extract_metadata(&msg);
debug!(
queue_url,
message_id = %metadata.delivery_id,
retry_count = metadata.retry_count,
"dispatching message to handler"
);
let rx = spawn_handler::<T, H>(
handler,
ctx,
message,
metadata,
options.handler_timeout,
¬ify,
);
in_flight.push_back(PendingMessage {
receipt_handle,
body: msg.body().unwrap_or_default().to_string(),
message_attributes: msg.message_attributes().cloned().unwrap_or_default(),
retry_count,
outcome_rx: rx,
});
options.processing.store(true, Ordering::Relaxed);
}
if local_buffer.is_empty() && in_flight.len() < max_in_flight {
if !pending_acks.is_empty() {
let batch_size = pending_acks.len();
debug!(queue_url, batch_size, "flushing ack batch before poll");
router::route_ack_batch(sqs, queue_url, std::mem::take(&mut pending_acks)).await;
}
let max_messages = receive_batch as i32;
let receive_result = sqs
.receive_message()
.queue_url(queue_url)
.wait_time_seconds(0)
.max_number_of_messages(max_messages)
.message_system_attribute_names(MessageSystemAttributeName::ApproximateReceiveCount)
.message_attribute_names("All")
.send()
.await
.map_err(|e| {
map_sqs_error(&format!("SQS ReceiveMessage failed on {queue_url}"), e)
})?;
let msgs = receive_result.messages.unwrap_or_default();
if msgs.is_empty() {
debug!(queue_url, "queue empty, backing off 500ms");
tokio::select! {
biased;
_ = options.shutdown.cancelled() => {}
_ = tokio::time::sleep(Duration::from_millis(500)) => {}
}
} else {
debug!(
queue_url,
received = msgs.len(),
"received messages from SQS"
);
}
local_buffer.extend(msgs);
continue;
}
if in_flight.len() >= max_in_flight {
notify.notified().await;
}
}
}
#[allow(clippy::too_many_arguments)]
async fn route_outcome(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
receipt_handle: &str,
body: &str,
message_attributes: &HashMap<String, MessageAttributeValue>,
outcome: Outcome,
topology: &'static QueueTopology,
retry_count: u32,
) {
match outcome {
Outcome::Ack => router::route_ack(sqs, queue_url, receipt_handle).await,
Outcome::Retry => {
router::route_retry(
sqs,
queue_url,
receipt_handle,
body,
message_attributes,
topology,
retry_count,
)
.await;
}
Outcome::Reject => router::route_reject(sqs, queue_url, receipt_handle, topology).await,
Outcome::Defer => {
router::route_defer(
sqs,
queue_url,
receipt_handle,
body,
message_attributes,
topology,
retry_count,
)
.await;
}
}
}
enum KeyState {
InFlight {
receipt_handle: String,
body: String,
message_attributes: HashMap<String, MessageAttributeValue>,
retry_count: u32,
outcome_rx: oneshot::Receiver<Outcome>,
},
AwaitingRetry,
}
fn extract_sequence_key(msg: &Message) -> Option<String> {
msg.attributes()
.and_then(|attrs| attrs.get(&MessageSystemAttributeName::MessageGroupId))
.map(|s| s.to_string())
}
fn spawn_handler_keyed<T, H>(
handler: &Arc<H>,
ctx: &Arc<H::Context>,
message: T::Message,
metadata: MessageMetadata,
timeout: Option<Duration>,
completed_tx: &mpsc::UnboundedSender<String>,
key: String,
) -> oneshot::Receiver<Outcome>
where
T: Topic,
H: MessageHandler<T>,
{
let (tx, rx) = oneshot::channel();
let h = handler.clone();
let c = ctx.clone();
let completed = completed_tx.clone();
tokio::spawn(async move {
let outcome = invoke_handler::<T, H>(&h, c.as_ref(), message, metadata, timeout).await;
let _ = tx.send(outcome);
let _ = completed.send(key);
});
rx
}
#[allow(clippy::too_many_arguments)]
async fn run_sequenced_shard<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
queue_name: &str,
topology: &'static QueueTopology,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
options: &ConsumerOptions,
on_failure: SequenceFailure,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let mut poisoned_keys = HashSet::new();
let mut pending_deliveries: HashMap<String, VecDeque<Message>> = HashMap::new();
let mut backoff = Backoff::default();
loop {
match consume_loop_sequenced::<T, H>(
sqs,
queue_url,
topology,
handler,
ctx,
options,
on_failure,
&mut poisoned_keys,
&mut pending_deliveries,
)
.await
{
Ok(()) => {
for (_key, msgs) in pending_deliveries.drain() {
for msg in msgs {
let rh = msg.receipt_handle().unwrap_or_default();
router::route_requeue(sqs, queue_url, rh).await;
}
}
return Ok(());
}
Err(e) => {
if options.shutdown.is_cancelled() {
pending_deliveries.clear();
return Ok(());
}
pending_deliveries.clear();
let delay = backoff.next().expect("backoff is infinite");
warn!("consumer error on {queue_name}: {e}. Reconnecting in {delay:?}");
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = options.shutdown.cancelled() => return Ok(()),
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
async fn consume_loop_sequenced<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
topology: &'static QueueTopology,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
options: &ConsumerOptions,
on_failure: SequenceFailure,
poisoned_keys: &mut HashSet<String>,
pending_deliveries: &mut HashMap<String, VecDeque<Message>>,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let prefetch = options.prefetch_count as usize;
let (completed_tx, mut completed_rx) = mpsc::unbounded_channel::<String>();
let mut key_states: HashMap<String, KeyState> = HashMap::new();
let mut in_flight_count: usize = 0;
info!(queue_url, prefetch, "sequenced SQS consumer started");
loop {
while let Ok(key) = completed_rx.try_recv() {
let Some(state) = key_states.remove(&key) else {
continue;
};
let KeyState::InFlight {
receipt_handle,
body,
message_attributes,
retry_count,
mut outcome_rx,
} = state
else {
key_states.insert(key, state);
continue;
};
let outcome = match outcome_rx.try_recv() {
Ok(o) => o,
Err(TryRecvError::Closed) => {
warn!(queue_url, sequence_key = %key, "handler task panicked, retrying");
Outcome::Retry
}
Err(TryRecvError::Empty) => {
key_states.insert(
key,
KeyState::InFlight {
receipt_handle,
body,
message_attributes,
retry_count,
outcome_rx,
},
);
continue;
}
};
debug!(queue_url, sequence_key = %key, ?outcome, "message handled (sequenced)");
match outcome {
Outcome::Ack => {
router::route_ack(sqs, queue_url, &receipt_handle).await;
in_flight_count -= 1;
drain_pending_for_key::<T, H>(
sqs,
queue_url,
&key,
handler,
ctx,
options,
on_failure,
topology,
poisoned_keys,
&completed_tx,
&mut key_states,
&mut in_flight_count,
pending_deliveries,
)
.await;
}
Outcome::Reject => {
if on_failure == SequenceFailure::FailAll {
info!(
sequence_key = %key,
queue_url,
"poisoning sequence key (FailAll)"
);
poisoned_keys.insert(key.clone());
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
in_flight_count -= 1;
drain_pending_for_key::<T, H>(
sqs,
queue_url,
&key,
handler,
ctx,
options,
on_failure,
topology,
poisoned_keys,
&completed_tx,
&mut key_states,
&mut in_flight_count,
pending_deliveries,
)
.await;
}
Outcome::Retry => {
router::route_retry_fifo(
sqs,
queue_url,
&receipt_handle,
topology,
retry_count,
)
.await;
in_flight_count -= 1;
key_states.insert(key, KeyState::AwaitingRetry);
}
Outcome::Defer => {
warn!(
queue_url,
sequence_key = %key,
"Defer is not supported on sequenced (FIFO) consumers, treating as Retry"
);
router::route_retry_fifo(
sqs,
queue_url,
&receipt_handle,
topology,
retry_count,
)
.await;
in_flight_count -= 1;
key_states.insert(key, KeyState::AwaitingRetry);
}
}
}
options
.processing
.store(in_flight_count > 0, Ordering::Relaxed);
let can_accept = in_flight_count < prefetch;
tokio::select! {
biased;
_ = options.shutdown.cancelled() => {
debug!(
"shutdown signal, draining {} in-flight messages on {queue_url}",
in_flight_count
);
for (key, state) in key_states.drain() {
if let KeyState::InFlight { receipt_handle, body: _, message_attributes: _, retry_count, outcome_rx } = state {
let outcome = outcome_rx.await.unwrap_or(Outcome::Retry);
debug!(
queue_url,
sequence_key = %key,
?outcome,
"draining in-flight message on shutdown"
);
match outcome {
Outcome::Ack => {
router::route_ack(sqs, queue_url, &receipt_handle).await;
}
Outcome::Reject => {
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
}
Outcome::Retry => {
router::route_retry_fifo(
sqs,
queue_url,
&receipt_handle,
topology,
retry_count,
)
.await;
}
Outcome::Defer => {
warn!(
queue_url,
sequence_key = %key,
"Defer is not supported on sequenced (FIFO) consumers, treating as Retry"
);
router::route_retry_fifo(
sqs,
queue_url,
&receipt_handle,
topology,
retry_count,
)
.await;
}
}
}
}
for (_key, msgs) in pending_deliveries.drain() {
for msg in msgs {
let rh = msg.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
}
return Ok(());
}
Some(key) = completed_rx.recv() => {
let _ = completed_tx.send(key);
}
result = async {
sqs.receive_message()
.queue_url(queue_url)
.wait_time_seconds(5)
.max_number_of_messages(prefetch.saturating_sub(in_flight_count).min(10) as i32)
.message_system_attribute_names(MessageSystemAttributeName::ApproximateReceiveCount)
.message_system_attribute_names(MessageSystemAttributeName::MessageGroupId)
.message_attribute_names("All")
.send()
.await
}, if can_accept => {
let messages = result
.map_err(|e| {
map_sqs_error(
&format!("SQS ReceiveMessage failed on {queue_url}"),
e,
)
})?
.messages
.unwrap_or_default();
debug!(queue_url, received = messages.len(), "received messages from SQS (sequenced)");
for msg in messages {
let receipt_handle = msg.receipt_handle().unwrap_or_default().to_string();
let retry_count = router::get_retry_count(&msg);
let seq_key = match extract_sequence_key(&msg) {
Some(k) => k,
None => {
warn!(
queue_url,
"message missing MessageGroupId, rejecting"
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
};
if on_failure == SequenceFailure::FailAll
&& poisoned_keys.contains(&seq_key)
{
warn!(
sequence_key = %seq_key,
queue_url,
"message with poisoned sequence key, rejecting"
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
if retry_count >= options.max_retries {
warn!(
queue_url,
retry_count,
max_retries = options.max_retries,
"message exceeded max retries, rejecting"
);
if on_failure == SequenceFailure::FailAll {
info!(
sequence_key = %seq_key,
queue_url,
"poisoning sequence key (FailAll)"
);
poisoned_keys.insert(seq_key.clone());
if let Some(pending) = pending_deliveries.remove(&seq_key) {
for pd in pending {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
}
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
match key_states.get(&seq_key) {
Some(KeyState::InFlight { .. }) => {
if let Some(limit) = options.max_pending_per_key {
let current_len = pending_deliveries
.get(&seq_key)
.map_or(0, |q| q.len());
if current_len >= limit {
warn!(
sequence_key = %seq_key,
queue_url,
limit,
"per-key pending buffer full, rejecting"
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
}
debug!(
sequence_key = %seq_key,
queue_url,
"key in-flight, buffering delivery locally"
);
pending_deliveries
.entry(seq_key)
.or_insert_with(|| VecDeque::with_capacity(4))
.push_back(msg);
continue;
}
Some(KeyState::AwaitingRetry) => {
if retry_count > 0 {
debug!(
sequence_key = %seq_key,
queue_url,
retry_count,
"returning retry clears AwaitingRetry"
);
key_states.remove(&seq_key);
} else {
if let Some(limit) = options.max_pending_per_key {
let current_len = pending_deliveries
.get(&seq_key)
.map_or(0, |q| q.len());
if current_len >= limit {
warn!(
sequence_key = %seq_key,
queue_url,
limit,
"per-key pending buffer full, rejecting"
);
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
}
debug!(
sequence_key = %seq_key,
queue_url,
"key awaiting retry, buffering new delivery locally"
);
pending_deliveries
.entry(seq_key)
.or_default()
.push_back(msg);
continue;
}
}
None => {}
}
let body = extract_payload(msg.body().unwrap_or_default());
if let Err(e) = options.validate_payload_message_size(body.len()) {
warn!(
error = %e,
queue_url,
sequence_key = %seq_key,
"rejecting oversized message"
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(seq_key.clone());
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let message: T::Message = match serde_json::from_str(&body) {
Ok(m) => m,
Err(err) => {
error!(
error = %err,
queue_url,
sequence_key = %seq_key,
"failed to deserialize SQS message, rejecting"
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(seq_key.clone());
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
};
let metadata = extract_metadata(&msg);
debug!(
queue_url,
sequence_key = %seq_key,
retry_count,
"dispatching sequenced message to handler"
);
let rx = spawn_handler_keyed::<T, H>(
handler,
ctx,
message,
metadata,
options.handler_timeout,
&completed_tx,
seq_key.clone(),
);
key_states.insert(
seq_key,
KeyState::InFlight {
receipt_handle,
body: msg.body().unwrap_or_default().to_string(),
message_attributes: msg
.message_attributes()
.cloned()
.unwrap_or_default(),
retry_count,
outcome_rx: rx,
},
);
in_flight_count += 1;
options.processing.store(true, Ordering::Relaxed);
}
}
}
}
}
#[allow(clippy::too_many_arguments)]
async fn drain_pending_for_key<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
key: &str,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
options: &ConsumerOptions,
on_failure: SequenceFailure,
topology: &'static QueueTopology,
poisoned_keys: &mut HashSet<String>,
completed_tx: &mpsc::UnboundedSender<String>,
key_states: &mut HashMap<String, KeyState>,
in_flight_count: &mut usize,
pending_deliveries: &mut HashMap<String, VecDeque<Message>>,
) where
T: Topic,
H: MessageHandler<T>,
{
if on_failure == SequenceFailure::FailAll && poisoned_keys.contains(key) {
if let Some(pending) = pending_deliveries.remove(key) {
for pd in pending {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
}
return;
}
let Some(pending) = pending_deliveries.get_mut(key) else {
return;
};
while let Some(msg) = pending.pop_front() {
let receipt_handle = msg.receipt_handle().unwrap_or_default().to_string();
let retry_count = router::get_retry_count(&msg);
if retry_count >= options.max_retries {
warn!(
queue_url,
sequence_key = %key,
retry_count,
"buffered message exceeded max retries, rejecting"
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(key.to_string());
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
while let Some(pd) = pending.pop_front() {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
pending_deliveries.remove(key);
return;
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let body = extract_payload(msg.body().unwrap_or_default());
if let Err(e) = options.validate_payload_message_size(body.len()) {
warn!(
error = %e,
queue_url,
sequence_key = %key,
"rejecting oversized buffered message"
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(key.to_string());
while let Some(pd) = pending.pop_front() {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
pending_deliveries.remove(key);
return;
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
let message: T::Message = match serde_json::from_str(&body) {
Ok(m) => m,
Err(err) => {
error!(
error = %err,
queue_url,
sequence_key = %key,
"failed to deserialize buffered SQS message, rejecting"
);
if on_failure == SequenceFailure::FailAll {
poisoned_keys.insert(key.to_string());
while let Some(pd) = pending.pop_front() {
let rh = pd.receipt_handle().unwrap_or_default();
router::route_reject(sqs, queue_url, rh, topology).await;
}
pending_deliveries.remove(key);
return;
}
router::route_reject(sqs, queue_url, &receipt_handle, topology).await;
continue;
}
};
let metadata = extract_metadata(&msg);
let rx = spawn_handler_keyed::<T, H>(
handler,
ctx,
message,
metadata,
options.handler_timeout,
completed_tx,
key.to_string(),
);
key_states.insert(
key.to_string(),
KeyState::InFlight {
receipt_handle,
body: msg.body().unwrap_or_default().to_string(),
message_attributes: msg.message_attributes().cloned().unwrap_or_default(),
retry_count,
outcome_rx: rx,
},
);
*in_flight_count += 1;
if pending.is_empty() {
pending_deliveries.remove(key);
}
return;
}
pending_deliveries.remove(key);
}
async fn consume_dlq_loop<T, H>(
sqs: &aws_sdk_sqs::Client,
queue_url: &str,
original_queue: &str,
handler: &Arc<H>,
ctx: &Arc<H::Context>,
shutdown: &CancellationToken,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
info!(queue_url, "DLQ consumer started");
loop {
tokio::select! {
biased;
_ = shutdown.cancelled() => {
debug!("shutdown signal received, stopping DLQ consumer on {queue_url}");
return Ok(());
}
result = sqs
.receive_message()
.queue_url(queue_url)
.wait_time_seconds(5)
.max_number_of_messages(10)
.message_system_attribute_names(MessageSystemAttributeName::ApproximateReceiveCount)
.message_attribute_names("All")
.send() => {
let output = result.map_err(|e| {
map_sqs_error(&format!("SQS ReceiveMessage failed on DLQ {queue_url}"), e)
})?;
let messages = output.messages.unwrap_or_default();
debug!(queue_url, received = messages.len(), "received messages from DLQ");
for msg in messages {
let receipt_handle = msg.receipt_handle().unwrap_or_default().to_string();
let body = extract_payload(msg.body().unwrap_or_default());
let metadata = extract_dead_metadata(&msg, original_queue);
if body.len() > DEFAULT_MAX_MESSAGE_SIZE {
warn!(
bytes = body.len(),
max = DEFAULT_MAX_MESSAGE_SIZE,
delivery_id = %metadata.message.delivery_id,
"oversized DLQ message — discarding"
);
} else {
match serde_json::from_str::<T::Message>(&body) {
Err(err) => {
error!(
error = %err,
delivery_id = %metadata.message.delivery_id,
"failed to deserialize message from DLQ — discarding"
);
}
Ok(message) => {
debug!(
queue_url,
delivery_id = %metadata.message.delivery_id,
death_count = metadata.death_count,
"dispatching DLQ message to handle_dead"
);
handler.handle_dead(message, metadata, ctx.as_ref()).await;
}
}
}
debug!(queue_url, "acking DLQ message");
router::route_ack(sqs, queue_url, &receipt_handle).await;
}
}
}
}
}
impl SqsConsumer {
pub fn run<T, H>(
&self,
handler: H,
ctx: H::Context,
options: crate::ConsumerOptions<Sqs>,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
self.run_with_inner::<T, H>(handler, ctx, options.into_inner())
}
pub(crate) fn run_with_inner<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
let client = self.client.clone();
let queue_registry = self.queue_registry.clone();
async move {
let topology = T::topology();
let consumer = SqsConsumer::new(client, queue_registry);
let queue_url = consumer.resolve_queue_url(topology.queue()).await?;
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let sqs = consumer.client.sqs().clone();
run_with_reconnect(&options.shutdown, topology.queue(), || {
consume_loop_concurrent::<T, H>(
&sqs, &queue_url, topology, &handler, &ctx, &options,
)
})
.await
}
}
pub fn run_fifo<T, H>(
&self,
handler: H,
ctx: H::Context,
options: crate::ConsumerOptions<Sqs>,
) -> impl Future<Output = Result<()>> + Send
where
T: SequencedTopic,
H: MessageHandler<T>,
{
self.run_fifo_with_inner::<T, H>(handler, ctx, options.into_inner())
}
pub(crate) fn run_fifo_with_inner<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> impl Future<Output = Result<()>> + Send
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let client = self.client.clone();
let queue_registry = self.queue_registry.clone();
async move {
let topology = T::topology();
let seq = topology.sequencing().ok_or_else(|| {
ShoveError::Topology("run_fifo requires a sequenced topic".into())
})?;
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let consumer = SqsConsumer::new(client, queue_registry);
let on_failure = seq.on_failure();
let mut handles = Vec::new();
for i in 0..seq.routing_shards() {
let shard_queue_name = format!("{}-seq-{i}", topology.queue());
let shard_queue_url = consumer.resolve_queue_url(&shard_queue_name).await?;
let sqs = consumer.client.sqs().clone();
let h = handler.clone();
let c = ctx.clone();
let opts = options.clone();
handles.push(tokio::spawn(async move {
run_sequenced_shard::<T, H>(
&sqs,
&shard_queue_url,
&shard_queue_name,
topology,
&h,
&c,
&opts,
on_failure,
)
.await
}));
}
for handle in handles {
if let Err(e) = handle.await {
error!("shard consumer task panicked: {e}");
}
}
Ok(())
}
}
pub fn run_dlq<T, H>(
&self,
handler: H,
ctx: H::Context,
) -> impl Future<Output = Result<()>> + Send
where
T: Topic,
H: MessageHandler<T>,
{
let client = self.client.clone();
let queue_registry = self.queue_registry.clone();
async move {
let topology = T::topology();
let dlq = topology.dlq().ok_or_else(|| {
ShoveError::Topology(format!(
"topic '{}' has no DLQ configured",
topology.queue()
))
})?;
let consumer = SqsConsumer::new(client, queue_registry);
let queue_url = consumer.resolve_queue_url(dlq).await?;
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let sqs = consumer.client.sqs().clone();
let shutdown = consumer.client.shutdown_token();
run_with_reconnect(&shutdown, dlq, || {
consume_dlq_loop::<T, H>(
&sqs,
&queue_url,
topology.queue(),
&handler,
&ctx,
&shutdown,
)
})
.await
}
}
}