1use std::path::Path;
8
9use super::{GrantTarget, PermissionLevel, PermissionRequest};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ToolCategory {
14 FileRead,
16 FileWrite,
18 CommandExec,
20 Network,
22 UserInteraction,
24 PermissionManagement,
26}
27
28impl ToolCategory {
29 pub fn default_level(&self) -> PermissionLevel {
31 match self {
32 ToolCategory::FileRead => PermissionLevel::Read,
33 ToolCategory::FileWrite => PermissionLevel::Write,
34 ToolCategory::CommandExec => PermissionLevel::Execute,
35 ToolCategory::Network => PermissionLevel::Read,
36 ToolCategory::UserInteraction => PermissionLevel::None,
37 ToolCategory::PermissionManagement => PermissionLevel::None,
38 }
39 }
40
41 pub fn requires_permission(&self) -> bool {
43 match self {
44 ToolCategory::FileRead => true,
45 ToolCategory::FileWrite => true,
46 ToolCategory::CommandExec => true,
47 ToolCategory::Network => true,
48 ToolCategory::UserInteraction => false,
49 ToolCategory::PermissionManagement => false,
50 }
51 }
52}
53
54pub struct ToolPermissions;
56
57impl ToolPermissions {
58 pub fn file_read(tool_use_id: &str, path: impl AsRef<Path>) -> PermissionRequest {
60 let path = path.as_ref();
61 PermissionRequest::new(
62 tool_use_id,
63 GrantTarget::path(path, false),
64 PermissionLevel::Read,
65 format!("Read file: {}", path.display()),
66 )
67 .with_tool("read_file")
68 }
69
70 pub fn file_write(
72 tool_use_id: &str,
73 path: impl AsRef<Path>,
74 is_create: bool,
75 ) -> PermissionRequest {
76 let path = path.as_ref();
77 let action = if is_create { "Create" } else { "Write" };
78 PermissionRequest::new(
79 tool_use_id,
80 GrantTarget::path(path, false),
81 PermissionLevel::Write,
82 format!("{} file: {}", action, path.display()),
83 )
84 .with_tool("write_file")
85 }
86
87 pub fn file_edit(tool_use_id: &str, path: impl AsRef<Path>) -> PermissionRequest {
89 let path = path.as_ref();
90 PermissionRequest::new(
91 tool_use_id,
92 GrantTarget::path(path, false),
93 PermissionLevel::Write,
94 format!("Edit file: {}", path.display()),
95 )
96 .with_tool("edit_file")
97 }
98
99 pub fn multi_edit(tool_use_id: &str, paths: &[impl AsRef<Path>]) -> Vec<PermissionRequest> {
101 paths
102 .iter()
103 .enumerate()
104 .map(|(i, path)| {
105 let path = path.as_ref();
106 PermissionRequest::new(
107 format!("{}-{}", tool_use_id, i),
108 GrantTarget::path(path, false),
109 PermissionLevel::Write,
110 format!("Edit file: {}", path.display()),
111 )
112 .with_tool("multi_edit")
113 })
114 .collect()
115 }
116
117 pub fn glob_search(tool_use_id: &str, directory: impl AsRef<Path>) -> PermissionRequest {
119 let directory = directory.as_ref();
120 PermissionRequest::new(
121 tool_use_id,
122 GrantTarget::path(directory, true),
123 PermissionLevel::Read,
124 format!("Search in: {}", directory.display()),
125 )
126 .with_tool("glob")
127 }
128
129 pub fn grep_search(tool_use_id: &str, directory: impl AsRef<Path>) -> PermissionRequest {
131 let directory = directory.as_ref();
132 PermissionRequest::new(
133 tool_use_id,
134 GrantTarget::path(directory, true),
135 PermissionLevel::Read,
136 format!("Search content in: {}", directory.display()),
137 )
138 .with_tool("grep")
139 }
140
141 pub fn list_directory(tool_use_id: &str, directory: impl AsRef<Path>) -> PermissionRequest {
143 let directory = directory.as_ref();
144 PermissionRequest::new(
145 tool_use_id,
146 GrantTarget::path(directory, false),
147 PermissionLevel::Read,
148 format!("List directory: {}", directory.display()),
149 )
150 .with_tool("ls")
151 }
152
153 pub fn bash_command(tool_use_id: &str, command: &str) -> PermissionRequest {
160 let level = classify_bash_command(command);
161 PermissionRequest::new(
162 tool_use_id,
163 GrantTarget::command(command),
164 level,
165 format!("Execute: {}", truncate_command(command, 60)),
166 )
167 .with_tool("bash")
168 }
169
170 pub fn network_access(tool_use_id: &str, domain: &str, method: &str) -> PermissionRequest {
172 let level = match method.to_uppercase().as_str() {
173 "GET" | "HEAD" | "OPTIONS" => PermissionLevel::Read,
174 "POST" | "PUT" | "PATCH" => PermissionLevel::Write,
175 "DELETE" => PermissionLevel::Execute,
176 _ => PermissionLevel::Execute,
177 };
178 PermissionRequest::new(
179 tool_use_id,
180 GrantTarget::domain(domain),
181 level,
182 format!("{} {}", method.to_uppercase(), domain),
183 )
184 .with_tool("web_fetch")
185 }
186
187 pub fn web_search(tool_use_id: &str, query: &str) -> PermissionRequest {
189 PermissionRequest::new(
190 tool_use_id,
191 GrantTarget::domain("*"),
192 PermissionLevel::Read,
193 format!("Web search: {}", truncate_command(query, 40)),
194 )
195 .with_tool("web_search")
196 }
197}
198
199fn classify_bash_command(command: &str) -> PermissionLevel {
201 let command_lower = command.to_lowercase();
202 let first_word = command_lower.split_whitespace().next().unwrap_or("");
203
204 let dangerous_patterns = [
206 "rm -rf",
207 "rm -fr",
208 "sudo",
209 "chmod -R",
210 "chown -R",
211 "mkfs",
212 "dd if=",
213 ":(){ :|:& };:",
214 "> /dev/",
215 "shutdown",
216 "reboot",
217 "init ",
218 "systemctl",
219 ];
220
221 for pattern in dangerous_patterns {
222 if command_lower.contains(pattern) {
223 return PermissionLevel::Admin;
224 }
225 }
226
227 if first_word == "rm" || command_lower.contains("--delete") {
229 return PermissionLevel::Admin;
230 }
231
232 let readonly_commands = [
234 "ls",
235 "cat",
236 "head",
237 "tail",
238 "less",
239 "more",
240 "pwd",
241 "whoami",
242 "echo",
243 "printf",
244 "date",
245 "which",
246 "whereis",
247 "file",
248 "stat",
249 "wc",
250 "grep",
251 "find",
252 "locate",
253 "tree",
254 "df",
255 "du",
256 "git status",
257 "git log",
258 "git diff",
259 "git show",
260 "git branch",
261 ];
262
263 for readonly in readonly_commands {
264 if command_lower.starts_with(readonly) {
265 return PermissionLevel::Read;
266 }
267 }
268
269 PermissionLevel::Execute
271}
272
273fn truncate_command(command: &str, max_len: usize) -> String {
275 if command.len() <= max_len {
276 command.to_string()
277 } else {
278 format!("{}...", &command[..max_len - 3])
279 }
280}
281
282pub fn get_tool_category(tool_name: &str) -> ToolCategory {
284 match tool_name {
285 "read_file" | "glob" | "grep" | "ls" => ToolCategory::FileRead,
286 "write_file" | "edit_file" | "multi_edit" => ToolCategory::FileWrite,
287 "bash" => ToolCategory::CommandExec,
288 "web_search" | "web_fetch" => ToolCategory::Network,
289 "ask_user_questions" => ToolCategory::UserInteraction,
290 "ask_for_permissions" => ToolCategory::PermissionManagement,
291 _ => ToolCategory::FileRead, }
293}
294
295#[cfg(test)]
296mod tests {
297 use super::*;
298
299 #[test]
300 fn test_file_read_request() {
301 let request = ToolPermissions::file_read("tool-1", "/project/src/main.rs");
302 assert_eq!(request.required_level, PermissionLevel::Read);
303 assert_eq!(request.tool_name, Some("read_file".to_string()));
304 }
305
306 #[test]
307 fn test_file_write_request() {
308 let request = ToolPermissions::file_write("tool-1", "/project/new_file.rs", true);
309 assert_eq!(request.required_level, PermissionLevel::Write);
310 assert!(request.description.contains("Create"));
311 }
312
313 #[test]
314 fn test_bash_command_readonly() {
315 let request = ToolPermissions::bash_command("tool-1", "ls -la");
316 assert_eq!(request.required_level, PermissionLevel::Read);
317 }
318
319 #[test]
320 fn test_bash_command_execute() {
321 let request = ToolPermissions::bash_command("tool-1", "cargo build");
322 assert_eq!(request.required_level, PermissionLevel::Execute);
323 }
324
325 #[test]
326 fn test_bash_command_admin() {
327 let request = ToolPermissions::bash_command("tool-1", "sudo apt install foo");
328 assert_eq!(request.required_level, PermissionLevel::Admin);
329
330 let request2 = ToolPermissions::bash_command("tool-1", "rm -rf /tmp/foo");
331 assert_eq!(request2.required_level, PermissionLevel::Admin);
332 }
333
334 #[test]
335 fn test_network_access() {
336 let get_request = ToolPermissions::network_access("tool-1", "api.github.com", "GET");
337 assert_eq!(get_request.required_level, PermissionLevel::Read);
338
339 let post_request = ToolPermissions::network_access("tool-1", "api.github.com", "POST");
340 assert_eq!(post_request.required_level, PermissionLevel::Write);
341
342 let delete_request = ToolPermissions::network_access("tool-1", "api.github.com", "DELETE");
343 assert_eq!(delete_request.required_level, PermissionLevel::Execute);
344 }
345
346 #[test]
347 fn test_multi_edit() {
348 let paths = vec!["/file1.rs", "/file2.rs"];
349 let requests = ToolPermissions::multi_edit("tool-1", &paths);
350 assert_eq!(requests.len(), 2);
351 assert_eq!(requests[0].id, "tool-1-0");
352 assert_eq!(requests[1].id, "tool-1-1");
353 }
354
355 #[test]
356 fn test_tool_category() {
357 assert_eq!(get_tool_category("read_file"), ToolCategory::FileRead);
358 assert_eq!(get_tool_category("write_file"), ToolCategory::FileWrite);
359 assert_eq!(get_tool_category("bash"), ToolCategory::CommandExec);
360 assert_eq!(get_tool_category("web_search"), ToolCategory::Network);
361 }
362
363 #[test]
364 fn test_category_requires_permission() {
365 assert!(ToolCategory::FileRead.requires_permission());
366 assert!(ToolCategory::FileWrite.requires_permission());
367 assert!(ToolCategory::CommandExec.requires_permission());
368 assert!(!ToolCategory::UserInteraction.requires_permission());
369 }
370}