Skip to main content

aster/background/
task_queue.rs

1//! 简单任务队列实现
2//!
3//! 支持优先级、并发控制和状态管理
4//!
5//! # 功能
6//! - FIFO 队列
7//! - 优先级支持 (high/normal/low)
8//! - 并发控制
9//! - 状态管理
10
11use chrono::{DateTime, Utc};
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::{Mutex, RwLock};
17
18use super::types::{QueueStatus, TaskPriority, TaskStatus, TaskType};
19
20/// 任务执行函数类型
21pub type TaskExecutor = Box<
22    dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send>>
23        + Send
24        + Sync,
25>;
26
27/// 队列中的任务
28pub struct QueuedTask {
29    pub id: String,
30    pub task_type: TaskType,
31    pub priority: TaskPriority,
32    pub execute: Option<TaskExecutor>,
33    pub enqueue_time: DateTime<Utc>,
34    pub start_time: Option<DateTime<Utc>>,
35    pub end_time: Option<DateTime<Utc>>,
36    pub metadata: Option<HashMap<String, serde_json::Value>>,
37    pub status: TaskStatus,
38    pub result: Option<serde_json::Value>,
39    pub error: Option<String>,
40}
41
42/// 任务队列配置
43#[derive(Debug, Clone)]
44pub struct TaskQueueOptions {
45    pub max_concurrent: usize,
46}
47
48impl Default for TaskQueueOptions {
49    fn default() -> Self {
50        Self { max_concurrent: 10 }
51    }
52}
53
54/// 任务队列回调
55pub type TaskCallback = Arc<dyn Fn(&QueuedTask) + Send + Sync>;
56
57/// 简单任务队列
58pub struct SimpleTaskQueue {
59    queue: Arc<Mutex<Vec<QueuedTask>>>,
60    running: Arc<RwLock<HashMap<String, QueuedTask>>>,
61    completed: Arc<RwLock<HashMap<String, QueuedTask>>>,
62    failed: Arc<RwLock<HashMap<String, QueuedTask>>>,
63    max_concurrent: usize,
64    on_task_start: Option<TaskCallback>,
65    on_task_complete: Option<TaskCallback>,
66    on_task_failed: Option<TaskCallback>,
67}
68
69impl SimpleTaskQueue {
70    /// 创建新的任务队列
71    pub fn new(options: TaskQueueOptions) -> Self {
72        Self {
73            queue: Arc::new(Mutex::new(Vec::new())),
74            running: Arc::new(RwLock::new(HashMap::new())),
75            completed: Arc::new(RwLock::new(HashMap::new())),
76            failed: Arc::new(RwLock::new(HashMap::new())),
77            max_concurrent: options.max_concurrent,
78            on_task_start: None,
79            on_task_complete: None,
80            on_task_failed: None,
81        }
82    }
83
84    /// 设置任务开始回调
85    pub fn set_on_task_start(&mut self, callback: TaskCallback) {
86        self.on_task_start = Some(callback);
87    }
88
89    /// 设置任务完成回调
90    pub fn set_on_task_complete(&mut self, callback: TaskCallback) {
91        self.on_task_complete = Some(callback);
92    }
93
94    /// 设置任务失败回调
95    pub fn set_on_task_failed(&mut self, callback: TaskCallback) {
96        self.on_task_failed = Some(callback);
97    }
98
99    /// 添加任务到队列
100    pub async fn enqueue(&self, mut task: QueuedTask) -> String {
101        task.status = TaskStatus::Pending;
102        task.enqueue_time = Utc::now();
103        let task_id = task.id.clone();
104
105        let mut queue = self.queue.lock().await;
106
107        // 按优先级插入
108        let insert_index = queue
109            .iter()
110            .position(|t| t.priority.order() > task.priority.order())
111            .unwrap_or(queue.len());
112
113        queue.insert(insert_index, task);
114        drop(queue);
115
116        // 尝试处理下一个任务
117        self.process_next().await;
118
119        task_id
120    }
121
122    /// 处理队列中的下一个任务
123    async fn process_next(&self) {
124        let running_count = self.running.read().await.len();
125        if running_count >= self.max_concurrent {
126            return;
127        }
128
129        let mut queue = self.queue.lock().await;
130        if queue.is_empty() {
131            return;
132        }
133
134        let mut task = queue.remove(0);
135        drop(queue);
136
137        // 更新任务状态
138        task.status = TaskStatus::Running;
139        task.start_time = Some(Utc::now());
140        let task_id = task.id.clone();
141
142        // 触发回调
143        if let Some(ref callback) = self.on_task_start {
144            callback(&task);
145        }
146
147        // 取出执行器
148        let executor = task.execute.take();
149        self.running.write().await.insert(task_id.clone(), task);
150
151        // 执行任务
152        if let Some(exec) = executor {
153            let running = Arc::clone(&self.running);
154            let completed = Arc::clone(&self.completed);
155            let failed = Arc::clone(&self.failed);
156            let on_complete = self.on_task_complete.clone();
157            let on_failed = self.on_task_failed.clone();
158
159            tokio::spawn(async move {
160                let result = exec().await;
161
162                if let Some(mut task) = running.write().await.remove(&task_id) {
163                    task.end_time = Some(Utc::now());
164
165                    match result {
166                        Ok(value) => {
167                            task.result = Some(value);
168                            task.status = TaskStatus::Completed;
169                            if let Some(cb) = on_complete {
170                                cb(&task);
171                            }
172                            completed.write().await.insert(task_id, task);
173                        }
174                        Err(e) => {
175                            task.error = Some(e);
176                            task.status = TaskStatus::Failed;
177                            if let Some(cb) = on_failed {
178                                cb(&task);
179                            }
180                            failed.write().await.insert(task_id, task);
181                        }
182                    }
183                }
184            });
185        }
186    }
187
188    /// 获取任务状态
189    pub async fn get_task(&self, task_id: &str) -> Option<TaskStatus> {
190        // 在队列中查找
191        if self.queue.lock().await.iter().any(|t| t.id == task_id) {
192            return Some(TaskStatus::Pending);
193        }
194        // 在运行中查找
195        if self.running.read().await.contains_key(task_id) {
196            return Some(TaskStatus::Running);
197        }
198        // 在已完成中查找
199        if self.completed.read().await.contains_key(task_id) {
200            return Some(TaskStatus::Completed);
201        }
202        // 在失败中查找
203        if self.failed.read().await.contains_key(task_id) {
204            return Some(TaskStatus::Failed);
205        }
206        None
207    }
208
209    /// 获取队列状态统计
210    pub async fn get_status(&self) -> QueueStatus {
211        let queued = self.queue.lock().await.len();
212        let running = self.running.read().await.len();
213        let completed = self.completed.read().await.len();
214        let failed = self.failed.read().await.len();
215
216        QueueStatus {
217            queued,
218            running,
219            completed,
220            failed,
221            capacity: self.max_concurrent,
222            available: self.max_concurrent.saturating_sub(running),
223        }
224    }
225
226    /// 取消队列中的任务
227    pub async fn cancel(&self, task_id: &str) -> bool {
228        let mut queue = self.queue.lock().await;
229        if let Some(pos) = queue.iter().position(|t| t.id == task_id) {
230            let mut task = queue.remove(pos);
231            task.status = TaskStatus::Cancelled;
232            return true;
233        }
234        false
235    }
236
237    /// 清空队列
238    pub async fn clear(&self) -> usize {
239        let mut queue = self.queue.lock().await;
240        let count = queue.len();
241        queue.clear();
242        count
243    }
244
245    /// 清理已完成的任务
246    pub async fn cleanup_completed(&self) -> usize {
247        let mut completed = self.completed.write().await;
248        let count = completed.len();
249        completed.clear();
250        count
251    }
252
253    /// 清理失败的任务
254    pub async fn cleanup_failed(&self) -> usize {
255        let mut failed = self.failed.write().await;
256        let count = failed.len();
257        failed.clear();
258        count
259    }
260
261    /// 获取按优先级分组的队列任务数
262    pub async fn get_queued_by_priority(&self) -> HashMap<TaskPriority, usize> {
263        let queue = self.queue.lock().await;
264        let mut counts = HashMap::new();
265        counts.insert(TaskPriority::High, 0);
266        counts.insert(TaskPriority::Normal, 0);
267        counts.insert(TaskPriority::Low, 0);
268
269        for task in queue.iter() {
270            *counts.entry(task.priority).or_insert(0) += 1;
271        }
272        counts
273    }
274}