Skip to main content

pawan/tools/
task.rs

1//! Task tool: spawn an in-process subagent with restricted tools.
2//!
3//! This tool runs a child `PawanAgent` with a narrowed `ToolRegistry`, a smaller
4//! context window, and a hard timeout. Subagents are depth-limited (they cannot
5//! spawn other agents).
6
7use super::Tool;
8use crate::agent::backend::LlmBackend;
9use crate::agent::PawanAgent;
10use crate::config::PawanConfig;
11use crate::tools::{bash, batch, edit, file, git, lsp_tool, mise, native, ToolRegistry, ToolTier};
12use crate::{PawanError, Result};
13use async_trait::async_trait;
14use serde::Deserialize;
15use serde_json::{json, Value};
16use std::path::{Path, PathBuf};
17use std::sync::Arc;
18use std::time::{Duration, Instant};
19use futures::stream::{self, StreamExt};
20use tokio::sync::Semaphore;
21use tokio::time::timeout;
22
23use crate::subagent::SubagentHandle;
24
25const DEFAULT_TIMEOUT_SECS: u64 = 300;
26
27const MAX_PARALLEL_SUBAGENTS: usize = 4;
28
29#[derive(Debug, Clone, Deserialize)]
30struct TaskItem {
31    agent: String,
32    assignment: String,
33    #[serde(default)]
34    context: Option<String>,
35    #[serde(default)]
36    description: Option<String>,
37}
38
39#[derive(Debug, Clone, Deserialize)]
40struct TaskArgs {
41    /// Single-task mode: agent type.
42    #[serde(default)]
43    agent: Option<String>,
44    /// Single-task mode: assignment text.
45    #[serde(default)]
46    assignment: Option<String>,
47    /// Parallel mode: one or more subagent jobs (max 8).
48    #[serde(default)]
49    tasks: Option<Vec<TaskItem>>,
50    #[serde(default)]
51    context: Option<String>,
52    #[serde(default)]
53    description: Option<String>,
54    #[serde(default)]
55    model: Option<String>,
56    /// Timeout in seconds (default: 300).
57    #[serde(default)]
58    timeout: Option<u64>,
59}
60
61#[derive(Clone)]
62pub struct TaskTool {
63    workspace_root: PathBuf,
64}
65
66impl TaskTool {
67    pub fn new(workspace_root: PathBuf) -> Self {
68        Self { workspace_root }
69    }
70
71    fn known_agent_types() -> &'static [&'static str] {
72        &[
73            "explore",
74            "plan",
75            "task",
76            "reviewer",
77            "designer",
78            "librarian",
79        ]
80    }
81
82    fn validate_agent_type(agent: &str) -> std::result::Result<(), String> {
83        if Self::known_agent_types().contains(&agent) {
84            Ok(())
85        } else {
86            Err(format!(
87                "unknown agent type '{agent}'. Valid types: {}",
88                Self::known_agent_types().join(", ")
89            ))
90        }
91    }
92
93    fn validate_assignment(assignment: &str) -> std::result::Result<(), String> {
94        if assignment.trim().is_empty() {
95            Err("assignment must be non-empty".to_string())
96        } else {
97            Ok(())
98        }
99    }
100
101    fn system_prompt_for(agent: &str) -> String {
102        match agent {
103            "explore" => "You are a read-only exploration subagent. Use only the allowed read/search tools to gather facts. Do not propose or apply code edits. Return concise findings with file paths and evidence.".to_string(),
104            "plan" => "You are an architecture subagent. Do not modify code. Make design decisions and propose an implementation plan with tradeoffs, invariants, and acceptance criteria.".to_string(),
105            "reviewer" => "You are a code review subagent. Do not modify code. Identify bugs, security issues, and quality concerns. Return a structured review report with severity and recommendations.".to_string(),
106            "designer" => "You are a UI/UX design subagent. If editing tools are available, you may implement UI changes carefully. Prioritize accessibility and consistency.".to_string(),
107            "librarian" => "You are a research subagent. Verify details from authoritative sources and the local codebase. Do not modify code. Return actionable guidance.".to_string(),
108            _ => "You are a subagent executing a delegated task. Follow the assignment precisely and return the final result. Do not spawn other agents.".to_string(),
109        }
110    }
111
112    fn build_user_prompt(context: Option<&str>, assignment: &str) -> String {
113        match context {
114            Some(ctx) if !ctx.trim().is_empty() => format!(
115                "{ctx}\n\n[Assignment]\n{assignment}\n\n[Constraints]\n- Subagent depth limit: you cannot spawn other agents.\n"
116            ),
117            _ => format!(
118                "[Assignment]\n{assignment}\n\n[Constraints]\n- Subagent depth limit: you cannot spawn other agents.\n"
119            ),
120        }
121    }
122
123    fn registry_for(agent: &str, workspace_root: &Path) -> ToolRegistry {
124        use ToolTier::*;
125        let workspace_root = workspace_root.to_path_buf();
126        let mut reg = ToolRegistry::new();
127
128        // Read/search tools
129        reg.register_with_tier(
130            Arc::new(file::ReadFileTool::new(workspace_root.clone())),
131            Core,
132        );
133        reg.register_with_tier(
134            Arc::new(file::ListDirectoryTool::new(workspace_root.clone())),
135            Standard,
136        );
137        reg.register_with_tier(
138            Arc::new(native::GlobSearchTool::new(workspace_root.clone())),
139            Core,
140        );
141        reg.register_with_tier(
142            Arc::new(native::GrepSearchTool::new(workspace_root.clone())),
143            Core,
144        );
145        reg.register_with_tier(
146            Arc::new(native::AstGrepTool::new(workspace_root.clone())),
147            Core,
148        );
149        reg.register_with_tier(
150            Arc::new(native::RipgrepTool::new(workspace_root.clone())),
151            Extended,
152        );
153        reg.register_with_tier(
154            Arc::new(native::FdTool::new(workspace_root.clone())),
155            Extended,
156        );
157
158        match agent {
159            "explore" | "plan" | "reviewer" | "librarian" => {
160                reg.register_with_tier(
161                    Arc::new(git::GitStatusTool::new(workspace_root.clone())),
162                    Standard,
163                );
164                reg.register_with_tier(
165                    Arc::new(git::GitDiffTool::new(workspace_root.clone())),
166                    Standard,
167                );
168                reg.register_with_tier(
169                    Arc::new(git::GitLogTool::new(workspace_root.clone())),
170                    Standard,
171                );
172                reg.register_with_tier(
173                    Arc::new(git::GitBlameTool::new(workspace_root.clone())),
174                    Standard,
175                );
176                reg
177            }
178            "task" | "designer" => {
179                reg.register_with_tier(Arc::new(bash::BashTool::new(workspace_root.clone())), Core);
180                reg.register_with_tier(
181                    Arc::new(file::WriteFileTool::new(workspace_root.clone())),
182                    Core,
183                );
184                reg.register_with_tier(
185                    Arc::new(edit::EditFileTool::new(workspace_root.clone())),
186                    Core,
187                );
188                reg.register_with_tier(
189                    Arc::new(edit::EditFileLinesTool::new(workspace_root.clone())),
190                    Standard,
191                );
192                reg.register_with_tier(
193                    Arc::new(edit::InsertAfterTool::new(workspace_root.clone())),
194                    Standard,
195                );
196                reg.register_with_tier(
197                    Arc::new(edit::AppendFileTool::new(workspace_root.clone())),
198                    Standard,
199                );
200
201                reg.register_with_tier(
202                    Arc::new(git::GitStatusTool::new(workspace_root.clone())),
203                    Standard,
204                );
205                reg.register_with_tier(
206                    Arc::new(git::GitDiffTool::new(workspace_root.clone())),
207                    Standard,
208                );
209                reg.register_with_tier(
210                    Arc::new(git::GitAddTool::new(workspace_root.clone())),
211                    Standard,
212                );
213                reg.register_with_tier(
214                    Arc::new(git::GitCommitTool::new(workspace_root.clone())),
215                    Standard,
216                );
217                reg.register_with_tier(
218                    Arc::new(git::GitLogTool::new(workspace_root.clone())),
219                    Standard,
220                );
221                reg.register_with_tier(
222                    Arc::new(git::GitBlameTool::new(workspace_root.clone())),
223                    Standard,
224                );
225                reg.register_with_tier(
226                    Arc::new(git::GitBranchTool::new(workspace_root.clone())),
227                    Standard,
228                );
229                reg.register_with_tier(
230                    Arc::new(git::GitCheckoutTool::new(workspace_root.clone())),
231                    Standard,
232                );
233                reg.register_with_tier(
234                    Arc::new(git::GitStashTool::new(workspace_root.clone())),
235                    Standard,
236                );
237
238                reg.register_with_tier(
239                    Arc::new(batch::BatchTool::new(workspace_root.clone())),
240                    Standard,
241                );
242                reg.register_with_tier(
243                    Arc::new(lsp_tool::LspTool::new(workspace_root.clone())),
244                    Extended,
245                );
246                reg.register_with_tier(
247                    Arc::new(mise::MiseTool::new(workspace_root.clone())),
248                    Extended,
249                );
250                reg.register_with_tier(
251                    Arc::new(native::SdTool::new(workspace_root.clone())),
252                    Extended,
253                );
254                reg.register_with_tier(
255                    Arc::new(native::ErdTool::new(workspace_root.clone())),
256                    Extended,
257                );
258                reg
259            }
260            _ => reg,
261        }
262    }
263
264    fn short_label(description: Option<&str>, assignment: &str) -> String {
265        description
266            .map(str::trim)
267            .filter(|s| !s.is_empty())
268            .map(str::to_string)
269            .unwrap_or_else(|| {
270                let one_line = assignment.lines().next().unwrap_or(assignment).trim();
271                if one_line.chars().count() > 48 {
272                    format!("{}…", one_line.chars().take(45).collect::<String>())
273                } else {
274                    one_line.to_string()
275                }
276            })
277    }
278
279    fn aggregate_batch_results(results: Vec<(usize, Result<Value>)>) -> Value {
280        let mut succeeded = 0usize;
281        let mut failed = 0usize;
282        let mut items = Vec::with_capacity(results.len());
283
284        for (index, result) in results {
285            match result {
286                Ok(v) => {
287                    let status = v.get("status").and_then(|s| s.as_str()).unwrap_or("error");
288                    if status == "completed" {
289                        succeeded += 1;
290                    } else {
291                        failed += 1;
292                    }
293                    items.push(json!({
294                        "index": index,
295                        "status": status,
296                        "agent": v.get("agent"),
297                        "result": v.get("result"),
298                        "duration_ms": v.get("duration_ms"),
299                        "usage": v.get("usage"),
300                        "subagent_id": v.get("subagent_id"),
301                    }));
302                }
303                Err(e) => {
304                    failed += 1;
305                    items.push(json!({
306                        "index": index,
307                        "status": "error",
308                        "result": e.to_string(),
309                    }));
310                }
311            }
312        }
313
314        let total = items.len();
315        json!({
316            "mode": "batch",
317            "success": failed == 0,
318            "total": total,
319            "succeeded": succeeded,
320            "failed": failed,
321            "results": items,
322        })
323    }
324
325    async fn run_subagent(
326        &self,
327        agent_type: &str,
328        assignment: &str,
329        context: Option<&str>,
330        model: Option<&str>,
331        timeout_secs: u64,
332        label: &str,
333        backend_override: Option<Box<dyn LlmBackend>>,
334    ) -> Result<Value> {
335        let started = Instant::now();
336        let handle = SubagentHandle::start(label, "task", Some(agent_type.to_string()));
337
338        let mut config = PawanConfig {
339            system_prompt: Some(Self::system_prompt_for(agent_type)),
340            max_context_tokens: 32_000,
341            max_tool_iterations: 20,
342            ..Default::default()
343        };
344        config.eruka.enabled = false;
345        if let Some(m) = model {
346            config.model = m.to_string();
347        }
348
349        let tools = Self::registry_for(agent_type, &self.workspace_root);
350        let prompt = Self::build_user_prompt(context, assignment);
351
352        let mut agent = PawanAgent::new(config, self.workspace_root.clone()).with_tools(tools);
353        if let Some(backend) = backend_override {
354            agent = agent.with_backend(backend);
355        }
356
357        let progress = handle.clone();
358        let on_tool_start: crate::agent::ToolStartCallback = Box::new(move |name: &str| {
359            progress.set_tool(name);
360        });
361        let progress_done = handle.clone();
362        let on_tool: crate::agent::ToolCallback = Box::new(move |_record| {
363            progress_done.clear_tool();
364        });
365
366        let run = agent.execute_with_callbacks(&prompt, None, Some(on_tool), Some(on_tool_start));
367        let response = match timeout(Duration::from_secs(timeout_secs), run).await {
368            Ok(Ok(res)) => res,
369            Ok(Err(e)) => {
370                handle.complete_err(e.to_string());
371                let duration_ms = started.elapsed().as_millis() as u64;
372                let out = json!({
373                    "agent": agent_type,
374                    "status": "error",
375                    "result": e.to_string(),
376                    "duration_ms": duration_ms,
377                    "subagent_id": handle.id(),
378                });
379                handle.dismiss();
380                return Ok(out);
381            }
382            Err(_) => {
383                handle.complete_err(format!("subagent timeout after {timeout_secs}s"));
384                let duration_ms = started.elapsed().as_millis() as u64;
385                let out = json!({
386                    "agent": agent_type,
387                    "status": "error",
388                    "result": format!("subagent timeout after {timeout_secs}s"),
389                    "duration_ms": duration_ms,
390                    "subagent_id": handle.id(),
391                });
392                handle.dismiss();
393                return Ok(out);
394            }
395        };
396
397        handle.complete_ok();
398        let duration_ms = started.elapsed().as_millis() as u64;
399        let out = json!({
400            "agent": agent_type,
401            "status": "completed",
402            "result": response.content,
403            "duration_ms": duration_ms,
404            "subagent_id": handle.id(),
405            "usage": {
406                "prompt_tokens": response.usage.prompt_tokens,
407                "completion_tokens": response.usage.completion_tokens,
408                "total_tokens": response.usage.total_tokens,
409                "reasoning_tokens": response.usage.reasoning_tokens,
410                "action_tokens": response.usage.action_tokens,
411            }
412        });
413        handle.dismiss();
414        Ok(out)
415    }
416
417    async fn run_tasks_parallel(
418        &self,
419        tasks: Vec<TaskItem>,
420        model: Option<&str>,
421        timeout_secs: u64,
422    ) -> Result<Value> {
423        if tasks.is_empty() {
424            return Ok(json!({
425                "mode": "batch",
426                "success": true,
427                "total": 0,
428                "succeeded": 0,
429                "failed": 0,
430                "results": [],
431            }));
432        }
433
434        let semaphore = Arc::new(Semaphore::new(MAX_PARALLEL_SUBAGENTS));
435        let model = model.map(str::to_string);
436
437        let results: Vec<(usize, Result<Value>)> = stream::iter(tasks.into_iter().enumerate())
438            .map(|(index, item)| {
439                let sem = Arc::clone(&semaphore);
440                let tool = self.clone();
441                let model = model.clone();
442                async move {
443                    let _permit = sem.acquire().await.expect("semaphore");
444                    let label = TaskTool::short_label(item.description.as_deref(), &item.assignment);
445                    let result = tool
446                        .run_subagent(
447                            &item.agent,
448                            &item.assignment,
449                            item.context.as_deref(),
450                            model.as_deref(),
451                            timeout_secs,
452                            &label,
453                            None,
454                        )
455                        .await;
456                    (index, result)
457                }
458            })
459            .buffered(MAX_PARALLEL_SUBAGENTS)
460            .collect()
461            .await;
462
463        Ok(Self::aggregate_batch_results(results))
464    }
465}
466
467#[async_trait]
468impl Tool for TaskTool {
469    fn name(&self) -> &str {
470        "task"
471    }
472
473    fn description(&self) -> &str {
474        "Spawn an in-process subagent with restricted tools to complete an assignment."
475    }
476
477    fn mutating(&self) -> bool {
478        true
479    }
480
481    fn parameters_schema(&self) -> Value {
482        json!({
483            "type": "object",
484            "properties": {
485                "agent": {"type": "string", "description": "Agent type (single-task mode)"},
486                "assignment": {"type": "string", "description": "Assignment (single-task mode)"},
487                "tasks": {
488                    "type": "array",
489                    "description": "Parallel subagents (max 8). Each item needs agent + assignment.",
490                    "items": {
491                        "type": "object",
492                        "properties": {
493                            "agent": {"type": "string"},
494                            "assignment": {"type": "string"},
495                            "context": {"type": "string"},
496                            "description": {"type": "string", "description": "Short label for TUI"}
497                        },
498                        "required": ["agent", "assignment"]
499                    }
500                },
501                "context": {"type": "string"},
502                "description": {"type": "string", "description": "Short label for TUI (single-task)"},
503                "model": {"type": "string"},
504                "timeout": {"type": "integer"}
505            }
506        })
507    }
508
509    async fn execute(&self, args: Value) -> Result<Value> {
510        let parsed: TaskArgs = serde_json::from_value(args)
511            .map_err(|e| PawanError::Tool(format!("invalid task args: {e}")))?;
512
513        let timeout_secs = parsed.timeout.unwrap_or(DEFAULT_TIMEOUT_SECS);
514
515        if let Some(tasks) = parsed.tasks {
516            if tasks.len() > 8 {
517                return Err(PawanError::Tool(
518                    "task tool accepts at most 8 parallel subagents".into(),
519                ));
520            }
521            for item in &tasks {
522                Self::validate_agent_type(&item.agent).map_err(PawanError::Tool)?;
523                Self::validate_assignment(&item.assignment).map_err(PawanError::Tool)?;
524            }
525            return self
526                .run_tasks_parallel(tasks, parsed.model.as_deref(), timeout_secs)
527                .await;
528        }
529
530        let agent = parsed
531            .agent
532            .as_deref()
533            .ok_or_else(|| PawanError::Tool("agent is required (or pass tasks array)".into()))?;
534        let assignment = parsed.assignment.as_deref().ok_or_else(|| {
535            PawanError::Tool("assignment is required (or pass tasks array)".into())
536        })?;
537
538        Self::validate_agent_type(agent).map_err(PawanError::Tool)?;
539        Self::validate_assignment(assignment).map_err(PawanError::Tool)?;
540
541        let label = Self::short_label(parsed.description.as_deref(), assignment);
542        self.run_subagent(
543            agent,
544            assignment,
545            parsed.context.as_deref(),
546            parsed.model.as_deref(),
547            timeout_secs,
548            &label,
549            None,
550        )
551        .await
552    }
553}
554
555#[cfg(test)]
556mod tests {
557    #[test]
558    fn batch_aggregate_counts_failures() {
559        let results = vec![
560            (0, Ok(json!({"status": "completed"}))),
561            (1, Ok(json!({"status": "error"}))),
562            (2, Err(PawanError::Tool("boom".into()))),
563        ];
564        let summary = TaskTool::aggregate_batch_results(results);
565        assert_eq!(summary["succeeded"], 1);
566        assert_eq!(summary["failed"], 2);
567        assert_eq!(summary["success"], false);
568    }
569
570    use super::*;
571    use crate::agent::backend::mock::{MockBackend, MockResponse};
572    use serde_json::json;
573
574    #[tokio::test]
575    async fn unknown_agent_type_rejects() {
576        let dir = tempfile::tempdir().unwrap();
577        let tool = TaskTool::new(dir.path().to_path_buf());
578        let err = tool
579            .execute(json!({"agent": "nope", "assignment": "hi", "timeout": 5}))
580            .await
581            .unwrap_err();
582        assert!(err.to_string().contains("unknown agent type"));
583    }
584
585    #[tokio::test]
586    async fn timeout_returns_error_status() {
587        let dir = tempfile::tempdir().unwrap();
588        let tool = TaskTool::new(dir.path().to_path_buf());
589
590        let out = tool
591            .run_subagent(
592                "explore",
593                "This will time out immediately.",
594                None,
595                None,
596                0,
597                "timeout test",
598                None,
599            )
600            .await
601            .unwrap();
602
603        assert_eq!(out["status"].as_str().unwrap(), "error");
604        assert!(out["result"].as_str().unwrap().contains("timeout"));
605    }
606
607    #[tokio::test]
608    async fn explore_agent_runs_and_returns_findings_with_mock_backend() {
609        let dir = tempfile::tempdir().unwrap();
610        let tool = TaskTool::new(dir.path().to_path_buf());
611
612        let backend = Box::new(MockBackend::new(vec![MockResponse::text(
613            "Findings: crates/pawan-core/src/lib.rs is the crate root.",
614        )]));
615
616        let out = tool
617            .run_subagent(
618                "explore",
619                "Explore the repo and return findings.",
620                Some("Context here"),
621                None,
622                5,
623                "explore repo",
624                Some(backend),
625            )
626            .await
627            .unwrap();
628
629        assert_eq!(out["agent"].as_str().unwrap(), "explore");
630        assert_eq!(out["status"].as_str().unwrap(), "completed");
631        assert!(out["result"].as_str().unwrap().contains("Findings:"));
632    }
633}