1#![allow(dead_code)]
3use std::sync::Arc;
4
5use std::collections::{HashMap, HashSet};
6
7use super::constants::{
8 AGENT_TOOL_NAME, ALL_AGENT_DISALLOWED_TOOLS, ASYNC_AGENT_ALLOWED_TOOLS,
9 CUSTOM_AGENT_DISALLOWED_TOOLS, FORK_BOILERPLATE_TAG, FORK_DIRECTIVE_PREFIX,
10};
11use super::load_agents_dir::AgentDefinition;
12
13#[derive(Debug, Clone)]
15pub struct ResolvedAgentTools {
16 pub has_wildcard: bool,
17 pub valid_tools: Vec<String>,
18 pub invalid_tools: Vec<String>,
19 pub resolved_tool_names: Vec<String>,
20 pub allowed_agent_types: Option<Vec<String>>,
21}
22
23pub fn filter_tools_for_agent(
25 available_tools: &[String],
26 is_built_in: bool,
27 is_async: bool,
28) -> Vec<String> {
29 available_tools
30 .iter()
31 .filter(|tool| {
32 if tool.starts_with("mcp__") {
34 return true;
35 }
36 if ALL_AGENT_DISALLOWED_TOOLS.contains(&tool.as_str()) {
38 return false;
39 }
40 if !is_built_in && CUSTOM_AGENT_DISALLOWED_TOOLS.contains(&tool.as_str()) {
42 return false;
43 }
44 if is_async && !ASYNC_AGENT_ALLOWED_TOOLS.contains(&tool.as_str()) {
46 return false;
47 }
48 true
49 })
50 .cloned()
51 .collect()
52}
53
54fn parse_tool_spec(spec: &str) -> (String, Option<String>) {
56 if let Some(pos) = spec.find('(') {
57 let tool_name = spec[..pos].trim().to_string();
58 let rule_content = spec[pos..].trim().to_string();
59 (tool_name, Some(rule_content))
60 } else {
61 (spec.trim().to_string(), None)
62 }
63}
64
65pub fn resolve_agent_tools(
68 agent_definition: &AgentDefinition,
69 available_tools: &[String],
70 is_async: bool,
71) -> ResolvedAgentTools {
72 let filtered_available = filter_tools_for_agent(
74 available_tools,
75 agent_definition.source == "built-in",
76 is_async,
77 );
78
79 let disallowed_set: HashSet<&str> = agent_definition
81 .disallowed_tools
82 .iter()
83 .map(|s| s.as_str())
84 .collect();
85
86 let allowed_available: Vec<String> = filtered_available
88 .into_iter()
89 .filter(|t| !disallowed_set.contains(t.as_str()))
90 .collect();
91
92 let has_wildcard = agent_definition.tools.is_empty()
94 || agent_definition.tools == vec!["*"]
95 || (agent_definition.tools.len() == 1 && agent_definition.tools[0] == "*");
96
97 if has_wildcard {
98 return ResolvedAgentTools {
99 has_wildcard: true,
100 valid_tools: vec![],
101 invalid_tools: vec![],
102 resolved_tool_names: allowed_available,
103 allowed_agent_types: None,
104 };
105 }
106
107 let available_map: HashMap<&str, &String> =
108 allowed_available.iter().map(|t| (t.as_str(), t)).collect();
109
110 let mut valid_tools: Vec<String> = Vec::new();
111 let mut invalid_tools: Vec<String> = Vec::new();
112 let mut resolved: Vec<String> = Vec::new();
113 let mut resolved_set: HashSet<String> = HashSet::new();
114 let mut allowed_agent_types: Option<Vec<String>> = None;
115
116 for tool_spec in &agent_definition.tools {
117 let (tool_name, rule_content) = parse_tool_spec(tool_spec);
118
119 if tool_name == AGENT_TOOL_NAME {
121 if let Some(ref rules) = rule_content {
122 let types: Vec<String> = rules
124 .trim_matches(|c: char| c == '(' || c == ')')
125 .split(',')
126 .map(|s| s.trim().to_string())
127 .collect();
128 allowed_agent_types = Some(types);
129 }
130 valid_tools.push(tool_spec.clone());
131 continue;
132 }
133
134 if available_map.contains_key(tool_name.as_str()) {
135 valid_tools.push(tool_spec.clone());
136 if resolved_set.insert(tool_name.clone()) {
137 resolved.push(tool_name);
138 }
139 } else {
140 invalid_tools.push(tool_spec.clone());
141 }
142 }
143
144 ResolvedAgentTools {
145 has_wildcard: false,
146 valid_tools,
147 invalid_tools,
148 allowed_agent_types,
149 resolved_tool_names: resolved,
150 }
151}
152
153pub fn count_tool_uses(messages: &[serde_json::Value]) -> usize {
155 let mut count = 0;
156 for msg in messages {
157 if msg.get("type").and_then(|t| t.as_str()) == Some("assistant") {
158 if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
159 if let Some(arr) = content.as_array() {
160 for block in arr {
161 if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
162 count += 1;
163 }
164 }
165 }
166 }
167 }
168 }
169 count
170}
171
172pub fn extract_text_content(content: &[serde_json::Value], separator: &str) -> String {
174 let texts: Vec<String> = content
175 .iter()
176 .filter(|block| block.get("type").and_then(|t| t.as_str()) == Some("text"))
177 .filter_map(|block| block.get("text").and_then(|t| t.as_str()))
178 .map(|t| t.to_string())
179 .collect();
180 texts.join(separator)
181}
182
183pub fn get_last_assistant_message(messages: &[serde_json::Value]) -> Option<&serde_json::Value> {
185 messages
186 .iter()
187 .rev()
188 .find(|msg| msg.get("type").and_then(|t| t.as_str()) == Some("assistant"))
189}
190
191pub fn extract_partial_result(messages: &[serde_json::Value]) -> Option<String> {
194 for msg in messages.iter().rev() {
195 if msg.get("type").and_then(|t| t.as_str()) != Some("assistant") {
196 continue;
197 }
198 if let Some(content) = msg.get("message").and_then(|m| m.get("content")) {
199 if let Some(arr) = content.as_array() {
200 let text = extract_text_content(arr, "\n");
201 if !text.is_empty() {
202 return Some(text);
203 }
204 }
205 }
206 }
207 None
208}
209
210pub fn extract_partial_result_from_engine(messages: &[crate::types::Message]) -> Option<String> {
214 for msg in messages.iter().rev() {
215 if msg.role != crate::types::MessageRole::Assistant {
216 continue;
217 }
218 if !msg.content.is_empty() {
219 return Some(msg.content.clone());
220 }
221 }
222 None
223}
224
225pub fn get_last_tool_use_name(message: &serde_json::Value) -> Option<String> {
227 if message.get("type").and_then(|t| t.as_str()) != Some("assistant") {
228 return None;
229 }
230 let content = message.get("message").and_then(|m| m.get("content"))?;
231 let arr = content.as_array()?;
232 for block in arr.iter().rev() {
233 if block.get("type").and_then(|t| t.as_str()) == Some("tool_use") {
234 return block
235 .get("name")
236 .and_then(|n| n.as_str())
237 .map(|s| s.to_string());
238 }
239 }
240 None
241}
242
243#[derive(Debug, Clone, Default)]
245pub struct TokenUsage {
246 pub input_tokens: usize,
247 pub output_tokens: usize,
248 pub cache_creation_input_tokens: usize,
249 pub cache_read_input_tokens: usize,
250}
251
252#[derive(Debug, Clone)]
254pub struct AgentToolResult {
255 pub agent_id: String,
256 pub agent_type: Option<String>,
257 pub content: String,
258 pub total_tool_use_count: usize,
259 pub total_duration_ms: u64,
260 pub total_tokens: usize,
261 pub usage: TokenUsage,
262}
263
264pub fn finalize_agent_tool(
266 messages: &[serde_json::Value],
267 agent_id: &str,
268 agent_type: &str,
269 start_time_ms: u64,
270) -> Result<AgentToolResult, String> {
271 let last_assistant = get_last_assistant_message(messages)
272 .ok_or_else(|| "No assistant messages found".to_string())?;
273
274 let content = last_assistant
276 .get("message")
277 .and_then(|m| m.get("content"))
278 .and_then(|c| c.as_array())
279 .map(|arr| extract_text_content(arr, "\n"))
280 .unwrap_or_default();
281
282 let total_tool_use_count = count_tool_uses(messages);
283
284 let usage = last_assistant
286 .get("message")
287 .and_then(|m| m.get("usage"))
288 .map(|u| TokenUsage {
289 input_tokens: u.get("input_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
290 output_tokens: u.get("output_tokens").and_then(|v| v.as_u64()).unwrap_or(0) as usize,
291 cache_creation_input_tokens: u
292 .get("cache_creation_input_tokens")
293 .and_then(|v| v.as_u64())
294 .unwrap_or(0) as usize,
295 cache_read_input_tokens: u
296 .get("cache_read_input_tokens")
297 .and_then(|v| v.as_u64())
298 .unwrap_or(0) as usize,
299 })
300 .unwrap_or_default();
301
302 let total_tokens = usage.input_tokens
303 + usage.output_tokens
304 + usage.cache_creation_input_tokens
305 + usage.cache_read_input_tokens;
306
307 Ok(AgentToolResult {
308 agent_id: agent_id.to_string(),
309 agent_type: Some(agent_type.to_string()),
310 content,
311 total_tool_use_count,
312 total_duration_ms: (std::time::SystemTime::now()
313 .duration_since(std::time::UNIX_EPOCH)
314 .unwrap_or_default()
315 .as_millis() as u64)
316 .saturating_sub(start_time_ms),
317 total_tokens,
318 usage,
319 })
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 fn make_agent_def(tools: Vec<&str>) -> AgentDefinition {
327 AgentDefinition {
328 agent_type: "test".to_string(),
329 when_to_use: "test".to_string(),
330 tools: tools.into_iter().map(|s| s.to_string()).collect(),
331 source: "built-in".to_string(),
332 base_dir: "built-in".to_string(),
333 get_system_prompt: Arc::new(|| String::new()),
334 model: None,
335 disallowed_tools: vec![],
336 max_turns: None,
337 permission_mode: None,
338 effort: None,
339 color: None,
340 mcp_servers: vec![],
341 hooks: None,
342 skills: vec![],
343 background: false,
344 initial_prompt: None,
345 memory: None,
346 isolation: None,
347 required_mcp_servers: vec![],
348 omit_claude_md: false,
349 critical_system_reminder_experimental: None,
350 }
351 }
352
353 #[test]
354 fn test_resolve_wildcard() {
355 let agent = make_agent_def(vec!["*"]);
356 let available = vec!["Bash".to_string(), "Read".to_string()];
357 let resolved = resolve_agent_tools(&agent, &available, false);
358 assert!(resolved.has_wildcard);
359 assert_eq!(resolved.resolved_tool_names.len(), 2);
360 }
361
362 #[test]
363 fn test_resolve_specific_tools() {
364 let agent = make_agent_def(vec!["Bash"]);
365 let available = vec!["Bash".to_string(), "Read".to_string()];
366 let resolved = resolve_agent_tools(&agent, &available, false);
367 assert!(!resolved.has_wildcard);
368 assert_eq!(resolved.resolved_tool_names, vec!["Bash"]);
369 }
370
371 #[test]
372 fn test_extract_text_content() {
373 let content = vec![
374 serde_json::json!({"type": "text", "text": "hello"}),
375 serde_json::json!({"type": "tool_use", "name": "Bash"}),
376 serde_json::json!({"type": "text", "text": "world"}),
377 ];
378 assert_eq!(extract_text_content(&content, " "), "hello world");
379 }
380
381 #[test]
382 fn test_count_tool_uses() {
383 let messages = vec![serde_json::json!({
384 "type": "assistant",
385 "message": {
386 "content": [
387 {"type": "tool_use", "id": "1", "name": "Bash"},
388 {"type": "tool_use", "id": "2", "name": "Read"},
389 ]
390 }
391 })];
392 assert_eq!(count_tool_uses(&messages), 2);
393 }
394}