1use std::sync::Arc;
16
17use parking_lot::RwLock;
18use tracing::debug;
19
20use crate::audit::{AuditEntry, AuditLogger, AuditResult, chrono_now};
21use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
22use crate::policy::{PolicyContext, PolicyDecision, PolicyEnforcer};
23use crate::registry::ToolDef;
24
25pub struct PolicyGateExecutor<T: ToolExecutor> {
30 inner: T,
31 enforcer: Arc<PolicyEnforcer>,
32 context: Arc<RwLock<PolicyContext>>,
33 audit: Option<Arc<AuditLogger>>,
34}
35
36impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for PolicyGateExecutor<T> {
37 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38 f.debug_struct("PolicyGateExecutor")
39 .field("inner", &self.inner)
40 .finish_non_exhaustive()
41 }
42}
43
44impl<T: ToolExecutor> PolicyGateExecutor<T> {
45 #[must_use]
47 pub fn new(
48 inner: T,
49 enforcer: Arc<PolicyEnforcer>,
50 context: Arc<RwLock<PolicyContext>>,
51 ) -> Self {
52 Self {
53 inner,
54 enforcer,
55 context,
56 audit: None,
57 }
58 }
59
60 #[must_use]
62 pub fn with_audit(mut self, audit: Arc<AuditLogger>) -> Self {
63 self.audit = Some(audit);
64 self
65 }
66
67 fn read_context(&self) -> PolicyContext {
68 self.context.read().clone()
69 }
70
71 pub fn update_context(&self, new_ctx: PolicyContext) {
73 *self.context.write() = new_ctx;
74 }
75
76 async fn check_policy(&self, call: &ToolCall) -> Result<(), ToolError> {
77 let ctx = self.read_context();
78 let decision = self.enforcer.evaluate(&call.tool_id, &call.params, &ctx);
79
80 match &decision {
81 PolicyDecision::Allow { trace } => {
82 debug!(tool = %call.tool_id, trace = %trace, "policy: allow");
83 if let Some(audit) = &self.audit {
84 let entry = AuditEntry {
85 timestamp: chrono_now(),
86 tool: call.tool_id.clone(),
87 command: truncate_params(&call.params),
88 result: AuditResult::Success,
89 duration_ms: 0,
90 error_category: None,
91 error_domain: None,
92 error_phase: None,
93 claim_source: None,
94 mcp_server_id: None,
95 injection_flagged: false,
96 embedding_anomalous: false,
97 cross_boundary_mcp_to_acp: false,
98 adversarial_policy_decision: None,
99 exit_code: None,
100 truncated: false,
101 caller_id: call.caller_id.clone(),
102 policy_match: Some(trace.clone()),
104 };
105 audit.log(&entry).await;
106 }
107 Ok(())
108 }
109 PolicyDecision::Deny { trace } => {
110 debug!(tool = %call.tool_id, trace = %trace, "policy: deny");
111 if let Some(audit) = &self.audit {
112 let entry = AuditEntry {
113 timestamp: chrono_now(),
114 tool: call.tool_id.clone(),
115 command: truncate_params(&call.params),
116 result: AuditResult::Blocked {
117 reason: trace.clone(),
118 },
119 duration_ms: 0,
120 error_category: Some("policy_blocked".to_owned()),
121 error_domain: Some("action".to_owned()),
122 error_phase: None,
123 claim_source: None,
124 mcp_server_id: None,
125 injection_flagged: false,
126 embedding_anomalous: false,
127 cross_boundary_mcp_to_acp: false,
128 adversarial_policy_decision: None,
129 exit_code: None,
130 truncated: false,
131 caller_id: call.caller_id.clone(),
132 policy_match: Some(trace.clone()),
134 };
135 audit.log(&entry).await;
136 }
137 Err(ToolError::Blocked {
139 command: "Tool call denied by policy".to_owned(),
140 })
141 }
142 }
143 }
144}
145
146impl<T: ToolExecutor> ToolExecutor for PolicyGateExecutor<T> {
147 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
150 Err(ToolError::Blocked {
151 command:
152 "legacy unstructured dispatch is not supported when policy enforcement is enabled"
153 .into(),
154 })
155 }
156
157 async fn execute_confirmed(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
158 Err(ToolError::Blocked {
159 command:
160 "legacy unstructured dispatch is not supported when policy enforcement is enabled"
161 .into(),
162 })
163 }
164
165 fn tool_definitions(&self) -> Vec<ToolDef> {
166 self.inner.tool_definitions()
167 }
168
169 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
170 self.check_policy(call).await?;
171 let result = self.inner.execute_tool_call(call).await;
172 if let Ok(Some(ref output)) = result
175 && let Some(colon) = output.tool_name.find(':')
176 {
177 let server_id = output.tool_name[..colon].to_owned();
178 if let Some(audit) = &self.audit {
179 let entry = AuditEntry {
180 timestamp: chrono_now(),
181 tool: call.tool_id.clone(),
182 command: truncate_params(&call.params),
183 result: AuditResult::Success,
184 duration_ms: 0,
185 error_category: None,
186 error_domain: None,
187 error_phase: None,
188 claim_source: None,
189 mcp_server_id: Some(server_id),
190 injection_flagged: false,
191 embedding_anomalous: false,
192 cross_boundary_mcp_to_acp: false,
193 adversarial_policy_decision: None,
194 exit_code: None,
195 truncated: false,
196 caller_id: call.caller_id.clone(),
197 policy_match: None,
198 };
199 audit.log(&entry).await;
200 }
201 }
202 result
203 }
204
205 async fn execute_tool_call_confirmed(
208 &self,
209 call: &ToolCall,
210 ) -> Result<Option<ToolOutput>, ToolError> {
211 self.check_policy(call).await?;
212 self.inner.execute_tool_call_confirmed(call).await
213 }
214
215 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
216 self.inner.set_skill_env(env);
217 }
218
219 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
220 self.context.write().trust_level = level;
221 self.inner.set_effective_trust(level);
222 }
223
224 fn is_tool_retryable(&self, tool_id: &str) -> bool {
225 self.inner.is_tool_retryable(tool_id)
226 }
227}
228
229fn truncate_params(params: &serde_json::Map<String, serde_json::Value>) -> String {
230 let s = serde_json::to_string(params).unwrap_or_default();
231 if s.chars().count() > 500 {
232 let truncated: String = s.chars().take(497).collect();
233 format!("{truncated}…")
234 } else {
235 s
236 }
237}
238
239#[cfg(test)]
240mod tests {
241 use std::collections::HashMap;
242 use std::sync::Arc;
243
244 use super::*;
245 use crate::SkillTrustLevel;
246 use crate::policy::{
247 DefaultEffect, PolicyConfig, PolicyEffect, PolicyEnforcer, PolicyRuleConfig,
248 };
249
250 #[derive(Debug)]
251 struct MockExecutor;
252
253 impl ToolExecutor for MockExecutor {
254 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
255 Ok(None)
256 }
257 async fn execute_tool_call(
258 &self,
259 call: &ToolCall,
260 ) -> Result<Option<ToolOutput>, ToolError> {
261 Ok(Some(ToolOutput {
262 tool_name: call.tool_id.clone(),
263 summary: "ok".into(),
264 blocks_executed: 1,
265 filter_stats: None,
266 diff: None,
267 streamed: false,
268 terminal_id: None,
269 locations: None,
270 raw_response: None,
271 claim_source: None,
272 }))
273 }
274 }
275
276 fn make_gate(config: &PolicyConfig) -> PolicyGateExecutor<MockExecutor> {
277 let enforcer = Arc::new(PolicyEnforcer::compile(config).unwrap());
278 let context = Arc::new(RwLock::new(PolicyContext {
279 trust_level: SkillTrustLevel::Trusted,
280 env: HashMap::new(),
281 }));
282 PolicyGateExecutor::new(MockExecutor, enforcer, context)
283 }
284
285 fn make_call(tool_id: &str) -> ToolCall {
286 ToolCall {
287 tool_id: tool_id.into(),
288 params: serde_json::Map::new(),
289 caller_id: None,
290 }
291 }
292
293 fn make_call_with_path(tool_id: &str, path: &str) -> ToolCall {
294 let mut params = serde_json::Map::new();
295 params.insert("file_path".into(), serde_json::Value::String(path.into()));
296 ToolCall {
297 tool_id: tool_id.into(),
298 params,
299 caller_id: None,
300 }
301 }
302
303 #[tokio::test]
304 async fn allow_by_default_when_default_allow() {
305 let config = PolicyConfig {
306 enabled: true,
307 default_effect: DefaultEffect::Allow,
308 rules: vec![],
309 policy_file: None,
310 };
311 let gate = make_gate(&config);
312 let result = gate.execute_tool_call(&make_call("bash")).await;
313 assert!(result.is_ok());
314 }
315
316 #[tokio::test]
317 async fn deny_by_default_when_default_deny() {
318 let config = PolicyConfig {
319 enabled: true,
320 default_effect: DefaultEffect::Deny,
321 rules: vec![],
322 policy_file: None,
323 };
324 let gate = make_gate(&config);
325 let result = gate.execute_tool_call(&make_call("bash")).await;
326 assert!(matches!(result, Err(ToolError::Blocked { .. })));
327 }
328
329 #[tokio::test]
330 async fn deny_rule_blocks_tool() {
331 let config = PolicyConfig {
332 enabled: true,
333 default_effect: DefaultEffect::Allow,
334 rules: vec![PolicyRuleConfig {
335 effect: PolicyEffect::Deny,
336 tool: "shell".to_owned(),
337 paths: vec!["/etc/*".to_owned()],
338 env: vec![],
339 trust_level: None,
340 args_match: None,
341 capabilities: vec![],
342 }],
343 policy_file: None,
344 };
345 let gate = make_gate(&config);
346 let result = gate
347 .execute_tool_call(&make_call_with_path("shell", "/etc/passwd"))
348 .await;
349 assert!(matches!(result, Err(ToolError::Blocked { .. })));
350 }
351
352 #[tokio::test]
353 async fn allow_rule_permits_tool() {
354 let config = PolicyConfig {
355 enabled: true,
356 default_effect: DefaultEffect::Deny,
357 rules: vec![PolicyRuleConfig {
358 effect: PolicyEffect::Allow,
359 tool: "shell".to_owned(),
360 paths: vec!["/tmp/*".to_owned()],
361 env: vec![],
362 trust_level: None,
363 args_match: None,
364 capabilities: vec![],
365 }],
366 policy_file: None,
367 };
368 let gate = make_gate(&config);
369 let result = gate
370 .execute_tool_call(&make_call_with_path("shell", "/tmp/foo.sh"))
371 .await;
372 assert!(result.is_ok());
373 }
374
375 #[tokio::test]
376 async fn error_message_is_generic() {
377 let config = PolicyConfig {
379 enabled: true,
380 default_effect: DefaultEffect::Deny,
381 rules: vec![],
382 policy_file: None,
383 };
384 let gate = make_gate(&config);
385 let err = gate
386 .execute_tool_call(&make_call("bash"))
387 .await
388 .unwrap_err();
389 if let ToolError::Blocked { command } = err {
390 assert!(!command.contains("rule["), "must not leak rule index");
391 assert!(!command.contains("/etc/"), "must not leak path pattern");
392 } else {
393 panic!("expected Blocked error");
394 }
395 }
396
397 #[tokio::test]
398 async fn confirmed_also_enforces_policy() {
399 let config = PolicyConfig {
401 enabled: true,
402 default_effect: DefaultEffect::Deny,
403 rules: vec![],
404 policy_file: None,
405 };
406 let gate = make_gate(&config);
407 let result = gate.execute_tool_call_confirmed(&make_call("bash")).await;
408 assert!(matches!(result, Err(ToolError::Blocked { .. })));
409 }
410
411 #[tokio::test]
413 async fn confirmed_allow_delegates_to_inner() {
414 let config = PolicyConfig {
415 enabled: true,
416 default_effect: DefaultEffect::Allow,
417 rules: vec![],
418 policy_file: None,
419 };
420 let gate = make_gate(&config);
421 let call = make_call("shell");
422 let result = gate.execute_tool_call_confirmed(&call).await;
423 assert!(result.is_ok(), "allow path must not return an error");
424 let output = result.unwrap();
425 assert!(
426 output.is_some(),
427 "inner executor must be invoked and return output on allow"
428 );
429 assert_eq!(
430 output.unwrap().tool_name,
431 "shell",
432 "output tool_name must match the confirmed call"
433 );
434 }
435
436 #[tokio::test]
437 async fn legacy_execute_blocked_when_policy_enabled() {
438 let config = PolicyConfig {
441 enabled: true,
442 default_effect: DefaultEffect::Deny,
443 rules: vec![],
444 policy_file: None,
445 };
446 let gate = make_gate(&config);
447 let result = gate.execute("```bash\necho hi\n```").await;
448 assert!(matches!(result, Err(ToolError::Blocked { .. })));
449 let result_confirmed = gate.execute_confirmed("```bash\necho hi\n```").await;
450 assert!(matches!(result_confirmed, Err(ToolError::Blocked { .. })));
451 }
452
453 #[tokio::test]
456 async fn set_effective_trust_quarantined_blocks_verified_threshold_rule() {
457 let config = PolicyConfig {
461 enabled: true,
462 default_effect: DefaultEffect::Deny,
463 rules: vec![PolicyRuleConfig {
464 effect: PolicyEffect::Allow,
465 tool: "shell".to_owned(),
466 paths: vec![],
467 env: vec![],
468 trust_level: Some(SkillTrustLevel::Verified),
469 args_match: None,
470 capabilities: vec![],
471 }],
472 policy_file: None,
473 };
474 let gate = make_gate(&config);
475 gate.set_effective_trust(SkillTrustLevel::Quarantined);
476 let result = gate.execute_tool_call(&make_call("shell")).await;
477 assert!(
478 matches!(result, Err(ToolError::Blocked { .. })),
479 "Quarantined context must not satisfy a Verified trust threshold allow rule"
480 );
481 }
482
483 #[tokio::test]
484 async fn set_effective_trust_trusted_satisfies_verified_threshold_rule() {
485 let config = PolicyConfig {
489 enabled: true,
490 default_effect: DefaultEffect::Deny,
491 rules: vec![PolicyRuleConfig {
492 effect: PolicyEffect::Allow,
493 tool: "shell".to_owned(),
494 paths: vec![],
495 env: vec![],
496 trust_level: Some(SkillTrustLevel::Verified),
497 args_match: None,
498 capabilities: vec![],
499 }],
500 policy_file: None,
501 };
502 let gate = make_gate(&config);
503 gate.set_effective_trust(SkillTrustLevel::Trusted);
504 let result = gate.execute_tool_call(&make_call("shell")).await;
505 assert!(
506 result.is_ok(),
507 "Trusted context must satisfy a Verified trust threshold allow rule"
508 );
509 }
510}