use crate::input::{Ack, Input};
use crate::{Error, MessageBatch};
use async_trait::async_trait;
use rdkafka::config::ClientConfig;
use rdkafka::consumer::{Consumer, StreamConsumer};
use rdkafka::message::Message as KafkaMessage;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct KafkaInputConfig {
pub brokers: Vec<String>,
pub topics: Vec<String>,
pub consumer_group: String,
pub client_id: Option<String>,
pub start_from_latest: bool,
}
pub struct KafkaInput {
config: KafkaInputConfig,
consumer: Arc<RwLock<Option<StreamConsumer>>>,
}
impl KafkaInput {
pub fn new(config: &KafkaInputConfig) -> Result<Self, Error> {
Ok(Self {
config: config.clone(),
consumer: Arc::new(RwLock::new(None)),
})
}
}
#[async_trait]
impl Input for KafkaInput {
async fn connect(&self) -> Result<(), Error> {
let mut client_config = ClientConfig::new();
client_config.set("bootstrap.servers", &self.config.brokers.join(","));
client_config.set("group.id", &self.config.consumer_group);
if let Some(client_id) = &self.config.client_id {
client_config.set("client.id", client_id);
}
if self.config.start_from_latest {
client_config.set("auto.offset.reset", "latest");
} else {
client_config.set("auto.offset.reset", "earliest");
}
let consumer: StreamConsumer = client_config
.create()
.map_err(|e| Error::Connection(format!("无法创建Kafka消费者: {}", e)))?;
let x: Vec<&str> = self
.config
.topics
.iter()
.map(|topic| topic.as_str())
.collect();
consumer
.subscribe(&x)
.map_err(|e| Error::Connection(format!("无法订阅Kafka主题: {}", e)))?;
let consumer_arc = self.consumer.clone();
let mut consumer_guard = consumer_arc.write().await;
*consumer_guard = Some(consumer);
Ok(())
}
async fn read(&self) -> Result<(MessageBatch, Arc<dyn Ack>), Error> {
let consumer_arc = self.consumer.clone();
let consumer_guard = consumer_arc.read().await;
if consumer_guard.is_none() {
return Err(Error::Connection("输入未连接".to_string()));
}
let consumer = consumer_guard.as_ref().unwrap();
match consumer.recv().await {
Ok(kafka_message) => {
let payload = kafka_message
.payload()
.ok_or_else(|| Error::Processing("Kafka消息没有内容".to_string()))?;
let mut binary_data = Vec::new();
binary_data.push(payload.to_vec());
let msg_batch = MessageBatch::new_binary(binary_data);
let topic = kafka_message.topic().to_string();
let partition = kafka_message.partition();
let offset = kafka_message.offset();
let ack = KafkaAck {
consumer: self.consumer.clone(),
topic,
partition,
offset,
};
Ok((msg_batch, Arc::new(ack)))
}
Err(e) => Err(Error::Connection(format!("Kafka消息接收错误: {}", e))),
}
}
async fn close(&self) -> Result<(), Error> {
let mut consumer_guard = self.consumer.write().await;
if let Some(consumer) = consumer_guard.take() {
if let Err(e) = consumer.unassign() {
tracing::warn!("无法取消Kafka消费者分配: {}", e);
}
}
Ok(())
}
}
pub struct KafkaAck {
consumer: Arc<RwLock<Option<StreamConsumer>>>,
topic: String,
partition: i32,
offset: i64,
}
#[async_trait]
impl Ack for KafkaAck {
async fn ack(&self) {
let consumer_mutex_guard = self.consumer.read().await;
if let Some(v) = &*consumer_mutex_guard {
if let Err(e) = v.store_offset(&self.topic, self.partition, self.offset) {
tracing::error!("无法提交Kafka偏移量: {}", e);
}
}
}
}