1use crate::agent_delegate::TaskQueueHandle;
2use argentor_core::{ArgentorResult, ToolCall, ToolResult};
3use argentor_skills::skill::{Skill, SkillDescriptor};
4use async_trait::async_trait;
5use std::sync::Arc;
6
7pub struct TaskStatusSkill {
9 descriptor: SkillDescriptor,
10 queue: Arc<dyn TaskQueueHandle>,
11}
12
13impl TaskStatusSkill {
14 pub fn new(queue: Arc<dyn TaskQueueHandle>) -> Self {
16 Self {
17 descriptor: SkillDescriptor {
18 name: "task_status".to_string(),
19 description: "Query the status of orchestration tasks. Use action 'query' with \
20 a task_id to check a specific task, 'list' to see all tasks, or 'summary' \
21 for aggregate counts."
22 .to_string(),
23 parameters_schema: serde_json::json!({
24 "type": "object",
25 "properties": {
26 "action": {
27 "type": "string",
28 "enum": ["query", "list", "summary"],
29 "description": "The query to perform"
30 },
31 "task_id": {
32 "type": "string",
33 "description": "Task ID (required for 'query' action)"
34 }
35 },
36 "required": ["action"]
37 }),
38 required_capabilities: vec![],
39 requires_approval: false,
40 },
41 queue,
42 }
43 }
44}
45
46#[async_trait]
47impl Skill for TaskStatusSkill {
48 fn descriptor(&self) -> &SkillDescriptor {
49 &self.descriptor
50 }
51
52 async fn execute(&self, call: ToolCall) -> ArgentorResult<ToolResult> {
53 let action = call.arguments["action"].as_str().unwrap_or("").to_string();
54
55 match action.as_str() {
56 "query" => {
57 let task_id = call.arguments["task_id"].as_str().unwrap_or("").to_string();
58 if task_id.is_empty() {
59 return Ok(ToolResult::error(
60 &call.id,
61 "task_id is required for query action",
62 ));
63 }
64
65 match self.queue.get_task_info(&task_id).await? {
66 Some(info) => Ok(ToolResult::success(
67 &call.id,
68 serde_json::to_string(&info).unwrap_or_else(|_| "{}".to_string()),
69 )),
70 None => Ok(ToolResult::success(
71 &call.id,
72 serde_json::json!({
73 "found": false,
74 "task_id": task_id
75 })
76 .to_string(),
77 )),
78 }
79 }
80 "list" => {
81 let tasks = self.queue.list_tasks().await?;
82 Ok(ToolResult::success(
83 &call.id,
84 serde_json::json!({
85 "count": tasks.len(),
86 "tasks": tasks
87 })
88 .to_string(),
89 ))
90 }
91 "summary" => {
92 let summary = self.queue.task_summary().await?;
93 Ok(ToolResult::success(
94 &call.id,
95 serde_json::to_string(&summary).unwrap_or_else(|_| "{}".to_string()),
96 ))
97 }
98 _ => Ok(ToolResult::error(
99 &call.id,
100 "Invalid action. Use 'query', 'list', or 'summary'",
101 )),
102 }
103 }
104}
105
106#[cfg(test)]
107#[allow(clippy::unwrap_used, clippy::expect_used)]
108mod tests {
109 use super::*;
110 use crate::agent_delegate::{TaskInfo, TaskSummary};
111 use std::sync::atomic::{AtomicUsize, Ordering};
112 use tokio::sync::RwLock;
113
114 struct MockQueue {
115 tasks: RwLock<Vec<TaskInfo>>,
116 counter: AtomicUsize,
117 }
118
119 impl MockQueue {
120 fn with_tasks(tasks: Vec<TaskInfo>) -> Self {
121 Self {
122 tasks: RwLock::new(tasks),
123 counter: AtomicUsize::new(100),
124 }
125 }
126 }
127
128 #[async_trait]
129 impl TaskQueueHandle for MockQueue {
130 async fn add_task(
131 &self,
132 description: String,
133 role: String,
134 _deps: Vec<String>,
135 ) -> ArgentorResult<String> {
136 let id = format!("t-{}", self.counter.fetch_add(1, Ordering::SeqCst));
137 self.tasks.write().await.push(TaskInfo {
138 id: id.clone(),
139 description,
140 role,
141 status: "pending".to_string(),
142 });
143 Ok(id)
144 }
145
146 async fn get_task_info(&self, task_id: &str) -> ArgentorResult<Option<TaskInfo>> {
147 Ok(self
148 .tasks
149 .read()
150 .await
151 .iter()
152 .find(|t| t.id == task_id)
153 .cloned())
154 }
155
156 async fn list_tasks(&self) -> ArgentorResult<Vec<TaskInfo>> {
157 Ok(self.tasks.read().await.clone())
158 }
159
160 async fn task_summary(&self) -> ArgentorResult<TaskSummary> {
161 let tasks = self.tasks.read().await;
162 Ok(TaskSummary {
163 total: tasks.len(),
164 pending: tasks.iter().filter(|t| t.status == "pending").count(),
165 running: tasks.iter().filter(|t| t.status == "running").count(),
166 completed: tasks.iter().filter(|t| t.status == "completed").count(),
167 failed: tasks.iter().filter(|t| t.status == "failed").count(),
168 needs_review: tasks
169 .iter()
170 .filter(|t| t.status == "needs_human_review")
171 .count(),
172 })
173 }
174 }
175
176 fn sample_tasks() -> Vec<TaskInfo> {
177 vec![
178 TaskInfo {
179 id: "task-1".into(),
180 description: "Spec".into(),
181 role: "spec".into(),
182 status: "completed".into(),
183 },
184 TaskInfo {
185 id: "task-2".into(),
186 description: "Code".into(),
187 role: "coder".into(),
188 status: "running".into(),
189 },
190 TaskInfo {
191 id: "task-3".into(),
192 description: "Test".into(),
193 role: "tester".into(),
194 status: "pending".into(),
195 },
196 ]
197 }
198
199 #[tokio::test]
200 async fn test_query_existing_task() {
201 let queue = Arc::new(MockQueue::with_tasks(sample_tasks()));
202 let skill = TaskStatusSkill::new(queue);
203 let call = ToolCall {
204 id: "t1".into(),
205 name: "task_status".into(),
206 arguments: serde_json::json!({ "action": "query", "task_id": "task-2" }),
207 };
208 let result = skill.execute(call).await.unwrap();
209 assert!(!result.is_error);
210 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
211 assert_eq!(parsed["status"], "running");
212 assert_eq!(parsed["role"], "coder");
213 }
214
215 #[tokio::test]
216 async fn test_query_nonexistent_task() {
217 let queue = Arc::new(MockQueue::with_tasks(sample_tasks()));
218 let skill = TaskStatusSkill::new(queue);
219 let call = ToolCall {
220 id: "t2".into(),
221 name: "task_status".into(),
222 arguments: serde_json::json!({ "action": "query", "task_id": "nope" }),
223 };
224 let result = skill.execute(call).await.unwrap();
225 assert!(!result.is_error);
226 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
227 assert_eq!(parsed["found"], false);
228 }
229
230 #[tokio::test]
231 async fn test_list_tasks() {
232 let queue = Arc::new(MockQueue::with_tasks(sample_tasks()));
233 let skill = TaskStatusSkill::new(queue);
234 let call = ToolCall {
235 id: "t3".into(),
236 name: "task_status".into(),
237 arguments: serde_json::json!({ "action": "list" }),
238 };
239 let result = skill.execute(call).await.unwrap();
240 assert!(!result.is_error);
241 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
242 assert_eq!(parsed["count"], 3);
243 }
244
245 #[tokio::test]
246 async fn test_summary() {
247 let queue = Arc::new(MockQueue::with_tasks(sample_tasks()));
248 let skill = TaskStatusSkill::new(queue);
249 let call = ToolCall {
250 id: "t4".into(),
251 name: "task_status".into(),
252 arguments: serde_json::json!({ "action": "summary" }),
253 };
254 let result = skill.execute(call).await.unwrap();
255 assert!(!result.is_error);
256 let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
257 assert_eq!(parsed["total"], 3);
258 assert_eq!(parsed["pending"], 1);
259 assert_eq!(parsed["running"], 1);
260 assert_eq!(parsed["completed"], 1);
261 }
262
263 #[tokio::test]
264 async fn test_query_missing_task_id_error() {
265 let queue = Arc::new(MockQueue::with_tasks(vec![]));
266 let skill = TaskStatusSkill::new(queue);
267 let call = ToolCall {
268 id: "t5".into(),
269 name: "task_status".into(),
270 arguments: serde_json::json!({ "action": "query" }),
271 };
272 let result = skill.execute(call).await.unwrap();
273 assert!(result.is_error);
274 }
275}