Skip to main content

bamboo_server/tools/
policy_aware.rs

1//! `PolicyAwareToolExecutor`: enforces a subagent profile's [`ToolPolicy`]
2//! at tool-call time.
3//!
4//! This is the runtime half of the SubagentProfile feature. Profile loading
5//! and prompt injection live in `subagent_profiles` and
6//! `session_app::child_session`; here we wrap the child-session
7//! [`ToolExecutor`] so that tool calls are filtered against the calling
8//! child's `subagent_type` metadata.
9//!
10//! ## Why a wrapper instead of changing `SpawnContext`
11//!
12//! `bamboo-engine`'s `SpawnContext` carries a single `Arc<dyn ToolExecutor>`
13//! shared by every child session. Different child sessions can have
14//! different `subagent_type`s and therefore different
15//! [`ToolPolicy`]s. Rather than changing the engine to dispatch executors
16//! per session (which would ripple into `SpawnContext`, `SpawnScheduler`
17//! and `ScheduleManager`), this wrapper sits in front of the existing
18//! executor and resolves the policy on the fly using
19//! [`ToolExecutionContext::session_id`] and the in-memory sessions cache
20//! that the server already maintains.
21//!
22//! ## Behaviour summary
23//!
24//! - Tool _execution_ is filtered:
25//!   - [`ToolPolicy::Inherit`] → forwarded unchanged.
26//!   - [`ToolPolicy::Allowlist`] → forwarded only if the tool name is on
27//!     the allow list; otherwise rejected with a clear
28//!     `ToolError::Execution` message.
29//!   - [`ToolPolicy::Denylist`] → rejected if the tool name is on the
30//!     deny list; otherwise forwarded.
31//! - Tool _discovery_ (`list_tools`) is **not** filtered here. The
32//!   advertised tool surface is filtered upstream via the engine's
33//!   `disabled_tools` mechanism, which includes the subagent profile
34//!   policy (see `ChildSessionAdapter::enqueue_child_run`). This wrapper
35//!   remains as a safety net for execution-time enforcement.
36//! - When the wrapper cannot determine a policy (no `session_id`, the
37//!   session is not in cache, no `subagent_type` metadata, or the
38//!   profile is unknown), it forwards unchanged. This keeps the change
39//!   strictly additive: any existing call path that has not yet adopted
40//!   subagent profiles continues to behave exactly as before.
41
42use std::collections::HashMap;
43use std::sync::Arc;
44
45use async_trait::async_trait;
46use bamboo_agent_core::tools::{
47    ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
48};
49use bamboo_agent_core::Session;
50use bamboo_domain::subagent::{SubagentProfileRegistry, ToolPolicy};
51use tokio::sync::RwLock;
52
53/// Tool executor that enforces a subagent profile's [`ToolPolicy`] when
54/// executing tool calls from a child session.
55pub struct PolicyAwareToolExecutor {
56    inner: Arc<dyn ToolExecutor>,
57    profiles: Arc<SubagentProfileRegistry>,
58    sessions: Arc<RwLock<HashMap<String, Session>>>,
59}
60
61impl PolicyAwareToolExecutor {
62    pub fn new(
63        inner: Arc<dyn ToolExecutor>,
64        profiles: Arc<SubagentProfileRegistry>,
65        sessions: Arc<RwLock<HashMap<String, Session>>>,
66    ) -> Self {
67        Self {
68            inner,
69            profiles,
70            sessions,
71        }
72    }
73
74    /// Look up the `subagent_type` metadata for a session id from the
75    /// in-memory cache. Returns `None` when the session is not cached or
76    /// the metadata key is missing / blank.
77    async fn subagent_type_for_session(&self, session_id: &str) -> Option<String> {
78        let sessions = self.sessions.read().await;
79        let value = sessions
80            .get(session_id)?
81            .metadata
82            .get("subagent_type")?
83            .trim();
84        if value.is_empty() {
85            None
86        } else {
87            Some(value.to_string())
88        }
89    }
90
91    /// Check whether a tool call is permitted under the given policy.
92    /// Returns `Ok(())` when allowed, `Err(reason)` when blocked.
93    fn check_policy(
94        policy: &ToolPolicy,
95        tool_name: &str,
96        subagent_type: &str,
97    ) -> Result<(), String> {
98        match policy {
99            ToolPolicy::Inherit => Ok(()),
100            ToolPolicy::Allowlist { allow } => {
101                if allow.iter().any(|t| t == tool_name) {
102                    Ok(())
103                } else {
104                    Err(format!(
105                        "tool '{tool_name}' is not permitted for subagent_type \
106                         '{subagent_type}' (allowlist policy: {allow:?})"
107                    ))
108                }
109            }
110            ToolPolicy::Denylist { deny } => {
111                if deny.iter().any(|t| t == tool_name) {
112                    Err(format!(
113                        "tool '{tool_name}' is denied for subagent_type \
114                         '{subagent_type}' (denylist policy: {deny:?})"
115                    ))
116                } else {
117                    Ok(())
118                }
119            }
120        }
121    }
122
123    /// Resolve the policy for a tool call: returns `Ok(())` when the call
124    /// should proceed, or an `Err` payload to be surfaced as a
125    /// `ToolError::Execution`. When no policy can be resolved the call is
126    /// allowed (legacy / inherit behaviour).
127    async fn evaluate(&self, call: &ToolCall, session_id: Option<&str>) -> Result<(), String> {
128        let Some(session_id) = session_id else {
129            return Ok(());
130        };
131        let Some(subagent_type) = self.subagent_type_for_session(session_id).await else {
132            return Ok(());
133        };
134        let profile = self.profiles.resolve(&subagent_type);
135        Self::check_policy(&profile.tools, call.function.name.trim(), &subagent_type)
136    }
137}
138
139#[async_trait]
140impl ToolExecutor for PolicyAwareToolExecutor {
141    async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
142        // No session context available via plain `execute` — fall through.
143        // The agent loop calls `execute_with_context`; this branch is only
144        // hit by direct callers that don't carry session metadata, in
145        // which case we mirror legacy behaviour exactly.
146        self.inner.execute(call).await
147    }
148
149    async fn execute_with_context(
150        &self,
151        call: &ToolCall,
152        ctx: ToolExecutionContext<'_>,
153    ) -> std::result::Result<ToolResult, ToolError> {
154        if let Err(reason) = self.evaluate(call, ctx.session_id).await {
155            return Err(ToolError::Execution(reason));
156        }
157        self.inner.execute_with_context(call, ctx).await
158    }
159
160    fn list_tools(&self) -> Vec<ToolSchema> {
161        // Discovery is intentionally not filtered here; see module docs.
162        self.inner.list_tools()
163    }
164
165    fn tool_mutability(&self, tool_name: &str) -> bamboo_agent_core::tools::ToolMutability {
166        self.inner.tool_mutability(tool_name)
167    }
168
169    fn call_mutability(&self, call: &ToolCall) -> bamboo_agent_core::tools::ToolMutability {
170        self.inner.call_mutability(call)
171    }
172
173    fn tool_concurrency_safe(&self, tool_name: &str) -> bool {
174        self.inner.tool_concurrency_safe(tool_name)
175    }
176
177    fn call_concurrency_safe(&self, call: &ToolCall) -> bool {
178        self.inner.call_concurrency_safe(call)
179    }
180}
181
182#[cfg(test)]
183mod tests {
184    use super::*;
185    use bamboo_agent_core::tools::{FunctionCall, ToolMutability};
186    use bamboo_domain::subagent::SubagentProfile;
187
188    /// Tiny stand-in executor that records every name it was asked to
189    /// execute and always succeeds with a fixed payload. We use it to
190    /// assert "forwarded vs blocked" without spinning up the full builtin
191    /// tool surface.
192    struct RecordingExecutor {
193        executed: Arc<RwLock<Vec<String>>>,
194    }
195
196    impl RecordingExecutor {
197        fn new() -> (Arc<Self>, Arc<RwLock<Vec<String>>>) {
198            let executed = Arc::new(RwLock::new(Vec::new()));
199            let exec = Arc::new(Self {
200                executed: executed.clone(),
201            });
202            (exec, executed)
203        }
204    }
205
206    #[async_trait]
207    impl ToolExecutor for RecordingExecutor {
208        async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
209            self.executed.write().await.push(call.function.name.clone());
210            Ok(ToolResult {
211                success: true,
212                result: "ok".to_string(),
213                display_preference: None,
214            })
215        }
216
217        fn list_tools(&self) -> Vec<ToolSchema> {
218            Vec::new()
219        }
220
221        fn tool_mutability(&self, _tool_name: &str) -> ToolMutability {
222            ToolMutability::ReadOnly
223        }
224    }
225
226    fn make_call(name: &str) -> ToolCall {
227        ToolCall {
228            id: "call_1".to_string(),
229            tool_type: "function".to_string(),
230            function: FunctionCall {
231                name: name.to_string(),
232                arguments: "{}".to_string(),
233            },
234        }
235    }
236
237    fn registry_with(profile: SubagentProfile) -> Arc<SubagentProfileRegistry> {
238        // The registry requires the fallback id to exist in the profile
239        // set. Use the profile's own id so each test can declare just one
240        // profile and still build a valid registry.
241        let id = profile.id.clone();
242        Arc::new(
243            SubagentProfileRegistry::builder()
244                .extend(vec![profile])
245                .fallback_id(id)
246                .build()
247                .expect("registry build"),
248        )
249    }
250
251    fn profile(id: &str, tools: ToolPolicy) -> SubagentProfile {
252        SubagentProfile {
253            id: id.to_string(),
254            display_name: id.to_string(),
255            description: String::new(),
256            system_prompt: "p".to_string(),
257            tools,
258            model_hint: None,
259            default_responsibility: None,
260            ui: Default::default(),
261        }
262    }
263
264    async fn sessions_with(
265        session_id: &str,
266        subagent_type: Option<&str>,
267    ) -> Arc<RwLock<HashMap<String, Session>>> {
268        let mut map = HashMap::new();
269        let mut session = Session::new_child(session_id, "root", "test-model", "Child");
270        if let Some(t) = subagent_type {
271            session
272                .metadata
273                .insert("subagent_type".to_string(), t.to_string());
274        }
275        map.insert(session_id.to_string(), session);
276        Arc::new(RwLock::new(map))
277    }
278
279    #[tokio::test]
280    async fn inherit_policy_forwards_all_calls() {
281        let (inner, executed) = RecordingExecutor::new();
282        let registry = registry_with(profile("test", ToolPolicy::Inherit));
283        let sessions = sessions_with("s1", Some("test")).await;
284        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
285
286        let call = make_call("Read");
287        let ctx = ToolExecutionContext {
288            session_id: Some("s1"),
289            tool_call_id: "call_1",
290            event_tx: None,
291            available_tool_schemas: None,
292        };
293        exec.execute_with_context(&call, ctx).await.unwrap();
294        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
295    }
296
297    #[tokio::test]
298    async fn allowlist_permits_listed_tool() {
299        let (inner, executed) = RecordingExecutor::new();
300        let registry = registry_with(profile(
301            "researcher",
302            ToolPolicy::Allowlist {
303                allow: vec!["Read".to_string(), "Grep".to_string()],
304            },
305        ));
306        let sessions = sessions_with("s1", Some("researcher")).await;
307        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
308
309        let ctx = ToolExecutionContext {
310            session_id: Some("s1"),
311            tool_call_id: "call_1",
312            event_tx: None,
313            available_tool_schemas: None,
314        };
315        exec.execute_with_context(&make_call("Read"), ctx)
316            .await
317            .unwrap();
318        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
319    }
320
321    #[tokio::test]
322    async fn allowlist_blocks_unlisted_tool() {
323        let (inner, executed) = RecordingExecutor::new();
324        let registry = registry_with(profile(
325            "researcher",
326            ToolPolicy::Allowlist {
327                allow: vec!["Read".to_string()],
328            },
329        ));
330        let sessions = sessions_with("s1", Some("researcher")).await;
331        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
332
333        let ctx = ToolExecutionContext {
334            session_id: Some("s1"),
335            tool_call_id: "call_1",
336            event_tx: None,
337            available_tool_schemas: None,
338        };
339        let err = exec
340            .execute_with_context(&make_call("Edit"), ctx)
341            .await
342            .unwrap_err();
343        match err {
344            ToolError::Execution(msg) => {
345                assert!(msg.contains("Edit"), "msg should name tool: {msg}");
346                assert!(
347                    msg.contains("researcher"),
348                    "msg should name subagent_type: {msg}"
349                );
350                assert!(msg.contains("allowlist"), "msg should name mode: {msg}");
351            }
352            other => panic!("expected ToolError::Execution, got {other:?}"),
353        }
354        assert!(executed.read().await.is_empty());
355    }
356
357    #[tokio::test]
358    async fn denylist_blocks_listed_tool() {
359        let (inner, executed) = RecordingExecutor::new();
360        let registry = registry_with(profile(
361            "coder",
362            ToolPolicy::Denylist {
363                deny: vec!["SubAgent".to_string()],
364            },
365        ));
366        let sessions = sessions_with("s1", Some("coder")).await;
367        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
368
369        let ctx = ToolExecutionContext {
370            session_id: Some("s1"),
371            tool_call_id: "call_1",
372            event_tx: None,
373            available_tool_schemas: None,
374        };
375        let err = exec
376            .execute_with_context(&make_call("SubAgent"), ctx)
377            .await
378            .unwrap_err();
379        match err {
380            ToolError::Execution(msg) => {
381                assert!(msg.contains("SubAgent"));
382                assert!(msg.contains("denylist"));
383            }
384            other => panic!("expected ToolError::Execution, got {other:?}"),
385        }
386        assert!(executed.read().await.is_empty());
387    }
388
389    #[tokio::test]
390    async fn denylist_permits_unlisted_tool() {
391        let (inner, executed) = RecordingExecutor::new();
392        let registry = registry_with(profile(
393            "coder",
394            ToolPolicy::Denylist {
395                deny: vec!["SubAgent".to_string()],
396            },
397        ));
398        let sessions = sessions_with("s1", Some("coder")).await;
399        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
400
401        let ctx = ToolExecutionContext {
402            session_id: Some("s1"),
403            tool_call_id: "call_1",
404            event_tx: None,
405            available_tool_schemas: None,
406        };
407        exec.execute_with_context(&make_call("Read"), ctx)
408            .await
409            .unwrap();
410        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
411    }
412
413    #[tokio::test]
414    async fn missing_session_id_falls_through() {
415        // No session_id in context → wrapper must not reject. This is the
416        // "legacy / direct caller" path and must keep working.
417        let (inner, executed) = RecordingExecutor::new();
418        let registry = registry_with(profile(
419            "researcher",
420            ToolPolicy::Allowlist {
421                allow: vec!["Read".to_string()],
422            },
423        ));
424        let sessions = sessions_with("s1", Some("researcher")).await;
425        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
426
427        let ctx = ToolExecutionContext::none("call_1");
428        exec.execute_with_context(&make_call("Edit"), ctx)
429            .await
430            .unwrap();
431        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
432    }
433
434    #[tokio::test]
435    async fn unknown_session_falls_through() {
436        let (inner, executed) = RecordingExecutor::new();
437        let registry = registry_with(profile(
438            "researcher",
439            ToolPolicy::Allowlist {
440                allow: vec!["Read".to_string()],
441            },
442        ));
443        // Cache contains a different session id than the one referenced.
444        let sessions = sessions_with("other", Some("researcher")).await;
445        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
446
447        let ctx = ToolExecutionContext {
448            session_id: Some("missing"),
449            tool_call_id: "call_1",
450            event_tx: None,
451            available_tool_schemas: None,
452        };
453        exec.execute_with_context(&make_call("Edit"), ctx)
454            .await
455            .unwrap();
456        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
457    }
458
459    #[tokio::test]
460    async fn missing_subagent_type_metadata_falls_through() {
461        let (inner, executed) = RecordingExecutor::new();
462        let registry = registry_with(profile(
463            "researcher",
464            ToolPolicy::Allowlist {
465                allow: vec!["Read".to_string()],
466            },
467        ));
468        // Session is cached but has no subagent_type metadata.
469        let sessions = sessions_with("s1", None).await;
470        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
471
472        let ctx = ToolExecutionContext {
473            session_id: Some("s1"),
474            tool_call_id: "call_1",
475            event_tx: None,
476            available_tool_schemas: None,
477        };
478        exec.execute_with_context(&make_call("Edit"), ctx)
479            .await
480            .unwrap();
481        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
482    }
483
484    #[tokio::test]
485    async fn execute_without_context_forwards() {
486        // The plain `execute` path has no session context. We document this
487        // by forwarding unconditionally.
488        let (inner, executed) = RecordingExecutor::new();
489        let registry = registry_with(profile(
490            "researcher",
491            ToolPolicy::Allowlist {
492                allow: vec!["Read".to_string()],
493            },
494        ));
495        let sessions = sessions_with("s1", Some("researcher")).await;
496        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
497
498        exec.execute(&make_call("Edit")).await.unwrap();
499        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
500    }
501}