Skip to main content

pawan/tools/
task.rs

1//! Task tool: spawn an in-process subagent with restricted tools.
2//!
3//! This tool runs a child `PawanAgent` with a narrowed `ToolRegistry`, a smaller
4//! context window, and a hard timeout. Subagents are depth-limited (they cannot
5//! spawn other agents).
6
7use super::Tool;
8use crate::agent::backend::LlmBackend;
9use crate::agent::PawanAgent;
10use crate::config::PawanConfig;
11use crate::tools::{bash, batch, edit, file, git, lsp_tool, mise, native, ToolRegistry, ToolTier};
12use crate::{PawanError, Result};
13use async_trait::async_trait;
14use futures::stream::{self, StreamExt};
15use serde::Deserialize;
16use serde_json::{json, Value};
17use std::path::{Path, PathBuf};
18use std::sync::Arc;
19use std::time::{Duration, Instant};
20use tokio::sync::Semaphore;
21use tokio::time::timeout;
22
23use crate::subagent::SubagentHandle;
24
25const DEFAULT_TIMEOUT_SECS: u64 = 300;
26
27const MAX_PARALLEL_SUBAGENTS: usize = 4;
28
29#[derive(Debug, Clone, Deserialize)]
30struct TaskItem {
31    agent: String,
32    assignment: String,
33    #[serde(default)]
34    context: Option<String>,
35    #[serde(default)]
36    description: Option<String>,
37}
38
39#[derive(Debug, Clone, Deserialize)]
40struct TaskArgs {
41    /// Single-task mode: agent type.
42    #[serde(default)]
43    agent: Option<String>,
44    /// Single-task mode: assignment text.
45    #[serde(default)]
46    assignment: Option<String>,
47    /// Parallel mode: one or more subagent jobs (max 8).
48    #[serde(default)]
49    tasks: Option<Vec<TaskItem>>,
50    #[serde(default)]
51    context: Option<String>,
52    #[serde(default)]
53    description: Option<String>,
54    #[serde(default)]
55    model: Option<String>,
56    /// Timeout in seconds (default: 300).
57    #[serde(default)]
58    timeout: Option<u64>,
59}
60
61#[derive(Clone)]
62pub struct TaskTool {
63    workspace_root: PathBuf,
64}
65
66impl TaskTool {
67    pub fn new(workspace_root: PathBuf) -> Self {
68        Self { workspace_root }
69    }
70
71    fn known_agent_types() -> &'static [&'static str] {
72        &[
73            "explore",
74            "plan",
75            "task",
76            "reviewer",
77            "designer",
78            "librarian",
79        ]
80    }
81
82    fn validate_agent_type(agent: &str) -> std::result::Result<(), String> {
83        if Self::known_agent_types().contains(&agent) {
84            Ok(())
85        } else {
86            Err(format!(
87                "unknown agent type '{agent}'. Valid types: {}",
88                Self::known_agent_types().join(", ")
89            ))
90        }
91    }
92
93    fn validate_assignment(assignment: &str) -> std::result::Result<(), String> {
94        if assignment.trim().is_empty() {
95            Err("assignment must be non-empty".to_string())
96        } else {
97            Ok(())
98        }
99    }
100
101    fn system_prompt_for(agent: &str) -> String {
102        match agent {
103            "explore" => "You are a read-only exploration subagent. Use only the allowed read/search tools to gather facts. Do not propose or apply code edits. Return concise findings with file paths and evidence.".to_string(),
104            "plan" => "You are an architecture subagent. Do not modify code. Make design decisions and propose an implementation plan with tradeoffs, invariants, and acceptance criteria.".to_string(),
105            "reviewer" => "You are a code review subagent. Do not modify code. Identify bugs, security issues, and quality concerns. Return a structured review report with severity and recommendations.".to_string(),
106            "designer" => "You are a UI/UX design subagent. If editing tools are available, you may implement UI changes carefully. Prioritize accessibility and consistency.".to_string(),
107            "librarian" => "You are a research subagent. Verify details from authoritative sources and the local codebase. Do not modify code. Return actionable guidance.".to_string(),
108            _ => "You are a subagent executing a delegated task. Follow the assignment precisely and return the final result. Do not spawn other agents.".to_string(),
109        }
110    }
111
112    fn build_user_prompt(context: Option<&str>, assignment: &str) -> String {
113        match context {
114            Some(ctx) if !ctx.trim().is_empty() => format!(
115                "{ctx}\n\n[Assignment]\n{assignment}\n\n[Constraints]\n- Subagent depth limit: you cannot spawn other agents.\n"
116            ),
117            _ => format!(
118                "[Assignment]\n{assignment}\n\n[Constraints]\n- Subagent depth limit: you cannot spawn other agents.\n"
119            ),
120        }
121    }
122
123    fn registry_for(agent: &str, workspace_root: &Path) -> ToolRegistry {
124        use ToolTier::*;
125        let workspace_root = workspace_root.to_path_buf();
126        let mut reg = ToolRegistry::new();
127
128        // Read/search tools
129        reg.register_with_tier(
130            Arc::new(file::ReadFileTool::new(workspace_root.clone())),
131            Core,
132        );
133        reg.register_with_tier(
134            Arc::new(file::ListDirectoryTool::new(workspace_root.clone())),
135            Standard,
136        );
137        reg.register_with_tier(
138            Arc::new(native::GlobSearchTool::new(workspace_root.clone())),
139            Core,
140        );
141        reg.register_with_tier(
142            Arc::new(native::GrepSearchTool::new(workspace_root.clone())),
143            Core,
144        );
145        reg.register_with_tier(
146            Arc::new(native::AstGrepTool::new(workspace_root.clone())),
147            Core,
148        );
149        reg.register_with_tier(
150            Arc::new(native::RipgrepTool::new(workspace_root.clone())),
151            Extended,
152        );
153        reg.register_with_tier(
154            Arc::new(native::FdTool::new(workspace_root.clone())),
155            Extended,
156        );
157
158        match agent {
159            "explore" | "plan" | "reviewer" | "librarian" => {
160                reg.register_with_tier(
161                    Arc::new(git::GitStatusTool::new(workspace_root.clone())),
162                    Standard,
163                );
164                reg.register_with_tier(
165                    Arc::new(git::GitDiffTool::new(workspace_root.clone())),
166                    Standard,
167                );
168                reg.register_with_tier(
169                    Arc::new(git::GitLogTool::new(workspace_root.clone())),
170                    Standard,
171                );
172                reg.register_with_tier(
173                    Arc::new(git::GitBlameTool::new(workspace_root.clone())),
174                    Standard,
175                );
176                reg
177            }
178            "task" | "designer" => {
179                reg.register_with_tier(Arc::new(bash::BashTool::new(workspace_root.clone())), Core);
180                reg.register_with_tier(
181                    Arc::new(file::WriteFileTool::new(workspace_root.clone())),
182                    Core,
183                );
184                reg.register_with_tier(
185                    Arc::new(edit::EditFileTool::new(workspace_root.clone())),
186                    Core,
187                );
188                reg.register_with_tier(
189                    Arc::new(edit::EditFileLinesTool::new(workspace_root.clone())),
190                    Standard,
191                );
192                reg.register_with_tier(
193                    Arc::new(edit::InsertAfterTool::new(workspace_root.clone())),
194                    Standard,
195                );
196                reg.register_with_tier(
197                    Arc::new(edit::AppendFileTool::new(workspace_root.clone())),
198                    Standard,
199                );
200
201                reg.register_with_tier(
202                    Arc::new(git::GitStatusTool::new(workspace_root.clone())),
203                    Standard,
204                );
205                reg.register_with_tier(
206                    Arc::new(git::GitDiffTool::new(workspace_root.clone())),
207                    Standard,
208                );
209                reg.register_with_tier(
210                    Arc::new(git::GitAddTool::new(workspace_root.clone())),
211                    Standard,
212                );
213                reg.register_with_tier(
214                    Arc::new(git::GitCommitTool::new(workspace_root.clone())),
215                    Standard,
216                );
217                reg.register_with_tier(
218                    Arc::new(git::GitLogTool::new(workspace_root.clone())),
219                    Standard,
220                );
221                reg.register_with_tier(
222                    Arc::new(git::GitBlameTool::new(workspace_root.clone())),
223                    Standard,
224                );
225                reg.register_with_tier(
226                    Arc::new(git::GitBranchTool::new(workspace_root.clone())),
227                    Standard,
228                );
229                reg.register_with_tier(
230                    Arc::new(git::GitCheckoutTool::new(workspace_root.clone())),
231                    Standard,
232                );
233                reg.register_with_tier(
234                    Arc::new(git::GitStashTool::new(workspace_root.clone())),
235                    Standard,
236                );
237
238                reg.register_with_tier(
239                    Arc::new(batch::BatchTool::new(workspace_root.clone())),
240                    Standard,
241                );
242                reg.register_with_tier(
243                    Arc::new(lsp_tool::LspTool::new(workspace_root.clone())),
244                    Extended,
245                );
246                reg.register_with_tier(
247                    Arc::new(mise::MiseTool::new(workspace_root.clone())),
248                    Extended,
249                );
250                reg.register_with_tier(
251                    Arc::new(native::SdTool::new(workspace_root.clone())),
252                    Extended,
253                );
254                reg.register_with_tier(
255                    Arc::new(native::ErdTool::new(workspace_root.clone())),
256                    Extended,
257                );
258                reg
259            }
260            _ => reg,
261        }
262    }
263
264    fn short_label(description: Option<&str>, assignment: &str) -> String {
265        description
266            .map(str::trim)
267            .filter(|s| !s.is_empty())
268            .map(str::to_string)
269            .unwrap_or_else(|| {
270                let one_line = assignment.lines().next().unwrap_or(assignment).trim();
271                if one_line.chars().count() > 48 {
272                    format!("{}…", one_line.chars().take(45).collect::<String>())
273                } else {
274                    one_line.to_string()
275                }
276            })
277    }
278
279    fn aggregate_batch_results(results: Vec<(usize, Result<Value>)>) -> Value {
280        let mut succeeded = 0usize;
281        let mut failed = 0usize;
282        let mut items = Vec::with_capacity(results.len());
283
284        for (index, result) in results {
285            match result {
286                Ok(v) => {
287                    let status = v.get("status").and_then(|s| s.as_str()).unwrap_or("error");
288                    if status == "completed" {
289                        succeeded += 1;
290                    } else {
291                        failed += 1;
292                    }
293                    items.push(json!({
294                        "index": index,
295                        "status": status,
296                        "agent": v.get("agent"),
297                        "result": v.get("result"),
298                        "duration_ms": v.get("duration_ms"),
299                        "usage": v.get("usage"),
300                        "subagent_id": v.get("subagent_id"),
301                    }));
302                }
303                Err(e) => {
304                    failed += 1;
305                    items.push(json!({
306                        "index": index,
307                        "status": "error",
308                        "result": e.to_string(),
309                    }));
310                }
311            }
312        }
313
314        let total = items.len();
315        json!({
316            "mode": "batch",
317            "success": failed == 0,
318            "total": total,
319            "succeeded": succeeded,
320            "failed": failed,
321            "results": items,
322        })
323    }
324
325    #[allow(clippy::too_many_arguments)]
326    async fn run_subagent(
327        &self,
328        agent_type: &str,
329        assignment: &str,
330        context: Option<&str>,
331        model: Option<&str>,
332        timeout_secs: u64,
333        label: &str,
334        backend_override: Option<Box<dyn LlmBackend>>,
335    ) -> Result<Value> {
336        let started = Instant::now();
337        let handle = SubagentHandle::start(label, "task", Some(agent_type.to_string()));
338
339        let mut config = PawanConfig {
340            system_prompt: Some(Self::system_prompt_for(agent_type)),
341            max_context_tokens: 32_000,
342            max_tool_iterations: 20,
343            ..Default::default()
344        };
345        config.eruka.enabled = false;
346        if let Some(m) = model {
347            config.model = m.to_string();
348        }
349
350        let tools = Self::registry_for(agent_type, &self.workspace_root);
351        let prompt = Self::build_user_prompt(context, assignment);
352
353        let mut agent = PawanAgent::new(config, self.workspace_root.clone()).with_tools(tools);
354        if let Some(backend) = backend_override {
355            agent = agent.with_backend(backend);
356        }
357
358        let progress = handle.clone();
359        let on_tool_start: crate::agent::ToolStartCallback = Box::new(move |name: &str| {
360            progress.set_tool(name);
361        });
362        let progress_done = handle.clone();
363        let on_tool: crate::agent::ToolCallback = Box::new(move |_record| {
364            progress_done.clear_tool();
365        });
366
367        let run = agent.execute_with_callbacks(&prompt, None, Some(on_tool), Some(on_tool_start));
368        let response = match timeout(Duration::from_secs(timeout_secs), run).await {
369            Ok(Ok(res)) => res,
370            Ok(Err(e)) => {
371                handle.complete_err(e.to_string());
372                let duration_ms = started.elapsed().as_millis() as u64;
373                let out = json!({
374                    "agent": agent_type,
375                    "status": "error",
376                    "result": e.to_string(),
377                    "duration_ms": duration_ms,
378                    "subagent_id": handle.id(),
379                });
380                handle.dismiss();
381                return Ok(out);
382            }
383            Err(_) => {
384                handle.complete_err(format!("subagent timeout after {timeout_secs}s"));
385                let duration_ms = started.elapsed().as_millis() as u64;
386                let out = json!({
387                    "agent": agent_type,
388                    "status": "error",
389                    "result": format!("subagent timeout after {timeout_secs}s"),
390                    "duration_ms": duration_ms,
391                    "subagent_id": handle.id(),
392                });
393                handle.dismiss();
394                return Ok(out);
395            }
396        };
397
398        handle.complete_ok();
399        let duration_ms = started.elapsed().as_millis() as u64;
400        let out = json!({
401            "agent": agent_type,
402            "status": "completed",
403            "result": response.content,
404            "duration_ms": duration_ms,
405            "subagent_id": handle.id(),
406            "usage": {
407                "prompt_tokens": response.usage.prompt_tokens,
408                "completion_tokens": response.usage.completion_tokens,
409                "total_tokens": response.usage.total_tokens,
410                "reasoning_tokens": response.usage.reasoning_tokens,
411                "action_tokens": response.usage.action_tokens,
412            }
413        });
414        handle.dismiss();
415        Ok(out)
416    }
417
418    async fn run_tasks_parallel(
419        &self,
420        tasks: Vec<TaskItem>,
421        model: Option<&str>,
422        timeout_secs: u64,
423    ) -> Result<Value> {
424        if tasks.is_empty() {
425            return Ok(json!({
426                "mode": "batch",
427                "success": true,
428                "total": 0,
429                "succeeded": 0,
430                "failed": 0,
431                "results": [],
432            }));
433        }
434
435        let semaphore = Arc::new(Semaphore::new(MAX_PARALLEL_SUBAGENTS));
436        let model = model.map(str::to_string);
437
438        let results: Vec<(usize, Result<Value>)> = stream::iter(tasks.into_iter().enumerate())
439            .map(|(index, item)| {
440                let sem = Arc::clone(&semaphore);
441                let tool = self.clone();
442                let model = model.clone();
443                async move {
444                    let _permit = sem.acquire().await.expect("semaphore");
445                    let label =
446                        TaskTool::short_label(item.description.as_deref(), &item.assignment);
447                    let result = tool
448                        .run_subagent(
449                            &item.agent,
450                            &item.assignment,
451                            item.context.as_deref(),
452                            model.as_deref(),
453                            timeout_secs,
454                            &label,
455                            None,
456                        )
457                        .await;
458                    (index, result)
459                }
460            })
461            .buffered(MAX_PARALLEL_SUBAGENTS)
462            .collect()
463            .await;
464
465        Ok(Self::aggregate_batch_results(results))
466    }
467}
468
469#[async_trait]
470impl Tool for TaskTool {
471    fn name(&self) -> &str {
472        "task"
473    }
474
475    fn description(&self) -> &str {
476        "Spawn an in-process subagent with restricted tools to complete an assignment."
477    }
478
479    fn mutating(&self) -> bool {
480        true
481    }
482
483    fn parameters_schema(&self) -> Value {
484        json!({
485            "type": "object",
486            "properties": {
487                "agent": {"type": "string", "description": "Agent type (single-task mode)"},
488                "assignment": {"type": "string", "description": "Assignment (single-task mode)"},
489                "tasks": {
490                    "type": "array",
491                    "description": "Parallel subagents (max 8). Each item needs agent + assignment.",
492                    "items": {
493                        "type": "object",
494                        "properties": {
495                            "agent": {"type": "string"},
496                            "assignment": {"type": "string"},
497                            "context": {"type": "string"},
498                            "description": {"type": "string", "description": "Short label for TUI"}
499                        },
500                        "required": ["agent", "assignment"]
501                    }
502                },
503                "context": {"type": "string"},
504                "description": {"type": "string", "description": "Short label for TUI (single-task)"},
505                "model": {"type": "string"},
506                "timeout": {"type": "integer"}
507            }
508        })
509    }
510
511    async fn execute(&self, args: Value) -> Result<Value> {
512        let parsed: TaskArgs = serde_json::from_value(args)
513            .map_err(|e| PawanError::Tool(format!("invalid task args: {e}")))?;
514
515        let timeout_secs = parsed.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS);
516
517        if let Some(tasks) = parsed.tasks {
518            if tasks.len() > 8 {
519                return Err(PawanError::Tool(
520                    "task tool accepts at most 8 parallel subagents".into(),
521                ));
522            }
523            for item in &tasks {
524                Self::validate_agent_type(&item.agent).map_err(PawanError::Tool)?;
525                Self::validate_assignment(&item.assignment).map_err(PawanError::Tool)?;
526            }
527            return self
528                .run_tasks_parallel(tasks, parsed.model.as_deref(), timeout_secs)
529                .await;
530        }
531
532        let agent = parsed
533            .agent
534            .as_deref()
535            .ok_or_else(|| PawanError::Tool("agent is required (or pass tasks array)".into()))?;
536        let assignment = parsed.assignment.as_deref().ok_or_else(|| {
537            PawanError::Tool("assignment is required (or pass tasks array)".into())
538        })?;
539
540        Self::validate_agent_type(agent).map_err(PawanError::Tool)?;
541        Self::validate_assignment(assignment).map_err(PawanError::Tool)?;
542
543        let label = Self::short_label(parsed.description.as_deref(), assignment);
544        self.run_subagent(
545            agent,
546            assignment,
547            parsed.context.as_deref(),
548            parsed.model.as_deref(),
549            timeout_secs,
550            &label,
551            None,
552        )
553        .await
554    }
555}
556
557#[cfg(test)]
558mod tests {
559    #[test]
560    fn batch_aggregate_counts_failures() {
561        let results = vec![
562            (0, Ok(json!({"status": "completed"}))),
563            (1, Ok(json!({"status": "error"}))),
564            (2, Err(PawanError::Tool("boom".into()))),
565        ];
566        let summary = TaskTool::aggregate_batch_results(results);
567        assert_eq!(summary["succeeded"], 1);
568        assert_eq!(summary["failed"], 2);
569        assert_eq!(summary["success"], false);
570    }
571
572    use super::*;
573    use crate::agent::backend::mock::{MockBackend, MockResponse};
574    use serde_json::json;
575
576    #[tokio::test]
577    async fn unknown_agent_type_rejects() {
578        let dir = tempfile::tempdir().unwrap();
579        let tool = TaskTool::new(dir.path().to_path_buf());
580        let err = tool
581            .execute(json!({"agent": "nope", "assignment": "hi", "timeout": 5}))
582            .await
583            .unwrap_err();
584        assert!(err.to_string().contains("unknown agent type"));
585    }
586
587    #[tokio::test]
588    async fn timeout_returns_error_status() {
589        let dir = tempfile::tempdir().unwrap();
590        let tool = TaskTool::new(dir.path().to_path_buf());
591
592        let out = tool
593            .run_subagent(
594                "explore",
595                "This will time out immediately.",
596                None,
597                None,
598                0,
599                "timeout test",
600                None,
601            )
602            .await
603            .unwrap();
604
605        assert_eq!(out["status"].as_str().unwrap(), "error");
606        assert!(out["result"].as_str().unwrap().contains("timeout"));
607    }
608
609    #[tokio::test]
610    async fn explore_agent_runs_and_returns_findings_with_mock_backend() {
611        let dir = tempfile::tempdir().unwrap();
612        let tool = TaskTool::new(dir.path().to_path_buf());
613
614        let backend = Box::new(MockBackend::new(vec![MockResponse::text(
615            "Findings: crates/pawan-core/src/lib.rs is the crate root.",
616        )]));
617
618        let out = tool
619            .run_subagent(
620                "explore",
621                "Explore the repo and return findings.",
622                Some("Context here"),
623                None,
624                5,
625                "explore repo",
626                Some(backend),
627            )
628            .await
629            .unwrap();
630
631        assert_eq!(out["agent"].as_str().unwrap(), "explore");
632        assert_eq!(out["status"].as_str().unwrap(), "completed");
633        assert!(out["result"].as_str().unwrap().contains("Findings:"));
634    }
635}