use std::collections::{BTreeSet, HashMap};
use std::future::Future;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
use bytes::Bytes;
use rdkafka::ClientConfig;
use rdkafka::consumer::{CommitMode, Consumer as RdkafkaConsumer, StreamConsumer};
use rdkafka::error::{KafkaError, KafkaResult};
use rdkafka::message::{BorrowedMessage, Header, Headers, Message, OwnedHeaders};
use rdkafka::{Offset, TopicPartitionList};
use tokio::sync::{Semaphore, mpsc};
use tokio_util::sync::CancellationToken;
use crate::backend::ConsumerOptionsInner as ConsumerOptions;
use crate::consumer::validate_message_size;
use crate::consumer_supervisor::{SupervisorOutcome, drive_fifo_until_timeout};
use crate::error::Result;
use crate::handler::MessageHandler;
use crate::metadata::{DeadMessageMetadata, MessageMetadata};
use crate::metrics;
use crate::outcome::Outcome;
use crate::retry::Backoff;
use crate::topic::{SequencedTopic, Topic};
use crate::topology::QueueTopology;
use crate::{HoldQueue, Kafka, ShoveError};
#[cfg(feature = "kafka-msk-iam")]
use super::msk_iam::MskIamContext;
use super::client::KafkaClient;
use super::constants::{
DEATH_COUNT_HEADER, DEATH_REASON_HEADER, FETCH_MIN_BYTES, FETCH_WAIT_MAX_MS,
MAX_POLL_INTERVAL_MS, MAX_PUBLISH_ATTEMPTS, MESSAGE_ID_HEADER, ORIGINAL_QUEUE_HEADER,
RETRY_COUNT_HEADER, SESSION_TIMEOUT_MS,
};
use super::consumer_group::KafkaAutoOffsetReset;
struct PartitionTracker {
next_to_commit: i64,
completed: BTreeSet<i64>,
}
impl PartitionTracker {
fn new(first_offset: i64) -> Self {
Self {
next_to_commit: first_offset,
completed: BTreeSet::new(),
}
}
fn mark_complete(&mut self, offset: i64) {
self.completed.insert(offset);
}
fn drain_committable(&mut self) -> Option<i64> {
let mut next = self.next_to_commit;
while self.completed.remove(&next) {
next += 1;
}
if next > self.next_to_commit {
self.next_to_commit = next;
Some(next)
} else {
None
}
}
}
struct OffsetTracker {
topic: String,
partitions: HashMap<i32, PartitionTracker>,
}
impl OffsetTracker {
fn new(topic: String) -> Self {
Self {
topic,
partitions: HashMap::new(),
}
}
fn track_received(&mut self, partition: i32, offset: i64) {
self.partitions
.entry(partition)
.or_insert_with(|| PartitionTracker::new(offset));
}
fn mark_complete(&mut self, partition: i32, offset: i64) {
if let Some(tracker) = self.partitions.get_mut(&partition) {
tracker.mark_complete(offset);
}
}
fn drain_committable(&mut self) -> Option<TopicPartitionList> {
let mut tpl: Option<TopicPartitionList> = None;
for (&partition, tracker) in &mut self.partitions {
if let Some(commit_offset) = tracker.drain_committable() {
tpl.get_or_insert_with(TopicPartitionList::new)
.add_partition_offset(&self.topic, partition, Offset::Offset(commit_offset))
.ok();
}
}
tpl
}
}
fn extract_string_headers(msg: &BorrowedMessage<'_>) -> Arc<HashMap<String, String>> {
let mut out = HashMap::new();
if let Some(headers) = msg.headers() {
for idx in 0..headers.count() {
let header = headers.get(idx);
if let Some(value) = header.value
&& let Ok(s) = std::str::from_utf8(value)
{
out.insert(header.key.to_string(), s.to_string());
}
}
}
Arc::new(out)
}
fn get_retry_count(headers: &HashMap<String, String>) -> u32 {
headers
.get(RETRY_COUNT_HEADER)
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(0)
}
fn build_message_metadata(
headers: &Arc<HashMap<String, String>>,
redelivered: bool,
) -> MessageMetadata {
let retry_count = get_retry_count(headers);
let delivery_id = headers.get(MESSAGE_ID_HEADER).cloned().unwrap_or_default();
MessageMetadata {
retry_count,
delivery_id,
redelivered,
headers: Arc::clone(headers),
}
}
fn build_dead_metadata(headers: &Arc<HashMap<String, String>>) -> DeadMessageMetadata {
let message = build_message_metadata(headers, false);
let reason = headers.get(DEATH_REASON_HEADER).cloned();
let original_queue = headers.get(ORIGINAL_QUEUE_HEADER).cloned();
let death_count = headers
.get(DEATH_COUNT_HEADER)
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(0);
DeadMessageMetadata {
message,
reason,
original_queue,
death_count,
}
}
fn headers_with_retry_count(
original: &HashMap<String, String>,
retry_count: u32,
message_id_suffix: &str,
) -> OwnedHeaders {
let mut headers = OwnedHeaders::new_with_capacity(original.len() + 2);
for (k, v) in original {
if k == RETRY_COUNT_HEADER || k == MESSAGE_ID_HEADER {
continue;
}
headers = headers.insert(Header {
key: k.as_str(),
value: Some(v.as_bytes()),
});
}
headers = headers.insert(Header {
key: RETRY_COUNT_HEADER,
value: Some(retry_count.to_string().as_bytes()),
});
let original_id = original.get(MESSAGE_ID_HEADER).cloned().unwrap_or_default();
let new_id = format!("{original_id}{message_id_suffix}");
headers = headers.insert(Header {
key: MESSAGE_ID_HEADER,
value: Some(new_id.as_bytes()),
});
headers
}
fn headers_for_dlq(
original: &HashMap<String, String>,
reason: &str,
original_queue: &str,
) -> OwnedHeaders {
let mut headers = OwnedHeaders::new_with_capacity(original.len() + 4);
for (k, v) in original {
if k == DEATH_REASON_HEADER
|| k == ORIGINAL_QUEUE_HEADER
|| k == DEATH_COUNT_HEADER
|| k == MESSAGE_ID_HEADER
{
continue;
}
headers = headers.insert(Header {
key: k.as_str(),
value: Some(v.as_bytes()),
});
}
headers = headers.insert(Header {
key: DEATH_REASON_HEADER,
value: Some(reason.as_bytes()),
});
headers = headers.insert(Header {
key: ORIGINAL_QUEUE_HEADER,
value: Some(original_queue.as_bytes()),
});
let current_death_count = original
.get(DEATH_COUNT_HEADER)
.and_then(|v| v.parse::<u32>().ok())
.unwrap_or(0);
headers = headers.insert(Header {
key: DEATH_COUNT_HEADER,
value: Some((current_death_count + 1).to_string().as_bytes()),
});
let original_id = original.get(MESSAGE_ID_HEADER).cloned().unwrap_or_default();
headers = headers.insert(Header {
key: MESSAGE_ID_HEADER,
value: Some(format!("{original_id}-dlq").as_bytes()),
});
headers
}
fn adjust_outcome_for_fifo(outcome: Outcome) -> Outcome {
match outcome {
Outcome::Defer => {
tracing::warn!("Defer is not supported on sequenced consumers — treating as Retry");
Outcome::Retry
}
other => other,
}
}
async fn publish_to_dlq(
client: &KafkaClient,
topology: &QueueTopology,
payload: &[u8],
key: Option<&[u8]>,
headers: &HashMap<String, String>,
reason: &str,
) -> Result<()> {
let dlq_topic = match topology.dlq() {
Some(dlq) => dlq.to_string(),
None => {
tracing::warn!(
queue = topology.queue(),
"no DLQ configured, message will be discarded"
);
return Ok(());
}
};
let dlq_headers = headers_for_dlq(headers, reason, topology.queue());
client
.publish_with_retry(
&dlq_topic,
key,
dlq_headers,
payload,
MAX_PUBLISH_ATTEMPTS,
"DLQ publish",
)
.await
}
type CompletionHandle = Option<(mpsc::Sender<(i32, i64)>, i32, i64)>;
fn signal_completion(handle: CompletionHandle, queue: &str) {
if let Some((tx, partition, offset)) = handle
&& tx.try_send((partition, offset)).is_err()
{
tracing::error!(
queue,
partition,
offset,
"completion channel full — logic bug in offset tracker"
);
}
}
#[allow(clippy::too_many_arguments)]
async fn route_outcome(
client: &KafkaClient,
topic: &str,
group: Option<&str>,
payload: &[u8],
key: Option<Bytes>,
headers: &HashMap<String, String>,
outcome: Outcome,
topology: &'static QueueTopology,
retry_count: u32,
max_retries: u32,
hold_queues: &[HoldQueue],
retry_permit: Option<tokio::sync::OwnedSemaphorePermit>,
completion: CompletionHandle,
shutdown: CancellationToken,
) -> bool {
match outcome {
Outcome::Ack => {
signal_completion(completion, topic);
true
}
Outcome::Retry => {
let new_count = retry_count + 1;
if new_count >= max_retries {
metrics::record_failed(topic, group, metrics::FailReason::MaxRetriesExceeded);
let dlq_ok = publish_to_dlq(
client,
topology,
payload,
key.as_deref(),
headers,
"max_retries_exceeded",
)
.await;
signal_completion(completion, topic);
return match dlq_ok {
Ok(()) => true,
Err(e) => {
tracing::error!(error = %e, "failed to publish to DLQ after exhausting retries");
false
}
};
}
let delay = if hold_queues.is_empty() {
Duration::from_secs(1)
} else {
let idx = (retry_count as usize).min(hold_queues.len() - 1);
hold_queues[idx].delay()
};
let retry_headers =
headers_with_retry_count(headers, new_count, &format!("-r{new_count}"));
run_delayed_republish(
client.clone(),
topic.to_string(),
key,
retry_headers,
payload.to_vec(),
delay,
retry_permit,
completion,
shutdown,
"retry republish",
)
.await
}
Outcome::Reject => {
metrics::record_failed(topic, group, metrics::FailReason::Rejected);
let dlq_ok = publish_to_dlq(
client,
topology,
payload,
key.as_deref(),
headers,
"rejected",
)
.await;
signal_completion(completion, topic);
match dlq_ok {
Ok(()) => true,
Err(e) => {
tracing::error!(error = %e, "failed to publish rejected message to DLQ");
false
}
}
}
Outcome::Defer => {
let delay = if hold_queues.is_empty() {
Duration::from_secs(1)
} else {
hold_queues[0].delay()
};
let defer_headers = headers_with_retry_count(
headers,
retry_count,
&format!("-d{}", uuid::Uuid::new_v4()),
);
run_delayed_republish(
client.clone(),
topic.to_string(),
key,
defer_headers,
payload.to_vec(),
delay,
retry_permit,
completion,
shutdown,
"defer republish",
)
.await
}
}
}
#[allow(clippy::too_many_arguments)]
async fn run_delayed_republish(
client: KafkaClient,
topic: String,
key: Option<Bytes>,
headers: OwnedHeaders,
payload: Vec<u8>,
delay: Duration,
retry_permit: Option<tokio::sync::OwnedSemaphorePermit>,
completion: CompletionHandle,
shutdown: CancellationToken,
label: &'static str,
) -> bool {
match completion {
Some(_) => {
tokio::spawn(async move {
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => {
tracing::debug!(
queue = %topic,
label,
"shutdown fired before delayed republish; dropping permit — \
offset stays uncommitted, message will be redelivered on restart"
);
drop(retry_permit);
return;
}
}
match client
.publish_with_retry(
&topic,
key.as_deref(),
headers,
&payload,
MAX_PUBLISH_ATTEMPTS,
label,
)
.await
{
Ok(()) => {
signal_completion(completion, &topic);
}
Err(e) => {
tracing::error!(
error = %e,
label,
"delayed republish failed — leaving offset uncommitted for redelivery"
);
}
}
drop(retry_permit);
});
true
}
None => {
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => {
tracing::debug!(
queue = %topic,
label,
"shutdown fired before FIFO republish; skipping — \
message will be redelivered on restart"
);
drop(retry_permit);
return false;
}
}
let ok = match client
.publish_with_retry(
&topic,
key.as_deref(),
headers,
&payload,
MAX_PUBLISH_ATTEMPTS,
label,
)
.await
{
Ok(()) => true,
Err(e) => {
tracing::error!(
error = %e,
label,
"FIFO delayed republish failed — leaving offset uncommitted for redelivery"
);
false
}
};
drop(retry_permit);
ok
}
}
}
async fn invoke_handler<F>(
fut: F,
timeout: Option<Duration>,
topic: &str,
group: Option<&str>,
) -> Outcome
where
F: std::future::Future<Output = Outcome> + Send,
{
use futures_util::FutureExt;
use std::panic::AssertUnwindSafe;
let _inflight = metrics::InflightGuard::from_refs(topic, group);
let start = std::time::Instant::now();
let safe_fut = AssertUnwindSafe(fut).catch_unwind();
let outcome = match timeout {
Some(duration) => match tokio::time::timeout(duration, safe_fut).await {
Ok(Ok(o)) => o,
Ok(Err(_panic)) => {
tracing::warn!("handler panicked, retrying message");
Outcome::Retry
}
Err(_) => {
tracing::warn!("handler timed out after {duration:?}, retrying");
metrics::record_failed(topic, group, metrics::FailReason::Timeout);
Outcome::Retry
}
},
None => match safe_fut.await {
Ok(o) => o,
Err(_panic) => {
tracing::warn!("handler panicked, retrying message");
Outcome::Retry
}
},
};
let elapsed = start.elapsed().as_secs_f64();
metrics::record_consumed(topic, group, &outcome);
metrics::record_processing_duration(topic, group, &outcome, elapsed);
outcome
}
fn map_kafka_error(context: &str, e: KafkaError) -> ShoveError {
let is_permanent = matches!(
&e,
KafkaError::ClientConfig(..)
| KafkaError::ClientCreation(_)
| KafkaError::MessageConsumptionFatal(_)
| KafkaError::Canceled
| KafkaError::Nul(_)
);
if is_permanent {
ShoveError::Topology(format!("{context}: {e}"))
} else {
ShoveError::Connection(format!("{context}: {e}"))
}
}
pub(super) enum KafkaStreamConsumer {
Default(StreamConsumer),
#[cfg(feature = "kafka-msk-iam")]
MskIam(StreamConsumer<MskIamContext>),
}
impl KafkaStreamConsumer {
pub(super) fn subscribe(&self, topics: &[&str]) -> KafkaResult<()> {
match self {
Self::Default(c) => c.subscribe(topics),
#[cfg(feature = "kafka-msk-iam")]
Self::MskIam(c) => c.subscribe(topics),
}
}
pub(super) async fn recv(&self) -> KafkaResult<BorrowedMessage<'_>> {
match self {
Self::Default(c) => c.recv().await,
#[cfg(feature = "kafka-msk-iam")]
Self::MskIam(c) => c.recv().await,
}
}
pub(super) fn commit(&self, tpl: &TopicPartitionList, mode: CommitMode) -> KafkaResult<()> {
match self {
Self::Default(c) => c.commit(tpl, mode),
#[cfg(feature = "kafka-msk-iam")]
Self::MskIam(c) => c.commit(tpl, mode),
}
}
pub(super) fn commit_message(
&self,
msg: &BorrowedMessage<'_>,
mode: CommitMode,
) -> KafkaResult<()> {
match self {
Self::Default(c) => c.commit_message(msg, mode),
#[cfg(feature = "kafka-msk-iam")]
Self::MskIam(c) => c.commit_message(msg, mode),
}
}
}
fn create_stream_consumer(
mut base: ClientConfig,
group_id: &str,
auto_offset_reset: KafkaAutoOffsetReset,
#[cfg(feature = "kafka-msk-iam")] msk_context: Option<MskIamContext>,
) -> Result<KafkaStreamConsumer> {
let client_id = format!("shove-{}", uuid::Uuid::new_v4().simple());
base.set("group.id", group_id)
.set("client.id", client_id)
.set("partition.assignment.strategy", "cooperative-sticky")
.set("enable.auto.commit", "false")
.set("auto.offset.reset", auto_offset_reset.as_rdkafka_str())
.set("session.timeout.ms", SESSION_TIMEOUT_MS.to_string())
.set("max.poll.interval.ms", MAX_POLL_INTERVAL_MS.to_string())
.set("fetch.min.bytes", FETCH_MIN_BYTES.to_string())
.set("fetch.wait.max.ms", FETCH_WAIT_MAX_MS.to_string());
#[cfg(feature = "kafka-msk-iam")]
if let Some(ctx) = msk_context {
let consumer: StreamConsumer<MskIamContext> = base
.create_with_context(ctx)
.map_err(|e| map_kafka_error("failed to create MSK consumer", e))?;
return Ok(KafkaStreamConsumer::MskIam(consumer));
}
let consumer: StreamConsumer = base
.create()
.map_err(|e| map_kafka_error("failed to create consumer", e))?;
Ok(KafkaStreamConsumer::Default(consumer))
}
async fn run_with_reconnect<F, Fut>(
shutdown: &CancellationToken,
label: &str,
max_reconnect_attempts: Option<u32>,
mut f: F,
) -> Result<()>
where
F: FnMut() -> Fut,
Fut: Future<Output = Result<()>>,
{
let mut backoff = Backoff::default();
let mut attempts = 0u32;
loop {
match f().await {
Ok(()) => return Ok(()),
Err(e) => {
if !e.is_retryable() {
return Err(e);
}
if shutdown.is_cancelled() {
return Ok(());
}
attempts += 1;
if let Some(max) = max_reconnect_attempts
&& attempts >= max
{
tracing::error!(
label,
attempts,
error = %e,
"max reconnect attempts reached, giving up"
);
return Err(ShoveError::Connection(format!(
"consumer on '{label}' exhausted {max} reconnect attempt(s): {e}"
)));
}
let delay = backoff.next().expect("backoff is infinite");
tracing::warn!(
label,
attempt = attempts,
?max_reconnect_attempts,
error = %e,
delay_ms = delay.as_millis() as u64,
"consumer error, reconnecting"
);
tokio::select! {
_ = tokio::time::sleep(delay) => {}
_ = shutdown.cancelled() => return Ok(()),
}
}
}
}
}
#[derive(Clone)]
pub struct KafkaConsumer {
client: KafkaClient,
}
impl KafkaConsumer {
pub fn new(client: KafkaClient) -> Self {
Self { client }
}
}
impl KafkaConsumer {
pub async fn run<T, H>(
&self,
handler: H,
ctx: H::Context,
options: crate::ConsumerOptions<Kafka>,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
self.run_with_inner::<T, H>(handler, ctx, options.into_inner())
.await
}
pub(crate) async fn run_with_inner<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let topology = T::topology();
let queue = topology.queue();
let group_id = options
.kafka_group_id
.as_deref()
.map(str::to_string)
.unwrap_or_else(|| super::constants::consumer_group_id(queue));
let auto_offset_reset = options
.kafka_auto_offset_reset
.unwrap_or(KafkaAutoOffsetReset::Earliest);
let shutdown = options.shutdown.clone();
let processing = options.processing.clone();
let max_retries = options.max_retries;
let prefetch_count = options.prefetch_count;
let handler_timeout = options.handler_timeout;
let max_message_size = options.max_message_size;
let hold_queues = topology.hold_queues();
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let client = self.client.clone();
tracing::info!(
queue,
group_id,
prefetch_count,
max_retries,
"Kafka consumer started"
);
let semaphore = Arc::new(Semaphore::new(prefetch_count as usize));
let topic: Arc<str> = Arc::from(queue);
let group: Option<Arc<str>> = options.consumer_group.clone();
run_with_reconnect(&shutdown, queue, options.max_reconnect_attempts, || {
let handler = handler.clone();
let ctx = ctx.clone();
let client = client.clone();
let processing = processing.clone();
let shutdown = shutdown.clone();
let group_id = group_id.clone();
let semaphore = semaphore.clone();
let topic = topic.clone();
let group = group.clone();
async move {
let consumer = create_stream_consumer(
client.base_config(),
&group_id,
auto_offset_reset,
#[cfg(feature = "kafka-msk-iam")]
client.msk_context(),
)?;
consumer
.subscribe(&[queue])
.map_err(|e| map_kafka_error("failed to subscribe", e))?;
let queue_owned = queue.to_string();
let mut tracker = OffsetTracker::new(queue_owned.clone());
let consumer = Arc::new(consumer);
let (completion_tx, mut completion_rx) =
mpsc::channel::<(i32, i64)>(prefetch_count as usize);
loop {
while let Ok((partition, offset)) = completion_rx.try_recv() {
tracker.mark_complete(partition, offset);
}
if let Some(tpl) = tracker.drain_committable() {
consumer
.commit(&tpl, CommitMode::Async)
.map_err(|e| map_kafka_error("commit failed", e))?;
}
tokio::select! {
_ = shutdown.cancelled() => {
tracing::info!(queue, "shutdown signal received, draining in-flight tasks");
let _ = semaphore.acquire_many(prefetch_count as u32).await;
while let Ok((partition, offset)) = completion_rx.try_recv() {
tracker.mark_complete(partition, offset);
}
if let Some(tpl) = tracker.drain_committable() {
consumer.commit(&tpl, CommitMode::Async).ok();
}
return Ok(());
}
msg_result = consumer.recv() => {
let msg = match msg_result {
Ok(msg) => msg,
Err(e) => {
tracing::error!(error = %e, queue, "consumer recv error");
return Err(map_kafka_error(
&format!("consumer recv error on {queue}"),
e,
));
}
};
let payload_slice = msg.payload().unwrap_or_default();
let headers = extract_string_headers(&msg);
let partition = msg.partition();
let offset = msg.offset();
let key = msg.key().map(Bytes::copy_from_slice);
tracker.track_received(partition, offset);
metrics::record_message_size(&topic, group.as_deref(), payload_slice.len());
if let Err(e) = validate_message_size(payload_slice.len(), max_message_size) {
tracing::warn!(
error = %e,
queue,
"rejecting oversized message to DLQ"
);
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::Oversize,
);
if let Err(dlq_err) = publish_to_dlq(
&client,
topology,
payload_slice,
key.as_deref(),
&headers,
&e.to_string(),
).await {
tracing::error!(
error = %dlq_err,
"failed to publish oversized message to DLQ"
);
}
if completion_tx.try_send((partition, offset)).is_err() {
tracing::error!(partition, offset, "completion channel full — logic bug in offset tracker");
}
continue;
}
let payload: T::Message = match <T::Codec as crate::Codec<T::Message>>::decode(payload_slice) {
Ok(m) => m,
Err(e) => {
tracing::error!(
error = %e,
queue,
"failed to deserialize message, sending to DLQ"
);
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::Deserialize,
);
if let Err(dlq_err) = publish_to_dlq(
&client,
topology,
payload_slice,
key.as_deref(),
&headers,
"deserialization_error",
).await {
tracing::error!(
error = %dlq_err,
"failed to publish bad message to DLQ"
);
}
if completion_tx.try_send((partition, offset)).is_err() {
tracing::error!(partition, offset, "completion channel full — logic bug in offset tracker");
}
continue;
}
};
let payload_bytes = payload_slice.to_vec();
let metadata = build_message_metadata(&headers, false);
let retry_count = metadata.retry_count;
let permit = semaphore.clone().acquire_owned().await.map_err(|_| {
ShoveError::Connection("semaphore closed".to_string())
})?;
let task_client = client.clone();
let task_processing = processing.clone();
let task_semaphore = semaphore.clone();
let task_prefetch = prefetch_count;
let task_tx = completion_tx.clone();
let task_topic = topic.clone();
let task_handler = handler.clone();
let task_ctx = ctx.clone();
let task_group = group.clone();
let task_shutdown = shutdown.clone();
tokio::spawn(async move {
task_processing.store(true, Ordering::Release);
let outcome = invoke_handler(
async move {
task_handler
.handle(payload, metadata, task_ctx.as_ref())
.await
},
handler_timeout,
&task_topic,
task_group.as_deref(),
)
.await;
route_outcome(
&task_client,
&task_topic,
task_group.as_deref(),
&payload_bytes,
key,
&headers,
outcome,
topology,
retry_count,
max_retries,
hold_queues,
Some(permit),
Some((task_tx, partition, offset)),
task_shutdown,
)
.await;
if task_semaphore.available_permits() == task_prefetch as usize {
task_processing.store(false, Ordering::Release);
}
});
}
}
}
}
})
.await
}
pub async fn run_fifo<T, H>(
&self,
handler: H,
ctx: H::Context,
options: crate::ConsumerOptions<Kafka>,
) -> Result<()>
where
T: SequencedTopic,
H: MessageHandler<T>,
{
self.run_fifo_with_inner::<T, H>(handler, ctx, options.into_inner())
.await
}
pub(crate) async fn run_fifo_with_inner<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> Result<()>
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let handles = self.spawn_fifo_shards::<T, H>(handler, ctx, options)?;
for handle in handles {
match handle.await {
Ok(Ok(())) => {}
Ok(Err(e)) => tracing::error!("Kafka FIFO consumer task failed: {e}"),
Err(e) => tracing::error!("Kafka FIFO consumer task panicked: {e}"),
}
}
Ok(())
}
pub(crate) fn spawn_fifo_shards<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> Result<Vec<tokio::task::JoinHandle<Result<()>>>>
where
T: SequencedTopic,
H: MessageHandler<T>,
{
let topology = T::topology();
let queue = topology.queue().to_string();
let _seq_config = topology.sequencing().ok_or_else(|| {
ShoveError::Topology(format!(
"run_fifo called on {queue} without sequencing config"
))
})?;
let shutdown = options.shutdown.clone();
let processing = options.processing.clone();
let max_retries = options.max_retries;
let handler_timeout = options.handler_timeout;
let max_message_size = options.max_message_size;
let hold_queues = topology.hold_queues();
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let client = self.client.clone();
let group_id = format!("{queue}-fifo");
let auto_offset_reset = options
.kafka_auto_offset_reset
.unwrap_or(KafkaAutoOffsetReset::Earliest);
let topic: Arc<str> = Arc::from(queue.as_str());
let group: Option<Arc<str>> = options.consumer_group.clone();
tracing::info!(queue, group_id, max_retries, "Kafka FIFO consumer started");
let shard_task = tokio::spawn(async move {
run_with_reconnect(&shutdown, &queue, options.max_reconnect_attempts, || {
let handler = handler.clone();
let ctx = ctx.clone();
let client = client.clone();
let shutdown = shutdown.clone();
let processing = processing.clone();
let group_id = group_id.clone();
let queue = queue.clone();
let topic = topic.clone();
let group = group.clone();
async move {
let consumer = create_stream_consumer(
client.base_config(),
&group_id,
auto_offset_reset,
#[cfg(feature = "kafka-msk-iam")]
client.msk_context(),
)?;
consumer
.subscribe(&[queue.as_str()])
.map_err(|e| map_kafka_error("failed to subscribe", e))?;
loop {
tokio::select! {
_ = shutdown.cancelled() => {
tracing::info!(queue, "shutdown signal received, stopping FIFO consumer");
return Ok(());
}
msg_result = consumer.recv() => {
let msg = match msg_result {
Ok(msg) => msg,
Err(e) => {
tracing::error!(error = %e, queue, "FIFO consumer recv error");
return Err(map_kafka_error(
&format!("FIFO consumer recv error on {queue}"),
e,
));
}
};
let payload_bytes = msg.payload().unwrap_or_default();
let headers = extract_string_headers(&msg);
let key = msg.key().map(Bytes::copy_from_slice);
metrics::record_message_size(&topic, group.as_deref(), payload_bytes.len());
if let Err(e) = validate_message_size(payload_bytes.len(), max_message_size) {
tracing::warn!(
error = %e,
queue,
"rejecting oversized FIFO message to DLQ"
);
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::Oversize,
);
if let Err(dlq_err) = publish_to_dlq(
&client,
topology,
payload_bytes,
key.as_deref(),
&headers,
&e.to_string(),
).await {
tracing::error!(
error = %dlq_err,
"failed to publish oversized message to DLQ"
);
}
consumer.commit_message(&msg, CommitMode::Async).ok();
continue;
}
let payload: T::Message = match <T::Codec as crate::Codec<T::Message>>::decode(payload_bytes) {
Ok(m) => m,
Err(e) => {
tracing::error!(
error = %e,
queue,
"failed to deserialize FIFO message, sending to DLQ"
);
metrics::record_failed(
&topic,
group.as_deref(),
metrics::FailReason::Deserialize,
);
if let Err(dlq_err) = publish_to_dlq(
&client,
topology,
payload_bytes,
key.as_deref(),
&headers,
"deserialization_error",
).await {
tracing::error!(
error = %dlq_err,
"failed to publish bad message to DLQ"
);
}
consumer.commit_message(&msg, CommitMode::Async).ok();
continue;
}
};
let metadata = build_message_metadata(&headers, false);
let retry_count = metadata.retry_count;
processing.store(true, Ordering::Release);
let handler_clone = handler.clone();
let ctx_clone = ctx.clone();
let outcome = invoke_handler(
async move {
handler_clone
.handle(payload, metadata, ctx_clone.as_ref())
.await
},
handler_timeout,
&topic,
group.as_deref(),
)
.await;
let outcome = adjust_outcome_for_fifo(outcome);
let route_ok = route_outcome(
&client,
&queue,
group.as_deref(),
payload_bytes,
key,
&headers,
outcome,
topology,
retry_count,
max_retries,
hold_queues,
None,
None,
shutdown.clone(),
)
.await;
if route_ok {
consumer.commit_message(&msg, CommitMode::Async).ok();
}
processing.store(false, Ordering::Release);
}
}
}
}
})
.await
});
Ok(vec![shard_task])
}
pub async fn run_fifo_until_timeout<T, H, S>(
&self,
handler: H,
ctx: H::Context,
options: crate::ConsumerOptions<Kafka>,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
T: SequencedTopic,
H: MessageHandler<T>,
S: Future<Output = ()> + Send + 'static,
{
self.run_fifo_until_timeout_with_inner::<T, H, S>(
handler,
ctx,
options.into_inner(),
signal,
drain_timeout,
)
.await
}
pub(crate) async fn run_fifo_until_timeout_with_inner<T, H, S>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
signal: S,
drain_timeout: Duration,
) -> SupervisorOutcome
where
T: SequencedTopic,
H: MessageHandler<T>,
S: Future<Output = ()> + Send + 'static,
{
let shutdown = options.shutdown.clone();
let handles = match self.spawn_fifo_shards::<T, H>(handler, ctx, options) {
Ok(h) => h,
Err(e) => {
tracing::error!(error = %e, "run_fifo_until_timeout: shard spawn failed");
return SupervisorOutcome {
errors: 1,
panics: 0,
timed_out: false,
};
}
};
drive_fifo_until_timeout(handles, shutdown, signal, drain_timeout).await
}
pub async fn run_dlq<T, H>(&self, handler: H, ctx: H::Context) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let options = crate::ConsumerOptions::<Kafka>::new().into_inner();
self.run_dlq_with_inner::<T, H>(handler, ctx, options).await
}
pub(crate) async fn run_dlq_with_inner<T, H>(
&self,
handler: H,
ctx: H::Context,
options: ConsumerOptions,
) -> Result<()>
where
T: Topic,
H: MessageHandler<T>,
{
let topology = T::topology();
let dlq = topology.dlq().ok_or_else(|| {
ShoveError::Topology("run_dlq requires a DLQ to be configured".into())
})?;
let dlq_group_id = super::constants::dlq_consumer_group_id(dlq);
let shutdown = self.client.shutdown_token();
let handler = Arc::new(handler);
let ctx = Arc::new(ctx);
let client = self.client.clone();
let max_message_size = options.max_message_size;
tracing::info!(dlq, group_id = dlq_group_id, "Kafka DLQ consumer started");
run_with_reconnect(&shutdown, dlq, None, || {
let handler = handler.clone();
let ctx = ctx.clone();
let client_clone = client.clone();
let shutdown = shutdown.clone();
let dlq_group_id = dlq_group_id.clone();
async move {
let consumer = create_stream_consumer(
client_clone.base_config(),
&dlq_group_id,
KafkaAutoOffsetReset::Earliest,
#[cfg(feature = "kafka-msk-iam")]
client_clone.msk_context(),
)?;
consumer
.subscribe(&[dlq])
.map_err(|e| map_kafka_error("failed to subscribe to DLQ", e))?;
loop {
tokio::select! {
_ = shutdown.cancelled() => {
tracing::info!(dlq, "shutdown signal received, stopping DLQ consumer");
return Ok(());
}
msg_result = consumer.recv() => {
let msg = match msg_result {
Ok(msg) => msg,
Err(e) => {
tracing::error!(error = %e, dlq, "DLQ consumer recv error");
return Err(map_kafka_error(
&format!("DLQ consumer recv error on {dlq}"),
e,
));
}
};
let payload_bytes = msg.payload().unwrap_or_default();
let headers = extract_string_headers(&msg);
if let Some(max) = max_message_size
&& payload_bytes.len() > max
{
tracing::warn!(
bytes = payload_bytes.len(),
max,
dlq,
"oversized DLQ message — discarding"
);
consumer.commit_message(&msg, CommitMode::Async).ok();
continue;
}
let payload: T::Message = match <T::Codec as crate::Codec<T::Message>>::decode(payload_bytes) {
Ok(m) => m,
Err(e) => {
tracing::error!(
error = %e,
dlq,
"failed to deserialize DLQ message, acking anyway"
);
consumer.commit_message(&msg, CommitMode::Async).ok();
continue;
}
};
let metadata = build_dead_metadata(&headers);
handler.handle_dead(payload, metadata, ctx.as_ref()).await;
if let Err(e) = consumer.commit_message(&msg, CommitMode::Async) {
tracing::error!(error = %e, dlq, "failed to commit DLQ message");
}
}
}
}
}
})
.await
}
}