1#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
2
3use std::path::PathBuf;
4use std::sync::Arc;
5
6use async_trait::async_trait;
7use motosan_agent_loop::core::decision::ToolDecision;
8use motosan_agent_loop::core::ext_error::ExtError;
9use motosan_agent_loop::core::extension::Extension;
10use motosan_agent_loop::core::hook_ctx::HookCtx;
11use motosan_agent_loop::llm::ToolCallItem;
12use motosan_agent_tool::ToolResult;
13use serde_json::Value;
14use tokio::sync::{mpsc, oneshot, RwLock};
15
16use super::policy::Policy;
17use super::session_cache::SessionCache;
18use super::Decision;
19use crate::events::UiEvent;
20
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
29pub enum PromptStrategy {
30 Prompt,
32 HeadlessDeny,
34 AllowAll,
36 AcceptEdits,
38}
39
40pub struct PermissionExtension {
41 policy: Arc<Policy>,
42 cache: Arc<SessionCache>,
43 project_root: PathBuf,
44 ui_tx: Option<mpsc::Sender<UiEvent>>,
49 strategy: Arc<RwLock<PromptStrategy>>,
53}
54
55impl PermissionExtension {
56 pub fn new(
57 policy: Arc<Policy>,
58 cache: Arc<SessionCache>,
59 project_root: PathBuf,
60 ui_tx: mpsc::Sender<UiEvent>,
61 ) -> Self {
62 Self {
63 policy,
64 cache,
65 project_root,
66 ui_tx: Some(ui_tx),
67 strategy: Arc::new(RwLock::new(PromptStrategy::Prompt)),
68 }
69 }
70
71 pub fn headless(policy: Arc<Policy>, cache: Arc<SessionCache>, project_root: PathBuf) -> Self {
72 Self {
73 policy,
74 cache,
75 project_root,
76 ui_tx: None,
77 strategy: Arc::new(RwLock::new(PromptStrategy::HeadlessDeny)),
78 }
79 }
80
81 pub fn accept_edits(
83 policy: Arc<Policy>,
84 cache: Arc<SessionCache>,
85 project_root: PathBuf,
86 ui_tx: mpsc::Sender<UiEvent>,
87 ) -> Self {
88 Self {
89 policy,
90 cache,
91 project_root,
92 ui_tx: Some(ui_tx),
93 strategy: Arc::new(RwLock::new(PromptStrategy::AcceptEdits)),
94 }
95 }
96
97 pub fn allow_all(policy: Arc<Policy>, cache: Arc<SessionCache>, project_root: PathBuf) -> Self {
99 Self {
100 policy,
101 cache,
102 project_root,
103 ui_tx: None,
104 strategy: Arc::new(RwLock::new(PromptStrategy::AllowAll)),
105 }
106 }
107
108 pub async fn set_strategy(&self, strategy: PromptStrategy) {
111 *self.strategy.write().await = strategy;
112 }
113
114 pub async fn current_strategy(&self) -> PromptStrategy {
116 *self.strategy.read().await
117 }
118
119 pub(crate) fn with_strategy_handle(
122 policy: Arc<Policy>,
123 cache: Arc<SessionCache>,
124 project_root: PathBuf,
125 ui_tx: Option<mpsc::Sender<UiEvent>>,
126 strategy: Arc<RwLock<PromptStrategy>>,
127 ) -> Self {
128 Self {
129 policy,
130 cache,
131 project_root,
132 ui_tx,
133 strategy,
134 }
135 }
136
137 async fn decide(&self, tool_name: &str, args: &Value) -> Decision {
138 if matches!(tool_name, "write" | "edit") {
141 if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
142 let abs = if std::path::Path::new(path).is_absolute() {
143 PathBuf::from(path)
144 } else {
145 self.project_root.join(path)
146 };
147 let blocked = path_is_builtin_hard_blocked(&abs, &self.project_root)
148 || match tool_name {
149 "edit" => self.policy.edit_is_blocked(&abs, &self.project_root),
150 _ => self.policy.write_is_blocked(&abs, &self.project_root),
151 };
152 if blocked {
153 return Decision::Denied(format!("{} is in a blocked path", abs.display()));
154 }
155 }
156 }
157
158 let strategy = *self.strategy.read().await;
160 match (strategy, tool_name) {
161 (PromptStrategy::AllowAll, _) => return Decision::Allowed,
162 (PromptStrategy::AcceptEdits, "write" | "edit") => return Decision::Allowed,
163 _ => {}
164 }
165
166 let policy_allowed = match tool_name {
168 "bash" => args
169 .get("command")
170 .and_then(|v| v.as_str())
171 .map(|c| self.policy.bash_is_allowed(c))
172 .unwrap_or(false),
173 "write" | "edit" => args
174 .get("path")
175 .and_then(|v| v.as_str())
176 .map(|p| {
177 let abs = std::path::PathBuf::from(p);
178 let abs = if abs.is_absolute() {
179 abs
180 } else {
181 self.project_root.join(&abs)
182 };
183 match tool_name {
184 "edit" => self.policy.edit_is_allowed(&abs, &self.project_root),
185 _ => self.policy.write_is_allowed(&abs, &self.project_root),
186 }
187 })
188 .unwrap_or(false),
189 "read" | "grep" | "find" | "ls" => return Decision::Allowed,
190 other if other.contains("__") => {
191 let mut parts = other.splitn(2, "__");
192 let server = parts.next().unwrap_or("");
193 let tool = parts.next().unwrap_or("");
194 self.policy.mcp_auto_allow(server, tool)
195 }
196 _ => false,
197 };
198 if policy_allowed {
199 return Decision::Allowed;
200 }
201
202 let cache_key = SessionCache::key(tool_name, args);
208 if let Some(cached) = self.cache.get(&cache_key) {
209 return cached;
210 }
211
212 let strategy = *self.strategy.read().await;
214 match strategy {
215 PromptStrategy::HeadlessDeny => {
216 Decision::Denied("non-interactive: tool requires approval".into())
217 }
218 PromptStrategy::Prompt | PromptStrategy::AcceptEdits => {
219 let Some(ui_tx) = &self.ui_tx else {
220 return Decision::Denied(
221 "no ui_tx attached to PermissionExtension; cannot prompt".into(),
222 );
223 };
224 let (resolver_tx, resolver_rx) = oneshot::channel::<Decision>();
225 if ui_tx
226 .send(UiEvent::PermissionRequested {
227 tool: tool_name.to_string(),
228 args: args.clone(),
229 resolver: resolver_tx,
230 })
231 .await
232 .is_err()
233 {
234 return Decision::Denied("no UI channel to prompt".into());
235 }
236 resolver_rx
237 .await
238 .unwrap_or(Decision::Denied("prompt cancelled".into()))
239 }
240 PromptStrategy::AllowAll => Decision::Allowed,
241 }
242 }
243}
244
245fn path_is_builtin_hard_blocked(path: &std::path::Path, project_root: &std::path::Path) -> bool {
246 let rel = path.strip_prefix(project_root).unwrap_or(path);
247 rel.components().any(|component| {
248 let std::path::Component::Normal(name) = component else {
249 return false;
250 };
251 let Some(name) = name.to_str() else {
252 return false;
253 };
254 name == ".git"
255 || name == ".ssh"
256 || name == "node_modules"
257 || name == "target"
258 || name.starts_with(".env")
259 })
260}
261
262#[async_trait]
263impl Extension for PermissionExtension {
264 fn name(&self) -> &'static str {
265 "capo-permissions"
266 }
267
268 async fn intercept_tool_call(
269 &mut self,
270 call: ToolCallItem,
271 _ctx: &mut HookCtx<'_>,
272 ) -> Result<ToolDecision, ExtError> {
273 match self.decide(&call.name, &call.args).await {
274 Decision::Allowed => Ok(ToolDecision::Proceed(call)),
275 Decision::Denied(reason) => Ok(ToolDecision::ShortCircuit(ToolResult::error(format!(
276 "Permission denied: {reason}"
277 )))),
278 }
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use std::sync::Arc;
285
286 use tokio::sync::mpsc;
287
288 use super::*;
289
290 #[tokio::test]
291 async fn session_cache_short_circuits_prompt() {
292 let policy = Arc::new(Policy::default());
293 let cache = Arc::new(SessionCache::new());
294 let args = serde_json::json!({"command": "curl https://example.com"});
295 cache.insert(SessionCache::key("bash", &args), Decision::Allowed);
296
297 let (ui_tx, mut ui_rx) = mpsc::channel::<UiEvent>(4);
298 let ext = PermissionExtension::new(
299 Arc::clone(&policy),
300 Arc::clone(&cache),
301 std::env::current_dir().unwrap_or_default(),
302 ui_tx,
303 );
304
305 let decision = ext.decide("bash", &args).await;
306 assert!(matches!(decision, Decision::Allowed));
307 assert!(ui_rx.try_recv().is_err());
308 }
309
310 #[tokio::test]
311 async fn grep_find_ls_are_auto_allowed() {
312 let policy = Arc::new(Policy::default());
313 let cache = Arc::new(SessionCache::new());
314 let (ui_tx, mut ui_rx) = mpsc::channel::<UiEvent>(4);
315 let ext = PermissionExtension::new(
316 Arc::clone(&policy),
317 Arc::clone(&cache),
318 std::env::current_dir().unwrap_or_default(),
319 ui_tx,
320 );
321
322 for tool in ["grep", "find", "ls"] {
323 let decision = ext.decide(tool, &serde_json::json!({})).await;
324 assert!(
325 matches!(decision, Decision::Allowed),
326 "{tool} not auto-allowed"
327 );
328 }
329 assert!(ui_rx.try_recv().is_err());
330 }
331
332 #[tokio::test]
333 async fn accept_edits_mode_auto_allows_write_and_edit() {
334 let dir = tempfile::tempdir().expect("tempdir");
335 let policy = Arc::new(crate::permissions::Policy::default());
336 let cache = Arc::new(SessionCache::new());
337 let (tx, _rx) = tokio::sync::mpsc::channel::<UiEvent>(16);
338 let ext = PermissionExtension::accept_edits(
339 Arc::clone(&policy),
340 Arc::clone(&cache),
341 dir.path().to_path_buf(),
342 tx,
343 );
344
345 let args = serde_json::json!({ "path": "src/foo.rs", "content": "..." });
347 let decision = ext.decide("write", &args).await;
348 assert!(
349 matches!(decision, Decision::Allowed),
350 "accept_edits should auto-allow write; got {decision:?}"
351 );
352
353 let args =
355 serde_json::json!({ "path": "src/foo.rs", "old_string": "x", "new_string": "y" });
356 let decision = ext.decide("edit", &args).await;
357 assert!(matches!(decision, Decision::Allowed));
358 }
359
360 #[tokio::test]
361 async fn accept_edits_mode_enforces_hard_blocked_paths_for_write() {
362 use crate::permissions::Policy;
363 let dir = tempfile::tempdir().expect("tempdir");
364 let policy = Arc::new(Policy::default());
365 let cache = Arc::new(SessionCache::new());
366 let (tx, _rx) = tokio::sync::mpsc::channel::<UiEvent>(16);
367 let ext = PermissionExtension::accept_edits(
368 Arc::clone(&policy),
369 Arc::clone(&cache),
370 dir.path().to_path_buf(),
371 tx,
372 );
373
374 let args = serde_json::json!({ "path": ".git/config", "content": "x" });
376 let decision = ext.decide("write", &args).await;
377 assert!(
378 matches!(decision, Decision::Denied(_)),
379 "accept_edits must still enforce hard-blocked paths; got {decision:?}"
380 );
381 }
382
383 #[tokio::test]
384 async fn allow_all_mode_auto_allows_bash_too() {
385 let dir = tempfile::tempdir().expect("tempdir");
386 let policy = Arc::new(crate::permissions::Policy::default());
387 let cache = Arc::new(SessionCache::new());
388 let ext = PermissionExtension::allow_all(
389 Arc::clone(&policy),
390 Arc::clone(&cache),
391 dir.path().to_path_buf(),
392 );
393
394 let args = serde_json::json!({ "command": "ls -la" });
395 let decision = ext.decide("bash", &args).await;
396 assert!(matches!(decision, Decision::Allowed));
397 }
398
399 #[tokio::test]
400 async fn allow_all_mode_still_enforces_hard_blocked_write() {
401 let dir = tempfile::tempdir().expect("tempdir");
402 let policy = Arc::new(crate::permissions::Policy::default());
403 let cache = Arc::new(SessionCache::new());
404 let ext = PermissionExtension::allow_all(
405 Arc::clone(&policy),
406 Arc::clone(&cache),
407 dir.path().to_path_buf(),
408 );
409
410 let args = serde_json::json!({ "path": ".env", "content": "x" });
411 let decision = ext.decide("write", &args).await;
412 assert!(
413 matches!(decision, Decision::Denied(_)),
414 "allow_all must still enforce hard-blocked .env*; got {decision:?}"
415 );
416 }
417
418 #[tokio::test]
419 async fn set_strategy_runtime_switches_mode() {
420 let dir = tempfile::tempdir().expect("tempdir");
421 let policy = Arc::new(crate::permissions::Policy::default());
422 let cache = Arc::new(SessionCache::new());
423
424 let ext = PermissionExtension::headless(
426 Arc::clone(&policy),
427 Arc::clone(&cache),
428 dir.path().to_path_buf(),
429 );
430 let args = serde_json::json!({ "command": "ls" });
431 assert!(matches!(
432 ext.decide("bash", &args).await,
433 Decision::Denied(_)
434 ));
435
436 ext.set_strategy(PromptStrategy::AllowAll).await;
438 assert!(matches!(ext.decide("bash", &args).await, Decision::Allowed));
439
440 ext.set_strategy(PromptStrategy::HeadlessDeny).await;
442 assert!(matches!(
443 ext.decide("bash", &args).await,
444 Decision::Denied(_)
445 ));
446 }
447
448 #[tokio::test]
449 async fn headless_denies_a_would_prompt_tool_but_keeps_auto_allows() {
450 let policy = Arc::new(Policy::default());
451 let cache = Arc::new(SessionCache::new());
452 let ext = PermissionExtension::headless(
453 Arc::clone(&policy),
454 Arc::clone(&cache),
455 std::env::current_dir().unwrap_or_default(),
456 );
457 let denied = ext
460 .decide("bash", &serde_json::json!({"command": "curl https://x"}))
461 .await;
462 assert!(matches!(denied, Decision::Denied(_)));
463 let allowed = ext.decide("read", &serde_json::json!({})).await;
465 assert!(matches!(allowed, Decision::Allowed));
466 }
467}