Skip to main content

alun_task/
worker.rs

1//! 任务消费者 —— 从 Kafka 消费任务并分发给 Handler 执行
2
3use std::sync::Arc;
4use std::time::Instant;
5use rdkafka::consumer::{StreamConsumer, Consumer};
6use rdkafka::ClientConfig;
7use rdkafka::message::BorrowedMessage;
8use rdkafka::producer::{FutureProducer, FutureRecord};
9use rdkafka::Message;
10use tokio::time::timeout;
11use tracing::{info, warn, error};
12
13use crate::storage::TaskStorage;
14use crate::HandlerRegistry;
15use crate::types::*;
16use crate::TaskMetrics;
17use crate::metrics::AtomicInc;
18
19/// 任务执行器
20///
21/// 从 Kafka 消费 `TaskMessage`,按 `task_type` 查找注册的 `TaskHandler`,
22/// 执行并委托 `TaskStorage` 记录结果、更新状态、处理重试和死信队列。
23/// 不持有任何 SQL 或表名——持久化完全交由 storage 代理。
24pub struct TaskWorker {
25    /// Kafka 消费者
26    consumer: Arc<StreamConsumer>,
27    /// 任务持久化接口
28    storage: Arc<dyn TaskStorage>,
29    /// 处理器注册中心
30    registry: HandlerRegistry,
31    /// 任务执行指标
32    metrics: Arc<TaskMetrics>,
33    /// TaskWorker 运行时配置
34    config: TaskWorkerConfig,
35    /// Kafka 生产者(用于发送死信消息)
36    producer: FutureProducer,
37    /// 运行状态标志
38    running: Arc<std::sync::atomic::AtomicBool>,
39}
40
41impl TaskWorker {
42    /// 创建任务执行器
43    pub fn new(
44        config: TaskWorkerConfig,
45        storage: Arc<dyn TaskStorage>,
46        registry: HandlerRegistry,
47        metrics: Arc<TaskMetrics>,
48    ) -> Result<Self, String> {
49        let consumer: StreamConsumer = ClientConfig::new()
50            .set("bootstrap.servers", &config.brokers)
51            .set("group.id", &config.group_id)
52            .set("enable.auto.commit", "false")
53            .set("auto.offset.reset", "earliest")
54            .set("session.timeout.ms", "30000")
55            .set("max.poll.interval.ms", "600000")
56            .create()
57            .map_err(|e| format!("Kafka Consumer 创建失败: {}", e))?;
58
59        let producer: FutureProducer = ClientConfig::new()
60            .set("bootstrap.servers", &config.brokers)
61            .set("message.timeout.ms", "5000")
62            .create()
63            .map_err(|e| format!("Kafka DLQ Producer 创建失败: {}", e))?;
64
65        Ok(Self {
66            consumer: Arc::new(consumer),
67            storage,
68            registry,
69            metrics,
70            config,
71            producer,
72            running: Arc::new(std::sync::atomic::AtomicBool::new(true)),
73        })
74    }
75
76    /// 订阅 topic 并启动消费循环
77    pub async fn run(&self, topics: &[String]) -> Result<(), String> {
78        self.consumer
79            .subscribe(&topics.iter().map(|s| s.as_str()).collect::<Vec<_>>())
80            .map_err(|e| format!("Kafka Consumer 订阅失败: {}", e))?;
81
82        info!("TaskWorker 启动,订阅 topics: {:?}", topics);
83
84        while self.running.load(std::sync::atomic::Ordering::Relaxed) {
85            match self.consumer.recv().await {
86                Err(e) => {
87                    error!("Kafka 接收消息失败: {}", e);
88                    tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
89                }
90                Ok(msg) => {
91                    self.handle_message(&msg).await;
92                }
93            }
94        }
95
96        info!("TaskWorker 已停止");
97        Ok(())
98    }
99
100    /// 停止消费循环
101    pub fn stop(&self) {
102        self.running.store(false, std::sync::atomic::Ordering::Relaxed);
103    }
104
105    /// 处理单条 Kafka 消息:反序列化、查找 handler、执行、更新状态
106    async fn handle_message(&self, msg: &BorrowedMessage<'_>) {
107        let payload = match msg.payload() {
108            Some(p) => p,
109            None => {
110                warn!("收到空消息");
111                return;
112            }
113        };
114
115        let task_msg: TaskMessage = match serde_json::from_slice(payload) {
116            Ok(m) => m,
117            Err(e) => {
118                error!("消息反序列化失败: {}", e);
119                return;
120            }
121        };
122
123        if let Err(reason) = self.check_message_age(&task_msg) {
124            warn!(task_id = %task_msg.task_id, reason = reason, "消息已过期,跳过");
125            self.commit_offset(msg).await;
126            return;
127        }
128
129        self.metrics.total.inc();
130
131        let (handler, config) = match self.registry.get(task_msg.task_type) {
132            Some(h) => h,
133            None => {
134                warn!(task_type = task_msg.task_type, "未找到 handler");
135                self.commit_offset(msg).await;
136                return;
137            }
138        };
139
140        let _ = self.storage.update_task_status(&task_msg.task_id, TaskStatus::Processing).await;
141
142        let started_at = Instant::now();
143
144        let result = if config.timeout_seconds > 0 {
145            match timeout(
146                tokio::time::Duration::from_secs(config.timeout_seconds),
147                handler.execute(task_msg.payload.clone()),
148            )
149            .await
150            {
151                Ok(r) => r,
152                Err(_) => Err(format!("任务超时 ({}s)", config.timeout_seconds)),
153            }
154        } else {
155            handler.execute(task_msg.payload.clone()).await
156        };
157
158        let elapsed_ms = started_at.elapsed().as_millis() as i64;
159
160        match result {
161            Ok(output) => {
162                self.handle_success(&task_msg, &output, elapsed_ms).await;
163            }
164            Err(e) => {
165                self.handle_failure(&task_msg, &e, &config, elapsed_ms).await;
166            }
167        }
168
169        self.commit_offset(msg).await;
170    }
171
172    /// 异步提交 Kafka 消息 offset
173    async fn commit_offset(&self, msg: &BorrowedMessage<'_>) {
174        if let Err(e) = self.consumer.commit_message(msg, rdkafka::consumer::CommitMode::Async) {
175            error!(error = %e, "Kafka offset 提交失败");
176        }
177    }
178
179    /// 检查消息是否超过最大时效,超过则返回错误
180    fn check_message_age(&self, msg: &TaskMessage) -> Result<(), String> {
181        let submitted = chrono::DateTime::parse_from_rfc3339(&msg.submitted_at)
182            .map_err(|e| format!("解析 submitted_at 失败: {}", e))?;
183        let age = chrono::Utc::now()
184            .signed_duration_since(submitted.with_timezone(&chrono::Utc))
185            .num_seconds();
186        if age > self.config.max_message_age_secs as i64 {
187            return Err(format!(
188                "消息已超过最大时效 {}s(实际 {}s)",
189                self.config.max_message_age_secs, age
190            ));
191        }
192        Ok(())
193    }
194
195    /// 处理任务执行成功:更新状态为 Completed、存储结果、记录执行日志
196    async fn handle_success(&self, msg: &TaskMessage, output: &serde_json::Value, elapsed_ms: i64) {
197        let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::Completed).await;
198        let _ = self.storage.save_task_result(&msg.task_id, output).await;
199        let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Completed, None, elapsed_ms).await;
200        self.metrics.completed.inc();
201        info!(task_id = %msg.task_id, elapsed_ms = elapsed_ms, "任务执行成功");
202    }
203
204    /// 处理任务执行失败:判断是否超过最大重试次数,决定是转入死信队列还是等待重试
205    async fn handle_failure(
206        &self,
207        msg: &TaskMessage,
208        err_msg: &str,
209        config: &TaskConfig,
210        elapsed_ms: i64,
211    ) {
212        self.metrics.failed.inc();
213
214        let current_retries = self.storage.get_retry_count(&msg.task_id).await.unwrap_or(0);
215        let attempt = current_retries + 1;
216
217        if attempt as u32 > config.max_retries {
218            if let Some(ref dlq_topic) = config.dead_letter_topic {
219                warn!(task_id = %msg.task_id, attempt = attempt, "转入死信队列");
220                let _ = self.send_to_dlq(msg, dlq_topic, err_msg).await;
221                let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::DeadLetter).await;
222            } else {
223                warn!(task_id = %msg.task_id, attempt = attempt, "超过最大重试次数");
224                let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::Failed).await;
225                let _ = self.storage.save_task_result(
226                    &msg.task_id,
227                    &serde_json::json!({"error": err_msg, "retries": attempt}),
228                ).await;
229            }
230            let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Failed, Some(err_msg), elapsed_ms).await;
231        } else {
232            let _ = self.storage.update_retry(&msg.task_id, attempt).await;
233            let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Failed, Some(err_msg), elapsed_ms).await;
234            info!(task_id = %msg.task_id, attempt = attempt, "任务失败,等待重试: {}", err_msg);
235        }
236    }
237
238    /// 发送任务到死信队列(DLQ)
239    async fn send_to_dlq(
240        &self,
241        msg: &TaskMessage,
242        dead_letter_topic: &str,
243        reason: &str,
244    ) -> Result<(), String> {
245        let payload = serde_json::to_vec(msg).map_err(|e| format!("DLQ 序列化失败: {}", e))?;
246        let record = FutureRecord::to(dead_letter_topic)
247            .key(&msg.task_id)
248            .payload(&payload);
249
250        self.producer
251            .send(record, std::time::Duration::from_secs(5))
252            .await
253            .map_err(|(e, _)| format!("DLQ 发送失败: {}", e))?;
254
255        info!(task_id = %msg.task_id, reason = reason, "任务已转入死信队列: {}", dead_letter_topic);
256        Ok(())
257    }
258}