Skip to main content

ai_agent/tools/agent/
agent_tool_utils.rs

1// Source: ~/claudecode/openclaudecode/src/tools/AgentTool/agentToolUtils.ts
2#![allow(dead_code)]
3use std::sync::Arc;
4
5use std::collections::{HashMap, HashSet};
6
7use super::constants::{
8    AGENT_TOOL_NAME, ALL_AGENT_DISALLOWED_TOOLS, ASYNC_AGENT_ALLOWED_TOOLS,
9    CUSTOM_AGENT_DISALLOWED_TOOLS, FORK_BOILERPLATE_TAG, FORK_DIRECTIVE_PREFIX,
10};
11use super::load_agents_dir::AgentDefinition;
12
13/// Resolved tools for an agent.
14#[derive(Debug, Clone)]
15pub struct ResolvedAgentTools {
16    pub has_wildcard: bool,
17    pub valid_tools: Vec<String>,
18    pub invalid_tools: Vec<String>,
19    pub resolved_tool_names: Vec<String>,
20    pub allowed_agent_types: Option<Vec<String>>,
21}
22
23/// Filter tools available to an agent based on built-in status and async mode.
24pub fn filter_tools_for_agent(
25    available_tools: &[String],
26    is_built_in: bool,
27    is_async: bool,
28) -> Vec<String> {
29    available_tools
30        .iter()
31        .filter(|tool| {
32            // Allow MCP tools for all agents
33            if tool.starts_with("mcp__") {
34                return true;
35            }
36            // Block globally disallowed tools
37            if ALL_AGENT_DISALLOWED_TOOLS.contains(&tool.as_str()) {
38                return false;
39            }
40            // Block custom-agent-specific tools for non-built-in agents
41            if !is_built_in && CUSTOM_AGENT_DISALLOWED_TOOLS.contains(&tool.as_str()) {
42                return false;
43            }
44            // Block async-restricted tools for async agents
45            if is_async && !ASYNC_AGENT_ALLOWED_TOOLS.contains(&tool.as_str()) {
46                return false;
47            }
48            true
49        })
50        .cloned()
51        .collect()
52}
53
54/// Parse a tool spec string to extract the tool name and any permission pattern.
55fn parse_tool_spec(spec: &str) -> (String, Option<String>) {
56    if let Some(pos) = spec.find('(') {
57        let tool_name = spec[..pos].trim().to_string();
58        let rule_content = spec[pos..].trim().to_string();
59        (tool_name, Some(rule_content))
60    } else {
61        (spec.trim().to_string(), None)
62    }
63}
64
65/// Resolves and validates agent tools against available tools.
66/// Handles wildcard expansion and validation.
67pub fn resolve_agent_tools(
68    agent_definition: &AgentDefinition,
69    available_tools: &[String],
70    is_async: bool,
71) -> ResolvedAgentTools {
72    // Filter available tools based on agent's built-in status and async mode
73    let filtered_available = filter_tools_for_agent(
74        available_tools,
75        agent_definition.source == "built-in",
76        is_async,
77    );
78
79    // Create a set of disallowed tool names
80    let disallowed_set: HashSet<&str> = agent_definition
81        .disallowed_tools
82        .iter()
83        .map(|s| s.as_str())
84        .collect();
85
86    // Filter out disallowed tools
87    let allowed_available: Vec<String> = filtered_available
88        .into_iter()
89        .filter(|t| !disallowed_set.contains(t.as_str()))
90        .collect();
91
92    // Check for wildcard
93    let has_wildcard = agent_definition.tools.is_empty()
94        || agent_definition.tools == vec!["*"]
95        || (agent_definition.tools.len() == 1 && agent_definition.tools[0] == "*");
96
97    if has_wildcard {
98        return ResolvedAgentTools {
99            has_wildcard: true,
100            valid_tools: vec![],
101            invalid_tools: vec![],
102            resolved_tool_names: allowed_available,
103            allowed_agent_types: None,
104        };
105    }
106
107    let available_map: HashMap<&str, &String> =
108        allowed_available.iter().map(|t| (t.as_str(), t)).collect();
109
110    let mut valid_tools: Vec<String> = Vec::new();
111    let mut invalid_tools: Vec<String> = Vec::new();
112    let mut resolved: Vec<String> = Vec::new();
113    let mut resolved_set: HashSet<String> = HashSet::new();
114    let mut allowed_agent_types: Option<Vec<String>> = None;
115
116    for tool_spec in &agent_definition.tools {
117        let (tool_name, rule_content) = parse_tool_spec(tool_spec);
118
119        // Special case: Agent tool carries allowedAgentTypes metadata
120        if tool_name == AGENT_TOOL_NAME {
121            if let Some(ref rules) = rule_content {
122                // Parse comma-separated agent types: "worker, researcher" -> ["worker", "researcher"]
123                let types: Vec<String> = rules
124                    .trim_matches(|c: char| c == '(' || c == ')')
125                    .split(',')
126                    .map(|s| s.trim().to_string())
127                    .collect();
128                allowed_agent_types = Some(types);
129            }
130            valid_tools.push(tool_spec.clone());
131            continue;
132        }
133
134        if available_map.contains_key(tool_name.as_str()) {
135            valid_tools.push(tool_spec.clone());
136            if resolved_set.insert(tool_name.clone()) {
137                resolved.push(tool_name);
138            }
139        } else {
140            invalid_tools.push(tool_spec.clone());
141        }
142    }
143
144    ResolvedAgentTools {
145        has_wildcard: false,
146        valid_tools,
147        invalid_tools,
148        allowed_agent_types,
149        resolved_tool_names: resolved,
150    }
151}
152
153/// Count tool uses in a list of messages (represented as JSON values).
154pub fn count_tool_uses(messages: &[serde_json::Value]) -> usize {
155    let mut count = 0;
156    for msg in messages {
157        if msg.get("type").and_then(|t| t.as_str()) == Some("assistant") {
158            if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
159                if let Some(arr) = content.as_array() {
160                    for block in arr {
161                        if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
162                            count += 1;
163                        }
164                    }
165                }
166            }
167        }
168    }
169    count
170}
171
172/// Extract text content from a message's content array.
173pub fn extract_text_content(content: &[serde_json::Value], separator: &str) -> String {
174    let texts: Vec<String> = content
175        .iter()
176        .filter(|block| block.get("type").and_then(|t| t.as_str()) == Some("text"))
177        .filter_map(|block| block.get("text").and_then(|t| t.as_str()))
178        .map(|t| t.to_string())
179        .collect();
180    texts.join(separator)
181}
182
183/// Get the last assistant message from a list of messages.
184pub fn get_last_assistant_message(messages: &[serde_json::Value]) -> Option<&serde_json::Value> {
185    messages
186        .iter()
187        .rev()
188        .find(|msg| msg.get("type").and_then(|t| t.as_str()) == Some("assistant"))
189}
190
191/// Extract a partial result string from an agent's accumulated messages.
192/// Used when an async agent is killed to preserve what it accomplished.
193pub fn extract_partial_result(messages: &[serde_json::Value]) -> Option<String> {
194    for msg in messages.iter().rev() {
195        if msg.get("type").and_then(|t| t.as_str()) != Some("assistant") {
196            continue;
197        }
198        if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
199            if let Some(arr) = content.as_array() {
200                let text = extract_text_content(arr, "\n");
201                if !text.is_empty() {
202                    return Some(text);
203                }
204            }
205        }
206    }
207    None
208}
209
210/// Extract a partial result string from a QueryEngine's message history.
211/// Used when a subagent is killed to preserve what it accomplished.
212/// Matches TypeScript's extractPartialResult but operates on engine Message type.
213pub fn extract_partial_result_from_engine(messages: &[crate::types::Message]) -> Option<String> {
214    for msg in messages.iter().rev() {
215        if msg.role != crate::types::MessageRole::Assistant {
216            continue;
217        }
218        if !msg.content.is_empty() {
219            return Some(msg.content.clone());
220        }
221    }
222    None
223}
224
225/// Get the name of the last tool_use block in a message.
226pub fn get_last_tool_use_name(message: &serde_json::Value) -> Option<String> {
227    if message.get("type").and_then(|t| t.as_str()) != Some("assistant") {
228        return None;
229    }
230    let content = message.get("message").and_then(|m| m.get("content"))?;
231    let arr = content.as_array()?;
232    for block in arr.iter().rev() {
233        if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
234            return block
235                .get("name")
236                .and_then(|n| n.as_str())
237                .map(|s| s.to_string());
238        }
239    }
240    None
241}
242
243/// Token usage tracking for an agent run.
244#[derive(Debug, Clone, Default)]
245pub struct TokenUsage {
246    pub input_tokens: usize,
247    pub output_tokens: usize,
248    pub cache_creation_input_tokens: usize,
249    pub cache_read_input_tokens: usize,
250}
251
252/// Result returned when an agent completes.
253#[derive(Debug, Clone)]
254pub struct AgentToolResult {
255    pub agent_id: String,
256    pub agent_type: Option<String>,
257    pub content: String,
258    pub total_tool_use_count: usize,
259    pub total_duration_ms: u64,
260    pub total_tokens: usize,
261    pub usage: TokenUsage,
262}
263
264/// Finalize an agent run and produce a result.
265pub fn finalize_agent_tool(
266    messages: &[serde_json::Value],
267    agent_id: &str,
268    agent_type: &str,
269    start_time_ms: u64,
270) -> Result<AgentToolResult, String> {
271    let last_assistant = get_last_assistant_message(messages)
272        .ok_or_else(|| "No assistant messages found".to_string())?;
273
274    // Extract text content
275    let content = last_assistant
276        .get("message")
277        .and_then(|m| m.get("content"))
278        .and_then(|c| c.as_array())
279        .map(|arr| extract_text_content(arr, "\n"))
280        .unwrap_or_default();
281
282    let total_tool_use_count = count_tool_uses(messages);
283
284    // Extract usage from last assistant message
285    let usage = last_assistant
286        .get("message")
287        .and_then(|m| m.get("usage"))
288        .map(|u| TokenUsage {
289            input_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
290            output_tokens: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
291            cache_creation_input_tokens: u
292                .get("cache_creation_input_tokens")
293                .and_then(|v| v.as_u64())
294                .unwrap_or(0) as usize,
295            cache_read_input_tokens: u
296                .get("cache_read_input_tokens")
297                .and_then(|v| v.as_u64())
298                .unwrap_or(0) as usize,
299        })
300        .unwrap_or_default();
301
302    let total_tokens = usage.input_tokens
303        + usage.output_tokens
304        + usage.cache_creation_input_tokens
305        + usage.cache_read_input_tokens;
306
307    Ok(AgentToolResult {
308        agent_id: agent_id.to_string(),
309        agent_type: Some(agent_type.to_string()),
310        content,
311        total_tool_use_count,
312        total_duration_ms: (std::time::SystemTime::now()
313            .duration_since(std::time::UNIX_EPOCH)
314            .unwrap_or_default()
315            .as_millis() as u64)
316            .saturating_sub(start_time_ms),
317        total_tokens,
318        usage,
319    })
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    fn make_agent_def(tools: Vec<&str>) -> AgentDefinition {
327        AgentDefinition {
328            agent_type: "test".to_string(),
329            when_to_use: "test".to_string(),
330            tools: tools.into_iter().map(|s| s.to_string()).collect(),
331            source: "built-in".to_string(),
332            base_dir: "built-in".to_string(),
333            get_system_prompt: Arc::new(|| String::new()),
334            model: None,
335            disallowed_tools: vec![],
336            max_turns: None,
337            permission_mode: None,
338            effort: None,
339            color: None,
340            mcp_servers: vec![],
341            hooks: None,
342            skills: vec![],
343            background: false,
344            initial_prompt: None,
345            memory: None,
346            isolation: None,
347            required_mcp_servers: vec![],
348            omit_claude_md: false,
349            critical_system_reminder_experimental: None,
350        }
351    }
352
353    #[test]
354    fn test_resolve_wildcard() {
355        let agent = make_agent_def(vec!["*"]);
356        let available = vec!["Bash".to_string(), "Read".to_string()];
357        let resolved = resolve_agent_tools(&agent, &available, false);
358        assert!(resolved.has_wildcard);
359        assert_eq!(resolved.resolved_tool_names.len(), 2);
360    }
361
362    #[test]
363    fn test_resolve_specific_tools() {
364        let agent = make_agent_def(vec!["Bash"]);
365        let available = vec!["Bash".to_string(), "Read".to_string()];
366        let resolved = resolve_agent_tools(&agent, &available, false);
367        assert!(!resolved.has_wildcard);
368        assert_eq!(resolved.resolved_tool_names, vec!["Bash"]);
369    }
370
371    #[test]
372    fn test_extract_text_content() {
373        let content = vec![
374            serde_json::json!({"type": "text", "text": "hello"}),
375            serde_json::json!({"type": "tool_use", "name": "Bash"}),
376            serde_json::json!({"type": "text", "text": "world"}),
377        ];
378        assert_eq!(extract_text_content(&content, " "), "hello world");
379    }
380
381    #[test]
382    fn test_count_tool_uses() {
383        let messages = vec![serde_json::json!({
384            "type": "assistant",
385            "message": {
386                "content": [
387                    {"type": "tool_use", "id": "1", "name": "Bash"},
388                    {"type": "tool_use", "id": "2", "name": "Read"},
389                ]
390            }
391        })];
392        assert_eq!(count_tool_uses(&messages), 2);
393    }
394}