aster/background/
task_queue.rs1use chrono::{DateTime, Utc};
12use std::collections::HashMap;
13use std::future::Future;
14use std::pin::Pin;
15use std::sync::Arc;
16use tokio::sync::{Mutex, RwLock};
17
18use super::types::{QueueStatus, TaskPriority, TaskStatus, TaskType};
19
20pub type TaskExecutor = Box<
22 dyn FnOnce() -> Pin<Box<dyn Future<Output = Result<serde_json::Value, String>> + Send>>
23 + Send
24 + Sync,
25>;
26
27pub struct QueuedTask {
29 pub id: String,
30 pub task_type: TaskType,
31 pub priority: TaskPriority,
32 pub execute: Option<TaskExecutor>,
33 pub enqueue_time: DateTime<Utc>,
34 pub start_time: Option<DateTime<Utc>>,
35 pub end_time: Option<DateTime<Utc>>,
36 pub metadata: Option<HashMap<String, serde_json::Value>>,
37 pub status: TaskStatus,
38 pub result: Option<serde_json::Value>,
39 pub error: Option<String>,
40}
41
42#[derive(Debug, Clone)]
44pub struct TaskQueueOptions {
45 pub max_concurrent: usize,
46}
47
48impl Default for TaskQueueOptions {
49 fn default() -> Self {
50 Self { max_concurrent: 10 }
51 }
52}
53
54pub type TaskCallback = Arc<dyn Fn(&QueuedTask) + Send + Sync>;
56
57pub struct SimpleTaskQueue {
59 queue: Arc<Mutex<Vec<QueuedTask>>>,
60 running: Arc<RwLock<HashMap<String, QueuedTask>>>,
61 completed: Arc<RwLock<HashMap<String, QueuedTask>>>,
62 failed: Arc<RwLock<HashMap<String, QueuedTask>>>,
63 max_concurrent: usize,
64 on_task_start: Option<TaskCallback>,
65 on_task_complete: Option<TaskCallback>,
66 on_task_failed: Option<TaskCallback>,
67}
68
69impl SimpleTaskQueue {
70 pub fn new(options: TaskQueueOptions) -> Self {
72 Self {
73 queue: Arc::new(Mutex::new(Vec::new())),
74 running: Arc::new(RwLock::new(HashMap::new())),
75 completed: Arc::new(RwLock::new(HashMap::new())),
76 failed: Arc::new(RwLock::new(HashMap::new())),
77 max_concurrent: options.max_concurrent,
78 on_task_start: None,
79 on_task_complete: None,
80 on_task_failed: None,
81 }
82 }
83
84 pub fn set_on_task_start(&mut self, callback: TaskCallback) {
86 self.on_task_start = Some(callback);
87 }
88
89 pub fn set_on_task_complete(&mut self, callback: TaskCallback) {
91 self.on_task_complete = Some(callback);
92 }
93
94 pub fn set_on_task_failed(&mut self, callback: TaskCallback) {
96 self.on_task_failed = Some(callback);
97 }
98
99 pub async fn enqueue(&self, mut task: QueuedTask) -> String {
101 task.status = TaskStatus::Pending;
102 task.enqueue_time = Utc::now();
103 let task_id = task.id.clone();
104
105 let mut queue = self.queue.lock().await;
106
107 let insert_index = queue
109 .iter()
110 .position(|t| t.priority.order() > task.priority.order())
111 .unwrap_or(queue.len());
112
113 queue.insert(insert_index, task);
114 drop(queue);
115
116 self.process_next().await;
118
119 task_id
120 }
121
122 async fn process_next(&self) {
124 let running_count = self.running.read().await.len();
125 if running_count >= self.max_concurrent {
126 return;
127 }
128
129 let mut queue = self.queue.lock().await;
130 if queue.is_empty() {
131 return;
132 }
133
134 let mut task = queue.remove(0);
135 drop(queue);
136
137 task.status = TaskStatus::Running;
139 task.start_time = Some(Utc::now());
140 let task_id = task.id.clone();
141
142 if let Some(ref callback) = self.on_task_start {
144 callback(&task);
145 }
146
147 let executor = task.execute.take();
149 self.running.write().await.insert(task_id.clone(), task);
150
151 if let Some(exec) = executor {
153 let running = Arc::clone(&self.running);
154 let completed = Arc::clone(&self.completed);
155 let failed = Arc::clone(&self.failed);
156 let on_complete = self.on_task_complete.clone();
157 let on_failed = self.on_task_failed.clone();
158
159 tokio::spawn(async move {
160 let result = exec().await;
161
162 if let Some(mut task) = running.write().await.remove(&task_id) {
163 task.end_time = Some(Utc::now());
164
165 match result {
166 Ok(value) => {
167 task.result = Some(value);
168 task.status = TaskStatus::Completed;
169 if let Some(cb) = on_complete {
170 cb(&task);
171 }
172 completed.write().await.insert(task_id, task);
173 }
174 Err(e) => {
175 task.error = Some(e);
176 task.status = TaskStatus::Failed;
177 if let Some(cb) = on_failed {
178 cb(&task);
179 }
180 failed.write().await.insert(task_id, task);
181 }
182 }
183 }
184 });
185 }
186 }
187
188 pub async fn get_task(&self, task_id: &str) -> Option<TaskStatus> {
190 if self.queue.lock().await.iter().any(|t| t.id == task_id) {
192 return Some(TaskStatus::Pending);
193 }
194 if self.running.read().await.contains_key(task_id) {
196 return Some(TaskStatus::Running);
197 }
198 if self.completed.read().await.contains_key(task_id) {
200 return Some(TaskStatus::Completed);
201 }
202 if self.failed.read().await.contains_key(task_id) {
204 return Some(TaskStatus::Failed);
205 }
206 None
207 }
208
209 pub async fn get_status(&self) -> QueueStatus {
211 let queued = self.queue.lock().await.len();
212 let running = self.running.read().await.len();
213 let completed = self.completed.read().await.len();
214 let failed = self.failed.read().await.len();
215
216 QueueStatus {
217 queued,
218 running,
219 completed,
220 failed,
221 capacity: self.max_concurrent,
222 available: self.max_concurrent.saturating_sub(running),
223 }
224 }
225
226 pub async fn cancel(&self, task_id: &str) -> bool {
228 let mut queue = self.queue.lock().await;
229 if let Some(pos) = queue.iter().position(|t| t.id == task_id) {
230 let mut task = queue.remove(pos);
231 task.status = TaskStatus::Cancelled;
232 return true;
233 }
234 false
235 }
236
237 pub async fn clear(&self) -> usize {
239 let mut queue = self.queue.lock().await;
240 let count = queue.len();
241 queue.clear();
242 count
243 }
244
245 pub async fn cleanup_completed(&self) -> usize {
247 let mut completed = self.completed.write().await;
248 let count = completed.len();
249 completed.clear();
250 count
251 }
252
253 pub async fn cleanup_failed(&self) -> usize {
255 let mut failed = self.failed.write().await;
256 let count = failed.len();
257 failed.clear();
258 count
259 }
260
261 pub async fn get_queued_by_priority(&self) -> HashMap<TaskPriority, usize> {
263 let queue = self.queue.lock().await;
264 let mut counts = HashMap::new();
265 counts.insert(TaskPriority::High, 0);
266 counts.insert(TaskPriority::Normal, 0);
267 counts.insert(TaskPriority::Low, 0);
268
269 for task in queue.iter() {
270 *counts.entry(task.priority).or_insert(0) += 1;
271 }
272 counts
273 }
274}