1use std::collections::HashMap;
12
13use serde_json::Value;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum ToolResultKind {
18 FileRead,
20 Shell,
22 Search,
24 Other,
26}
27
28pub fn classify_tool_name(name: &str) -> ToolResultKind {
33 let n = name.to_ascii_lowercase();
34
35 const FILE_READ: &[&str] = &[
37 "read_file",
38 "readfile",
39 "file_read",
40 "fsread",
41 "fs_read",
42 "view_file",
43 "viewfile",
44 "open_file",
45 "notebookread",
46 "notebook_read",
47 "cat_file",
48 "get_file",
49 "fetch_file",
50 "ctx_read",
51 "str_replace_editor", ];
53 if FILE_READ.iter().any(|k| n.contains(k)) {
54 return ToolResultKind::FileRead;
55 }
56 if matches!(n.as_str(), "read" | "view" | "cat" | "open") {
58 return ToolResultKind::FileRead;
59 }
60
61 const SEARCH: &[&str] = &[
62 "grep",
63 "ripgrep",
64 "search",
65 "find",
66 "glob",
67 "list_dir",
68 "listdir",
69 "list_files",
70 "listfiles",
71 "ls",
72 "codebase_search",
73 "ctx_search",
74 "ctx_tree",
75 ];
76 if SEARCH.iter().any(|k| n.contains(k)) {
77 return ToolResultKind::Search;
78 }
79
80 const SHELL: &[&str] = &[
81 "bash",
82 "shell",
83 "terminal",
84 "run_command",
85 "run_terminal",
86 "runterminal",
87 "execute_command",
88 "exec_command",
89 "command_exec",
90 "ctx_shell",
91 ];
92 if SHELL.iter().any(|k| n.contains(k)) {
93 return ToolResultKind::Shell;
94 }
95 if matches!(n.as_str(), "run" | "exec" | "execute" | "command" | "sh") {
96 return ToolResultKind::Shell;
97 }
98
99 ToolResultKind::Other
100}
101
102pub fn anthropic_tool_names(messages: &[Value]) -> HashMap<String, String> {
106 let mut map = HashMap::new();
107 for msg in messages {
108 let Some(blocks) = msg.get("content").and_then(|c| c.as_array()) else {
109 continue;
110 };
111 for block in blocks {
112 if block.get("type").and_then(|t| t.as_str()) != Some("tool_use") {
113 continue;
114 }
115 if let (Some(id), Some(name)) = (
116 block.get("id").and_then(|v| v.as_str()),
117 block.get("name").and_then(|v| v.as_str()),
118 ) {
119 map.insert(id.to_string(), name.to_string());
120 }
121 }
122 }
123 map
124}
125
126pub fn openai_tool_names(messages: &[Value]) -> HashMap<String, String> {
129 let mut map = HashMap::new();
130 for msg in messages {
131 let Some(calls) = msg.get("tool_calls").and_then(|c| c.as_array()) else {
132 continue;
133 };
134 for call in calls {
135 let id = call.get("id").and_then(|v| v.as_str());
136 let name = call
137 .get("function")
138 .and_then(|f| f.get("name"))
139 .and_then(|v| v.as_str());
140 if let (Some(id), Some(name)) = (id, name) {
141 map.insert(id.to_string(), name.to_string());
142 }
143 }
144 }
145 map
146}
147
148pub fn responses_tool_names(input: &[Value]) -> HashMap<String, String> {
151 let mut map = HashMap::new();
152 for item in input {
153 if item.get("type").and_then(|t| t.as_str()) != Some("function_call") {
154 continue;
155 }
156 if let (Some(id), Some(name)) = (
157 item.get("call_id").and_then(|v| v.as_str()),
158 item.get("name").and_then(|v| v.as_str()),
159 ) {
160 map.insert(id.to_string(), name.to_string());
161 }
162 }
163 map
164}
165
166pub fn should_protect(kind: ToolResultKind, content: &str) -> bool {
173 match kind {
174 ToolResultKind::FileRead => true,
175 ToolResultKind::Other => looks_like_source_code(content),
176 ToolResultKind::Shell | ToolResultKind::Search => false,
177 }
178}
179
180pub fn looks_like_source_code(content: &str) -> bool {
186 let mut code_signals = 0usize;
187 let mut shell_signals = 0usize;
188 let mut considered = 0usize;
189
190 for raw in content.lines().take(200) {
191 let line = raw.trim_end();
192 let trimmed = line.trim_start();
193 if trimmed.is_empty() {
194 continue;
195 }
196 considered += 1;
197
198 if trimmed.starts_with("$ ")
200 || trimmed.starts_with("% ")
201 || trimmed.starts_with(">>> ")
202 || trimmed.starts_with("warning:")
203 || trimmed.starts_with("error:")
204 || trimmed.starts_with("error[")
205 || trimmed.starts_with("INFO ")
206 || trimmed.starts_with("WARN ")
207 || trimmed.starts_with("DEBUG ")
208 || trimmed.starts_with("ERROR ")
209 || trimmed.starts_with("Compiling ")
210 || trimmed.starts_with("Downloaded ")
211 || trimmed.starts_with("test result:")
212 {
213 shell_signals += 1;
214 continue;
215 }
216
217 let is_indented = line.len() != trimmed.len();
219 let has_code_punct = trimmed.ends_with('{')
220 || trimmed.ends_with('}')
221 || trimmed.ends_with(';')
222 || trimmed.ends_with("=>")
223 || trimmed.ends_with("->")
224 || trimmed.ends_with(':');
225 let has_keyword = [
226 "fn ",
227 "def ",
228 "class ",
229 "import ",
230 "from ",
231 "function ",
232 "func ",
233 "pub ",
234 "const ",
235 "let ",
236 "var ",
237 "package ",
238 "public ",
239 "private ",
240 "struct ",
241 "enum ",
242 "impl ",
243 "#include",
244 "return ",
245 "async ",
246 "export ",
247 ]
248 .iter()
249 .any(|k| trimmed.starts_with(k) || trimmed.contains(k));
250
251 if (is_indented && has_code_punct) || has_keyword {
252 code_signals += 1;
253 }
254 }
255
256 if considered < 5 || shell_signals > 0 {
257 return false;
258 }
259 code_signals * 2 >= considered
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 #[test]
268 fn classifies_file_read_tools() {
269 for name in [
270 "Read",
271 "read_file",
272 "view_file",
273 "ctx_read",
274 "mcp__fs__readFile",
275 ] {
276 assert_eq!(
277 classify_tool_name(name),
278 ToolResultKind::FileRead,
279 "{name} should be FileRead"
280 );
281 }
282 }
283
284 #[test]
285 fn classifies_shell_and_search() {
286 assert_eq!(classify_tool_name("Bash"), ToolResultKind::Shell);
287 assert_eq!(
288 classify_tool_name("run_terminal_cmd"),
289 ToolResultKind::Shell
290 );
291 assert_eq!(classify_tool_name("Grep"), ToolResultKind::Search);
292 assert_eq!(
293 classify_tool_name("codebase_search"),
294 ToolResultKind::Search
295 );
296 }
297
298 #[test]
299 fn unknown_tool_is_other() {
300 assert_eq!(classify_tool_name("submit_pr"), ToolResultKind::Other);
301 }
302
303 #[test]
304 fn anthropic_names_resolve_from_tool_use() {
305 let messages = vec![
306 serde_json::json!({
307 "role": "assistant",
308 "content": [
309 {"type": "text", "text": "reading"},
310 {"type": "tool_use", "id": "toolu_1", "name": "Read", "input": {}}
311 ]
312 }),
313 serde_json::json!({
314 "role": "user",
315 "content": [{"type": "tool_result", "tool_use_id": "toolu_1", "content": "x"}]
316 }),
317 ];
318 let names = anthropic_tool_names(&messages);
319 assert_eq!(names.get("toolu_1").map(String::as_str), Some("Read"));
320 }
321
322 #[test]
323 fn openai_names_resolve_from_tool_calls() {
324 let messages = vec![serde_json::json!({
325 "role": "assistant",
326 "tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "read_file"}}]
327 })];
328 let names = openai_tool_names(&messages);
329 assert_eq!(names.get("call_1").map(String::as_str), Some("read_file"));
330 }
331
332 #[test]
333 fn responses_names_resolve_from_function_call() {
334 let input = vec![serde_json::json!({
335 "type": "function_call", "call_id": "call_1", "name": "Read", "arguments": "{}"
336 })];
337 let names = responses_tool_names(&input);
338 assert_eq!(names.get("call_1").map(String::as_str), Some("Read"));
339 }
340
341 #[test]
342 fn source_code_detected() {
343 let code = "pub fn build(cfg: &Config) -> Result<App> {\n let mut app = App::new();\n app.configure(cfg);\n for route in cfg.routes() {\n app.register(route);\n }\n Ok(app)\n}";
344 assert!(looks_like_source_code(code));
345 }
346
347 #[test]
348 fn command_output_not_code() {
349 let log = "$ cargo build\n Compiling foo v0.1.0\n Compiling bar v0.2.0\nwarning: unused variable\n Finished dev target\nerror: could not compile";
350 assert!(!looks_like_source_code(log));
351 }
352
353 #[test]
354 fn plain_prose_not_code() {
355 let prose = "This is a normal paragraph of text.\nIt has several sentences.\nNone of them are code.\nThey are just words on lines.\nMore words follow here.";
356 assert!(!looks_like_source_code(prose));
357 }
358}