1use std::collections::HashMap;
10use std::path::PathBuf;
11
12use claude_agent_sdk_rs::{
13 ClaudeAgentOptions, HookContext, HookEvent, HookInput, HookJsonOutput, HookMatcher,
14 HookSpecificOutput, Hooks, PermissionMode, PreToolUseHookSpecificOutput, SyncHookJsonOutput,
15 SystemPrompt, SystemPromptPreset, Tools,
16};
17use regex::Regex;
18use tracing::debug;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25#[non_exhaustive]
26pub enum AgentProfile {
27 Planner,
30
31 Coder,
34}
35
36impl AgentProfile {
37 pub fn to_options(
42 &self,
43 system_append: &str,
44 cwd: PathBuf,
45 max_turns: u32,
46 max_budget_usd: f64,
47 model: &str,
48 ) -> ClaudeAgentOptions {
49 let system_prompt = SystemPrompt::Preset(SystemPromptPreset::with_append(
50 "claude_code",
51 system_append,
52 ));
53
54 match self {
55 Self::Planner => ClaudeAgentOptions::builder()
56 .system_prompt(system_prompt)
57 .permission_mode(PermissionMode::BypassPermissions)
58 .cwd(cwd)
59 .max_turns(max_turns)
60 .max_budget_usd(max_budget_usd)
61 .model(model.to_string())
62 .tools(Tools::from(["Read", "Glob", "Grep"]))
63 .build(),
64
65 Self::Coder => ClaudeAgentOptions::builder()
66 .system_prompt(system_prompt)
67 .permission_mode(PermissionMode::BypassPermissions)
68 .cwd(cwd)
69 .max_turns(max_turns)
70 .max_budget_usd(max_budget_usd)
71 .model(model.to_string())
72 .tools(Tools::from(["Read", "Write", "Bash", "Glob", "Grep"]))
73 .hooks(build_safety_hooks())
74 .build(),
75 }
76 }
77}
78
79const DANGEROUS_PATTERNS: &[&str] = &[
81 r"rm\s+-rf\s+/",
82 r"git\s+push\s+--force",
83 r"git\s+push\s+-f\b",
84 r"DROP\s+TABLE",
85 r"DROP\s+DATABASE",
86 r"mkfs\.",
87 r"dd\s+if=.+of=/dev/",
88 r">\s*/dev/sda",
89 r"chmod\s+-R\s+777\s+/",
90 r":\(\)\s*\{\s*:\|:\s*&\s*\}\s*;",
91];
92
93pub fn build_safety_hooks() -> HashMap<HookEvent, Vec<HookMatcher>> {
99 let mut hooks = Hooks::new();
100
101 hooks.add_pre_tool_use_with_matcher("Bash", |input, _tool_use_id, _context| async move {
103 if let HookInput::PreToolUse(ref pre) = input {
104 let command = pre
105 .tool_input
106 .get("command")
107 .and_then(|v| v.as_str())
108 .unwrap_or("");
109
110 if is_dangerous_command(command) {
111 debug!(command = command, "Blocked dangerous Bash command");
112 return HookJsonOutput::Sync(SyncHookJsonOutput {
113 decision: Some("deny".to_string()),
114 reason: Some(format!("Command blocked by safety hook: {command}")),
115 hook_specific_output: Some(HookSpecificOutput::PreToolUse(
116 PreToolUseHookSpecificOutput {
117 permission_decision: Some("deny".to_string()),
118 permission_decision_reason: Some(
119 "Dangerous command detected by CODA safety hook".to_string(),
120 ),
121 updated_input: None,
122 },
123 )),
124 ..SyncHookJsonOutput::default()
125 });
126 }
127 }
128
129 HookJsonOutput::Sync(SyncHookJsonOutput::default())
131 });
132
133 hooks.add_post_tool_use(
135 |input: HookInput, _tool_use_id: Option<String>, _context: HookContext| async move {
136 if let HookInput::PostToolUse(ref post) = input {
137 debug!(
138 tool_name = post.tool_name.as_str(),
139 "Tool execution completed"
140 );
141 }
142 HookJsonOutput::Sync(SyncHookJsonOutput::default())
143 },
144 );
145
146 hooks.build()
147}
148
149static DANGEROUS_REGEXES: std::sync::LazyLock<Vec<Regex>> = std::sync::LazyLock::new(|| {
151 DANGEROUS_PATTERNS
152 .iter()
153 .filter_map(|p| Regex::new(p).ok())
154 .collect()
155});
156
157fn is_dangerous_command(command: &str) -> bool {
159 DANGEROUS_REGEXES.iter().any(|re| re.is_match(command))
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165
166 #[test]
167 fn test_should_detect_dangerous_rm_rf() {
168 assert!(is_dangerous_command("rm -rf /"));
169 assert!(is_dangerous_command("sudo rm -rf / --no-preserve-root"));
170 }
171
172 #[test]
173 fn test_should_detect_dangerous_git_force_push() {
174 assert!(is_dangerous_command("git push --force"));
175 assert!(is_dangerous_command("git push -f origin main"));
176 }
177
178 #[test]
179 fn test_should_detect_dangerous_drop_table() {
180 assert!(is_dangerous_command("DROP TABLE users"));
181 assert!(is_dangerous_command("DROP DATABASE production"));
182 }
183
184 #[test]
185 fn test_should_allow_safe_commands() {
186 assert!(!is_dangerous_command("cargo build"));
187 assert!(!is_dangerous_command("git status"));
188 assert!(!is_dangerous_command("ls -la"));
189 assert!(!is_dangerous_command("echo hello"));
190 }
191
192 #[test]
193 fn test_should_build_safety_hooks() {
194 let hooks = build_safety_hooks();
195 assert!(hooks.contains_key(&HookEvent::PreToolUse));
196 assert!(hooks.contains_key(&HookEvent::PostToolUse));
197
198 let pre_matchers = &hooks[&HookEvent::PreToolUse];
200 assert_eq!(pre_matchers.len(), 1);
201 assert_eq!(pre_matchers[0].matcher, Some("Bash".to_string()));
202
203 let post_matchers = &hooks[&HookEvent::PostToolUse];
205 assert_eq!(post_matchers.len(), 1);
206 assert_eq!(post_matchers[0].matcher, None);
207 }
208
209 #[test]
210 fn test_should_create_planner_options() {
211 let profile = AgentProfile::Planner;
212 let options = profile.to_options(
213 "Test append",
214 PathBuf::from("/tmp"),
215 10,
216 5.0,
217 "claude-opus-4-6",
218 );
219
220 assert_eq!(options.max_turns, Some(10));
221 assert_eq!(options.max_budget_usd, Some(5.0));
222 assert_eq!(options.model, Some("claude-opus-4-6".to_string()));
223 assert_eq!(
224 options.permission_mode,
225 Some(PermissionMode::BypassPermissions)
226 );
227 assert!(options.hooks.is_none());
228
229 match options.tools {
230 Some(Tools::List(tools)) => {
231 assert!(tools.contains(&"Read".to_string()));
232 assert!(tools.contains(&"Glob".to_string()));
233 assert!(tools.contains(&"Grep".to_string()));
234 assert!(!tools.contains(&"Write".to_string()));
235 assert!(!tools.contains(&"Bash".to_string()));
236 }
237 _ => panic!("Expected Tools::List for Planner"),
238 }
239 }
240
241 #[test]
242 fn test_should_create_coder_options() {
243 let profile = AgentProfile::Coder;
244 let options = profile.to_options(
245 "Test append",
246 PathBuf::from("/tmp"),
247 20,
248 10.0,
249 "claude-opus-4-6",
250 );
251
252 assert_eq!(options.max_turns, Some(20));
253 assert_eq!(options.max_budget_usd, Some(10.0));
254 assert_eq!(options.model, Some("claude-opus-4-6".to_string()));
255 assert!(options.hooks.is_some());
256
257 match options.tools {
258 Some(Tools::List(tools)) => {
259 assert!(tools.contains(&"Read".to_string()));
260 assert!(tools.contains(&"Write".to_string()));
261 assert!(tools.contains(&"Bash".to_string()));
262 assert!(tools.contains(&"Glob".to_string()));
263 assert!(tools.contains(&"Grep".to_string()));
264 }
265 _ => panic!("Expected Tools::List for Coder"),
266 }
267 }
268}