1use std::sync::Arc;
4use std::time::Instant;
5use rdkafka::consumer::{StreamConsumer, Consumer};
6use rdkafka::ClientConfig;
7use rdkafka::message::BorrowedMessage;
8use rdkafka::producer::{FutureProducer, FutureRecord};
9use rdkafka::Message;
10use tokio::time::timeout;
11use tracing::{info, warn, error};
12
13use crate::storage::TaskStorage;
14use crate::HandlerRegistry;
15use crate::types::*;
16use crate::TaskMetrics;
17use crate::metrics::AtomicInc;
18
19pub struct TaskWorker {
25 consumer: Arc<StreamConsumer>,
27 storage: Arc<dyn TaskStorage>,
29 registry: HandlerRegistry,
31 metrics: Arc<TaskMetrics>,
33 config: TaskWorkerConfig,
35 producer: FutureProducer,
37 running: Arc<std::sync::atomic::AtomicBool>,
39}
40
41impl TaskWorker {
42 pub fn new(
44 config: TaskWorkerConfig,
45 storage: Arc<dyn TaskStorage>,
46 registry: HandlerRegistry,
47 metrics: Arc<TaskMetrics>,
48 ) -> Result<Self, String> {
49 let consumer: StreamConsumer = ClientConfig::new()
50 .set("bootstrap.servers", &config.brokers)
51 .set("group.id", &config.group_id)
52 .set("enable.auto.commit", "false")
53 .set("auto.offset.reset", "earliest")
54 .set("session.timeout.ms", "30000")
55 .set("max.poll.interval.ms", "600000")
56 .create()
57 .map_err(|e| format!("Kafka Consumer 创建失败: {}", e))?;
58
59 let producer: FutureProducer = ClientConfig::new()
60 .set("bootstrap.servers", &config.brokers)
61 .set("message.timeout.ms", "5000")
62 .create()
63 .map_err(|e| format!("Kafka DLQ Producer 创建失败: {}", e))?;
64
65 Ok(Self {
66 consumer: Arc::new(consumer),
67 storage,
68 registry,
69 metrics,
70 config,
71 producer,
72 running: Arc::new(std::sync::atomic::AtomicBool::new(true)),
73 })
74 }
75
76 pub async fn run(&self, topics: &[String]) -> Result<(), String> {
78 self.consumer
79 .subscribe(&topics.iter().map(|s| s.as_str()).collect::<Vec<_>>())
80 .map_err(|e| format!("Kafka Consumer 订阅失败: {}", e))?;
81
82 info!("TaskWorker 启动,订阅 topics: {:?}", topics);
83
84 while self.running.load(std::sync::atomic::Ordering::Relaxed) {
85 match self.consumer.recv().await {
86 Err(e) => {
87 error!("Kafka 接收消息失败: {}", e);
88 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
89 }
90 Ok(msg) => {
91 self.handle_message(&msg).await;
92 }
93 }
94 }
95
96 info!("TaskWorker 已停止");
97 Ok(())
98 }
99
100 pub fn stop(&self) {
102 self.running.store(false, std::sync::atomic::Ordering::Relaxed);
103 }
104
105 async fn handle_message(&self, msg: &BorrowedMessage<'_>) {
107 let payload = match msg.payload() {
108 Some(p) => p,
109 None => {
110 warn!("收到空消息");
111 return;
112 }
113 };
114
115 let task_msg: TaskMessage = match serde_json::from_slice(payload) {
116 Ok(m) => m,
117 Err(e) => {
118 error!("消息反序列化失败: {}", e);
119 return;
120 }
121 };
122
123 if let Err(reason) = self.check_message_age(&task_msg) {
124 warn!(task_id = %task_msg.task_id, reason = reason, "消息已过期,跳过");
125 self.commit_offset(msg).await;
126 return;
127 }
128
129 self.metrics.total.inc();
130
131 let (handler, config) = match self.registry.get(task_msg.task_type) {
132 Some(h) => h,
133 None => {
134 warn!(task_type = task_msg.task_type, "未找到 handler");
135 self.commit_offset(msg).await;
136 return;
137 }
138 };
139
140 let _ = self.storage.update_task_status(&task_msg.task_id, TaskStatus::Processing).await;
141
142 let started_at = Instant::now();
143
144 let result = if config.timeout_seconds > 0 {
145 match timeout(
146 tokio::time::Duration::from_secs(config.timeout_seconds),
147 handler.execute(task_msg.payload.clone()),
148 )
149 .await
150 {
151 Ok(r) => r,
152 Err(_) => Err(format!("任务超时 ({}s)", config.timeout_seconds)),
153 }
154 } else {
155 handler.execute(task_msg.payload.clone()).await
156 };
157
158 let elapsed_ms = started_at.elapsed().as_millis() as i64;
159
160 match result {
161 Ok(output) => {
162 self.handle_success(&task_msg, &output, elapsed_ms).await;
163 }
164 Err(e) => {
165 self.handle_failure(&task_msg, &e, &config, elapsed_ms).await;
166 }
167 }
168
169 self.commit_offset(msg).await;
170 }
171
172 async fn commit_offset(&self, msg: &BorrowedMessage<'_>) {
174 if let Err(e) = self.consumer.commit_message(msg, rdkafka::consumer::CommitMode::Async) {
175 error!(error = %e, "Kafka offset 提交失败");
176 }
177 }
178
179 fn check_message_age(&self, msg: &TaskMessage) -> Result<(), String> {
181 let submitted = chrono::DateTime::parse_from_rfc3339(&msg.submitted_at)
182 .map_err(|e| format!("解析 submitted_at 失败: {}", e))?;
183 let age = chrono::Utc::now()
184 .signed_duration_since(submitted.with_timezone(&chrono::Utc))
185 .num_seconds();
186 if age > self.config.max_message_age_secs as i64 {
187 return Err(format!(
188 "消息已超过最大时效 {}s(实际 {}s)",
189 self.config.max_message_age_secs, age
190 ));
191 }
192 Ok(())
193 }
194
195 async fn handle_success(&self, msg: &TaskMessage, output: &serde_json::Value, elapsed_ms: i64) {
197 let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::Completed).await;
198 let _ = self.storage.save_task_result(&msg.task_id, output).await;
199 let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Completed, None, elapsed_ms).await;
200 self.metrics.completed.inc();
201 info!(task_id = %msg.task_id, elapsed_ms = elapsed_ms, "任务执行成功");
202 }
203
204 async fn handle_failure(
206 &self,
207 msg: &TaskMessage,
208 err_msg: &str,
209 config: &TaskConfig,
210 elapsed_ms: i64,
211 ) {
212 self.metrics.failed.inc();
213
214 let current_retries = self.storage.get_retry_count(&msg.task_id).await.unwrap_or(0);
215 let attempt = current_retries + 1;
216
217 if attempt as u32 > config.max_retries {
218 if let Some(ref dlq_topic) = config.dead_letter_topic {
219 warn!(task_id = %msg.task_id, attempt = attempt, "转入死信队列");
220 let _ = self.send_to_dlq(msg, dlq_topic, err_msg).await;
221 let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::DeadLetter).await;
222 } else {
223 warn!(task_id = %msg.task_id, attempt = attempt, "超过最大重试次数");
224 let _ = self.storage.update_task_status(&msg.task_id, TaskStatus::Failed).await;
225 let _ = self.storage.save_task_result(
226 &msg.task_id,
227 &serde_json::json!({"error": err_msg, "retries": attempt}),
228 ).await;
229 }
230 let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Failed, Some(err_msg), elapsed_ms).await;
231 } else {
232 let _ = self.storage.update_retry(&msg.task_id, attempt).await;
233 let _ = self.storage.log_execution(&msg.task_id, TaskStatus::Failed, Some(err_msg), elapsed_ms).await;
234 info!(task_id = %msg.task_id, attempt = attempt, "任务失败,等待重试: {}", err_msg);
235 }
236 }
237
238 async fn send_to_dlq(
240 &self,
241 msg: &TaskMessage,
242 dead_letter_topic: &str,
243 reason: &str,
244 ) -> Result<(), String> {
245 let payload = serde_json::to_vec(msg).map_err(|e| format!("DLQ 序列化失败: {}", e))?;
246 let record = FutureRecord::to(dead_letter_topic)
247 .key(&msg.task_id)
248 .payload(&payload);
249
250 self.producer
251 .send(record, std::time::Duration::from_secs(5))
252 .await
253 .map_err(|(e, _)| format!("DLQ 发送失败: {}", e))?;
254
255 info!(task_id = %msg.task_id, reason = reason, "任务已转入死信队列: {}", dead_letter_topic);
256 Ok(())
257 }
258}