use std::sync::Arc;
use std::time::Instant;
use rdkafka::consumer::{StreamConsumer, Consumer};
use rdkafka::ClientConfig;
use rdkafka::message::BorrowedMessage;
use rdkafka::producer::{FutureProducer, FutureRecord};
use rdkafka::Message;
use tokio::time::timeout;
use tracing::{info, warn, error};
use crate::storage::TaskStorage;
use crate::HandlerRegistry;
use crate::types::*;
use crate::TaskMetrics;
use crate::metrics::AtomicInc;
pub struct TaskWorker {
consumer: Arc<StreamConsumer>,
storage: Arc<dyn TaskStorage>,
registry: HandlerRegistry,
metrics: Arc<TaskMetrics>,
config: TaskWorkerConfig,
producer: FutureProducer,
running: Arc<std::sync::atomic::AtomicBool>,
}
impl TaskWorker {
pub fn new(
config: TaskWorkerConfig,
storage: Arc<dyn TaskStorage>,
registry: HandlerRegistry,
metrics: Arc<TaskMetrics>,
) -> Result<Self, String> {
let consumer: StreamConsumer = ClientConfig::new()
.set("bootstrap.servers", &config.brokers)
.set("group.id", &config.group_id)
.set("enable.auto.commit", "false")
.set("auto.offset.reset", "earliest")
.set("session.timeout.ms", "30000")
.set("max.poll.interval.ms", "600000")
.create()
.map_err(|e| format!("Kafka Consumer 创建失败: {}", e))?;
let producer: FutureProducer = ClientConfig::new()
.set("bootstrap.servers", &config.brokers)
.set("message.timeout.ms", "5000")
.create()
.map_err(|e| format!("Kafka DLQ Producer 创建失败: {}", e))?;
Ok(Self {
consumer: Arc::new(consumer),
storage,
registry,
metrics,
config,
producer,
running: Arc::new(std::sync::atomic::AtomicBool::new(true)),
})
}
pub async fn run(&self, topics: &[String]) -> Result<(), String> {
self.consumer
.subscribe(&topics.iter().map(|s| s.as_str()).collect::<Vec<_>>())
.map_err(|e| format!("Kafka Consumer 订阅失败: {}", e))?;
info!("TaskWorker 启动,订阅 topics: {:?}", topics);
while self.running.load(std::sync::atomic::Ordering::Relaxed) {
match self.consumer.recv().await {
Err(e) => {
error!("Kafka 接收消息失败: {}", e);
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
}
Ok(msg) => {
self.handle_message(&msg).await;
}
}
}
info!("TaskWorker 已停止");
Ok(())
}
pub fn stop(&self) {
self.running.store(false, std::sync::atomic::Ordering::Relaxed);
}
async fn handle_message(&self, msg: &BorrowedMessage<'_>) {
let payload = match msg.payload() {
Some(p) => p,
None => {
warn!("收到空消息");
return;
}
};
let task_msg: TaskMessage = match serde_json::from_slice(payload) {
Ok(m) => m,
Err(e) => {
error!("消息反序列化失败: {}", e);
return;
}
};
if let Err(reason) = self.check_message_age(&task_msg) {
warn!(task_id = %task_msg.task_id, reason = reason, "消息已过期,跳过");
self.commit_offset(msg).await;
return;
}
self.metrics.total.inc();
let (handler, config) = match self.registry.get(task_msg.task_type) {
Some(h) => h,
None => {
warn!(task_type = task_msg.task_type, "未找到 handler");
self.commit_offset(msg).await;
return;
}
};
let _ = self.storage.update_task_status(&task_msg.task_id, TaskStatus::Processing).await;
let started_at = Instant::now();
let result = if config.timeout_seconds > 0 {
match timeout(
tokio::time::Duration::from_secs(config.timeout_seconds),
handler.execute(task_msg.payload.clone()),
)
.await
{
Ok(r) => r,
Err(_) => Err(format!("任务超时 ({}s)", config.timeout_seconds)),
}
} else {
handler.execute(task_msg.payload.clone()).await
};
let elapsed_ms = started_at.elapsed().as_millis() as i64;
match result {
Ok(output) => {
self.handle_success(&task_msg, &output, elapsed_ms).await;
}
Err(e) => {
self.handle_failure(&task_msg, &e, &config, elapsed_ms).await;
}
}
self.commit_offset(msg).await;
}
async fn commit_offset(&self, msg: &BorrowedMessage<'_>) {
if let Err(e) = self.consumer.commit_message(msg, rdkafka::consumer::CommitMode::Async) {
error!(error = %e, "Kafka offset 提交失败");
}
}
fn check_message_age(&self, msg: &TaskMessage) -> Result<(), String> {
let submitted = chrono::DateTime::parse_from_rfc3339(&msg.submitted_at)
.map_err(|e| format!("解析 submitted_at 失败: {}", e))?;
let age = chrono::Utc::now()
.signed_duration_since(submitted.with_timezone(&chrono::Utc))
.num_seconds();
if age > self.config.max_message_age_secs as i64 {
return Err(format!(
"消息已超过最大时效 {}s(实际 {}s)",
self.config.max_message_age_secs, age
));
}
Ok(())
}
async fn handle_success(&self, msg: &TaskMessage, output: &serde_json::Value, elapsed_ms: i64) {
let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::Completed).await;
let _ = self.storage.save_task_result(&msg.task_id, output).await;
let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Completed, None, elapsed_ms).await;
self.metrics.completed.inc();
info!(task_id = %msg.task_id, elapsed_ms = elapsed_ms, "任务执行成功");
}
async fn handle_failure(
&self,
msg: &TaskMessage,
err_msg: &str,
config: &TaskConfig,
elapsed_ms: i64,
) {
self.metrics.failed.inc();
let current_retries = self.storage.get_retry_count(&msg.task_id).await.unwrap_or(0);
let attempt = current_retries + 1;
if attempt as u32 > config.max_retries {
if let Some(ref dlq_topic) = config.dead_letter_topic {
warn!(task_id = %msg.task_id, attempt = attempt, "转入死信队列");
let _ = self.send_to_dlq(msg, dlq_topic, err_msg).await;
let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::DeadLetter).await;
} else {
warn!(task_id = %msg.task_id, attempt = attempt, "超过最大重试次数");
let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::Failed).await;
let _ = self.storage.save_task_result(
&msg.task_id,
&serde_json::json!({"error": err_msg, "retries": attempt}),
).await;
}
let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Failed, Some(err_msg), elapsed_ms).await;
} else {
let _ = self.storage.update_retry(&msg.task_id, attempt).await;
let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Failed, Some(err_msg), elapsed_ms).await;
info!(task_id = %msg.task_id, attempt = attempt, "任务失败,等待重试: {}", err_msg);
}
}
async fn send_to_dlq(
&self,
msg: &TaskMessage,
dead_letter_topic: &str,
reason: &str,
) -> Result<(), String> {
let payload = serde_json::to_vec(msg).map_err(|e| format!("DLQ 序列化失败: {}", e))?;
let record = FutureRecord::to(dead_letter_topic)
.key(&msg.task_id)
.payload(&payload);
self.producer
.send(record, std::time::Duration::from_secs(5))
.await
.map_err(|(e, _)| format!("DLQ 发送失败: {}", e))?;
info!(task_id = %msg.task_id, reason = reason, "任务已转入死信队列: {}", dead_letter_topic);
Ok(())
}
}