1use std::sync::Arc;
22
23use tracing::info_span;
24
25use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
26use crate::registry::ToolDef;
27
28pub trait ProbeGate: Send + Sync {
36 fn probe<'a>(
38 &'a self,
39 qualified_tool_id: &'a str,
40 args: &'a serde_json::Value,
41 turn_number: u64,
42 risk_level: &'a str,
43 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>;
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
48pub enum ProbeOutcome {
49 Allow,
51 Deny {
53 reason: String,
55 },
56 Skip,
58}
59
60pub struct ShadowProbeExecutor<T: ToolExecutor> {
70 inner: T,
71 probe: Arc<dyn ProbeGate>,
72 turn_number: Arc<std::sync::atomic::AtomicU64>,
75 risk_level: Arc<parking_lot::RwLock<String>>,
77}
78
79impl<T: ToolExecutor + std::fmt::Debug> std::fmt::Debug for ShadowProbeExecutor<T> {
80 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
81 f.debug_struct("ShadowProbeExecutor")
82 .field("inner", &self.inner)
83 .finish_non_exhaustive()
84 }
85}
86
87impl<T: ToolExecutor> ShadowProbeExecutor<T> {
88 #[must_use]
97 pub fn new(
98 inner: T,
99 probe: Arc<dyn ProbeGate>,
100 turn_number: Arc<std::sync::atomic::AtomicU64>,
101 risk_level: Arc<parking_lot::RwLock<String>>,
102 ) -> Self {
103 Self {
104 inner,
105 probe,
106 turn_number,
107 risk_level,
108 }
109 }
110
111 fn current_turn(&self) -> u64 {
112 self.turn_number.load(std::sync::atomic::Ordering::Acquire)
113 }
114
115 fn current_risk_level(&self) -> String {
116 self.risk_level.read().clone()
117 }
118}
119
120impl<T: ToolExecutor> ToolExecutor for ShadowProbeExecutor<T> {
121 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
123 self.inner.execute(response).await
124 }
125
126 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
128 self.inner.execute_confirmed(response).await
129 }
130
131 fn tool_definitions(&self) -> Vec<ToolDef> {
132 self.inner.tool_definitions()
133 }
134
135 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
140 let span = info_span!(
141 "security.shadow.probe_executor",
142 tool_id = %call.tool_id
143 );
144 let _enter = span.enter();
145
146 let args = serde_json::Value::Object(call.params.clone());
147 let turn = self.current_turn();
148 let risk = self.current_risk_level();
149
150 let outcome = self
151 .probe
152 .probe(call.tool_id.as_str(), &args, turn, &risk)
153 .await;
154
155 match outcome {
156 ProbeOutcome::Allow | ProbeOutcome::Skip => self.inner.execute_tool_call(call).await,
157 ProbeOutcome::Deny { reason } => {
158 tracing::warn!(
159 tool_id = %call.tool_id,
160 reason = %reason,
161 "ShadowProbeExecutor: safety probe denied tool call"
162 );
163 Err(ToolError::SafetyDenied { reason })
164 }
165 }
166 }
167
168 async fn execute_tool_call_confirmed(
172 &self,
173 call: &ToolCall,
174 ) -> Result<Option<ToolOutput>, ToolError> {
175 let span = info_span!(
176 "security.shadow.probe_executor_confirmed",
177 tool_id = %call.tool_id
178 );
179 let _enter = span.enter();
180
181 let args = serde_json::Value::Object(call.params.clone());
182 let turn = self.current_turn();
183 let risk = self.current_risk_level();
184
185 let outcome = self
186 .probe
187 .probe(call.tool_id.as_str(), &args, turn, &risk)
188 .await;
189
190 match outcome {
191 ProbeOutcome::Allow | ProbeOutcome::Skip => {
192 self.inner.execute_tool_call_confirmed(call).await
193 }
194 ProbeOutcome::Deny { reason } => {
195 tracing::warn!(
196 tool_id = %call.tool_id,
197 reason = %reason,
198 "ShadowProbeExecutor: safety probe denied confirmed tool call"
199 );
200 Err(ToolError::SafetyDenied { reason })
201 }
202 }
203 }
204
205 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
206 self.inner.set_skill_env(env);
207 }
208
209 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
210 self.inner.set_effective_trust(level);
211 }
212
213 fn is_tool_retryable(&self, tool_id: &str) -> bool {
214 self.inner.is_tool_retryable(tool_id)
215 }
216
217 fn is_tool_speculatable(&self, tool_id: &str) -> bool {
218 let _ = tool_id;
221 false
222 }
223
224 fn requires_confirmation(&self, call: &ToolCall) -> bool {
225 self.inner.requires_confirmation(call)
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::executor::{ToolError, ToolOutput};
233 use crate::{ToolCall, ToolExecutor};
234 use zeph_common::ToolName;
235
236 struct AllowProbe;
237 impl ProbeGate for AllowProbe {
238 fn probe<'a>(
239 &'a self,
240 _: &'a str,
241 _: &'a serde_json::Value,
242 _: u64,
243 _: &'a str,
244 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
245 {
246 Box::pin(async { ProbeOutcome::Allow })
247 }
248 }
249
250 struct DenyProbe;
251 impl ProbeGate for DenyProbe {
252 fn probe<'a>(
253 &'a self,
254 _: &'a str,
255 _: &'a serde_json::Value,
256 _: u64,
257 _: &'a str,
258 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
259 {
260 Box::pin(async {
261 ProbeOutcome::Deny {
262 reason: "test denial".to_owned(),
263 }
264 })
265 }
266 }
267
268 struct SkipProbe;
269 impl ProbeGate for SkipProbe {
270 fn probe<'a>(
271 &'a self,
272 _: &'a str,
273 _: &'a serde_json::Value,
274 _: u64,
275 _: &'a str,
276 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ProbeOutcome> + Send + 'a>>
277 {
278 Box::pin(async { ProbeOutcome::Skip })
279 }
280 }
281
282 struct OkInner;
283 impl ToolExecutor for OkInner {
284 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
285 Ok(None)
286 }
287
288 async fn execute_tool_call(
289 &self,
290 call: &ToolCall,
291 ) -> Result<Option<ToolOutput>, ToolError> {
292 Ok(Some(ToolOutput {
293 tool_name: call.tool_id.clone(),
294 summary: "ok".to_owned(),
295 blocks_executed: 1,
296 filter_stats: None,
297 diff: None,
298 streamed: false,
299 terminal_id: None,
300 locations: None,
301 raw_response: None,
302 claim_source: None,
303 }))
304 }
305 }
306
307 fn make_call(tool: &str) -> ToolCall {
308 ToolCall {
309 tool_id: ToolName::new(tool),
310 params: serde_json::Map::new(),
311 caller_id: None,
312 context: None,
313 tool_call_id: String::new(),
314 }
315 }
316
317 fn make_executor<P: ProbeGate + 'static>(probe: P) -> ShadowProbeExecutor<OkInner> {
318 ShadowProbeExecutor::new(
319 OkInner,
320 Arc::new(probe),
321 Arc::new(std::sync::atomic::AtomicU64::new(1)),
322 Arc::new(parking_lot::RwLock::new("calm".to_owned())),
323 )
324 }
325
326 #[tokio::test]
327 async fn allow_probe_delegates_to_inner() {
328 let exec = make_executor(AllowProbe);
329 let result = exec.execute_tool_call(&make_call("builtin:shell")).await;
330 assert!(result.unwrap().is_some());
331 }
332
333 #[tokio::test]
334 async fn deny_probe_returns_safety_denied() {
335 let exec = make_executor(DenyProbe);
336 let result = exec.execute_tool_call(&make_call("builtin:shell")).await;
337 match result {
338 Err(ToolError::SafetyDenied { reason }) => {
339 assert_eq!(reason, "test denial");
340 }
341 other => panic!("expected SafetyDenied, got {other:?}"),
342 }
343 }
344
345 #[tokio::test]
346 async fn skip_probe_delegates_to_inner() {
347 let exec = make_executor(SkipProbe);
348 let result = exec.execute_tool_call(&make_call("builtin:read")).await;
349 assert!(result.unwrap().is_some());
350 }
351
352 #[tokio::test]
353 async fn legacy_execute_bypasses_probe() {
354 let exec = make_executor(DenyProbe);
355 let result = exec.execute("some text").await;
357 assert!(result.unwrap().is_none());
358 }
359
360 #[tokio::test]
361 async fn deny_probe_blocks_confirmed_call() {
362 let exec = make_executor(DenyProbe);
364 let result = exec
365 .execute_tool_call_confirmed(&make_call("builtin:shell"))
366 .await;
367 match result {
368 Err(ToolError::SafetyDenied { reason }) => {
369 assert_eq!(reason, "test denial");
370 }
371 other => panic!("expected SafetyDenied on confirmed call, got {other:?}"),
372 }
373 }
374
375 #[test]
376 fn is_tool_speculatable_always_false() {
377 let exec = make_executor(AllowProbe);
378 assert!(!exec.is_tool_speculatable("builtin:read"));
379 assert!(!exec.is_tool_speculatable("builtin:shell"));
380 }
381}