1use serde::{Deserialize, Serialize};
6
7use arkflow_core::output::{register_output_builder, Output, OutputBuilder};
8use arkflow_core::{Content, Error, MessageBatch};
9
10use async_trait::async_trait;
11use rdkafka::config::ClientConfig;
12use rdkafka::error::KafkaResult;
13use rdkafka::message::ToBytes;
14use rdkafka::producer::future_producer::OwnedDeliveryResult;
15use rdkafka::producer::{FutureProducer, FutureRecord, Producer};
16use rdkafka::util::Timeout;
17use std::sync::Arc;
18use std::time::Duration;
19use tokio::sync::RwLock;
20
21#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
22#[serde(rename_all = "lowercase")]
23pub enum CompressionType {
24 None,
25 Gzip,
26 Snappy,
27 Lz4,
28}
29
30impl std::fmt::Display for CompressionType {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 match self {
33 CompressionType::None => write!(f, "none"),
34 CompressionType::Gzip => write!(f, "gzip"),
35 CompressionType::Snappy => write!(f, "snappy"),
36 CompressionType::Lz4 => write!(f, "lz4"),
37 }
38 }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct KafkaOutputConfig {
44 pub brokers: Vec<String>,
46 pub topic: String,
48 pub key: Option<String>,
50 pub client_id: Option<String>,
52 pub compression: Option<CompressionType>,
54 pub acks: Option<String>,
56}
57
58struct KafkaOutput<T> {
60 config: KafkaOutputConfig,
61 producer: Arc<RwLock<Option<T>>>,
62}
63
64impl<T: KafkaClient> KafkaOutput<T> {
65 pub fn new(config: KafkaOutputConfig) -> Result<Self, Error> {
67 Ok(Self {
68 config,
69 producer: Arc::new(RwLock::new(None)),
70 })
71 }
72}
73
74#[async_trait]
75impl<T: KafkaClient> Output for KafkaOutput<T> {
76 async fn connect(&self) -> Result<(), Error> {
77 let mut client_config = ClientConfig::new();
78
79 client_config.set("bootstrap.servers", &self.config.brokers.join(","));
81
82 if let Some(client_id) = &self.config.client_id {
84 client_config.set("client.id", client_id);
85 }
86
87 if let Some(compression) = &self.config.compression {
89 client_config.set("compression.type", compression.to_string().to_lowercase());
90 }
91
92 if let Some(acks) = &self.config.acks {
94 client_config.set("acks", acks);
95 }
96
97 let producer = T::create(&client_config)
99 .map_err(|e| Error::Connection(format!("A Kafka producer cannot be created: {}", e)))?;
100
101 let producer_arc = self.producer.clone();
103 let mut producer_guard = producer_arc.write().await;
104 *producer_guard = Some(producer);
105
106 Ok(())
107 }
108
109 async fn write(&self, msg: &MessageBatch) -> Result<(), Error> {
110 let producer_arc = self.producer.clone();
111 let producer_guard = producer_arc.read().await;
112 let producer = producer_guard.as_ref().ok_or_else(|| {
113 Error::Connection("The Kafka producer is not initialized".to_string())
114 })?;
115
116 let payloads = msg.as_string()?;
117 if payloads.is_empty() {
118 return Ok(());
119 }
120
121 match &msg.content {
122 Content::Arrow(_) => {
123 return Err(Error::Process(
124 "The arrow format is not supported".to_string(),
125 ))
126 }
127 Content::Binary(v) => {
128 for x in v {
129 let mut record = FutureRecord::to(&self.config.topic).payload(&x);
131
132 if let Some(key) = &self.config.key {
134 record = record.key(key);
135 }
136
137 producer
139 .send(record, Duration::from_secs(5))
140 .await
141 .map_err(|(e, _)| {
142 Error::Process(format!("Failed to send a Kafka message: {}", e))
143 })?;
144 }
145 }
146 }
147 Ok(())
148 }
149
150 async fn close(&self) -> Result<(), Error> {
151 let producer_arc = self.producer.clone();
153 let mut producer_guard = producer_arc.write().await;
154
155 if let Some(producer) = producer_guard.take() {
156 producer.flush(Duration::from_secs(30)).map_err(|e| {
158 Error::Connection(format!(
159 "Failed to refresh the message when the Kafka producer is disabled: {}",
160 e
161 ))
162 })?;
163 }
164 Ok(())
165 }
166}
167
168pub(crate) struct KafkaOutputBuilder;
169impl OutputBuilder for KafkaOutputBuilder {
170 fn build(&self, config: &Option<serde_json::Value>) -> Result<Arc<dyn Output>, Error> {
171 if config.is_none() {
172 return Err(Error::Config(
173 "HTTP output configuration is missing".to_string(),
174 ));
175 }
176 let config: KafkaOutputConfig = serde_json::from_value(config.clone().unwrap())?;
177
178 Ok(Arc::new(KafkaOutput::<FutureProducer>::new(config)?))
179 }
180}
181
182pub fn init() {
183 register_output_builder("kafka", Arc::new(KafkaOutputBuilder));
184}
185#[async_trait]
186trait KafkaClient: Send + Sync {
187 fn create(config: &ClientConfig) -> KafkaResult<Self>
188 where
189 Self: Sized;
190
191 async fn send<K, P, T>(
192 &self,
193 record: FutureRecord<'_, K, P>,
194 queue_timeout: T,
195 ) -> OwnedDeliveryResult
196 where
197 K: ToBytes + ?Sized + Sync,
198 P: ToBytes + ?Sized + Sync,
199 T: Into<Timeout> + Sync + Send;
200
201 fn flush<T: Into<Timeout>>(&self, timeout: T) -> KafkaResult<()>;
202}
203#[async_trait]
204impl KafkaClient for FutureProducer {
205 fn create(config: &ClientConfig) -> KafkaResult<Self> {
206 config.create()
207 }
208 async fn send<K, P, T>(
209 &self,
210 record: FutureRecord<'_, K, P>,
211 queue_timeout: T,
212 ) -> OwnedDeliveryResult
213 where
214 K: ToBytes + ?Sized + Sync,
215 P: ToBytes + ?Sized + Sync,
216 T: Into<Timeout> + Sync + Send,
217 {
218 FutureProducer::send(self, record, queue_timeout).await
219 }
220
221 fn flush<T: Into<Timeout>>(&self, timeout: T) -> KafkaResult<()> {
222 Producer::flush(self, timeout)
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use super::*;
229 use rdkafka::Timestamp;
230 use std::sync::atomic::{AtomicBool, Ordering};
231 use std::sync::Arc;
232 use tokio::sync::Mutex;
233
234 struct MockKafkaClient {
236 connected: Arc<AtomicBool>,
238 sent_messages: Arc<Mutex<Vec<(String, Vec<u8>, Option<String>)>>>,
240 should_fail: Arc<AtomicBool>,
242 }
243
244 impl MockKafkaClient {
245 fn new() -> Self {
246 Self {
247 connected: Arc::new(AtomicBool::new(true)),
248 sent_messages: Arc::new(Mutex::new(Vec::new())),
249 should_fail: Arc::new(AtomicBool::new(false)),
250 }
251 }
252
253 fn with_failure() -> Self {
254 let client = Self::new();
255 client.should_fail.store(true, Ordering::SeqCst);
256 client
257 }
258 }
259
260 #[async_trait]
261 impl KafkaClient for MockKafkaClient {
262 fn create(config: &ClientConfig) -> KafkaResult<Self> {
263 if config.get("bootstrap.servers").unwrap_or("") == "" {
265 return Err(rdkafka::error::KafkaError::ClientCreation(
266 "Failed to create client".to_string(),
267 ));
268 }
269 Ok(Self::new())
270 }
271
272 async fn send<K, P, T>(
273 &self,
274 record: FutureRecord<'_, K, P>,
275 _queue_timeout: T,
276 ) -> OwnedDeliveryResult
277 where
278 K: ToBytes + ?Sized + Sync,
279 P: ToBytes + ?Sized + Sync,
280 T: Into<Timeout> + Sync + Send,
281 {
282 if self.should_fail.load(Ordering::SeqCst) {
284 let err = rdkafka::error::KafkaError::MessageProduction(
285 rdkafka::types::RDKafkaErrorCode::QueueFull,
286 );
287 let payload = rdkafka::message::OwnedMessage::new(
289 Some(record.payload.unwrap().to_bytes().to_vec()),
290 None,
291 record.topic.to_string(),
292 Timestamp::NotAvailable,
293 0,
294 0,
295 None,
296 );
297 return Err((err, payload));
298 }
299
300 let mut messages = self.sent_messages.lock().await;
302 messages.push((
303 record.topic.to_string(),
304 record.payload.unwrap().to_bytes().to_vec(),
305 record
306 .key
307 .map(|k| String::from_utf8_lossy(k.to_bytes()).to_string()),
308 ));
309
310 Ok((
313 rdkafka::types::RDKafkaRespErr::RD_KAFKA_RESP_ERR_NO_ERROR as i32,
314 0,
315 ))
316 }
317
318 fn flush<T: Into<Timeout>>(&self, _timeout: T) -> KafkaResult<()> {
319 if self.should_fail.load(Ordering::SeqCst) {
321 return Err(rdkafka::error::KafkaError::Flush(
322 rdkafka::types::RDKafkaErrorCode::QueueFull,
323 ));
324 }
325 Ok(())
326 }
327 }
328
329 #[tokio::test]
331 async fn test_kafka_output_new() {
332 let config = KafkaOutputConfig {
334 brokers: vec!["localhost:9092".to_string()],
335 topic: "test-topic".to_string(),
336 key: None,
337 client_id: None,
338 compression: None,
339 acks: None,
340 };
341
342 let output = KafkaOutput::<MockKafkaClient>::new(config);
344 assert!(output.is_ok(), "Failed to create Kafka output component");
345 }
346
347 #[tokio::test]
349 async fn test_kafka_output_connect() {
350 let config = KafkaOutputConfig {
352 brokers: vec!["localhost:9092".to_string()],
353 topic: "test-topic".to_string(),
354 key: None,
355 client_id: None,
356 compression: None,
357 acks: None,
358 };
359
360 let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
362 let result = output.connect().await;
363 assert!(result.is_ok(), "Failed to connect to Kafka");
364
365 let producer_guard = output.producer.read().await;
367 assert!(producer_guard.is_some(), "Kafka producer not initialized");
368 }
369
370 #[tokio::test]
372 async fn test_kafka_output_connect_failure() {
373 let config = KafkaOutputConfig {
375 brokers: vec![],
376 topic: "test-topic".to_string(),
377 key: None,
378 client_id: None,
379 compression: None,
380 acks: None,
381 };
382
383 let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
385 let result = output.connect().await;
386 assert!(result.is_err(), "Connection should fail with empty brokers");
387 }
388
389 #[tokio::test]
391 async fn test_kafka_output_write() {
392 let config = KafkaOutputConfig {
394 brokers: vec!["localhost:9092".to_string()],
395 topic: "test-topic".to_string(),
396 key: None,
397 client_id: None,
398 compression: None,
399 acks: None,
400 };
401
402 let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
404 output.connect().await.unwrap();
405
406 let msg = MessageBatch::from_string("test message");
408 let result = output.write(&msg).await;
409 assert!(result.is_ok(), "Failed to write message to Kafka");
410
411 let producer_guard = output.producer.read().await;
413 let producer = producer_guard.as_ref().unwrap();
414 let messages = producer.sent_messages.lock().await;
415 assert_eq!(messages.len(), 1, "Message not sent to Kafka");
416 assert_eq!(messages[0].0, "test-topic", "Wrong topic");
417 assert_eq!(messages[0].1, b"test message", "Wrong message content");
418 assert_eq!(messages[0].2, None, "Key should be None");
419 }
420
421 #[tokio::test]
423 async fn test_kafka_output_write_with_key() {
424 let config = KafkaOutputConfig {
426 brokers: vec!["localhost:9092".to_string()],
427 topic: "test-topic".to_string(),
428 key: Some("test-key".to_string()),
429 client_id: None,
430 compression: None,
431 acks: None,
432 };
433
434 let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
436 output.connect().await.unwrap();
437
438 let msg = MessageBatch::from_string("test message");
440 let result = output.write(&msg).await;
441 assert!(result.is_ok(), "Failed to write message to Kafka");
442
443 let producer_guard = output.producer.read().await;
445 let producer = producer_guard.as_ref().unwrap();
446 let messages = producer.sent_messages.lock().await;
447 assert_eq!(messages.len(), 1, "Message not sent to Kafka");
448 assert_eq!(messages[0].2, Some("test-key".to_string()), "Wrong key");
449 }
450
451 #[tokio::test]
453 async fn test_kafka_output_write_without_connect() {
454 let config = KafkaOutputConfig {
456 brokers: vec!["localhost:9092".to_string()],
457 topic: "test-topic".to_string(),
458 key: None,
459 client_id: None,
460 compression: None,
461 acks: None,
462 };
463
464 let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
466 let msg = MessageBatch::from_string("test message");
467 let result = output.write(&msg).await;
468
469 assert!(result.is_err(), "Write should fail when not connected");
471 match result {
472 Err(Error::Connection(_)) => {} _ => panic!("Expected Connection error"),
474 }
475 }
476
477 #[tokio::test]
479 async fn test_kafka_output_write_failure() {
480 let config = KafkaOutputConfig {
482 brokers: vec!["localhost:9092".to_string()],
483 topic: "test-topic".to_string(),
484 key: None,
485 client_id: None,
486 compression: None,
487 acks: None,
488 };
489
490 let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
492 output.connect().await.unwrap();
493
494 let producer_guard = output.producer.read().await;
496 let producer = producer_guard.as_ref().unwrap();
497 producer.should_fail.store(true, Ordering::SeqCst);
498
499 let msg = MessageBatch::from_string("test message");
501 let result = output.write(&msg).await;
502 assert!(result.is_err(), "Write should fail with producer error");
503 }
504
505 #[tokio::test]
507 async fn test_kafka_output_close() {
508 let config = KafkaOutputConfig {
510 brokers: vec!["localhost:9092".to_string()],
511 topic: "test-topic".to_string(),
512 key: None,
513 client_id: None,
514 compression: None,
515 acks: None,
516 };
517
518 let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
520 output.connect().await.unwrap();
521
522 let result = output.close().await;
524 assert!(result.is_ok(), "Failed to close Kafka connection");
525
526 let producer_guard = output.producer.read().await;
528 assert!(producer_guard.is_none(), "Kafka producer not cleared");
529 }
530
531 #[tokio::test]
533 async fn test_kafka_output_close_failure() {
534 let config = KafkaOutputConfig {
536 brokers: vec!["localhost:9092".to_string()],
537 topic: "test-topic".to_string(),
538 key: None,
539 client_id: None,
540 compression: None,
541 acks: None,
542 };
543
544 let output = KafkaOutput::<MockKafkaClient>::new(config).unwrap();
546 output.connect().await.unwrap();
547
548 {
550 let producer_guard = output.producer.read().await;
551 let producer = producer_guard.as_ref().unwrap();
552 producer.should_fail.store(true, Ordering::SeqCst);
553 }
554
555 let result = output.close().await;
557 assert!(result.is_err(), "Close should fail with flush error");
558 }
559}