use arkflow_core::input::{register_input_builder, Ack, Input, InputBuilder};
use arkflow_core::{Error, MessageBatch};
use async_trait::async_trait;
use rdkafka::config::ClientConfig;
use rdkafka::consumer::{Consumer, StreamConsumer};
use rdkafka::message::Message as KafkaMessage;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KafkaInputConfig {
pub brokers: Vec<String>,
pub topics: Vec<String>,
pub consumer_group: String,
pub client_id: Option<String>,
pub start_from_latest: bool,
}
pub struct KafkaInput {
config: KafkaInputConfig,
consumer: Arc<RwLock<Option<StreamConsumer>>>,
}
impl KafkaInput {
pub fn new(config: KafkaInputConfig) -> Result<Self, Error> {
Ok(Self {
config,
consumer: Arc::new(RwLock::new(None)),
})
}
}
#[async_trait]
impl Input for KafkaInput {
async fn connect(&self) -> Result<(), Error> {
let mut client_config = ClientConfig::new();
client_config.set("bootstrap.servers", &self.config.brokers.join(","));
client_config.set("group.id", &self.config.consumer_group);
if let Some(client_id) = &self.config.client_id {
client_config.set("client.id", client_id);
}
if self.config.start_from_latest {
client_config.set("auto.offset.reset", "latest");
} else {
client_config.set("auto.offset.reset", "earliest");
}
let consumer: StreamConsumer = client_config
.create()
.map_err(|e| Error::Connection(format!("Unable to create a Kafka consumer: {}", e)))?;
let x: Vec<&str> = self
.config
.topics
.iter()
.map(|topic| topic.as_str())
.collect();
consumer.subscribe(&x).map_err(|e| {
Error::Connection(format!("You cannot subscribe to a Kafka topic: {}", e))
})?;
let consumer_arc = self.consumer.clone();
let mut consumer_guard = consumer_arc.write().await;
*consumer_guard = Some(consumer);
Ok(())
}
async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
let consumer_arc = self.consumer.clone();
let consumer_guard = consumer_arc.read().await;
if consumer_guard.is_none() {
return Err(Error::Connection("The input is not connected".to_string()));
}
let consumer = consumer_guard.as_ref().unwrap();
match consumer.recv().await {
Ok(kafka_message) => {
let payload = kafka_message.payload().ok_or_else(|| {
Error::Process("The Kafka message has no content".to_string())
})?;
let mut binary_data = Vec::new();
binary_data.push(payload.to_vec());
let msg_batch = MessageBatch::new_binary(binary_data);
let topic = kafka_message.topic().to_string();
let partition = kafka_message.partition();
let offset = kafka_message.offset();
let ack = KafkaAck {
consumer: self.consumer.clone(),
topic,
partition,
offset,
};
Ok((msg_batch, Arc::new(ack)))
}
Err(e) => Err(Error::Connection(format!(
"Error receiving Kafka message: {}",
e
))),
}
}
async fn close(&self) -> Result<(), Error> {
let mut consumer_guard = self.consumer.write().await;
if let Some(consumer) = consumer_guard.take() {
if let Err(e) = consumer.unassign() {
tracing::warn!("Error unassigning Kafka consumer: {}", e);
}
}
Ok(())
}
}
pub struct KafkaAck {
consumer: Arc<RwLock<Option<StreamConsumer>>>,
topic: String,
partition: i32,
offset: i64,
}
#[async_trait]
impl Ack for KafkaAck {
async fn ack(&self) {
let consumer_mutex_guard = self.consumer.read().await;
if let Some(v) = &*consumer_mutex_guard {
if let Err(e) = v.store_offset(&self.topic, self.partition, self.offset) {
tracing::error!("Error committing Kafka offset: {}", e);
}
}
}
}
pub(crate) struct KafkaInputBuilder;
impl InputBuilder for KafkaInputBuilder {
fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Input>, Error> {
if config.is_none() {
return Err(Error::Config(
"Kafka input configuration is missing".to_string(),
));
}
let config: KafkaInputConfig = serde_json::from_value(config.clone().unwrap())?;
Ok(Arc::new(KafkaInput::new(config)?))
}
}
pub fn init() {
register_input_builder("kafka", Arc::new(KafkaInputBuilder));
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_kafka_input_new() {
let config = KafkaInputConfig {
brokers: vec!["localhost:9092".to_string()],
topics: vec!["test-topic".to_string()],
consumer_group: "test-group".to_string(),
client_id: Some("test-client".to_string()),
start_from_latest: false,
};
let input = KafkaInput::new(config);
assert!(input.is_ok());
let input = input.unwrap();
assert_eq!(input.config.brokers, vec!["localhost:9092".to_string()]);
assert_eq!(input.config.topics, vec!["test-topic".to_string()]);
assert_eq!(input.config.consumer_group, "test-group".to_string());
assert_eq!(input.config.client_id, Some("test-client".to_string()));
assert_eq!(input.config.start_from_latest, false);
}
#[tokio::test]
async fn test_kafka_input_read_not_connected() {
let config = KafkaInputConfig {
brokers: vec!["localhost:9092".to_string()],
topics: vec!["test-topic".to_string()],
consumer_group: "test-group".to_string(),
client_id: None,
start_from_latest: true,
};
let input = KafkaInput::new(config).unwrap();
let result = input.read().await;
assert!(result.is_err());
match result {
Err(Error::Connection(msg)) => {
assert_eq!(msg, "The input is not connected");
}
_ => panic!("Expected Connection error"),
}
}
#[tokio::test]
async fn test_kafka_ack() {
let config = KafkaInputConfig {
brokers: vec!["localhost:9092".to_string()],
topics: vec!["test-topic".to_string()],
consumer_group: "test-group".to_string(),
client_id: None,
start_from_latest: true,
};
let input = KafkaInput::new(config).unwrap();
let ack = KafkaAck {
consumer: input.consumer.clone(),
topic: "test-topic".to_string(),
partition: 0,
offset: 100,
};
ack.ack().await;
}
}