Skip to main content

astrid_tools/
task.rs

1//! Task tool — spawns a sub-agent to handle a scoped task autonomously.
2
3use crate::subagent_spawner::SubAgentRequest;
4use crate::{BuiltinTool, ToolContext, ToolError, ToolResult};
5use serde_json::Value;
6use std::time::Duration;
7
8/// Maximum allowed timeout for sub-agents (50 minutes).
9const MAX_TIMEOUT_SECS: u64 = 3000;
10
11/// Tool for spawning sub-agent tasks.
12pub struct TaskTool;
13
14#[async_trait::async_trait]
15impl BuiltinTool for TaskTool {
16    fn name(&self) -> &'static str {
17        "task"
18    }
19
20    fn description(&self) -> &'static str {
21        "Spawns a sub-agent to handle a complex, multi-step task autonomously. \
22         The sub-agent works within inherited capability bounds and returns a result."
23    }
24
25    fn input_schema(&self) -> Value {
26        serde_json::json!({
27            "type": "object",
28            "properties": {
29                "description": {
30                    "type": "string",
31                    "description": "A short description of the task (3-5 words)"
32                },
33                "prompt": {
34                    "type": "string",
35                    "description": "Detailed instructions for the sub-agent"
36                },
37                "timeout_secs": {
38                    "type": "integer",
39                    "description": "Optional timeout in seconds (default: 300)"
40                }
41            },
42            "required": ["description", "prompt"]
43        })
44    }
45
46    async fn execute(&self, args: Value, ctx: &ToolContext) -> ToolResult {
47        let description = args
48            .get("description")
49            .and_then(|v| v.as_str())
50            .ok_or_else(|| ToolError::InvalidArguments("missing 'description'".into()))?
51            .to_string();
52
53        let prompt = args
54            .get("prompt")
55            .and_then(|v| v.as_str())
56            .ok_or_else(|| ToolError::InvalidArguments("missing 'prompt'".into()))?
57            .to_string();
58
59        let timeout = args
60            .get("timeout_secs")
61            .and_then(serde_json::Value::as_u64)
62            .map(|s| Duration::from_secs(s.min(MAX_TIMEOUT_SECS)));
63
64        let spawner = ctx.subagent_spawner().await.ok_or_else(|| {
65            ToolError::ExecutionFailed("Sub-agent spawning is not available in this context".into())
66        })?;
67
68        let request = SubAgentRequest {
69            description,
70            prompt,
71            timeout,
72        };
73
74        match spawner.spawn(request).await {
75            Ok(result) => {
76                if result.success {
77                    Ok(result.output)
78                } else {
79                    let error_msg = result
80                        .error
81                        .unwrap_or_else(|| "sub-agent failed".to_string());
82                    Err(ToolError::ExecutionFailed(format!(
83                        "Sub-agent failed: {error_msg}"
84                    )))
85                }
86            },
87            Err(e) => Err(ToolError::ExecutionFailed(format!(
88                "Failed to spawn sub-agent: {e}"
89            ))),
90        }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97    use crate::subagent_spawner::{SubAgentResult, SubAgentSpawner};
98    use std::sync::Arc;
99
100    #[tokio::test]
101    async fn test_task_without_spawner_returns_error() {
102        let ctx = ToolContext::new(std::env::temp_dir(), None);
103        let result = TaskTool
104            .execute(
105                serde_json::json!({
106                    "description": "test",
107                    "prompt": "do something"
108                }),
109                &ctx,
110            )
111            .await;
112
113        assert!(result.is_err());
114        assert!(
115            result
116                .unwrap_err()
117                .to_string()
118                .contains("not available in this context")
119        );
120    }
121
122    #[tokio::test]
123    async fn test_task_missing_description() {
124        let ctx = ToolContext::new(std::env::temp_dir(), None);
125        let result = TaskTool
126            .execute(serde_json::json!({"prompt": "do something"}), &ctx)
127            .await;
128
129        assert!(result.is_err());
130        assert!(
131            result
132                .unwrap_err()
133                .to_string()
134                .contains("missing 'description'")
135        );
136    }
137
138    #[tokio::test]
139    async fn test_task_missing_prompt() {
140        let ctx = ToolContext::new(std::env::temp_dir(), None);
141        let result = TaskTool
142            .execute(serde_json::json!({"description": "test"}), &ctx)
143            .await;
144
145        assert!(result.is_err());
146        assert!(result.unwrap_err().to_string().contains("missing 'prompt'"));
147    }
148
149    struct MockSpawner {
150        response: SubAgentResult,
151    }
152
153    #[async_trait::async_trait]
154    impl SubAgentSpawner for MockSpawner {
155        async fn spawn(&self, _request: SubAgentRequest) -> Result<SubAgentResult, String> {
156            Ok(self.response.clone())
157        }
158    }
159
160    #[tokio::test]
161    async fn test_task_with_mock_spawner_success() {
162        let ctx = ToolContext::new(std::env::temp_dir(), None);
163        let spawner = Arc::new(MockSpawner {
164            response: SubAgentResult {
165                success: true,
166                output: "Task completed successfully".into(),
167                duration_ms: 1000,
168                tool_calls: 3,
169                error: None,
170            },
171        });
172        ctx.set_subagent_spawner(Some(spawner)).await;
173
174        let result = TaskTool
175            .execute(
176                serde_json::json!({
177                    "description": "test task",
178                    "prompt": "do the thing"
179                }),
180                &ctx,
181            )
182            .await;
183
184        assert!(result.is_ok());
185        assert_eq!(result.unwrap(), "Task completed successfully");
186    }
187
188    #[tokio::test]
189    async fn test_task_with_mock_spawner_failure() {
190        let ctx = ToolContext::new(std::env::temp_dir(), None);
191        let spawner = Arc::new(MockSpawner {
192            response: SubAgentResult {
193                success: false,
194                output: String::new(),
195                duration_ms: 500,
196                tool_calls: 1,
197                error: Some("ran out of budget".into()),
198            },
199        });
200        ctx.set_subagent_spawner(Some(spawner)).await;
201
202        let result = TaskTool
203            .execute(
204                serde_json::json!({
205                    "description": "failing task",
206                    "prompt": "do something expensive"
207                }),
208                &ctx,
209            )
210            .await;
211
212        assert!(result.is_err());
213        assert!(
214            result
215                .unwrap_err()
216                .to_string()
217                .contains("ran out of budget")
218        );
219    }
220
221    #[tokio::test]
222    async fn test_task_timeout_clamped_to_max() {
223        let ctx = ToolContext::new(std::env::temp_dir(), None);
224        let spawner = Arc::new(MockSpawner {
225            response: SubAgentResult {
226                success: true,
227                output: "done".into(),
228                duration_ms: 100,
229                tool_calls: 0,
230                error: None,
231            },
232        });
233        ctx.set_subagent_spawner(Some(spawner)).await;
234
235        // Pass an absurdly large timeout — should be clamped to MAX_TIMEOUT_SECS
236        let result = TaskTool
237            .execute(
238                serde_json::json!({
239                    "description": "test",
240                    "prompt": "do something",
241                    "timeout_secs": 999_999_999_u64
242                }),
243                &ctx,
244            )
245            .await;
246
247        // The tool should succeed (clamped timeout doesn't affect mock spawner)
248        assert!(result.is_ok());
249    }
250
251    struct ErrorSpawner;
252
253    #[async_trait::async_trait]
254    impl SubAgentSpawner for ErrorSpawner {
255        async fn spawn(&self, _request: SubAgentRequest) -> Result<SubAgentResult, String> {
256            Err("maximum concurrent subagents reached".into())
257        }
258    }
259
260    #[tokio::test]
261    async fn test_task_spawn_error() {
262        let ctx = ToolContext::new(std::env::temp_dir(), None);
263        ctx.set_subagent_spawner(Some(Arc::new(ErrorSpawner))).await;
264
265        let result = TaskTool
266            .execute(
267                serde_json::json!({
268                    "description": "test",
269                    "prompt": "do something"
270                }),
271                &ctx,
272            )
273            .await;
274
275        assert!(result.is_err());
276        assert!(
277            result
278                .unwrap_err()
279                .to_string()
280                .contains("maximum concurrent subagents reached")
281        );
282    }
283}