Skip to main content

argentor_builtins/
task_status.rs

1use crate::agent_delegate::TaskQueueHandle;
2use argentor_core::{ArgentorResult, ToolCall, ToolResult};
3use argentor_skills::skill::{Skill, SkillDescriptor};
4use async_trait::async_trait;
5use std::sync::Arc;
6
7/// Skill for querying task status from the orchestrator's task queue.
8pub struct TaskStatusSkill {
9    descriptor: SkillDescriptor,
10    queue: Arc<dyn TaskQueueHandle>,
11}
12
13impl TaskStatusSkill {
14    /// Create a new task status skill backed by the given queue handle.
15    pub fn new(queue: Arc<dyn TaskQueueHandle>) -> Self {
16        Self {
17            descriptor: SkillDescriptor {
18                name: "task_status".to_string(),
19                description: "Query the status of orchestration tasks. Use action 'query' with \
20                    a task_id to check a specific task, 'list' to see all tasks, or 'summary' \
21                    for aggregate counts."
22                    .to_string(),
23                parameters_schema: serde_json::json!({
24                    "type": "object",
25                    "properties": {
26                        "action": {
27                            "type": "string",
28                            "enum": ["query", "list", "summary"],
29                            "description": "The query to perform"
30                        },
31                        "task_id": {
32                            "type": "string",
33                            "description": "Task ID (required for 'query' action)"
34                        }
35                    },
36                    "required": ["action"]
37                }),
38                required_capabilities: vec![],
39                requires_approval: false,
40            },
41            queue,
42        }
43    }
44}
45
46#[async_trait]
47impl Skill for TaskStatusSkill {
48    fn descriptor(&self) -> &SkillDescriptor {
49        &self.descriptor
50    }
51
52    async fn execute(&self, call: ToolCall) -> ArgentorResult<ToolResult> {
53        let action = call.arguments["action"].as_str().unwrap_or("").to_string();
54
55        match action.as_str() {
56            "query" => {
57                let task_id = call.arguments["task_id"].as_str().unwrap_or("").to_string();
58                if task_id.is_empty() {
59                    return Ok(ToolResult::error(
60                        &call.id,
61                        "task_id is required for query action",
62                    ));
63                }
64
65                match self.queue.get_task_info(&task_id).await? {
66                    Some(info) => Ok(ToolResult::success(
67                        &call.id,
68                        serde_json::to_string(&info).unwrap_or_else(|_| "{}".to_string()),
69                    )),
70                    None => Ok(ToolResult::success(
71                        &call.id,
72                        serde_json::json!({
73                            "found": false,
74                            "task_id": task_id
75                        })
76                        .to_string(),
77                    )),
78                }
79            }
80            "list" => {
81                let tasks = self.queue.list_tasks().await?;
82                Ok(ToolResult::success(
83                    &call.id,
84                    serde_json::json!({
85                        "count": tasks.len(),
86                        "tasks": tasks
87                    })
88                    .to_string(),
89                ))
90            }
91            "summary" => {
92                let summary = self.queue.task_summary().await?;
93                Ok(ToolResult::success(
94                    &call.id,
95                    serde_json::to_string(&summary).unwrap_or_else(|_| "{}".to_string()),
96                ))
97            }
98            _ => Ok(ToolResult::error(
99                &call.id,
100                "Invalid action. Use 'query', 'list', or 'summary'",
101            )),
102        }
103    }
104}
105
106#[cfg(test)]
107#[allow(clippy::unwrap_used, clippy::expect_used)]
108mod tests {
109    use super::*;
110    use crate::agent_delegate::{TaskInfo, TaskSummary};
111    use std::sync::atomic::{AtomicUsize, Ordering};
112    use tokio::sync::RwLock;
113
114    struct MockQueue {
115        tasks: RwLock<Vec<TaskInfo>>,
116        counter: AtomicUsize,
117    }
118
119    impl MockQueue {
120        fn with_tasks(tasks: Vec<TaskInfo>) -> Self {
121            Self {
122                tasks: RwLock::new(tasks),
123                counter: AtomicUsize::new(100),
124            }
125        }
126    }
127
128    #[async_trait]
129    impl TaskQueueHandle for MockQueue {
130        async fn add_task(
131            &self,
132            description: String,
133            role: String,
134            _deps: Vec<String>,
135        ) -> ArgentorResult<String> {
136            let id = format!("t-{}", self.counter.fetch_add(1, Ordering::SeqCst));
137            self.tasks.write().await.push(TaskInfo {
138                id: id.clone(),
139                description,
140                role,
141                status: "pending".to_string(),
142            });
143            Ok(id)
144        }
145
146        async fn get_task_info(&self, task_id: &str) -> ArgentorResult<Option<TaskInfo>> {
147            Ok(self
148                .tasks
149                .read()
150                .await
151                .iter()
152                .find(|t| t.id == task_id)
153                .cloned())
154        }
155
156        async fn list_tasks(&self) -> ArgentorResult<Vec<TaskInfo>> {
157            Ok(self.tasks.read().await.clone())
158        }
159
160        async fn task_summary(&self) -> ArgentorResult<TaskSummary> {
161            let tasks = self.tasks.read().await;
162            Ok(TaskSummary {
163                total: tasks.len(),
164                pending: tasks.iter().filter(|t| t.status == "pending").count(),
165                running: tasks.iter().filter(|t| t.status == "running").count(),
166                completed: tasks.iter().filter(|t| t.status == "completed").count(),
167                failed: tasks.iter().filter(|t| t.status == "failed").count(),
168                needs_review: tasks
169                    .iter()
170                    .filter(|t| t.status == "needs_human_review")
171                    .count(),
172            })
173        }
174    }
175
176    fn sample_tasks() -> Vec<TaskInfo> {
177        vec![
178            TaskInfo {
179                id: "task-1".into(),
180                description: "Spec".into(),
181                role: "spec".into(),
182                status: "completed".into(),
183            },
184            TaskInfo {
185                id: "task-2".into(),
186                description: "Code".into(),
187                role: "coder".into(),
188                status: "running".into(),
189            },
190            TaskInfo {
191                id: "task-3".into(),
192                description: "Test".into(),
193                role: "tester".into(),
194                status: "pending".into(),
195            },
196        ]
197    }
198
199    #[tokio::test]
200    async fn test_query_existing_task() {
201        let queue = Arc::new(MockQueue::with_tasks(sample_tasks()));
202        let skill = TaskStatusSkill::new(queue);
203        let call = ToolCall {
204            id: "t1".into(),
205            name: "task_status".into(),
206            arguments: serde_json::json!({ "action": "query", "task_id": "task-2" }),
207        };
208        let result = skill.execute(call).await.unwrap();
209        assert!(!result.is_error);
210        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
211        assert_eq!(parsed["status"], "running");
212        assert_eq!(parsed["role"], "coder");
213    }
214
215    #[tokio::test]
216    async fn test_query_nonexistent_task() {
217        let queue = Arc::new(MockQueue::with_tasks(sample_tasks()));
218        let skill = TaskStatusSkill::new(queue);
219        let call = ToolCall {
220            id: "t2".into(),
221            name: "task_status".into(),
222            arguments: serde_json::json!({ "action": "query", "task_id": "nope" }),
223        };
224        let result = skill.execute(call).await.unwrap();
225        assert!(!result.is_error);
226        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
227        assert_eq!(parsed["found"], false);
228    }
229
230    #[tokio::test]
231    async fn test_list_tasks() {
232        let queue = Arc::new(MockQueue::with_tasks(sample_tasks()));
233        let skill = TaskStatusSkill::new(queue);
234        let call = ToolCall {
235            id: "t3".into(),
236            name: "task_status".into(),
237            arguments: serde_json::json!({ "action": "list" }),
238        };
239        let result = skill.execute(call).await.unwrap();
240        assert!(!result.is_error);
241        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
242        assert_eq!(parsed["count"], 3);
243    }
244
245    #[tokio::test]
246    async fn test_summary() {
247        let queue = Arc::new(MockQueue::with_tasks(sample_tasks()));
248        let skill = TaskStatusSkill::new(queue);
249        let call = ToolCall {
250            id: "t4".into(),
251            name: "task_status".into(),
252            arguments: serde_json::json!({ "action": "summary" }),
253        };
254        let result = skill.execute(call).await.unwrap();
255        assert!(!result.is_error);
256        let parsed: serde_json::Value = serde_json::from_str(&result.content).unwrap();
257        assert_eq!(parsed["total"], 3);
258        assert_eq!(parsed["pending"], 1);
259        assert_eq!(parsed["running"], 1);
260        assert_eq!(parsed["completed"], 1);
261    }
262
263    #[tokio::test]
264    async fn test_query_missing_task_id_error() {
265        let queue = Arc::new(MockQueue::with_tasks(vec![]));
266        let skill = TaskStatusSkill::new(queue);
267        let call = ToolCall {
268            id: "t5".into(),
269            name: "task_status".into(),
270            arguments: serde_json::json!({ "action": "query" }),
271        };
272        let result = skill.execute(call).await.unwrap();
273        assert!(result.is_error);
274    }
275}