capo_agent/permissions/
extension.rs1use std::path::PathBuf;
2use std::sync::Arc;
3
4use async_trait::async_trait;
5use motosan_agent_loop::core::decision::ToolDecision;
6use motosan_agent_loop::core::ext_error::ExtError;
7use motosan_agent_loop::core::extension::Extension;
8use motosan_agent_loop::core::hook_ctx::HookCtx;
9use motosan_agent_loop::llm::ToolCallItem;
10use motosan_agent_tool::ToolResult;
11use serde_json::Value;
12use tokio::sync::{mpsc, oneshot};
13
14use super::policy::Policy;
15use super::session_cache::SessionCache;
16use super::Decision;
17use crate::events::UiEvent;
18
19pub enum PromptStrategy {
22 Prompt(mpsc::Sender<UiEvent>),
25 HeadlessDeny,
27}
28
29pub struct PermissionExtension {
30 policy: Arc<Policy>,
31 cache: Arc<SessionCache>,
32 project_root: PathBuf,
33 prompt: PromptStrategy,
34}
35
36impl PermissionExtension {
37 pub fn new(
38 policy: Arc<Policy>,
39 cache: Arc<SessionCache>,
40 project_root: PathBuf,
41 ui_tx: mpsc::Sender<UiEvent>,
42 ) -> Self {
43 Self {
44 policy,
45 cache,
46 project_root,
47 prompt: PromptStrategy::Prompt(ui_tx),
48 }
49 }
50
51 pub fn headless(policy: Arc<Policy>, cache: Arc<SessionCache>, project_root: PathBuf) -> Self {
52 Self {
53 policy,
54 cache,
55 project_root,
56 prompt: PromptStrategy::HeadlessDeny,
57 }
58 }
59
60 async fn decide(&self, tool_name: &str, args: &Value) -> Decision {
61 if matches!(tool_name, "write" | "edit") {
64 if let Some(path) = args.get("path").and_then(|v| v.as_str()) {
65 let abs = if std::path::Path::new(path).is_absolute() {
66 PathBuf::from(path)
67 } else {
68 self.project_root.join(path)
69 };
70 let blocked = match tool_name {
71 "edit" => self.policy.edit_is_blocked(&abs, &self.project_root),
72 _ => self.policy.write_is_blocked(&abs, &self.project_root),
73 };
74 if blocked {
75 return Decision::Denied(format!("{} is in a blocked path", abs.display()));
76 }
77 }
78 }
79
80 let policy_allowed = match tool_name {
82 "bash" => args
83 .get("command")
84 .and_then(|v| v.as_str())
85 .map(|c| self.policy.bash_is_allowed(c))
86 .unwrap_or(false),
87 "write" | "edit" => args
88 .get("path")
89 .and_then(|v| v.as_str())
90 .map(|p| {
91 let abs = std::path::PathBuf::from(p);
92 let abs = if abs.is_absolute() {
93 abs
94 } else {
95 self.project_root.join(&abs)
96 };
97 match tool_name {
98 "edit" => self.policy.edit_is_allowed(&abs, &self.project_root),
99 _ => self.policy.write_is_allowed(&abs, &self.project_root),
100 }
101 })
102 .unwrap_or(false),
103 "read" | "grep" | "find" | "ls" => return Decision::Allowed,
104 other if other.contains("__") => {
105 let mut parts = other.splitn(2, "__");
106 let server = parts.next().unwrap_or("");
107 let tool = parts.next().unwrap_or("");
108 self.policy.mcp_auto_allow(server, tool)
109 }
110 _ => false,
111 };
112 if policy_allowed {
113 return Decision::Allowed;
114 }
115
116 let cache_key = SessionCache::key(tool_name, args);
122 if let Some(cached) = self.cache.get(&cache_key) {
123 return cached;
124 }
125
126 match &self.prompt {
128 PromptStrategy::HeadlessDeny => {
129 Decision::Denied("non-interactive: tool requires approval".into())
130 }
131 PromptStrategy::Prompt(ui_tx) => {
132 let (resolver_tx, resolver_rx) = oneshot::channel::<Decision>();
133 if ui_tx
134 .send(UiEvent::PermissionRequested {
135 tool: tool_name.to_string(),
136 args: args.clone(),
137 resolver: resolver_tx,
138 })
139 .await
140 .is_err()
141 {
142 return Decision::Denied("no UI channel to prompt".into());
143 }
144 resolver_rx
145 .await
146 .unwrap_or(Decision::Denied("prompt cancelled".into()))
147 }
148 }
149 }
150}
151
152#[async_trait]
153impl Extension for PermissionExtension {
154 fn name(&self) -> &'static str {
155 "capo-permissions"
156 }
157
158 async fn intercept_tool_call(
159 &mut self,
160 call: ToolCallItem,
161 _ctx: &mut HookCtx<'_>,
162 ) -> Result<ToolDecision, ExtError> {
163 match self.decide(&call.name, &call.args).await {
164 Decision::Allowed => Ok(ToolDecision::Proceed(call)),
165 Decision::Denied(reason) => Ok(ToolDecision::ShortCircuit(ToolResult::error(format!(
166 "Permission denied: {reason}"
167 )))),
168 }
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use std::sync::Arc;
175
176 use tokio::sync::mpsc;
177
178 use super::*;
179
180 #[tokio::test]
181 async fn session_cache_short_circuits_prompt() {
182 let policy = Arc::new(Policy::default());
183 let cache = Arc::new(SessionCache::new());
184 let args = serde_json::json!({"command": "curl https://example.com"});
185 cache.insert(SessionCache::key("bash", &args), Decision::Allowed);
186
187 let (ui_tx, mut ui_rx) = mpsc::channel::<UiEvent>(4);
188 let ext = PermissionExtension::new(
189 Arc::clone(&policy),
190 Arc::clone(&cache),
191 std::env::current_dir().unwrap_or_default(),
192 ui_tx,
193 );
194
195 let decision = ext.decide("bash", &args).await;
196 assert!(matches!(decision, Decision::Allowed));
197 assert!(ui_rx.try_recv().is_err());
198 }
199
200 #[tokio::test]
201 async fn grep_find_ls_are_auto_allowed() {
202 let policy = Arc::new(Policy::default());
203 let cache = Arc::new(SessionCache::new());
204 let (ui_tx, mut ui_rx) = mpsc::channel::<UiEvent>(4);
205 let ext = PermissionExtension::new(
206 Arc::clone(&policy),
207 Arc::clone(&cache),
208 std::env::current_dir().unwrap_or_default(),
209 ui_tx,
210 );
211
212 for tool in ["grep", "find", "ls"] {
213 let decision = ext.decide(tool, &serde_json::json!({})).await;
214 assert!(
215 matches!(decision, Decision::Allowed),
216 "{tool} not auto-allowed"
217 );
218 }
219 assert!(ui_rx.try_recv().is_err());
220 }
221
222 #[tokio::test]
223 async fn headless_denies_a_would_prompt_tool_but_keeps_auto_allows() {
224 let policy = Arc::new(Policy::default());
225 let cache = Arc::new(SessionCache::new());
226 let ext = PermissionExtension::headless(
227 Arc::clone(&policy),
228 Arc::clone(&cache),
229 std::env::current_dir().unwrap_or_default(),
230 );
231 let denied = ext
234 .decide("bash", &serde_json::json!({"command": "curl https://x"}))
235 .await;
236 assert!(matches!(denied, Decision::Denied(_)));
237 let allowed = ext.decide("read", &serde_json::json!({})).await;
239 assert!(matches!(allowed, Decision::Allowed));
240 }
241}