use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use rdkafka::ClientConfig;
use rdkafka::Message as _;
use rdkafka::TopicPartitionList;
use rdkafka::consumer::{Consumer, StreamConsumer};
use tokio::sync::{RwLock, watch};
use crate::config::KafkaIngestConfig;
use crate::errors::OrionError;
use crate::kafka::producer::KafkaProducer;
use crate::metrics;
struct ConsumeLoopContext {
consumer: Arc<StreamConsumer>,
topic_map: HashMap<String, String>,
engine: Arc<RwLock<Arc<dataflow_rs::Engine>>>,
dlq_producer: Option<Arc<KafkaProducer>>,
dlq_topic: Option<String>,
processing_timeout_ms: u64,
max_inflight: usize,
lag_poll_interval_secs: u64,
}
use rdkafka::message::Headers;
pub struct ConsumerHandle {
shutdown_tx: watch::Sender<bool>,
join_handle: tokio::task::JoinHandle<()>,
consumer: Arc<StreamConsumer>,
topics: HashSet<String>,
}
impl ConsumerHandle {
pub async fn shutdown(self) {
if let Err(e) = self.shutdown_tx.send(true) {
tracing::error!(error = %e, "Failed to send Kafka consumer shutdown signal");
}
if let Err(e) = self.join_handle.await {
tracing::error!(error = %e, "Kafka consumer task panicked during shutdown");
}
}
pub fn pause(&self) -> Result<(), OrionError> {
let assignment = self
.consumer
.assignment()
.map_err(|e| OrionError::Internal(format!("Failed to get consumer assignment: {e}")))?;
if assignment.count() == 0 {
return Ok(());
}
self.consumer.pause(&assignment).map_err(|e| {
OrionError::Internal(format!("Failed to pause consumer partitions: {e}"))
})?;
Ok(())
}
pub fn resume(&self) -> Result<(), OrionError> {
let assignment = self
.consumer
.assignment()
.map_err(|e| OrionError::Internal(format!("Failed to get consumer assignment: {e}")))?;
if assignment.count() == 0 {
return Ok(());
}
self.consumer.resume(&assignment).map_err(|e| {
OrionError::Internal(format!("Failed to resume consumer partitions: {e}"))
})?;
Ok(())
}
pub fn topics(&self) -> &HashSet<String> {
&self.topics
}
}
pub fn start_consumer(
config: &KafkaIngestConfig,
engine: Arc<RwLock<Arc<dataflow_rs::Engine>>>,
dlq_producer: Option<Arc<KafkaProducer>>,
dlq_topic: Option<String>,
) -> Result<ConsumerHandle, OrionError> {
let consumer: StreamConsumer = ClientConfig::new()
.set("bootstrap.servers", config.brokers.join(","))
.set("group.id", &config.group_id)
.set("enable.auto.commit", "false")
.set("auto.offset.reset", "earliest")
.create()
.map_err(|e| OrionError::InternalSource {
context: "Failed to create Kafka consumer".to_string(),
source: Box::new(e),
})?;
match consumer.fetch_metadata(None, std::time::Duration::from_secs(5)) {
Ok(metadata) => {
tracing::info!(
brokers = metadata.brokers().len(),
topics = metadata.topics().len(),
"Kafka broker connectivity verified"
);
}
Err(e) => {
tracing::warn!(
error = %e,
"Kafka broker connectivity check failed — consumer will retry on its own"
);
}
}
let topic_map: HashMap<String, String> = config
.topics
.iter()
.map(|t| (t.topic.clone(), t.channel.clone()))
.collect();
let topics: Vec<&str> = config.topics.iter().map(|t| t.topic.as_str()).collect();
consumer
.subscribe(&topics)
.map_err(|e| OrionError::InternalSource {
context: "Failed to subscribe to Kafka topics".to_string(),
source: Box::new(e),
})?;
let (shutdown_tx, shutdown_rx) = watch::channel(false);
let processing_timeout_ms = config.processing_timeout_ms;
let max_inflight = config.max_inflight;
let lag_poll_interval_secs = config.lag_poll_interval_secs;
let consumer = Arc::new(consumer);
let topic_set: HashSet<String> = config.topics.iter().map(|t| t.topic.clone()).collect();
let ctx = ConsumeLoopContext {
consumer: consumer.clone(),
topic_map,
engine,
dlq_producer,
dlq_topic,
processing_timeout_ms,
max_inflight,
lag_poll_interval_secs,
};
let handle = tokio::spawn(consume_loop(ctx, shutdown_rx));
Ok(ConsumerHandle {
shutdown_tx,
join_handle: handle,
consumer,
topics: topic_set,
})
}
async fn process_one_kafka_message(
ctx: &ConsumeLoopContext,
msg: &rdkafka::message::BorrowedMessage<'_>,
) {
let topic = msg.topic().to_string();
let channel = match ctx.topic_map.get(&topic) {
Some(ch) => ch.clone(),
None => {
tracing::warn!(topic = %topic, "No channel mapping for topic, skipping");
return;
}
};
let payload = match msg.payload_view::<str>() {
Some(Ok(text)) => text,
Some(Err(e)) => {
tracing::warn!(
topic = %topic,
error = %e,
"Failed to decode Kafka message payload as UTF-8, skipping"
);
return;
}
None => {
tracing::warn!(topic = %topic, "Empty Kafka message, skipping");
return;
}
};
let data: serde_json::Value = match serde_json::from_str(payload) {
Ok(v) => v,
Err(e) => {
tracing::warn!(
topic = %topic,
error = %e,
"Failed to parse Kafka message as JSON, skipping"
);
send_to_dlq(
&ctx.dlq_producer,
&ctx.dlq_topic,
&topic,
payload.as_bytes(),
&format!("JSON parse error: {e}"),
)
.await;
return;
}
};
let _parent_cx = extract_kafka_trace_context(msg);
let start = Instant::now();
let mut message = dataflow_rs::Message::from_value(&data);
inject_kafka_metadata(&mut message, &topic, msg);
let engine_ref = crate::engine::acquire_engine_read(&ctx.engine).await;
let process_result = tokio::time::timeout(
std::time::Duration::from_millis(ctx.processing_timeout_ms),
engine_ref.process_message_for_channel(&channel, &mut message),
)
.await;
match process_result {
Err(_) => {
report_failure_and_dlq(
ctx,
FailureReport {
channel: &channel,
topic: &topic,
payload: payload.as_bytes(),
message_status: "timeout",
error_kind: "kafka_timeout",
log_msg: "Kafka message processing timed out",
dlq_reason: &format!(
"Processing timed out after {}ms",
ctx.processing_timeout_ms
),
},
)
.await;
}
Ok(Err(e)) => {
report_failure_and_dlq(
ctx,
FailureReport {
channel: &channel,
topic: &topic,
payload: payload.as_bytes(),
message_status: "error",
error_kind: "kafka_processing",
log_msg: "Failed to process Kafka message",
dlq_reason: &format!("Processing error: {e}"),
},
)
.await;
}
Ok(Ok(())) if message.has_errors() => {
let summary = message
.errors()
.iter()
.map(|e| format!("{}: {}", e.code, e.message))
.collect::<Vec<_>>()
.join("; ");
report_failure_and_dlq(
ctx,
FailureReport {
channel: &channel,
topic: &topic,
payload: payload.as_bytes(),
message_status: "error",
error_kind: "kafka_processing",
log_msg: "Kafka message processed with workflow errors",
dlq_reason: &format!("Workflow errors: {summary}"),
},
)
.await;
}
Ok(Ok(())) => {
let duration = start.elapsed().as_secs_f64();
metrics::record_message(&channel, "ok");
metrics::record_message_duration(&channel, duration);
tracing::debug!(
topic = %topic,
channel = %channel,
"Kafka message processed successfully"
);
}
}
}
fn extract_kafka_trace_context(
msg: &rdkafka::message::BorrowedMessage<'_>,
) -> opentelemetry::Context {
use opentelemetry::propagation::TextMapPropagator;
use opentelemetry_sdk::propagation::TraceContextPropagator;
use tracing_opentelemetry::OpenTelemetrySpanExt;
struct KafkaHeaderExtractor(HashMap<String, String>);
impl opentelemetry::propagation::Extractor for KafkaHeaderExtractor {
fn get(&self, key: &str) -> Option<&str> {
self.0.get(key).map(|v| v.as_str())
}
fn keys(&self) -> Vec<&str> {
self.0.keys().map(|k| k.as_str()).collect()
}
}
let mut header_map = HashMap::new();
if let Some(headers) = msg.headers() {
for idx in 0..headers.count() {
if let Ok(header) = headers.get_as::<str>(idx)
&& let Some(value) = header.value
{
header_map.insert(header.key.to_string(), value.to_string());
}
}
}
let propagator = TraceContextPropagator::new();
let cx = propagator.extract(&KafkaHeaderExtractor(header_map));
let _ = tracing::Span::current().set_parent(cx.clone());
cx
}
async fn consume_loop(ctx: ConsumeLoopContext, mut shutdown_rx: watch::Receiver<bool>) {
let backpressure = Arc::new(tokio::sync::Semaphore::new(ctx.max_inflight));
let lag_handle = if ctx.lag_poll_interval_secs > 0 {
let lag_consumer = ctx.consumer.clone();
let lag_shutdown = shutdown_rx.clone();
Some(tokio::spawn(poll_consumer_lag(
lag_consumer,
lag_shutdown,
ctx.lag_poll_interval_secs,
)))
} else {
None
};
tracing::info!(
topics = ?ctx.topic_map.keys().collect::<Vec<_>>(),
max_inflight = ctx.max_inflight,
lag_poll_secs = ctx.lag_poll_interval_secs,
"Kafka consumer started"
);
loop {
let _permit = match backpressure.clone().acquire_owned().await {
Ok(p) => p,
Err(_) => break, };
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() {
tracing::info!("Kafka consumer shutting down");
break;
}
}
msg_result = ctx.consumer.recv() => {
match msg_result {
Ok(msg) => {
process_one_kafka_message(&ctx, &msg).await;
commit_offset(&ctx.consumer, &msg);
}
Err(e) => {
tracing::error!(error = %e, "Kafka consumer error");
}
}
}
}
}
if let Some(handle) = lag_handle {
handle.abort();
}
tracing::info!("Kafka consumer stopped");
}
fn commit_offset(consumer: &StreamConsumer, msg: &rdkafka::message::BorrowedMessage<'_>) {
use rdkafka::consumer::CommitMode;
if let Err(e) = consumer.commit_message(msg, CommitMode::Async) {
tracing::error!(error = %e, "Failed to commit Kafka offset");
}
}
fn inject_kafka_metadata(
message: &mut dataflow_rs::Message,
topic: &str,
msg: &rdkafka::message::BorrowedMessage<'_>,
) {
use dataflow_rs::engine::utils::set_nested_value;
use datavalue::OwnedDataValue;
use rdkafka::Message as KafkaMsg;
set_nested_value(
&mut message.context,
"metadata.kafka_topic",
OwnedDataValue::from(topic.to_string()),
);
if let Some(key) = msg.key().and_then(|k| std::str::from_utf8(k).ok()) {
set_nested_value(
&mut message.context,
"metadata.kafka_key",
OwnedDataValue::from(key.to_string()),
);
}
set_nested_value(
&mut message.context,
"metadata.kafka_partition",
OwnedDataValue::from_i64(msg.partition() as i64),
);
set_nested_value(
&mut message.context,
"metadata.kafka_offset",
OwnedDataValue::from_i64(msg.offset()),
);
}
async fn poll_consumer_lag(
consumer: Arc<StreamConsumer>,
mut shutdown_rx: watch::Receiver<bool>,
interval_secs: u64,
) {
let mut interval = tokio::time::interval(std::time::Duration::from_secs(interval_secs));
interval.tick().await;
loop {
tokio::select! {
_ = shutdown_rx.changed() => {
if *shutdown_rx.borrow() { break; }
}
_ = interval.tick() => {
let consumer = consumer.clone();
let _ = tokio::task::spawn_blocking(move || {
let committed = match consumer.committed(std::time::Duration::from_secs(5)) {
Ok(tpl) => tpl,
Err(e) => {
tracing::debug!(error = %e, "Failed to fetch committed offsets for lag metric");
return;
}
};
report_lag_for_partitions(&consumer, &committed);
}).await;
}
}
}
}
fn report_lag_for_partitions(consumer: &StreamConsumer, committed: &TopicPartitionList) {
for elem in committed.elements() {
let topic = elem.topic();
let partition = elem.partition();
let committed_offset = match elem.offset() {
rdkafka::Offset::Offset(n) => n,
rdkafka::Offset::Invalid | rdkafka::Offset::Beginning => 0,
_ => continue, };
match consumer.fetch_watermarks(topic, partition, std::time::Duration::from_secs(5)) {
Ok((_low, high)) => {
let lag = (high - committed_offset).max(0);
metrics::set_kafka_consumer_lag(topic, partition, lag as f64);
}
Err(e) => {
tracing::debug!(
topic = %topic,
partition = partition,
error = %e,
"Failed to fetch watermarks for lag metric"
);
}
}
}
}
struct FailureReport<'a> {
channel: &'a str,
topic: &'a str,
payload: &'a [u8],
message_status: &'static str,
error_kind: &'static str,
log_msg: &'a str,
dlq_reason: &'a str,
}
async fn report_failure_and_dlq(ctx: &ConsumeLoopContext, failure: FailureReport<'_>) {
metrics::record_message(failure.channel, failure.message_status);
metrics::record_error(failure.error_kind);
tracing::error!(
topic = %failure.topic,
channel = %failure.channel,
error = %failure.dlq_reason,
"{}",
failure.log_msg
);
send_to_dlq(
&ctx.dlq_producer,
&ctx.dlq_topic,
failure.topic,
failure.payload,
failure.dlq_reason,
)
.await;
}
fn build_dlq_message(source_topic: &str, payload: &[u8], error: &str) -> serde_json::Value {
serde_json::json!({
"source_topic": source_topic,
"error": error,
"original_payload": String::from_utf8_lossy(payload),
"timestamp": chrono::Utc::now().to_rfc3339(),
})
}
async fn send_to_dlq(
producer: &Option<Arc<KafkaProducer>>,
dlq_topic: &Option<String>,
source_topic: &str,
payload: &[u8],
error: &str,
) {
if let (Some(producer), Some(topic)) = (producer, dlq_topic) {
let dlq_message = build_dlq_message(source_topic, payload, error);
let dlq_payload =
serde_json::to_string(&dlq_message).expect("DLQ envelope is always serialisable");
if let Err(e) = producer
.send(topic, Some(source_topic), dlq_payload.as_bytes())
.await
{
tracing::error!(
dlq_topic = %topic,
error = %e,
"Failed to send message to DLQ"
);
} else {
tracing::debug!(
dlq_topic = %topic,
source_topic = %source_topic,
"Message sent to DLQ"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_topic_map_construction() {
let config = crate::config::KafkaIngestConfig {
enabled: true,
brokers: vec!["localhost:9092".into()],
group_id: "test".into(),
topics: vec![
crate::config::TopicMapping {
topic: "orders".into(),
channel: "order-channel".into(),
},
crate::config::TopicMapping {
topic: "events".into(),
channel: "event-channel".into(),
},
],
dlq: crate::config::DlqConfig::default(),
processing_timeout_ms: 60_000,
max_inflight: 10,
lag_poll_interval_secs: 30,
};
let topic_map: HashMap<String, String> = config
.topics
.iter()
.map(|t| (t.topic.clone(), t.channel.clone()))
.collect();
assert_eq!(topic_map.len(), 2);
assert_eq!(topic_map.get("orders").expect("test"), "order-channel");
assert_eq!(topic_map.get("events").expect("test"), "event-channel");
assert!(!topic_map.contains_key("unknown"));
}
#[test]
fn test_dlq_message_format() {
let payload = br#"{"data": {"broken": true}}"#;
let msg = build_dlq_message("test-topic", payload, "JSON parse error");
assert_eq!(msg["source_topic"], "test-topic");
assert_eq!(msg["error"], "JSON parse error");
assert_eq!(msg["original_payload"], r#"{"data": {"broken": true}}"#);
let ts = msg["timestamp"].as_str().expect("test");
assert!(ts.contains("T"));
assert!(ts.ends_with('Z') || ts.contains('+'));
}
#[test]
fn test_dlq_message_invalid_utf8_payload() {
let payload: &[u8] = &[0xFF, 0xFE, 0xFD];
let msg = build_dlq_message("bad-topic", payload, "UTF-8 decode error");
assert_eq!(msg["source_topic"], "bad-topic");
assert_eq!(msg["error"], "UTF-8 decode error");
let original = msg["original_payload"].as_str().expect("test");
assert!(original.contains('\u{FFFD}'));
}
#[test]
fn test_dlq_message_empty_payload() {
let msg = build_dlq_message("topic", b"", "empty message");
assert_eq!(msg["source_topic"], "topic");
assert_eq!(msg["error"], "empty message");
assert_eq!(msg["original_payload"], "");
}
}