1use std::path::Path;
8
9use super::{GrantTarget, PermissionLevel, PermissionRequest};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ToolCategory {
14 FileRead,
16 FileWrite,
18 CommandExec,
20 Network,
22 UserInteraction,
24 PermissionManagement,
26}
27
28impl ToolCategory {
29 pub fn default_level(&self) -> PermissionLevel {
31 match self {
32 ToolCategory::FileRead => PermissionLevel::Read,
33 ToolCategory::FileWrite => PermissionLevel::Write,
34 ToolCategory::CommandExec => PermissionLevel::Execute,
35 ToolCategory::Network => PermissionLevel::Read,
36 ToolCategory::UserInteraction => PermissionLevel::None,
37 ToolCategory::PermissionManagement => PermissionLevel::None,
38 }
39 }
40
41 pub fn requires_permission(&self) -> bool {
43 match self {
44 ToolCategory::FileRead => true,
45 ToolCategory::FileWrite => true,
46 ToolCategory::CommandExec => true,
47 ToolCategory::Network => true,
48 ToolCategory::UserInteraction => false,
49 ToolCategory::PermissionManagement => false,
50 }
51 }
52}
53
54pub struct ToolPermissions;
56
57impl ToolPermissions {
58 pub fn file_read(tool_use_id: &str, path: impl AsRef<Path>) -> PermissionRequest {
60 let path = path.as_ref();
61 PermissionRequest::new(
62 tool_use_id,
63 GrantTarget::path(path, false),
64 PermissionLevel::Read,
65 format!("Read file: {}", path.display()),
66 )
67 .with_tool("read_file")
68 }
69
70 pub fn file_write(tool_use_id: &str, path: impl AsRef<Path>, is_create: bool) -> PermissionRequest {
72 let path = path.as_ref();
73 let action = if is_create { "Create" } else { "Write" };
74 PermissionRequest::new(
75 tool_use_id,
76 GrantTarget::path(path, false),
77 PermissionLevel::Write,
78 format!("{} file: {}", action, path.display()),
79 )
80 .with_tool("write_file")
81 }
82
83 pub fn file_edit(tool_use_id: &str, path: impl AsRef<Path>) -> PermissionRequest {
85 let path = path.as_ref();
86 PermissionRequest::new(
87 tool_use_id,
88 GrantTarget::path(path, false),
89 PermissionLevel::Write,
90 format!("Edit file: {}", path.display()),
91 )
92 .with_tool("edit_file")
93 }
94
95 pub fn multi_edit(tool_use_id: &str, paths: &[impl AsRef<Path>]) -> Vec<PermissionRequest> {
97 paths
98 .iter()
99 .enumerate()
100 .map(|(i, path)| {
101 let path = path.as_ref();
102 PermissionRequest::new(
103 format!("{}-{}", tool_use_id, i),
104 GrantTarget::path(path, false),
105 PermissionLevel::Write,
106 format!("Edit file: {}", path.display()),
107 )
108 .with_tool("multi_edit")
109 })
110 .collect()
111 }
112
113 pub fn glob_search(tool_use_id: &str, directory: impl AsRef<Path>) -> PermissionRequest {
115 let directory = directory.as_ref();
116 PermissionRequest::new(
117 tool_use_id,
118 GrantTarget::path(directory, true),
119 PermissionLevel::Read,
120 format!("Search in: {}", directory.display()),
121 )
122 .with_tool("glob")
123 }
124
125 pub fn grep_search(tool_use_id: &str, directory: impl AsRef<Path>) -> PermissionRequest {
127 let directory = directory.as_ref();
128 PermissionRequest::new(
129 tool_use_id,
130 GrantTarget::path(directory, true),
131 PermissionLevel::Read,
132 format!("Search content in: {}", directory.display()),
133 )
134 .with_tool("grep")
135 }
136
137 pub fn list_directory(tool_use_id: &str, directory: impl AsRef<Path>) -> PermissionRequest {
139 let directory = directory.as_ref();
140 PermissionRequest::new(
141 tool_use_id,
142 GrantTarget::path(directory, false),
143 PermissionLevel::Read,
144 format!("List directory: {}", directory.display()),
145 )
146 .with_tool("ls")
147 }
148
149 pub fn bash_command(tool_use_id: &str, command: &str) -> PermissionRequest {
156 let level = classify_bash_command(command);
157 PermissionRequest::new(
158 tool_use_id,
159 GrantTarget::command(command),
160 level,
161 format!("Execute: {}", truncate_command(command, 60)),
162 )
163 .with_tool("bash")
164 }
165
166 pub fn network_access(tool_use_id: &str, domain: &str, method: &str) -> PermissionRequest {
168 let level = match method.to_uppercase().as_str() {
169 "GET" | "HEAD" | "OPTIONS" => PermissionLevel::Read,
170 "POST" | "PUT" | "PATCH" => PermissionLevel::Write,
171 "DELETE" => PermissionLevel::Execute,
172 _ => PermissionLevel::Execute,
173 };
174 PermissionRequest::new(
175 tool_use_id,
176 GrantTarget::domain(domain),
177 level,
178 format!("{} {}", method.to_uppercase(), domain),
179 )
180 .with_tool("web_fetch")
181 }
182
183 pub fn web_search(tool_use_id: &str, query: &str) -> PermissionRequest {
185 PermissionRequest::new(
186 tool_use_id,
187 GrantTarget::domain("*"),
188 PermissionLevel::Read,
189 format!("Web search: {}", truncate_command(query, 40)),
190 )
191 .with_tool("web_search")
192 }
193}
194
195fn classify_bash_command(command: &str) -> PermissionLevel {
197 let command_lower = command.to_lowercase();
198 let first_word = command_lower.split_whitespace().next().unwrap_or("");
199
200 let dangerous_patterns = [
202 "rm -rf",
203 "rm -fr",
204 "sudo",
205 "chmod -R",
206 "chown -R",
207 "mkfs",
208 "dd if=",
209 ":(){ :|:& };:",
210 "> /dev/",
211 "shutdown",
212 "reboot",
213 "init ",
214 "systemctl",
215 ];
216
217 for pattern in dangerous_patterns {
218 if command_lower.contains(pattern) {
219 return PermissionLevel::Admin;
220 }
221 }
222
223 if first_word == "rm" || command_lower.contains("--delete") {
225 return PermissionLevel::Admin;
226 }
227
228 let readonly_commands = [
230 "ls", "cat", "head", "tail", "less", "more", "pwd", "whoami", "echo", "printf", "date",
231 "which", "whereis", "file", "stat", "wc", "grep", "find", "locate", "tree", "df", "du",
232 "git status", "git log", "git diff", "git show", "git branch",
233 ];
234
235 for readonly in readonly_commands {
236 if command_lower.starts_with(readonly) {
237 return PermissionLevel::Read;
238 }
239 }
240
241 PermissionLevel::Execute
243}
244
245fn truncate_command(command: &str, max_len: usize) -> String {
247 if command.len() <= max_len {
248 command.to_string()
249 } else {
250 format!("{}...", &command[..max_len - 3])
251 }
252}
253
254pub fn get_tool_category(tool_name: &str) -> ToolCategory {
256 match tool_name {
257 "read_file" | "glob" | "grep" | "ls" => ToolCategory::FileRead,
258 "write_file" | "edit_file" | "multi_edit" => ToolCategory::FileWrite,
259 "bash" => ToolCategory::CommandExec,
260 "web_search" | "web_fetch" => ToolCategory::Network,
261 "ask_user_questions" => ToolCategory::UserInteraction,
262 "ask_for_permissions" => ToolCategory::PermissionManagement,
263 _ => ToolCategory::FileRead, }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_file_read_request() {
273 let request = ToolPermissions::file_read("tool-1", "/project/src/main.rs");
274 assert_eq!(request.required_level, PermissionLevel::Read);
275 assert_eq!(request.tool_name, Some("read_file".to_string()));
276 }
277
278 #[test]
279 fn test_file_write_request() {
280 let request = ToolPermissions::file_write("tool-1", "/project/new_file.rs", true);
281 assert_eq!(request.required_level, PermissionLevel::Write);
282 assert!(request.description.contains("Create"));
283 }
284
285 #[test]
286 fn test_bash_command_readonly() {
287 let request = ToolPermissions::bash_command("tool-1", "ls -la");
288 assert_eq!(request.required_level, PermissionLevel::Read);
289 }
290
291 #[test]
292 fn test_bash_command_execute() {
293 let request = ToolPermissions::bash_command("tool-1", "cargo build");
294 assert_eq!(request.required_level, PermissionLevel::Execute);
295 }
296
297 #[test]
298 fn test_bash_command_admin() {
299 let request = ToolPermissions::bash_command("tool-1", "sudo apt install foo");
300 assert_eq!(request.required_level, PermissionLevel::Admin);
301
302 let request2 = ToolPermissions::bash_command("tool-1", "rm -rf /tmp/foo");
303 assert_eq!(request2.required_level, PermissionLevel::Admin);
304 }
305
306 #[test]
307 fn test_network_access() {
308 let get_request = ToolPermissions::network_access("tool-1", "api.github.com", "GET");
309 assert_eq!(get_request.required_level, PermissionLevel::Read);
310
311 let post_request = ToolPermissions::network_access("tool-1", "api.github.com", "POST");
312 assert_eq!(post_request.required_level, PermissionLevel::Write);
313
314 let delete_request = ToolPermissions::network_access("tool-1", "api.github.com", "DELETE");
315 assert_eq!(delete_request.required_level, PermissionLevel::Execute);
316 }
317
318 #[test]
319 fn test_multi_edit() {
320 let paths = vec!["/file1.rs", "/file2.rs"];
321 let requests = ToolPermissions::multi_edit("tool-1", &paths);
322 assert_eq!(requests.len(), 2);
323 assert_eq!(requests[0].id, "tool-1-0");
324 assert_eq!(requests[1].id, "tool-1-1");
325 }
326
327 #[test]
328 fn test_tool_category() {
329 assert_eq!(get_tool_category("read_file"), ToolCategory::FileRead);
330 assert_eq!(get_tool_category("write_file"), ToolCategory::FileWrite);
331 assert_eq!(get_tool_category("bash"), ToolCategory::CommandExec);
332 assert_eq!(get_tool_category("web_search"), ToolCategory::Network);
333 }
334
335 #[test]
336 fn test_category_requires_permission() {
337 assert!(ToolCategory::FileRead.requires_permission());
338 assert!(ToolCategory::FileWrite.requires_permission());
339 assert!(ToolCategory::CommandExec.requires_permission());
340 assert!(!ToolCategory::UserInteraction.requires_permission());
341 }
342}