cool_task/
worker.rs

1//! Worker 实现
2
3use crate::job::JobHandlerFactory;
4use crate::queue::Queue;
5use parking_lot::RwLock;
6use std::collections::HashMap;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::Semaphore;
10
11/// Worker 配置
12#[derive(Debug, Clone)]
13pub struct WorkerConfig {
14    /// 并发数
15    pub concurrency: usize,
16    /// 轮询间隔(毫秒)
17    pub poll_interval: u64,
18    /// 是否自动启动
19    pub auto_start: bool,
20}
21
22impl Default for WorkerConfig {
23    fn default() -> Self {
24        Self {
25            concurrency: 4,
26            poll_interval: 1000,
27            auto_start: true,
28        }
29    }
30}
31
32/// Worker
33pub struct Worker {
34    /// 队列
35    queue: Arc<Queue>,
36    /// 配置
37    config: WorkerConfig,
38    /// 任务处理器工厂映射
39    handlers: Arc<RwLock<HashMap<String, Box<dyn JobHandlerFactory>>>>,
40    /// 并发控制信号量
41    semaphore: Arc<Semaphore>,
42    /// 是否运行中
43    running: Arc<RwLock<bool>>,
44}
45
46impl Worker {
47    /// 创建 Worker
48    pub fn new(queue: Queue, config: WorkerConfig) -> Self {
49        let concurrency = config.concurrency;
50        Self {
51            queue: Arc::new(queue),
52            config,
53            handlers: Arc::new(RwLock::new(HashMap::new())),
54            semaphore: Arc::new(Semaphore::new(concurrency)),
55            running: Arc::new(RwLock::new(false)),
56        }
57    }
58
59    /// 注册任务处理器
60    pub fn register<F: JobHandlerFactory + 'static>(&self, factory: F) {
61        let mut handlers = self.handlers.write();
62        handlers.insert(factory.name().to_string(), Box::new(factory));
63    }
64
65    /// 启动 Worker
66    pub async fn start(&self) {
67        {
68            let mut running = self.running.write();
69            if *running {
70                return;
71            }
72            *running = true;
73        }
74
75        tracing::info!("Worker 已启动,并发数: {}", self.config.concurrency);
76
77        let queue = Arc::clone(&self.queue);
78        let handlers = Arc::clone(&self.handlers);
79        let semaphore = Arc::clone(&self.semaphore);
80        let running = Arc::clone(&self.running);
81        let poll_interval = self.config.poll_interval;
82
83        tokio::spawn(async move {
84            loop {
85                // 检查是否停止
86                if !*running.read() {
87                    break;
88                }
89
90                // 检查队列是否暂停
91                if let Ok(paused) = queue.is_paused().await {
92                    if paused {
93                        tokio::time::sleep(Duration::from_millis(poll_interval)).await;
94                        continue;
95                    }
96                }
97
98                // 获取信号量
99                let permit = match semaphore.clone().try_acquire_owned() {
100                    Ok(permit) => permit,
101                    Err(_) => {
102                        tokio::time::sleep(Duration::from_millis(100)).await;
103                        continue;
104                    }
105                };
106
107                // 获取下一个任务
108                let job = match queue.get_next_job().await {
109                    Ok(Some(job)) => job,
110                    Ok(None) => {
111                        drop(permit);
112                        tokio::time::sleep(Duration::from_millis(poll_interval)).await;
113                        continue;
114                    }
115                    Err(e) => {
116                        tracing::error!("获取任务失败: {}", e);
117                        drop(permit);
118                        tokio::time::sleep(Duration::from_millis(poll_interval)).await;
119                        continue;
120                    }
121                };
122
123                // 查找处理器
124                let handler = {
125                    let handlers = handlers.read();
126                    handlers.get(&job.name).map(|f| f.create(job.data.clone()))
127                };
128
129                let queue_clone = Arc::clone(&queue);
130                let mut job_clone = job.clone();
131
132                // 异步处理任务
133                tokio::spawn(async move {
134                    let _permit = permit; // 保持 permit 直到任务完成
135
136                    match handler {
137                        Some(handler) => {
138                            tracing::info!("开始处理任务: {} ({})", job_clone.name, job_clone.id);
139
140                            // 设置超时
141                            let timeout_duration = Duration::from_secs(job_clone.options.timeout);
142                            let result =
143                                tokio::time::timeout(timeout_duration, handler.handle()).await;
144
145                            match result {
146                                Ok(Ok(value)) => {
147                                    tracing::info!(
148                                        "任务完成: {} ({})",
149                                        job_clone.name,
150                                        job_clone.id
151                                    );
152                                    handler.on_completed(&value).await;
153                                    let _ =
154                                        queue_clone.complete_job(&mut job_clone, Some(value)).await;
155                                }
156                                Ok(Err(e)) => {
157                                    tracing::error!(
158                                        "任务失败: {} ({}) - {}",
159                                        job_clone.name,
160                                        job_clone.id,
161                                        e
162                                    );
163                                    handler.on_failed(&e).await;
164                                    let _ =
165                                        queue_clone.fail_job(&mut job_clone, &e.to_string()).await;
166                                }
167                                Err(_) => {
168                                    tracing::error!(
169                                        "任务超时: {} ({})",
170                                        job_clone.name,
171                                        job_clone.id
172                                    );
173                                    let _ =
174                                        queue_clone.fail_job(&mut job_clone, "任务执行超时").await;
175                                }
176                            }
177                        }
178                        None => {
179                            tracing::error!("未找到任务处理器: {}", job_clone.name);
180                            let _ = queue_clone
181                                .fail_job(&mut job_clone, "未找到任务处理器")
182                                .await;
183                        }
184                    }
185                });
186            }
187
188            tracing::info!("Worker 已停止");
189        });
190    }
191
192    /// 停止 Worker
193    pub fn stop(&self) {
194        let mut running = self.running.write();
195        *running = false;
196    }
197
198    /// 是否运行中
199    pub fn is_running(&self) -> bool {
200        *self.running.read()
201    }
202}