1use async_trait::async_trait;
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use std::collections::{BinaryHeap, HashMap};
9use std::time::Duration;
10
11use crate::types::{Layer2Result, TaskId};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
15pub enum TaskStatus {
16 #[default]
17 Pending,
18 Running,
19 Completed,
20 Failed,
21 Cancelled,
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
26pub enum TaskPriority {
27 Low = 0,
28 #[default]
29 Normal = 1,
30 High = 2,
31 Urgent = 3,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
36pub struct Task {
37 pub id: TaskId,
38 pub name: String,
39 pub description: String,
40 pub status: TaskStatus,
41 pub priority: TaskPriority,
42 pub dependencies: Vec<TaskId>,
43 pub created_at: chrono::DateTime<chrono::Utc>,
44 pub started_at: Option<chrono::DateTime<chrono::Utc>>,
45 pub completed_at: Option<chrono::DateTime<chrono::Utc>>,
46 pub timeout: Option<Duration>,
47 pub retry_count: u32,
48 pub max_retries: u32,
49 pub metadata: HashMap<String, serde_json::Value>,
50}
51
52impl Task {
53 pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
54 Self {
55 id: TaskId::new(),
56 name: name.into(),
57 description: description.into(),
58 status: TaskStatus::Pending,
59 priority: TaskPriority::Normal,
60 dependencies: Vec::new(),
61 created_at: chrono::Utc::now(),
62 started_at: None,
63 completed_at: None,
64 timeout: None,
65 retry_count: 0,
66 max_retries: 3,
67 metadata: HashMap::new(),
68 }
69 }
70
71 pub fn with_priority(mut self, priority: TaskPriority) -> Self {
72 self.priority = priority;
73 self
74 }
75
76 pub fn with_timeout(mut self, timeout: Duration) -> Self {
77 self.timeout = Some(timeout);
78 self
79 }
80
81 pub fn with_dependency(mut self, task_id: TaskId) -> Self {
82 self.dependencies.push(task_id);
83 self
84 }
85
86 pub fn with_metadata(mut self, key: &str, value: serde_json::Value) -> Self {
87 self.metadata.insert(key.to_string(), value);
88 self
89 }
90
91 pub fn can_execute(&self, completed: &HashMap<TaskId, TaskStatus>) -> bool {
93 self.dependencies
94 .iter()
95 .all(|dep_id| completed.get(dep_id) == Some(&TaskStatus::Completed))
96 }
97
98 pub fn duration(&self) -> Option<Duration> {
100 self.started_at.and_then(|start| {
101 self.completed_at
102 .map(|end| Duration::from_secs((end - start).num_seconds() as u64))
103 })
104 }
105}
106
107impl Eq for Task {}
108
109impl PartialEq for Task {
110 fn eq(&self, other: &Self) -> bool {
111 self.id == other.id
112 }
113}
114
115impl Ord for Task {
116 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
117 other
119 .priority
120 .cmp(&self.priority)
121 .then_with(|| other.created_at.cmp(&self.created_at))
122 }
123}
124
125impl PartialOrd for Task {
126 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
127 Some(self.cmp(other))
128 }
129}
130
131#[async_trait]
133pub trait TaskManagerTrait: Send + Sync {
134 fn add(&self, task: Task) -> Layer2Result<TaskId>;
136
137 fn get(&self, id: &TaskId) -> Option<Task>;
139
140 async fn update_status(&self, id: &TaskId, status: TaskStatus) -> Layer2Result<bool>;
142
143 async fn cancel(&self, id: &TaskId) -> Layer2Result<bool>;
145
146 fn next(&self) -> Option<Task>;
148
149 fn count(&self) -> usize;
151
152 fn count_by_status(&self, status: TaskStatus) -> usize;
154
155 fn cleanup_completed(&self) -> usize;
157}
158
159pub struct TaskManager {
161 tasks: RwLock<HashMap<TaskId, Task>>,
162 queue: RwLock<BinaryHeap<Task>>,
163}
164
165impl TaskManager {
166 pub fn new() -> Self {
167 Self {
168 tasks: RwLock::new(HashMap::new()),
169 queue: RwLock::new(BinaryHeap::new()),
170 }
171 }
172}
173
174impl Default for TaskManager {
175 fn default() -> Self {
176 Self::new()
177 }
178}
179
180#[async_trait]
181impl TaskManagerTrait for TaskManager {
182 fn add(&self, task: Task) -> Layer2Result<TaskId> {
183 let id = task.id.clone();
184
185 self.queue.write().push(task.clone());
186 self.tasks.write().insert(id.clone(), task);
187
188 Ok(id)
189 }
190
191 fn get(&self, id: &TaskId) -> Option<Task> {
192 self.tasks.read().get(id).cloned()
193 }
194
195 async fn update_status(&self, id: &TaskId, status: TaskStatus) -> Layer2Result<bool> {
196 let mut tasks = self.tasks.write();
197
198 if let Some(task) = tasks.get_mut(id) {
199 task.status = status;
200
201 if status == TaskStatus::Running {
202 task.started_at = Some(chrono::Utc::now());
203 } else if matches!(status, TaskStatus::Completed | TaskStatus::Failed) {
204 task.completed_at = Some(chrono::Utc::now());
205 }
206
207 Ok(true)
208 } else {
209 Ok(false)
210 }
211 }
212
213 async fn cancel(&self, id: &TaskId) -> Layer2Result<bool> {
214 self.update_status(id, TaskStatus::Cancelled).await
215 }
216
217 fn next(&self) -> Option<Task> {
218 let tasks = self.tasks.read();
219 let completed: HashMap<TaskId, TaskStatus> = tasks
220 .iter()
221 .filter(|(_, t)| t.status == TaskStatus::Completed)
222 .map(|(id, t)| (id.clone(), t.status))
223 .collect();
224
225 self.queue
226 .write()
227 .pop()
228 .filter(|t| t.can_execute(&completed))
229 }
230
231 fn count(&self) -> usize {
232 self.tasks.read().len()
233 }
234
235 fn count_by_status(&self, status: TaskStatus) -> usize {
236 self.tasks
237 .read()
238 .values()
239 .filter(|t| t.status == status)
240 .count()
241 }
242
243 fn cleanup_completed(&self) -> usize {
244 let mut tasks = self.tasks.write();
245 let completed: Vec<TaskId> = tasks
246 .iter()
247 .filter(|(_, t)| t.status == TaskStatus::Completed)
248 .map(|(id, _)| id.clone())
249 .collect();
250
251 let count = completed.len();
252 for id in completed {
253 tasks.remove(&id);
254 }
255
256 let mut queue = self.queue.write();
258 *queue = tasks.values().cloned().collect();
259
260 count
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_task_creation() {
270 let task = Task::new("test", "Test task");
271 assert_eq!(task.status, TaskStatus::Pending);
272 assert_eq!(task.priority, TaskPriority::Normal);
273 }
274
275 #[test]
276 fn test_task_priority() {
277 let task = Task::new("test", "Test").with_priority(TaskPriority::High);
278 assert_eq!(task.priority, TaskPriority::High);
279 }
280
281 #[test]
282 fn test_task_manager() {
283 let manager = TaskManager::new();
284 assert_eq!(manager.count(), 0);
285
286 let task = Task::new("test", "Test task");
287 manager.add(task).unwrap();
288
289 assert_eq!(manager.count(), 1);
290 }
291}