Skip to main content

rho_tools/
task.rs

1use std::path::{Path, PathBuf};
2
3use async_trait::async_trait;
4use tokio_util::sync::CancellationToken;
5
6use rho_core::config::AgentDef;
7use rho_core::tool::{AgentTool, ToolError};
8use rho_core::types::{Content, ToolResult};
9
10const DEFAULT_MAX_AGENT_DEPTH: usize = 5;
11
12pub struct TaskTool {
13    rho_binary: PathBuf,
14    cwd: PathBuf,
15    max_depth: usize,
16    current_depth: usize,
17    allowed_agents: Option<Vec<AgentDef>>,
18}
19
20impl TaskTool {
21    pub fn new(cwd: PathBuf, max_depth: Option<usize>, allowed_agents: Vec<AgentDef>) -> Self {
22        // Find the rho binary — prefer the one next to current exe
23        let rho_binary = std::env::current_exe()
24            .ok()
25            .and_then(|p| {
26                let dir = p.parent()?;
27                let candidate = dir.join("rho");
28                if candidate.exists() {
29                    Some(candidate)
30                } else {
31                    None
32                }
33            })
34            .unwrap_or_else(|| PathBuf::from("rho"));
35
36        let current_depth: usize = std::env::var("RHO_AGENT_DEPTH")
37            .ok()
38            .and_then(|v| v.parse().ok())
39            .unwrap_or(0);
40
41        let allowed_agents = if allowed_agents.is_empty() {
42            None
43        } else {
44            Some(allowed_agents)
45        };
46
47        Self {
48            rho_binary,
49            cwd,
50            max_depth: max_depth.unwrap_or(DEFAULT_MAX_AGENT_DEPTH),
51            current_depth,
52            allowed_agents,
53        }
54    }
55}
56
57#[async_trait]
58impl AgentTool for TaskTool {
59    fn name(&self) -> &str {
60        "task"
61    }
62
63    fn label(&self) -> String {
64        "Task (subagent)".into()
65    }
66
67    fn description(&self) -> String {
68        let mut desc = "Launch a subagent to handle a task. The subagent runs as a separate process with \
69         its own context. Use this for research, analysis, or delegating work that should \
70         not pollute the current conversation context.".to_string();
71        if let Some(ref agents) = self.allowed_agents {
72            desc.push_str("\n\nAvailable agents:");
73            for a in agents {
74                desc.push_str(&format!("\n- {}", a.name));
75                if let Some(ref d) = a.description {
76                    desc.push_str(&format!(": {}", d));
77                }
78            }
79        }
80        desc
81    }
82
83    fn parameters_schema(&self) -> serde_json::Value {
84        let agent_prop = if let Some(ref agents) = self.allowed_agents {
85            let names: Vec<String> = agents.iter().map(|a| a.name.clone()).collect();
86            serde_json::json!({
87                "type": "string",
88                "description": "Agent to use for this task",
89                "enum": names
90            })
91        } else {
92            serde_json::json!({
93                "type": "string",
94                "description": "Name of agent config from .rho/agents/ (optional)"
95            })
96        };
97
98        serde_json::json!({
99            "type": "object",
100            "required": ["prompt"],
101            "properties": {
102                "prompt": {
103                    "type": "string",
104                    "description": "The task prompt for the subagent"
105                },
106                "agent": agent_prop,
107                "tools": {
108                    "type": "string",
109                    "description": "Comma-separated list of allowed tools (e.g. 'read,grep,find')"
110                }
111            }
112        })
113    }
114
115    async fn execute(
116        &self,
117        _tool_call_id: &str,
118        params: serde_json::Value,
119        cancel: CancellationToken,
120    ) -> Result<ToolResult, ToolError> {
121        // Depth limit check
122        if self.current_depth >= self.max_depth {
123            return Ok(ToolResult {
124                content: vec![Content::Text {
125                    text: format!(
126                        "Subagent depth limit reached ({}/{}). Cannot spawn further subagents.",
127                        self.current_depth, self.max_depth
128                    ),
129                }],
130                details: serde_json::json!({"error": "depth_limit"}),
131            });
132        }
133
134        let prompt = params["prompt"]
135            .as_str()
136            .ok_or_else(|| ToolError::InvalidParameters("prompt is required".into()))?;
137
138        let agent_name = params["agent"].as_str();
139        let tools_override = params["tools"].as_str();
140
141        // Agent allowlist check
142        let matched_agent_def = if let Some(ref allowed) = self.allowed_agents {
143            if let Some(name) = agent_name {
144                let found = allowed.iter().find(|a| a.name == name);
145                if found.is_none() {
146                    let available: Vec<&str> = allowed.iter().map(|a| a.name.as_str()).collect();
147                    return Ok(ToolResult {
148                        content: vec![Content::Text {
149                            text: format!(
150                                "Agent '{}' is not in the allowed agents list. Available agents: {}",
151                                name,
152                                available.join(", ")
153                            ),
154                        }],
155                        details: serde_json::json!({"error": "agent_not_allowed"}),
156                    });
157                }
158                found
159            } else {
160                None
161            }
162        } else {
163            None
164        };
165
166        // Load agent config file if specified (file-based config)
167        let agent_config = if let Some(name) = agent_name {
168            load_agent_config(&self.cwd, name)
169        } else {
170            None
171        };
172
173        let mut cmd = tokio::process::Command::new(&self.rho_binary);
174        cmd.current_dir(&self.cwd);
175
176        // Set depth env for child process
177        cmd.env("RHO_AGENT_DEPTH", (self.current_depth + 1).to_string());
178
179        // Apply tools restriction
180        // Priority: explicit tools param > AgentDef from RHO.md > agent config file
181        if let Some(tools) = tools_override {
182            cmd.arg("--tools").arg(tools);
183        } else if let Some(agent_def) = matched_agent_def {
184            if let Some(ref tools) = agent_def.tools {
185                cmd.arg("--tools").arg(tools);
186            }
187        } else if let Some(ref ac) = agent_config {
188            if let Some(ref tools) = ac.tools {
189                cmd.arg("--tools").arg(tools);
190            }
191        }
192
193        // Apply model override
194        // Priority: AgentDef from RHO.md > agent config file
195        if let Some(agent_def) = matched_agent_def {
196            if let Some(ref model) = agent_def.model {
197                cmd.arg("--model").arg(model);
198            }
199        } else if let Some(ref ac) = agent_config {
200            if let Some(ref model) = ac.model {
201                cmd.arg("--model").arg(model);
202            }
203        }
204
205        // Apply system prompt append (only from agent config file)
206        if let Some(ref ac) = agent_config {
207            if let Some(ref append) = ac.system_prompt_append {
208                cmd.arg("--system-append").arg(append);
209            }
210        }
211
212        // The prompt is the positional argument
213        cmd.arg(prompt);
214
215        // Capture output
216        cmd.stdout(std::process::Stdio::piped());
217        cmd.stderr(std::process::Stdio::piped());
218
219        let child = cmd.spawn().map_err(|e| {
220            ToolError::ExecutionFailed(format!("Failed to spawn subagent: {}", e))
221        })?;
222
223        let output: std::process::Output = tokio::select! {
224            result = child.wait_with_output() => {
225                result.map_err(|e| ToolError::ExecutionFailed(format!("Subagent error: {}", e)))?
226            }
227            _ = cancel.cancelled() => {
228                return Ok(ToolResult {
229                    content: vec![Content::Text {
230                        text: "Subagent cancelled".into(),
231                    }],
232                    details: serde_json::json!({}),
233                });
234            }
235        };
236
237        let stdout = String::from_utf8_lossy(&output.stdout);
238        let stderr = String::from_utf8_lossy(&output.stderr);
239
240        let mut result_text = stdout.to_string();
241        if !stderr.is_empty() && !output.status.success() {
242            result_text.push_str("\n\n[stderr]\n");
243            result_text.push_str(&stderr);
244        }
245
246        // Truncate to 20KB
247        if result_text.len() > 20_000 {
248            result_text.truncate(20_000);
249            result_text.push_str("\n... [truncated]");
250        }
251
252        Ok(ToolResult {
253            content: vec![Content::Text { text: result_text }],
254            details: serde_json::json!({
255                "exit_code": output.status.code(),
256            }),
257        })
258    }
259}
260
261#[derive(Debug, Clone)]
262struct AgentConfig {
263    tools: Option<String>,
264    model: Option<String>,
265    system_prompt_append: Option<String>,
266}
267
268/// Load agent config from .rho/agents/{name}.md or .claude/agents/{name}.md
269fn load_agent_config(cwd: &Path, name: &str) -> Option<AgentConfig> {
270    let candidates = [
271        cwd.join(format!(".rho/agents/{}.md", name)),
272        cwd.join(format!(".claude/agents/{}.md", name)),
273    ];
274
275    for path in &candidates {
276        if let Ok(content) = std::fs::read_to_string(path) {
277            return Some(parse_agent_config(&content));
278        }
279    }
280
281    // Check home directory
282    if let Some(home) = dirs::home_dir() {
283        let path = home.join(format!(".rho/agents/{}.md", name));
284        if let Ok(content) = std::fs::read_to_string(&path) {
285            return Some(parse_agent_config(&content));
286        }
287    }
288
289    None
290}
291
292fn parse_agent_config(content: &str) -> AgentConfig {
293    let trimmed = content.trim_start();
294    if !trimmed.starts_with("---") {
295        return AgentConfig {
296            tools: None,
297            model: None,
298            system_prompt_append: Some(content.to_string()),
299        };
300    }
301
302    let after_first = &trimmed[3..];
303    let Some(end) = after_first.find("\n---") else {
304        return AgentConfig {
305            tools: None,
306            model: None,
307            system_prompt_append: Some(content.to_string()),
308        };
309    };
310
311    let frontmatter = &after_first[..end];
312    let body_start = 3 + end + 4;
313    let body = trimmed[body_start..].trim().to_string();
314
315    let mut config = AgentConfig {
316        tools: None,
317        model: None,
318        system_prompt_append: if body.is_empty() { None } else { Some(body) },
319    };
320
321    for line in frontmatter.lines() {
322        let line = line.trim();
323        if let Some(val) = line.strip_prefix("tools:") {
324            config.tools = Some(val.trim().to_string());
325        } else if let Some(val) = line.strip_prefix("model:") {
326            config.model = Some(val.trim().to_string());
327        }
328    }
329
330    config
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn parse_agent_config_with_frontmatter() {
339        let content = "\
340---
341name: researcher
342tools: read,grep,find
343model: claude-sonnet-4-5-20250929
344---
345You are a research agent. Analyze code and return findings.
346Do not modify any files.";
347
348        let config = parse_agent_config(content);
349        assert_eq!(config.tools.as_deref(), Some("read,grep,find"));
350        assert_eq!(config.model.as_deref(), Some("claude-sonnet-4-5-20250929"));
351        assert!(config
352            .system_prompt_append
353            .unwrap()
354            .contains("research agent"));
355    }
356
357    #[test]
358    fn parse_agent_config_no_frontmatter() {
359        let content = "Just do research.";
360        let config = parse_agent_config(content);
361        assert!(config.tools.is_none());
362        assert!(config.model.is_none());
363        assert_eq!(config.system_prompt_append.as_deref(), Some(content));
364    }
365
366    #[test]
367    fn task_tool_schema_no_agents() {
368        let tool = TaskTool::new(PathBuf::from("."), None, vec![]);
369        let schema = tool.parameters_schema();
370        assert_eq!(schema["required"][0], "prompt");
371        assert!(schema["properties"]["agent"].is_object());
372        assert!(schema["properties"]["agent"]["enum"].is_null());
373        assert!(schema["properties"]["tools"].is_object());
374    }
375
376    #[test]
377    fn task_tool_schema_with_agents() {
378        let agents = vec![
379            AgentDef {
380                name: "researcher".into(),
381                tools: Some("read,grep".into()),
382                model: None,
383                description: Some("Research agent".into()),
384            },
385            AgentDef {
386                name: "coder".into(),
387                tools: None,
388                model: Some("claude-opus".into()),
389                description: None,
390            },
391        ];
392        let tool = TaskTool::new(PathBuf::from("."), None, agents);
393        let schema = tool.parameters_schema();
394        let agent_enum = &schema["properties"]["agent"]["enum"];
395        assert_eq!(agent_enum[0], "researcher");
396        assert_eq!(agent_enum[1], "coder");
397    }
398
399    #[test]
400    fn task_tool_description_lists_agents() {
401        let agents = vec![AgentDef {
402            name: "researcher".into(),
403            tools: None,
404            model: None,
405            description: Some("Finds stuff".into()),
406        }];
407        let tool = TaskTool::new(PathBuf::from("."), None, agents);
408        let desc = tool.description();
409        assert!(desc.contains("Available agents:"));
410        assert!(desc.contains("- researcher: Finds stuff"));
411    }
412
413    #[tokio::test]
414    async fn task_tool_depth_limit() {
415        // Simulate being at max depth
416        std::env::set_var("RHO_AGENT_DEPTH", "5");
417        let tool = TaskTool::new(PathBuf::from("."), Some(5), vec![]);
418        let cancel = CancellationToken::new();
419        let params = serde_json::json!({"prompt": "do something"});
420        let result = tool.execute("test", params, cancel).await.unwrap();
421        let text = match &result.content[0] {
422            Content::Text { text } => text.as_str(),
423            _ => panic!("expected text"),
424        };
425        assert!(text.contains("depth limit reached"));
426        assert_eq!(result.details["error"], "depth_limit");
427        // Clean up
428        std::env::remove_var("RHO_AGENT_DEPTH");
429    }
430
431    #[tokio::test]
432    async fn task_tool_agent_not_allowed() {
433        std::env::remove_var("RHO_AGENT_DEPTH");
434        let agents = vec![AgentDef {
435            name: "researcher".into(),
436            tools: None,
437            model: None,
438            description: None,
439        }];
440        let tool = TaskTool::new(PathBuf::from("."), None, agents);
441        let cancel = CancellationToken::new();
442        let params = serde_json::json!({"prompt": "do something", "agent": "hacker"});
443        let result = tool.execute("test", params, cancel).await.unwrap();
444        let text = match &result.content[0] {
445            Content::Text { text } => text.as_str(),
446            _ => panic!("expected text"),
447        };
448        assert!(text.contains("not in the allowed agents list"));
449        assert!(text.contains("researcher"));
450        assert_eq!(result.details["error"], "agent_not_allowed");
451    }
452}