1use std::sync::LazyLock;
8
9use crate::types::KnowledgeBase;
10
11static DEFAULT_KB: LazyLock<KnowledgeBase> = LazyLock::new(|| {
12 toml::from_str(include_str!("../config/commands.toml"))
13 .expect("embedded commands.toml is invalid")
14});
15
16pub fn default_knowledge_base() -> &'static KnowledgeBase {
22 &DEFAULT_KB
23}
24
25#[cfg(test)]
26mod tests {
27 use super::*;
28 use crate::lookup::classify;
29 use crate::types::Effect;
30 use agent_shell_parser::parse::types::Word;
31
32 fn words(args: &[&str]) -> Vec<Word> {
33 args.iter().map(|s| Word::from(*s)).collect()
34 }
35
36 #[test]
39 fn embedded_toml_parses_successfully() {
40 let kb = default_knowledge_base();
41 assert!(
42 !kb.commands.is_empty(),
43 "knowledge base should have commands"
44 );
45 assert!(
46 !kb.wrappers.is_empty(),
47 "knowledge base should have wrappers"
48 );
49 }
50
51 #[test]
54 fn knowledge_base_round_trips_through_toml() {
55 let kb = default_knowledge_base();
56 let serialized = toml::to_string(kb).expect("KB should serialize to TOML");
57 let _: KnowledgeBase =
58 toml::from_str(&serialized).expect("re-parsed KB should deserialize");
59 }
60
61 #[test]
64 fn command_effects() {
65 #[rustfmt::skip]
66 let cases: &[(&str, Effect)] = &[
67 ("ls", Effect::ReadOnly),
69 ("tree", Effect::ReadOnly),
70 ("cat", Effect::ReadOnly),
71 ("head", Effect::ReadOnly),
72 ("tail", Effect::ReadOnly),
73 ("grep", Effect::ReadOnly),
74 ("find", Effect::ReadOnly),
75 ("stat", Effect::ReadOnly),
76 ("diff", Effect::ReadOnly),
77 ("wc", Effect::ReadOnly),
78 ("sort", Effect::ReadOnly),
79 ("uniq", Effect::ReadOnly),
80 ("echo", Effect::ReadOnly),
81 ("printf", Effect::ReadOnly),
82 ("date", Effect::ReadOnly),
83 ("pwd", Effect::ReadOnly),
84 ("which", Effect::ReadOnly),
85 ("ps", Effect::ReadOnly),
86 ("uname", Effect::ReadOnly),
87 ("hostname", Effect::ReadOnly),
88 ("id", Effect::ReadOnly),
89 ("whoami", Effect::ReadOnly),
90 ("df", Effect::ReadOnly),
91 ("du", Effect::ReadOnly),
92 ("free", Effect::ReadOnly),
93 ("uptime", Effect::ReadOnly),
94 ("printenv", Effect::ReadOnly),
95 ("rg", Effect::ReadOnly),
96 ("fd", Effect::ReadOnly),
97 ("bat", Effect::ReadOnly),
98 ("eza", Effect::ReadOnly),
99 ("tokei", Effect::ReadOnly),
100 ("hyperfine", Effect::ReadOnly),
101 ("jq", Effect::ReadOnly),
102 ("mkdir", Effect::Mutating),
104 ("touch", Effect::Mutating),
105 ("mv", Effect::Mutating),
106 ("cp", Effect::Mutating),
107 ("ln", Effect::Mutating),
108 ("chmod", Effect::Mutating),
109 ("chown", Effect::Mutating),
110 ("tee", Effect::Mutating),
111 ("curl", Effect::Mutating),
112 ("wget", Effect::Mutating),
113 ("rm", Effect::Mutating),
115 ("rmdir", Effect::Mutating),
116 ("shred", Effect::Mutating),
117 ("dd", Effect::Mutating),
118 ("mkfs", Effect::Mutating),
119 ("fdisk", Effect::Mutating),
120 ("parted", Effect::Mutating),
121 ("shutdown", Effect::Mutating),
122 ("reboot", Effect::Mutating),
123 ("halt", Effect::Mutating),
124 ("poweroff", Effect::Mutating),
125 ("git", Effect::Unknown),
127 ("cargo", Effect::Unknown),
128 ("gh", Effect::Unknown),
129 ("kubectl", Effect::Unknown),
130 ];
131
132 let kb = default_knowledge_base();
133 for (cmd, expected) in cases {
134 let entry = kb
135 .commands
136 .get(*cmd)
137 .unwrap_or_else(|| panic!("'{cmd}' should be in the KB"));
138 assert_eq!(
139 entry.effect, *expected,
140 "'{cmd}' effect: expected {expected:?}, got {:?}",
141 entry.effect
142 );
143 }
144 }
145
146 #[test]
149 fn subcommand_effects() {
150 #[rustfmt::skip]
151 let cases: &[(&str, &str, Effect)] = &[
152 ("git", "status", Effect::ReadOnly),
154 ("git", "log", Effect::ReadOnly),
155 ("git", "diff", Effect::ReadOnly),
156 ("git", "show", Effect::ReadOnly),
157 ("git", "branch", Effect::ReadOnly),
158 ("git", "tag", Effect::ReadOnly),
159 ("git", "remote", Effect::ReadOnly),
160 ("git", "rev-parse", Effect::ReadOnly),
161 ("git", "ls-files", Effect::ReadOnly),
162 ("git", "ls-tree", Effect::ReadOnly),
163 ("git", "shortlog", Effect::ReadOnly),
164 ("git", "blame", Effect::ReadOnly),
165 ("git", "describe", Effect::ReadOnly),
166 ("git", "stash", Effect::ReadOnly),
167 ("git", "cat-file", Effect::ReadOnly),
168 ("git", "for-each-ref", Effect::ReadOnly),
169 ("git", "push", Effect::Mutating),
171 ("git", "pull", Effect::Mutating),
172 ("git", "fetch", Effect::Mutating),
173 ("git", "commit", Effect::Mutating),
174 ("git", "add", Effect::Mutating),
175 ("git", "rebase", Effect::Mutating),
176 ("git", "merge", Effect::Mutating),
177 ("git", "checkout", Effect::Mutating),
178 ("git", "switch", Effect::Mutating),
179 ("git", "restore", Effect::Mutating),
180 ("git", "init", Effect::Mutating),
181 ("git", "clone", Effect::Mutating),
182 ("git", "config", Effect::Mutating),
183 ("git", "cherry-pick", Effect::Mutating),
184 ("git", "revert", Effect::Mutating),
185 ("git", "am", Effect::Mutating),
186 ("git", "apply", Effect::Mutating),
187 ("git", "submodule", Effect::Mutating),
188 ("git", "reset", Effect::Mutating),
190 ("git", "clean", Effect::Mutating),
191 ("git", "rm", Effect::Mutating),
192 ("git", "update-ref", Effect::Mutating),
193 ("git", "update-index", Effect::Mutating),
194 ("cargo", "build", Effect::ReadOnly),
196 ("cargo", "check", Effect::ReadOnly),
197 ("cargo", "test", Effect::ReadOnly),
198 ("cargo", "bench", Effect::ReadOnly),
199 ("cargo", "run", Effect::ReadOnly),
200 ("cargo", "clippy", Effect::ReadOnly),
201 ("cargo", "fmt", Effect::ReadOnly),
202 ("cargo", "doc", Effect::ReadOnly),
203 ("cargo", "clean", Effect::ReadOnly),
204 ("cargo", "update", Effect::ReadOnly),
205 ("cargo", "fetch", Effect::ReadOnly),
206 ("cargo", "tree", Effect::ReadOnly),
207 ("cargo", "metadata", Effect::ReadOnly),
208 ("cargo", "version", Effect::ReadOnly),
209 ("cargo", "verify-project", Effect::ReadOnly),
210 ("cargo", "search", Effect::ReadOnly),
211 ("cargo", "generate-lockfile", Effect::ReadOnly),
212 ("cargo", "nextest", Effect::ReadOnly),
213 ("cargo", "deny", Effect::ReadOnly),
214 ("cargo", "audit", Effect::ReadOnly),
215 ("cargo", "outdated", Effect::ReadOnly),
216 ("cargo", "package", Effect::ReadOnly),
217 ("cargo", "semver-checks", Effect::ReadOnly),
218 ("cargo", "expand", Effect::ReadOnly),
219 ("cargo", "insta", Effect::ReadOnly),
220 ("cargo", "install", Effect::Mutating),
222 ("cargo", "uninstall", Effect::Mutating),
223 ("cargo", "publish", Effect::Mutating),
224 ("cargo", "add", Effect::Mutating),
225 ("cargo", "remove", Effect::Mutating),
226 ("cargo", "init", Effect::Mutating),
227 ("cargo", "new", Effect::Mutating),
228 ("gh", "status", Effect::ReadOnly),
230 ("gh", "repo view", Effect::ReadOnly),
231 ("gh", "repo list", Effect::ReadOnly),
232 ("gh", "repo clone", Effect::ReadOnly),
233 ("gh", "pr list", Effect::ReadOnly),
234 ("gh", "pr view", Effect::ReadOnly),
235 ("gh", "pr diff", Effect::ReadOnly),
236 ("gh", "pr checks", Effect::ReadOnly),
237 ("gh", "pr status", Effect::ReadOnly),
238 ("gh", "issue list", Effect::ReadOnly),
239 ("gh", "issue view", Effect::ReadOnly),
240 ("gh", "issue status", Effect::ReadOnly),
241 ("gh", "run list", Effect::ReadOnly),
242 ("gh", "run view", Effect::ReadOnly),
243 ("gh", "run watch", Effect::ReadOnly),
244 ("gh", "workflow list", Effect::ReadOnly),
245 ("gh", "workflow view", Effect::ReadOnly),
246 ("gh", "release list", Effect::ReadOnly),
247 ("gh", "release view", Effect::ReadOnly),
248 ("gh", "search", Effect::ReadOnly),
249 ("gh", "browse", Effect::ReadOnly),
250 ("gh", "api", Effect::ReadOnly),
251 ("gh", "auth status", Effect::ReadOnly),
252 ("gh", "auth token", Effect::ReadOnly),
253 ("gh", "extension list", Effect::ReadOnly),
254 ("gh", "label list", Effect::ReadOnly),
255 ("gh", "cache list", Effect::ReadOnly),
256 ("gh", "variable list", Effect::ReadOnly),
257 ("gh", "variable get", Effect::ReadOnly),
258 ("gh", "secret list", Effect::ReadOnly),
259 ("gh", "repo create", Effect::Mutating),
261 ("gh", "repo edit", Effect::Mutating),
262 ("gh", "repo fork", Effect::Mutating),
263 ("gh", "repo rename", Effect::Mutating),
264 ("gh", "repo archive", Effect::Mutating),
265 ("gh", "pr create", Effect::Mutating),
266 ("gh", "pr merge", Effect::Mutating),
267 ("gh", "pr close", Effect::Mutating),
268 ("gh", "pr reopen", Effect::Mutating),
269 ("gh", "pr comment", Effect::Mutating),
270 ("gh", "pr review", Effect::Mutating),
271 ("gh", "pr edit", Effect::Mutating),
272 ("gh", "issue create", Effect::Mutating),
273 ("gh", "issue close", Effect::Mutating),
274 ("gh", "issue reopen", Effect::Mutating),
275 ("gh", "issue comment", Effect::Mutating),
276 ("gh", "issue edit", Effect::Mutating),
277 ("gh", "issue pin", Effect::Mutating),
278 ("gh", "issue unpin", Effect::Mutating),
279 ("gh", "run rerun", Effect::Mutating),
280 ("gh", "run cancel", Effect::Mutating),
281 ("gh", "run delete", Effect::Mutating),
282 ("gh", "workflow enable", Effect::Mutating),
283 ("gh", "workflow disable", Effect::Mutating),
284 ("gh", "workflow run", Effect::Mutating),
285 ("gh", "release create", Effect::Mutating),
286 ("gh", "release edit", Effect::Mutating),
287 ("gh", "auth login", Effect::Mutating),
288 ("gh", "auth logout", Effect::Mutating),
289 ("gh", "auth refresh", Effect::Mutating),
290 ("gh", "extension install", Effect::Mutating),
291 ("gh", "extension remove", Effect::Mutating),
292 ("gh", "extension upgrade", Effect::Mutating),
293 ("gh", "label create", Effect::Mutating),
294 ("gh", "label edit", Effect::Mutating),
295 ("gh", "variable set", Effect::Mutating),
296 ("gh", "variable delete", Effect::Mutating),
297 ("gh", "secret set", Effect::Mutating),
298 ("gh", "secret delete", Effect::Mutating),
299 ("gh", "config set", Effect::Mutating),
300 ("gh", "repo delete", Effect::Mutating),
302 ("gh", "issue delete", Effect::Mutating),
303 ("gh", "issue transfer", Effect::Mutating),
304 ("gh", "release delete", Effect::Mutating),
305 ("gh", "label delete", Effect::Mutating),
306 ("gh", "cache delete", Effect::Mutating),
307 ("kubectl", "get", Effect::ReadOnly),
309 ("kubectl", "describe", Effect::ReadOnly),
310 ("kubectl", "logs", Effect::ReadOnly),
311 ("kubectl", "top", Effect::ReadOnly),
312 ("kubectl", "explain", Effect::ReadOnly),
313 ("kubectl", "api-resources", Effect::ReadOnly),
314 ("kubectl", "api-versions", Effect::ReadOnly),
315 ("kubectl", "version", Effect::ReadOnly),
316 ("kubectl", "cluster-info", Effect::ReadOnly),
317 ("kubectl", "apply", Effect::Mutating),
319 ("kubectl", "delete", Effect::Mutating),
320 ("kubectl", "rollout", Effect::Mutating),
321 ("kubectl", "scale", Effect::Mutating),
322 ("kubectl", "autoscale", Effect::Mutating),
323 ("kubectl", "patch", Effect::Mutating),
324 ("kubectl", "replace", Effect::Mutating),
325 ("kubectl", "create", Effect::Mutating),
326 ("kubectl", "edit", Effect::Mutating),
327 ("kubectl", "drain", Effect::Mutating),
328 ("kubectl", "cordon", Effect::Mutating),
329 ("kubectl", "uncordon", Effect::Mutating),
330 ("kubectl", "taint", Effect::Mutating),
331 ("kubectl", "exec", Effect::Mutating),
332 ("kubectl", "run", Effect::Mutating),
333 ("kubectl", "port-forward", Effect::Mutating),
334 ("kubectl", "cp", Effect::Mutating),
335 ];
336
337 let kb = default_knowledge_base();
338 for (cmd, subcmd, expected) in cases {
339 let command = kb
340 .commands
341 .get(*cmd)
342 .unwrap_or_else(|| panic!("'{cmd}' should be in the KB"));
343 let entry = command
344 .subcommands
345 .get(*subcmd)
346 .unwrap_or_else(|| panic!("'{cmd} {subcmd}' should be in the KB"));
347 assert_eq!(
348 entry.effect, *expected,
349 "'{cmd} {subcmd}' effect: expected {expected:?}, got {:?}",
350 entry.effect
351 );
352 }
353 }
354
355 #[test]
358 fn classify_integration() {
359 let cases: &[(&[&str], Effect, Option<&str>)] = &[
361 (&["git", "status"], Effect::ReadOnly, Some("status")),
362 (&["git", "log"], Effect::ReadOnly, Some("log")),
363 (&["git", "diff"], Effect::ReadOnly, Some("diff")),
364 (&["git", "push"], Effect::Mutating, Some("push")),
365 (&["git", "reset"], Effect::Mutating, Some("reset")),
366 (&["gh", "pr", "list"], Effect::ReadOnly, Some("pr list")),
367 (&["gh", "pr", "create"], Effect::Mutating, Some("pr create")),
368 (
369 &["gh", "repo", "delete"],
370 Effect::Mutating,
371 Some("repo delete"),
372 ),
373 (&["cargo", "test"], Effect::ReadOnly, Some("test")),
374 (&["cargo", "publish"], Effect::Mutating, Some("publish")),
375 (&["kubectl", "get"], Effect::ReadOnly, Some("get")),
376 (&["kubectl", "apply"], Effect::Mutating, Some("apply")),
377 (&["frobnicate", "arg"], Effect::Unknown, None),
378 ];
379
380 let kb = default_knowledge_base();
381 for (argv, expected_effect, expected_subcmd) in cases {
382 let cmd_word = Word::from(argv[0]);
383 let word_vec = words(argv);
384 let info = classify(&cmd_word, &word_vec, kb);
385 let label = argv.join(" ");
386 assert_eq!(
387 info.effect, *expected_effect,
388 "classify({label:?}) effect: expected {expected_effect:?}, got {:?}",
389 info.effect
390 );
391 assert_eq!(
392 info.subcommand.as_deref(),
393 *expected_subcmd,
394 "classify({label:?}) subcommand: expected {expected_subcmd:?}, got {:?}",
395 info.subcommand
396 );
397 }
398 }
399
400 #[test]
403 fn classify_git_push_force_has_escalation_flag() {
404 let kb = default_knowledge_base();
405 let info = classify(&Word::from("git"), &words(&["git", "push", "--force"]), kb);
406 assert_eq!(info.effect, Effect::Mutating);
407 assert!(
408 info.has_escalation_flags,
409 "git push --force should set has_escalation_flags"
410 );
411 }
412
413 #[test]
416 fn classify_sudo_wrapper() {
417 let kb = default_knowledge_base();
418 let info = classify(&Word::from("sudo"), &words(&["sudo", "rm", "-rf", "/"]), kb);
419 let wrapper = info.wrapper.expect("sudo should return wrapper info");
420 assert_eq!(wrapper.name, "sudo");
421 assert!(wrapper.escalates_privilege);
422 }
423
424 #[test]
427 fn deny_list_commands_are_mutating() {
428 let kb = default_knowledge_base();
429 for cmd in &["shred", "dd", "shutdown", "reboot"] {
430 let entry = kb
431 .commands
432 .get(*cmd)
433 .unwrap_or_else(|| panic!("{cmd} should be in the KB"));
434 assert_eq!(entry.effect, Effect::Mutating, "{cmd} should be Mutating");
435 }
436 }
437
438 #[test]
443 fn git_flag_schema() {
444 let kb = default_knowledge_base();
445 let git = &kb.commands["git"];
446 for flag in &["--force", "-f", "--force-with-lease"] {
447 assert!(
448 git.flags.escalation.contains(&flag.to_string()),
449 "git missing escalation flag {flag}"
450 );
451 }
452 for flag in &["-C", "--git-dir"] {
453 assert!(
454 git.flags.skip_arg.contains(&flag.to_string()),
455 "git missing skip_arg flag {flag}"
456 );
457 }
458 }
459
460 const WRAPPER_FIELDS: &[(&str, Effect, bool, bool)] = &[
463 ("sudo", Effect::Mutating, false, true),
465 ("su", Effect::Mutating, true, true),
466 ("doas", Effect::Mutating, false, true),
467 ("pkexec", Effect::Mutating, true, true),
468 ("env", Effect::ReadOnly, false, false),
469 ("xargs", Effect::ReadOnly, false, false),
470 ("nohup", Effect::ReadOnly, false, false),
471 ("nice", Effect::ReadOnly, false, false),
472 ("timeout", Effect::ReadOnly, false, false),
473 ("time", Effect::ReadOnly, false, false),
474 ("watch", Effect::ReadOnly, false, false),
475 ("strace", Effect::ReadOnly, false, false),
476 ("ltrace", Effect::ReadOnly, false, false),
477 ("parallel", Effect::ReadOnly, false, false),
478 ];
479
480 #[test]
481 fn wrapper_fields() {
482 let kb = default_knowledge_base();
483 for (name, floor, clears, escalates) in WRAPPER_FIELDS {
484 let w = kb
485 .wrappers
486 .get(*name)
487 .unwrap_or_else(|| panic!("{name} should be in wrappers"));
488 assert_eq!(w.floor_effect, *floor, "{name} floor_effect");
489 assert_eq!(w.clears_env, *clears, "{name} clears_env");
490 assert_eq!(
491 w.escalates_privilege, *escalates,
492 "{name} escalates_privilege"
493 );
494 }
495 }
496}