use std::collections::HashMap;
use serde_json::Value;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolResultKind {
FileRead,
Shell,
Search,
Other,
}
pub fn classify_tool_name(name: &str) -> ToolResultKind {
let n = name.to_ascii_lowercase();
const FILE_READ: &[&str] = &[
"read_file",
"readfile",
"file_read",
"fsread",
"fs_read",
"view_file",
"viewfile",
"open_file",
"notebookread",
"notebook_read",
"cat_file",
"get_file",
"fetch_file",
"ctx_read",
"str_replace_editor", ];
if FILE_READ.iter().any(|k| n.contains(k)) {
return ToolResultKind::FileRead;
}
if matches!(n.as_str(), "read" | "view" | "cat" | "open") {
return ToolResultKind::FileRead;
}
const SEARCH: &[&str] = &[
"grep",
"ripgrep",
"search",
"find",
"glob",
"list_dir",
"listdir",
"list_files",
"listfiles",
"ls",
"codebase_search",
"ctx_search",
"ctx_tree",
];
if SEARCH.iter().any(|k| n.contains(k)) {
return ToolResultKind::Search;
}
const SHELL: &[&str] = &[
"bash",
"shell",
"terminal",
"run_command",
"run_terminal",
"runterminal",
"execute_command",
"exec_command",
"command_exec",
"ctx_shell",
];
if SHELL.iter().any(|k| n.contains(k)) {
return ToolResultKind::Shell;
}
if matches!(n.as_str(), "run" | "exec" | "execute" | "command" | "sh") {
return ToolResultKind::Shell;
}
ToolResultKind::Other
}
pub fn anthropic_tool_names(messages: &[Value]) -> HashMap<String, String> {
let mut map = HashMap::new();
for msg in messages {
let Some(blocks) = msg.get("content").and_then(|c| c.as_array()) else {
continue;
};
for block in blocks {
if block.get("type").and_then(|t| t.as_str()) != Some("tool_use") {
continue;
}
if let (Some(id), Some(name)) = (
block.get("id").and_then(|v| v.as_str()),
block.get("name").and_then(|v| v.as_str()),
) {
map.insert(id.to_string(), name.to_string());
}
}
}
map
}
pub fn openai_tool_names(messages: &[Value]) -> HashMap<String, String> {
let mut map = HashMap::new();
for msg in messages {
let Some(calls) = msg.get("tool_calls").and_then(|c| c.as_array()) else {
continue;
};
for call in calls {
let id = call.get("id").and_then(|v| v.as_str());
let name = call
.get("function")
.and_then(|f| f.get("name"))
.and_then(|v| v.as_str());
if let (Some(id), Some(name)) = (id, name) {
map.insert(id.to_string(), name.to_string());
}
}
}
map
}
pub fn responses_tool_names(input: &[Value]) -> HashMap<String, String> {
let mut map = HashMap::new();
for item in input {
if item.get("type").and_then(|t| t.as_str()) != Some("function_call") {
continue;
}
if let (Some(id), Some(name)) = (
item.get("call_id").and_then(|v| v.as_str()),
item.get("name").and_then(|v| v.as_str()),
) {
map.insert(id.to_string(), name.to_string());
}
}
map
}
pub fn should_protect(kind: ToolResultKind, content: &str) -> bool {
match kind {
ToolResultKind::FileRead => true,
ToolResultKind::Other => looks_like_source_code(content),
ToolResultKind::Shell | ToolResultKind::Search => false,
}
}
pub fn looks_like_source_code(content: &str) -> bool {
let mut code_signals = 0usize;
let mut shell_signals = 0usize;
let mut considered = 0usize;
for raw in content.lines().take(200) {
let line = raw.trim_end();
let trimmed = line.trim_start();
if trimmed.is_empty() {
continue;
}
considered += 1;
if trimmed.starts_with("$ ")
|| trimmed.starts_with("% ")
|| trimmed.starts_with(">>> ")
|| trimmed.starts_with("warning:")
|| trimmed.starts_with("error:")
|| trimmed.starts_with("error[")
|| trimmed.starts_with("INFO ")
|| trimmed.starts_with("WARN ")
|| trimmed.starts_with("DEBUG ")
|| trimmed.starts_with("ERROR ")
|| trimmed.starts_with("Compiling ")
|| trimmed.starts_with("Downloaded ")
|| trimmed.starts_with("test result:")
{
shell_signals += 1;
continue;
}
let is_indented = line.len() != trimmed.len();
let has_code_punct = trimmed.ends_with('{')
|| trimmed.ends_with('}')
|| trimmed.ends_with(';')
|| trimmed.ends_with("=>")
|| trimmed.ends_with("->")
|| trimmed.ends_with(':');
let has_keyword = [
"fn ",
"def ",
"class ",
"import ",
"from ",
"function ",
"func ",
"pub ",
"const ",
"let ",
"var ",
"package ",
"public ",
"private ",
"struct ",
"enum ",
"impl ",
"#include",
"return ",
"async ",
"export ",
]
.iter()
.any(|k| trimmed.starts_with(k) || trimmed.contains(k));
if (is_indented && has_code_punct) || has_keyword {
code_signals += 1;
}
}
if considered < 5 || shell_signals > 0 {
return false;
}
code_signals * 2 >= considered
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classifies_file_read_tools() {
for name in [
"Read",
"read_file",
"view_file",
"ctx_read",
"mcp__fs__readFile",
] {
assert_eq!(
classify_tool_name(name),
ToolResultKind::FileRead,
"{name} should be FileRead"
);
}
}
#[test]
fn classifies_shell_and_search() {
assert_eq!(classify_tool_name("Bash"), ToolResultKind::Shell);
assert_eq!(
classify_tool_name("run_terminal_cmd"),
ToolResultKind::Shell
);
assert_eq!(classify_tool_name("Grep"), ToolResultKind::Search);
assert_eq!(
classify_tool_name("codebase_search"),
ToolResultKind::Search
);
}
#[test]
fn unknown_tool_is_other() {
assert_eq!(classify_tool_name("submit_pr"), ToolResultKind::Other);
}
#[test]
fn anthropic_names_resolve_from_tool_use() {
let messages = vec![
serde_json::json!({
"role": "assistant",
"content": [
{"type": "text", "text": "reading"},
{"type": "tool_use", "id": "toolu_1", "name": "Read", "input": {}}
]
}),
serde_json::json!({
"role": "user",
"content": [{"type": "tool_result", "tool_use_id": "toolu_1", "content": "x"}]
}),
];
let names = anthropic_tool_names(&messages);
assert_eq!(names.get("toolu_1").map(String::as_str), Some("Read"));
}
#[test]
fn openai_names_resolve_from_tool_calls() {
let messages = vec![serde_json::json!({
"role": "assistant",
"tool_calls": [{"id": "call_1", "type": "function", "function": {"name": "read_file"}}]
})];
let names = openai_tool_names(&messages);
assert_eq!(names.get("call_1").map(String::as_str), Some("read_file"));
}
#[test]
fn responses_names_resolve_from_function_call() {
let input = vec![serde_json::json!({
"type": "function_call", "call_id": "call_1", "name": "Read", "arguments": "{}"
})];
let names = responses_tool_names(&input);
assert_eq!(names.get("call_1").map(String::as_str), Some("Read"));
}
#[test]
fn source_code_detected() {
let code = "pub fn build(cfg: &Config) -> Result<App> {\n let mut app = App::new();\n app.configure(cfg);\n for route in cfg.routes() {\n app.register(route);\n }\n Ok(app)\n}";
assert!(looks_like_source_code(code));
}
#[test]
fn command_output_not_code() {
let log = "$ cargo build\n Compiling foo v0.1.0\n Compiling bar v0.2.0\nwarning: unused variable\n Finished dev target\nerror: could not compile";
assert!(!looks_like_source_code(log));
}
#[test]
fn plain_prose_not_code() {
let prose = "This is a normal paragraph of text.\nIt has several sentences.\nNone of them are code.\nThey are just words on lines.\nMore words follow here.";
assert!(!looks_like_source_code(prose));
}
}