Skip to main content

axum_apcore/engine/
tasks.rs

1// Async task management for axum-apcore.
2//
3// Provides async task submission with background execution,
4// status tracking, cancellation, and cleanup.
5
6use std::collections::HashMap;
7use std::sync::{Arc, Mutex};
8
9use apcore::async_task::TaskStatus;
10use apcore::cancel::CancelToken;
11use serde_json::Value;
12
13use crate::config::ApcoreSettings;
14use crate::errors::AxumApcoreError;
15
16/// Manages async task submission, tracking, and cancellation.
17#[derive(Clone)]
18pub struct TaskManager {
19    tasks: Arc<Mutex<HashMap<String, TaskEntry>>>,
20    max_concurrent: usize,
21    max_tasks: usize,
22}
23
24/// Serializable task info returned from list/status queries.
25#[derive(Debug, Clone, serde::Serialize)]
26pub struct TaskInfo {
27    pub task_id: String,
28    pub module_id: String,
29    pub status: String,
30    pub result: Option<Value>,
31    pub error: Option<String>,
32    pub created_at: String,
33}
34
35#[derive(Debug, Clone)]
36struct TaskEntry {
37    pub status: TaskStatus,
38    pub module_id: String,
39    pub result: Option<Value>,
40    pub error: Option<String>,
41    pub cancel_token: CancelToken,
42    pub created_at: chrono::DateTime<chrono::Utc>,
43}
44
45impl TaskEntry {
46    fn to_info(&self, task_id: &str) -> TaskInfo {
47        TaskInfo {
48            task_id: task_id.to_string(),
49            module_id: self.module_id.clone(),
50            status: format!("{:?}", self.status),
51            result: self.result.clone(),
52            error: self.error.clone(),
53            created_at: self.created_at.to_rfc3339(),
54        }
55    }
56}
57
58impl TaskManager {
59    /// Create a new TaskManager from settings.
60    pub fn from_settings(settings: &ApcoreSettings) -> Self {
61        Self {
62            tasks: Arc::new(Mutex::new(HashMap::new())),
63            max_concurrent: settings.task_max_concurrent,
64            max_tasks: settings.task_max_tasks,
65        }
66    }
67
68    /// Submit a new async task. Returns the task ID and a CancelToken.
69    pub fn submit(
70        &self,
71        task_id: String,
72        module_id: String,
73    ) -> Result<(String, CancelToken), AxumApcoreError> {
74        let mut tasks = self.tasks.lock().expect("task lock poisoned");
75
76        if tasks.len() >= self.max_tasks {
77            return Err(AxumApcoreError::Execution(apcore::ModuleError::new(
78                apcore::ErrorCode::GeneralInternalError,
79                format!("Maximum task limit reached ({})", self.max_tasks),
80            )));
81        }
82
83        let active_count = tasks
84            .values()
85            .filter(|t| matches!(t.status, TaskStatus::Running))
86            .count();
87
88        if active_count >= self.max_concurrent {
89            return Err(AxumApcoreError::Execution(apcore::ModuleError::new(
90                apcore::ErrorCode::GeneralInternalError,
91                format!(
92                    "Maximum concurrent task limit reached ({})",
93                    self.max_concurrent
94                ),
95            )));
96        }
97
98        let cancel_token = CancelToken::new();
99        tasks.insert(
100            task_id.clone(),
101            TaskEntry {
102                status: TaskStatus::Running,
103                module_id,
104                result: None,
105                error: None,
106                cancel_token: cancel_token.clone(),
107                created_at: chrono::Utc::now(),
108            },
109        );
110
111        Ok((task_id, cancel_token))
112    }
113
114    /// Get the status of a task.
115    pub fn get_status(&self, task_id: &str) -> Option<TaskStatus> {
116        let tasks = self.tasks.lock().expect("task lock poisoned");
117        tasks.get(task_id).map(|t| t.status)
118    }
119
120    /// Get full info for a task.
121    pub fn get_task_info(&self, task_id: &str) -> Option<TaskInfo> {
122        let tasks = self.tasks.lock().expect("task lock poisoned");
123        tasks.get(task_id).map(|t| t.to_info(task_id))
124    }
125
126    /// Get the result of a completed task.
127    pub fn get_result(&self, task_id: &str) -> Option<Value> {
128        let tasks = self.tasks.lock().expect("task lock poisoned");
129        tasks.get(task_id).and_then(|t| {
130            if matches!(t.status, TaskStatus::Completed) {
131                t.result.clone()
132            } else {
133                None
134            }
135        })
136    }
137
138    /// List tasks, optionally filtered by status.
139    pub fn list_tasks(&self, status_filter: Option<&str>) -> Vec<TaskInfo> {
140        let tasks = self.tasks.lock().expect("task lock poisoned");
141        tasks
142            .iter()
143            .filter(|(_, entry)| {
144                status_filter
145                    .map(|s| format!("{:?}", entry.status).to_lowercase() == s.to_lowercase())
146                    .unwrap_or(true)
147            })
148            .map(|(id, entry)| entry.to_info(id))
149            .collect()
150    }
151
152    /// Mark a task as completed with a result.
153    pub fn complete(&self, task_id: &str, result: Value) {
154        let mut tasks = self.tasks.lock().expect("task lock poisoned");
155        if let Some(entry) = tasks.get_mut(task_id) {
156            entry.status = TaskStatus::Completed;
157            entry.result = Some(result);
158        }
159    }
160
161    /// Mark a task as failed with an error message.
162    pub fn fail(&self, task_id: &str, error: String) {
163        let mut tasks = self.tasks.lock().expect("task lock poisoned");
164        if let Some(entry) = tasks.get_mut(task_id) {
165            entry.status = TaskStatus::Failed;
166            entry.error = Some(error);
167        }
168    }
169
170    /// Cancel a running task.
171    pub fn cancel(&self, task_id: &str) -> bool {
172        let mut tasks = self.tasks.lock().expect("task lock poisoned");
173        if let Some(entry) = tasks.get_mut(task_id) {
174            if matches!(entry.status, TaskStatus::Running) {
175                entry.cancel_token.cancel();
176                entry.status = TaskStatus::Cancelled;
177                return true;
178            }
179        }
180        false
181    }
182
183    /// Remove completed/failed/cancelled tasks older than the cleanup age.
184    pub fn cleanup(&self, max_age_secs: u64) -> usize {
185        let mut tasks = self.tasks.lock().expect("task lock poisoned");
186        let before = tasks.len();
187        let cutoff = chrono::Utc::now() - chrono::Duration::seconds(max_age_secs as i64);
188        tasks.retain(|_, entry| {
189            matches!(entry.status, TaskStatus::Running) || entry.created_at > cutoff
190        });
191        before - tasks.len()
192    }
193
194    /// Count tasks by status.
195    pub fn count(&self) -> (usize, usize, usize, usize) {
196        let tasks = self.tasks.lock().expect("task lock poisoned");
197        let mut running = 0;
198        let mut completed = 0;
199        let mut failed = 0;
200        let mut cancelled = 0;
201        for entry in tasks.values() {
202            match entry.status {
203                TaskStatus::Running => running += 1,
204                TaskStatus::Completed => completed += 1,
205                TaskStatus::Failed => failed += 1,
206                TaskStatus::Cancelled => cancelled += 1,
207                _ => {}
208            }
209        }
210        (running, completed, failed, cancelled)
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    fn make_manager() -> TaskManager {
219        let settings = ApcoreSettings::default();
220        TaskManager::from_settings(&settings)
221    }
222
223    #[test]
224    fn test_submit_and_get_status() {
225        let mgr = make_manager();
226        let (id, _token) = mgr.submit("task-1".into(), "mod.a".into()).unwrap();
227        assert_eq!(id, "task-1");
228        assert!(matches!(
229            mgr.get_status("task-1"),
230            Some(TaskStatus::Running)
231        ));
232    }
233
234    #[test]
235    fn test_complete_task() {
236        let mgr = make_manager();
237        mgr.submit("task-1".into(), "mod.a".into()).unwrap();
238        mgr.complete("task-1", serde_json::json!({"result": 42}));
239        assert!(matches!(
240            mgr.get_status("task-1"),
241            Some(TaskStatus::Completed)
242        ));
243    }
244
245    #[test]
246    fn test_get_result() {
247        let mgr = make_manager();
248        mgr.submit("task-1".into(), "mod.a".into()).unwrap();
249        assert!(mgr.get_result("task-1").is_none()); // Not completed yet
250        mgr.complete("task-1", serde_json::json!({"val": 1}));
251        assert_eq!(
252            mgr.get_result("task-1").unwrap(),
253            serde_json::json!({"val": 1})
254        );
255    }
256
257    #[test]
258    fn test_fail_task() {
259        let mgr = make_manager();
260        mgr.submit("task-1".into(), "mod.a".into()).unwrap();
261        mgr.fail("task-1", "something went wrong".into());
262        assert!(matches!(mgr.get_status("task-1"), Some(TaskStatus::Failed)));
263    }
264
265    #[test]
266    fn test_cancel_task() {
267        let mgr = make_manager();
268        let (_id, token) = mgr.submit("task-1".into(), "mod.a".into()).unwrap();
269        assert!(!token.is_cancelled());
270        assert!(mgr.cancel("task-1"));
271        assert!(token.is_cancelled());
272        assert!(matches!(
273            mgr.get_status("task-1"),
274            Some(TaskStatus::Cancelled)
275        ));
276    }
277
278    #[test]
279    fn test_cancel_completed_task_fails() {
280        let mgr = make_manager();
281        mgr.submit("task-1".into(), "mod.a".into()).unwrap();
282        mgr.complete("task-1", serde_json::json!(null));
283        assert!(!mgr.cancel("task-1"));
284    }
285
286    #[test]
287    fn test_get_status_nonexistent() {
288        let mgr = make_manager();
289        assert!(mgr.get_status("nonexistent").is_none());
290    }
291
292    #[test]
293    fn test_list_tasks() {
294        let mgr = make_manager();
295        mgr.submit("t1".into(), "mod.a".into()).unwrap();
296        mgr.submit("t2".into(), "mod.b".into()).unwrap();
297        mgr.complete("t1", serde_json::json!(null));
298
299        let all = mgr.list_tasks(None);
300        assert_eq!(all.len(), 2);
301
302        let running = mgr.list_tasks(Some("running"));
303        assert_eq!(running.len(), 1);
304        assert_eq!(running[0].task_id, "t2");
305    }
306
307    #[test]
308    fn test_count() {
309        let mgr = make_manager();
310        mgr.submit("t1".into(), "mod.a".into()).unwrap();
311        mgr.submit("t2".into(), "mod.b".into()).unwrap();
312        mgr.complete("t1", serde_json::json!(null));
313        let (running, completed, failed, cancelled) = mgr.count();
314        assert_eq!(running, 1);
315        assert_eq!(completed, 1);
316        assert_eq!(failed, 0);
317        assert_eq!(cancelled, 0);
318    }
319
320    #[test]
321    fn test_cleanup() {
322        let mgr = make_manager();
323        mgr.submit("t1".into(), "mod.a".into()).unwrap();
324        mgr.complete("t1", serde_json::json!(null));
325        // Cleanup with 0 age = remove everything not running
326        let removed = mgr.cleanup(0);
327        assert_eq!(removed, 1);
328    }
329
330    #[test]
331    fn test_get_task_info() {
332        let mgr = make_manager();
333        mgr.submit("t1".into(), "mod.a".into()).unwrap();
334        let info = mgr.get_task_info("t1").unwrap();
335        assert_eq!(info.task_id, "t1");
336        assert_eq!(info.module_id, "mod.a");
337        assert_eq!(info.status, "Running");
338    }
339}