use serde::{Deserialize, Serialize};
use arkflow_core::output::{register_output_builder, Output, OutputBuilder};
use arkflow_core::{Content, Error, MessageBatch};
use async_trait::async_trait;
use rdkafka::config::ClientConfig;
use rdkafka::error::KafkaResult;
use rdkafka::message::ToBytes;
use rdkafka::producer::future_producer::OwnedDeliveryResult;
use rdkafka::producer::{FutureProducer, FutureRecord, Producer};
use rdkafka::util::Timeout;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum CompressionType {
None,
Gzip,
Snappy,
Lz4,
}
impl std::fmt::Display for CompressionType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CompressionType::None => write!(f, "none"),
CompressionType::Gzip => write!(f, "gzip"),
CompressionType::Snappy => write!(f, "snappy"),
CompressionType::Lz4 => write!(f, "lz4"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KafkaOutputConfig {
pub brokers: Vec<String>,
pub topic: String,
pub key: Option<String>,
pub client_id: Option<String>,
pub compression: Option<CompressionType>,
pub acks: Option<String>,
}
struct KafkaOutput<T> {
config: KafkaOutputConfig,
producer: Arc<RwLock<Option<T>>>,
}
impl<T: KafkaClient> KafkaOutput<T> {
pub fn new(config: KafkaOutputConfig) -> Result<Self, Error> {
Ok(Self {
config,
producer: Arc::new(RwLock::new(None)),
})
}
}
#[async_trait]
impl<T: KafkaClient> Output for KafkaOutput<T> {
async fn connect(&self) -> Result<(), Error> {
let mut client_config = ClientConfig::new();
client_config.set("bootstrap.servers", &self.config.brokers.join(","));
if let Some(client_id) = &self.config.client_id {
client_config.set("client.id", client_id);
}
if let Some(compression) = &self.config.compression {
client_config.set("compression.type", compression.to_string().to_lowercase());
}
if let Some(acks) = &self.config.acks {
client_config.set("acks", acks);
}
let producer = T::create(&client_config)
.map_err(|e| Error::Connection(format!("A Kafka producer cannot be created: {}", e)))?;
let producer_arc = self.producer.clone();
let mut producer_guard = producer_arc.write().await;
*producer_guard = Some(producer);
Ok(())
}
async fn write(&self, msg: &MessageBatch) -> Result<(), Error> {
let producer_arc = self.producer.clone();
let producer_guard = producer_arc.read().await;
let producer = producer_guard.as_ref().ok_or_else(|| {
Error::Connection("The Kafka producer is not initialized".to_string())
})?;
let payloads = msg.as_string()?;
if payloads.is_empty() {
return Ok(());
}
match &msg.content {
Content::Arrow(_) => {
return Err(Error::Process(
"The arrow format is not supported".to_string(),
))
}
Content::Binary(v) => {
for x in v {
let mut record = FutureRecord::to(&self.config.topic).payload(&x);
if let Some(key) = &self.config.key {
record = record.key(key);
}
producer
.send(record, Duration::from_secs(5))
.await
.map_err(|(e, _)| {
Error::Process(format!("Failed to send a Kafka message: {}", e))
})?;
}
}
}
Ok(())
}
async fn close(&self) -> Result<(), Error> {
let producer_arc = self.producer.clone();
let mut producer_guard = producer_arc.write().await;
if let Some(producer) = producer_guard.take() {
producer.flush(Duration::from_secs(30)).map_err(|e| {
Error::Connection(format!(
"Failed to refresh the message when the Kafka producer is disabled: {}",
e
))
})?;
}
Ok(())
}
}
pub(crate) struct KafkaOutputBuilder;
impl OutputBuilder for KafkaOutputBuilder {
fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Output>, Error> {
if config.is_none() {
return Err(Error::Config(
"HTTP output configuration is missing".to_string(),
));
}
let config: KafkaOutputConfig = serde_json::from_value(config.clone().unwrap())?;
Ok(Arc::new(KafkaOutput::<FutureProducer>::new(config)?))
}
}
pub fn init() {
register_output_builder("kafka", Arc::new(KafkaOutputBuilder));
}
#[async_trait]
trait KafkaClient: Send + Sync {
fn create(config: &ClientConfig) -> KafkaResult<Self>
where
Self: Sized;
async fn send<K, P, T>(
&self,
record: FutureRecord<'_, K, P>,
queue_timeout: T,
) -> OwnedDeliveryResult
where
K: ToBytes + ?Sized + Sync,
P: ToBytes + ?Sized + Sync,
T: Into<Timeout> + Sync + Send;
fn flush<T: Into<Timeout>>(&self, timeout: T) -> KafkaResult<()>;
}
#[async_trait]
impl KafkaClient for FutureProducer {
fn create(config: &ClientConfig) -> KafkaResult<Self> {
config.create()
}
async fn send<K, P, T>(
&self,
record: FutureRecord<'_, K, P>,
queue_timeout: T,
) -> OwnedDeliveryResult
where
K: ToBytes + ?Sized + Sync,
P: ToBytes + ?Sized + Sync,
T: Into<Timeout> + Sync + Send,
{
FutureProducer::send(self, record, queue_timeout).await
}
fn flush<T: Into<Timeout>>(&self, timeout: T) -> KafkaResult<()> {
Producer::flush(self, timeout)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rdkafka::Timestamp;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Mutex;
struct MockKafkaClient {
connected: Arc<AtomicBool>,
sent_messages: Arc<Mutex<Vec<(String, Vec<u8>, Option<String>)>>>,
should_fail: Arc<AtomicBool>,
}
impl MockKafkaClient {
fn new() -> Self {
Self {
connected: Arc::new(AtomicBool::new(true)),
sent_messages: Arc::new(Mutex::new(Vec::new())),
should_fail: Arc::new(AtomicBool::new(false)),
}
}
fn with_failure() -> Self {
let client = Self::new();
client.should_fail.store(true, Ordering::SeqCst);
client
}
}
#[async_trait]
impl KafkaClient for MockKafkaClient {
fn create(config: &ClientConfig) -> KafkaResult<Self> {
if config.get("bootstrap.servers").unwrap_or("") == "" {
return Err(rdkafka::error::KafkaError::ClientCreation(
"Failed to create client".to_string(),
));
}
Ok(Self::new())
}
async fn send<K, P, T>(
&self,
record: FutureRecord<'_, K, P>,
_queue_timeout: T,
) -> OwnedDeliveryResult
where
K: ToBytes + ?Sized + Sync,
P: ToBytes + ?Sized + Sync,
T: Into<Timeout> + Sync + Send,
{
if self.should_fail.load(Ordering::SeqCst) {
let err = rdkafka::error::KafkaError::MessageProduction(
rdkafka::types::RDKafkaErrorCode::QueueFull,
);
let payload = rdkafka::message::OwnedMessage::new(
Some(record.payload.unwrap().to_bytes().to_vec()),
None,
record.topic.to_string(),
Timestamp::NotAvailable,
0,
0,
None,
);
return Err((err, payload));
}
let mut messages = self.sent_messages.lock().await;
messages.push((
record.topic.to_string(),
record.payload.unwrap().to_bytes().to_vec(),
record
.key
.map(|k| String::from_utf8_lossy(k.to_bytes()).to_string()),
));
Ok((
rdkafka::types::RDKafkaRespErr::RD_KAFKA_RESP_ERR_NO_ERROR as i32,
0,
))
}
fn flush<T: Into<Timeout>>(&self, _timeout: T) -> KafkaResult<()> {
if self.should_fail.load(Ordering::SeqCst) {
return Err(rdkafka::error::KafkaError::Flush(
rdkafka::types::RDKafkaErrorCode::QueueFull,
));
}
Ok(())
}
}
#[tokio::test]
async fn test_kafka_output_new() {
let config = KafkaOutputConfig {
brokers: vec!["localhost:9092".to_string()],
topic: "test-topic".to_string(),
key: None,
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config);
assert!(output.is_ok(), "Failed to create Kafka output component");
}
#[tokio::test]
async fn test_kafka_output_connect() {
let config = KafkaOutputConfig {
brokers: vec!["localhost:9092".to_string()],
topic: "test-topic".to_string(),
key: None,
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
let result = output.connect().await;
assert!(result.is_ok(), "Failed to connect to Kafka");
let producer_guard = output.producer.read().await;
assert!(producer_guard.is_some(), "Kafka producer not initialized");
}
#[tokio::test]
async fn test_kafka_output_connect_failure() {
let config = KafkaOutputConfig {
brokers: vec![],
topic: "test-topic".to_string(),
key: None,
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
let result = output.connect().await;
assert!(result.is_err(), "Connection should fail with empty brokers");
}
#[tokio::test]
async fn test_kafka_output_write() {
let config = KafkaOutputConfig {
brokers: vec!["localhost:9092".to_string()],
topic: "test-topic".to_string(),
key: None,
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
output.connect().await.unwrap();
let msg = MessageBatch::from_string("test message");
let result = output.write(&msg).await;
assert!(result.is_ok(), "Failed to write message to Kafka");
let producer_guard = output.producer.read().await;
let producer = producer_guard.as_ref().unwrap();
let messages = producer.sent_messages.lock().await;
assert_eq!(messages.len(), 1, "Message not sent to Kafka");
assert_eq!(messages[0].0, "test-topic", "Wrong topic");
assert_eq!(messages[0].1, b"test message", "Wrong message content");
assert_eq!(messages[0].2, None, "Key should be None");
}
#[tokio::test]
async fn test_kafka_output_write_with_key() {
let config = KafkaOutputConfig {
brokers: vec!["localhost:9092".to_string()],
topic: "test-topic".to_string(),
key: Some("test-key".to_string()),
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
output.connect().await.unwrap();
let msg = MessageBatch::from_string("test message");
let result = output.write(&msg).await;
assert!(result.is_ok(), "Failed to write message to Kafka");
let producer_guard = output.producer.read().await;
let producer = producer_guard.as_ref().unwrap();
let messages = producer.sent_messages.lock().await;
assert_eq!(messages.len(), 1, "Message not sent to Kafka");
assert_eq!(messages[0].2, Some("test-key".to_string()), "Wrong key");
}
#[tokio::test]
async fn test_kafka_output_write_without_connect() {
let config = KafkaOutputConfig {
brokers: vec!["localhost:9092".to_string()],
topic: "test-topic".to_string(),
key: None,
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
let msg = MessageBatch::from_string("test message");
let result = output.write(&msg).await;
assert!(result.is_err(), "Write should fail when not connected");
match result {
Err(Error::Connection(_)) => {} _ => panic!("Expected Connection error"),
}
}
#[tokio::test]
async fn test_kafka_output_write_failure() {
let config = KafkaOutputConfig {
brokers: vec!["localhost:9092".to_string()],
topic: "test-topic".to_string(),
key: None,
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
output.connect().await.unwrap();
let producer_guard = output.producer.read().await;
let producer = producer_guard.as_ref().unwrap();
producer.should_fail.store(true, Ordering::SeqCst);
let msg = MessageBatch::from_string("test message");
let result = output.write(&msg).await;
assert!(result.is_err(), "Write should fail with producer error");
}
#[tokio::test]
async fn test_kafka_output_close() {
let config = KafkaOutputConfig {
brokers: vec!["localhost:9092".to_string()],
topic: "test-topic".to_string(),
key: None,
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
output.connect().await.unwrap();
let result = output.close().await;
assert!(result.is_ok(), "Failed to close Kafka connection");
let producer_guard = output.producer.read().await;
assert!(producer_guard.is_none(), "Kafka producer not cleared");
}
#[tokio::test]
async fn test_kafka_output_close_failure() {
let config = KafkaOutputConfig {
brokers: vec!["localhost:9092".to_string()],
topic: "test-topic".to_string(),
key: None,
client_id: None,
compression: None,
acks: None,
};
let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
output.connect().await.unwrap();
{
let producer_guard = output.producer.read().await;
let producer = producer_guard.as_ref().unwrap();
producer.should_fail.store(true, Ordering::SeqCst);
}
let result = output.close().await;
assert!(result.is_err(), "Close should fail with flush error");
}
}