Skip to main content

bamboo_tools/
policy_aware.rs

1//! `PolicyAwareToolExecutor`: enforces a subagent profile's [`ToolPolicy`]
2//! at tool-call time.
3//!
4//! This module lives in `bamboo-tools` (it depends only on `bamboo-agent-core`
5//! and `bamboo-domain`); `bamboo-server` re-exports it through a thin shim
6//! (`crate::tools::PolicyAwareToolExecutor`) for back-compat.
7//!
8//! This is the runtime half of the SubagentProfile feature. Profile loading
9//! and prompt injection live in `bamboo_engine::profiles` and
10//! `bamboo_engine::session_app::child_session`; here we wrap the child-session
11//! [`ToolExecutor`] so that tool calls are filtered against the calling
12//! child's `subagent_type` metadata.
13//!
14//! ## Why a wrapper instead of changing `SpawnContext`
15//!
16//! `bamboo-engine`'s `SpawnContext` carries a single `Arc<dyn ToolExecutor>`
17//! shared by every child session. Different child sessions can have
18//! different `subagent_type`s and therefore different
19//! [`ToolPolicy`]s. Rather than changing the engine to dispatch executors
20//! per session (which would ripple into `SpawnContext`, `SpawnScheduler`
21//! and `ScheduleManager`), this wrapper sits in front of the existing
22//! executor and resolves the policy on the fly using
23//! [`ToolExecutionContext::session_id`] and the in-memory sessions cache
24//! that the server already maintains.
25//!
26//! ## Behaviour summary
27//!
28//! - Tool _execution_ is filtered:
29//!   - [`ToolPolicy::Inherit`] → forwarded unchanged.
30//!   - [`ToolPolicy::Allowlist`] → forwarded only if the tool name is on
31//!     the allow list; otherwise rejected with a clear
32//!     `ToolError::Execution` message.
33//!   - [`ToolPolicy::Denylist`] → rejected if the tool name is on the
34//!     deny list; otherwise forwarded.
35//! - Tool _discovery_ (`list_tools`) is **not** filtered here. The
36//!   advertised tool surface is filtered upstream via the engine's
37//!   `disabled_tools` mechanism, which includes the subagent profile
38//!   policy (see `ChildSessionAdapter::enqueue_child_run`). This wrapper
39//!   remains as a safety net for execution-time enforcement.
40//! - When the wrapper cannot associate the call with a subagent profile (no
41//!   `session_id`, the session is not in cache, or no `subagent_type`
42//!   metadata), it forwards unchanged. This keeps the change strictly
43//!   additive: any existing call path that has not yet adopted subagent
44//!   profiles continues to behave exactly as before.
45//! - When a `subagent_type` *is* present but unrecognized,
46//!   [`SubagentProfileRegistry::resolve`] returns the registry's fallback
47//!   profile and that profile's policy is enforced. With the default
48//!   `general-purpose` (Inherit) fallback this still forwards unchanged, but a
49//!   restrictively-configured fallback profile would apply to unknown types.
50
51use std::sync::Arc;
52
53use async_trait::async_trait;
54use bamboo_agent_core::tools::{
55    ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
56};
57use bamboo_agent_core::Session;
58use bamboo_domain::subagent::{SubagentProfileRegistry, ToolPolicy};
59
60/// Shared, per-session-locked session cache.
61///
62/// Concretely identical to `bamboo_engine::SessionCache`; defined here (rather
63/// than imported) because `bamboo-tools` sits *below* `bamboo-engine` in the
64/// dependency graph. Both are transparent aliases of the same type, so values
65/// flow between them without conversion.
66pub type SessionCache = Arc<dashmap::DashMap<String, Arc<parking_lot::RwLock<Session>>>>;
67
68/// Tool executor that enforces a subagent profile's [`ToolPolicy`] when
69/// executing tool calls from a child session.
70pub struct PolicyAwareToolExecutor {
71    inner: Arc<dyn ToolExecutor>,
72    profiles: Arc<SubagentProfileRegistry>,
73    sessions: SessionCache,
74}
75
76impl PolicyAwareToolExecutor {
77    pub fn new(
78        inner: Arc<dyn ToolExecutor>,
79        profiles: Arc<SubagentProfileRegistry>,
80        sessions: SessionCache,
81    ) -> Self {
82        Self {
83            inner,
84            profiles,
85            sessions,
86        }
87    }
88
89    /// Look up the `subagent_type` metadata for a session id from the
90    /// in-memory cache. Returns `None` when the session is not cached or
91    /// the metadata key is missing / blank.
92    async fn subagent_type_for_session(&self, session_id: &str) -> Option<String> {
93        let arc = self.sessions.get(session_id).map(|e| e.value().clone())?;
94        let value = arc.read().subagent_type()?;
95        let trimmed = value.trim();
96        if trimmed.is_empty() {
97            None
98        } else {
99            Some(trimmed.to_string())
100        }
101    }
102
103    /// Check whether a tool call is permitted under the given policy.
104    /// Returns `Ok(())` when allowed, `Err(reason)` when blocked.
105    fn check_policy(
106        policy: &ToolPolicy,
107        tool_name: &str,
108        subagent_type: &str,
109    ) -> Result<(), String> {
110        match policy {
111            ToolPolicy::Inherit => Ok(()),
112            ToolPolicy::Allowlist { allow } => {
113                if allow.iter().any(|t| t == tool_name) {
114                    Ok(())
115                } else {
116                    Err(format!(
117                        "tool '{tool_name}' is not permitted for subagent_type \
118                         '{subagent_type}' (allowlist policy: {allow:?})"
119                    ))
120                }
121            }
122            ToolPolicy::Denylist { deny } => {
123                if deny.iter().any(|t| t == tool_name) {
124                    Err(format!(
125                        "tool '{tool_name}' is denied for subagent_type \
126                         '{subagent_type}' (denylist policy: {deny:?})"
127                    ))
128                } else {
129                    Ok(())
130                }
131            }
132        }
133    }
134
135    /// Resolve the policy for a tool call: returns `Ok(())` when the call
136    /// should proceed, or an `Err` payload to be surfaced as a
137    /// `ToolError::Execution`. When no policy can be resolved the call is
138    /// allowed (legacy / inherit behaviour).
139    async fn evaluate(&self, call: &ToolCall, session_id: Option<&str>) -> Result<(), String> {
140        let Some(session_id) = session_id else {
141            return Ok(());
142        };
143        let Some(subagent_type) = self.subagent_type_for_session(session_id).await else {
144            return Ok(());
145        };
146        let profile = self.profiles.resolve(&subagent_type);
147        Self::check_policy(&profile.tools, call.function.name.trim(), &subagent_type)
148    }
149}
150
151#[async_trait]
152impl ToolExecutor for PolicyAwareToolExecutor {
153    async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
154        // No session context available via plain `execute` — fall through.
155        // The agent loop calls `execute_with_context`; this branch is only
156        // hit by direct callers that don't carry session metadata, in
157        // which case we mirror legacy behaviour exactly.
158        self.inner.execute(call).await
159    }
160
161    async fn execute_with_context(
162        &self,
163        call: &ToolCall,
164        ctx: ToolExecutionContext<'_>,
165    ) -> std::result::Result<ToolResult, ToolError> {
166        if let Err(reason) = self.evaluate(call, ctx.session_id).await {
167            return Err(ToolError::Execution(reason));
168        }
169        self.inner.execute_with_context(call, ctx).await
170    }
171
172    fn list_tools(&self) -> Vec<ToolSchema> {
173        // Discovery is intentionally not filtered here; see module docs.
174        self.inner.list_tools()
175    }
176
177    fn tool_mutability(&self, tool_name: &str) -> bamboo_agent_core::tools::ToolMutability {
178        self.inner.tool_mutability(tool_name)
179    }
180
181    fn call_mutability(&self, call: &ToolCall) -> bamboo_agent_core::tools::ToolMutability {
182        self.inner.call_mutability(call)
183    }
184
185    fn tool_concurrency_safe(&self, tool_name: &str) -> bool {
186        self.inner.tool_concurrency_safe(tool_name)
187    }
188
189    fn call_concurrency_safe(&self, call: &ToolCall) -> bool {
190        self.inner.call_concurrency_safe(call)
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use bamboo_agent_core::tools::{FunctionCall, ToolMutability};
198    use bamboo_domain::subagent::SubagentProfile;
199    use tokio::sync::RwLock;
200
201    /// Tiny stand-in executor that records every name it was asked to
202    /// execute and always succeeds with a fixed payload. We use it to
203    /// assert "forwarded vs blocked" without spinning up the full builtin
204    /// tool surface.
205    struct RecordingExecutor {
206        executed: Arc<RwLock<Vec<String>>>,
207    }
208
209    impl RecordingExecutor {
210        fn new() -> (Arc<Self>, Arc<RwLock<Vec<String>>>) {
211            let executed = Arc::new(RwLock::new(Vec::new()));
212            let exec = Arc::new(Self {
213                executed: executed.clone(),
214            });
215            (exec, executed)
216        }
217    }
218
219    #[async_trait]
220    impl ToolExecutor for RecordingExecutor {
221        async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
222            self.executed.write().await.push(call.function.name.clone());
223            Ok(ToolResult {
224                success: true,
225                result: "ok".to_string(),
226                display_preference: None,
227                images: Vec::new(),
228            })
229        }
230
231        fn list_tools(&self) -> Vec<ToolSchema> {
232            Vec::new()
233        }
234
235        fn tool_mutability(&self, _tool_name: &str) -> ToolMutability {
236            ToolMutability::ReadOnly
237        }
238    }
239
240    fn make_call(name: &str) -> ToolCall {
241        ToolCall {
242            id: "call_1".to_string(),
243            tool_type: "function".to_string(),
244            function: FunctionCall {
245                name: name.to_string(),
246                arguments: "{}".to_string(),
247            },
248        }
249    }
250
251    fn registry_with(profile: SubagentProfile) -> Arc<SubagentProfileRegistry> {
252        // The registry requires the fallback id to exist in the profile
253        // set. Use the profile's own id so each test can declare just one
254        // profile and still build a valid registry.
255        let id = profile.id.clone();
256        Arc::new(
257            SubagentProfileRegistry::builder()
258                .extend(vec![profile])
259                .fallback_id(id)
260                .build()
261                .expect("registry build"),
262        )
263    }
264
265    fn profile(id: &str, tools: ToolPolicy) -> SubagentProfile {
266        SubagentProfile {
267            id: id.to_string(),
268            display_name: id.to_string(),
269            description: String::new(),
270            system_prompt: "p".to_string(),
271            tools,
272            model_hint: None,
273            default_responsibility: None,
274            ui: Default::default(),
275        }
276    }
277
278    async fn sessions_with(session_id: &str, subagent_type: Option<&str>) -> SessionCache {
279        let map = dashmap::DashMap::new();
280        let mut session = Session::new_child(session_id, "root", "test-model", "Child");
281        if let Some(t) = subagent_type {
282            session
283                .metadata
284                .insert("subagent_type".to_string(), t.to_string());
285        }
286        map.insert(
287            session_id.to_string(),
288            Arc::new(parking_lot::RwLock::new(session)),
289        );
290        Arc::new(map)
291    }
292
293    #[tokio::test]
294    async fn inherit_policy_forwards_all_calls() {
295        let (inner, executed) = RecordingExecutor::new();
296        let registry = registry_with(profile("test", ToolPolicy::Inherit));
297        let sessions = sessions_with("s1", Some("test")).await;
298        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
299
300        let call = make_call("Read");
301        let ctx = ToolExecutionContext {
302            session_id: Some("s1"),
303            tool_call_id: "call_1",
304            event_tx: None,
305            available_tool_schemas: None,
306        };
307        exec.execute_with_context(&call, ctx).await.unwrap();
308        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
309    }
310
311    #[tokio::test]
312    async fn allowlist_permits_listed_tool() {
313        let (inner, executed) = RecordingExecutor::new();
314        let registry = registry_with(profile(
315            "researcher",
316            ToolPolicy::Allowlist {
317                allow: vec!["Read".to_string(), "Grep".to_string()],
318            },
319        ));
320        let sessions = sessions_with("s1", Some("researcher")).await;
321        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
322
323        let ctx = ToolExecutionContext {
324            session_id: Some("s1"),
325            tool_call_id: "call_1",
326            event_tx: None,
327            available_tool_schemas: None,
328        };
329        exec.execute_with_context(&make_call("Read"), ctx)
330            .await
331            .unwrap();
332        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
333    }
334
335    #[tokio::test]
336    async fn allowlist_blocks_unlisted_tool() {
337        let (inner, executed) = RecordingExecutor::new();
338        let registry = registry_with(profile(
339            "researcher",
340            ToolPolicy::Allowlist {
341                allow: vec!["Read".to_string()],
342            },
343        ));
344        let sessions = sessions_with("s1", Some("researcher")).await;
345        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
346
347        let ctx = ToolExecutionContext {
348            session_id: Some("s1"),
349            tool_call_id: "call_1",
350            event_tx: None,
351            available_tool_schemas: None,
352        };
353        let err = exec
354            .execute_with_context(&make_call("Edit"), ctx)
355            .await
356            .unwrap_err();
357        match err {
358            ToolError::Execution(msg) => {
359                assert!(msg.contains("Edit"), "msg should name tool: {msg}");
360                assert!(
361                    msg.contains("researcher"),
362                    "msg should name subagent_type: {msg}"
363                );
364                assert!(msg.contains("allowlist"), "msg should name mode: {msg}");
365            }
366            other => panic!("expected ToolError::Execution, got {other:?}"),
367        }
368        assert!(executed.read().await.is_empty());
369    }
370
371    #[tokio::test]
372    async fn denylist_blocks_listed_tool() {
373        let (inner, executed) = RecordingExecutor::new();
374        let registry = registry_with(profile(
375            "coder",
376            ToolPolicy::Denylist {
377                deny: vec!["SubAgent".to_string()],
378            },
379        ));
380        let sessions = sessions_with("s1", Some("coder")).await;
381        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
382
383        let ctx = ToolExecutionContext {
384            session_id: Some("s1"),
385            tool_call_id: "call_1",
386            event_tx: None,
387            available_tool_schemas: None,
388        };
389        let err = exec
390            .execute_with_context(&make_call("SubAgent"), ctx)
391            .await
392            .unwrap_err();
393        match err {
394            ToolError::Execution(msg) => {
395                assert!(msg.contains("SubAgent"));
396                assert!(msg.contains("denylist"));
397            }
398            other => panic!("expected ToolError::Execution, got {other:?}"),
399        }
400        assert!(executed.read().await.is_empty());
401    }
402
403    #[tokio::test]
404    async fn denylist_permits_unlisted_tool() {
405        let (inner, executed) = RecordingExecutor::new();
406        let registry = registry_with(profile(
407            "coder",
408            ToolPolicy::Denylist {
409                deny: vec!["SubAgent".to_string()],
410            },
411        ));
412        let sessions = sessions_with("s1", Some("coder")).await;
413        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
414
415        let ctx = ToolExecutionContext {
416            session_id: Some("s1"),
417            tool_call_id: "call_1",
418            event_tx: None,
419            available_tool_schemas: None,
420        };
421        exec.execute_with_context(&make_call("Read"), ctx)
422            .await
423            .unwrap();
424        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
425    }
426
427    #[tokio::test]
428    async fn missing_session_id_falls_through() {
429        // No session_id in context → wrapper must not reject. This is the
430        // "legacy / direct caller" path and must keep working.
431        let (inner, executed) = RecordingExecutor::new();
432        let registry = registry_with(profile(
433            "researcher",
434            ToolPolicy::Allowlist {
435                allow: vec!["Read".to_string()],
436            },
437        ));
438        let sessions = sessions_with("s1", Some("researcher")).await;
439        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
440
441        let ctx = ToolExecutionContext::none("call_1");
442        exec.execute_with_context(&make_call("Edit"), ctx)
443            .await
444            .unwrap();
445        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
446    }
447
448    #[tokio::test]
449    async fn unknown_session_falls_through() {
450        let (inner, executed) = RecordingExecutor::new();
451        let registry = registry_with(profile(
452            "researcher",
453            ToolPolicy::Allowlist {
454                allow: vec!["Read".to_string()],
455            },
456        ));
457        // Cache contains a different session id than the one referenced.
458        let sessions = sessions_with("other", Some("researcher")).await;
459        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
460
461        let ctx = ToolExecutionContext {
462            session_id: Some("missing"),
463            tool_call_id: "call_1",
464            event_tx: None,
465            available_tool_schemas: None,
466        };
467        exec.execute_with_context(&make_call("Edit"), ctx)
468            .await
469            .unwrap();
470        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
471    }
472
473    #[tokio::test]
474    async fn missing_subagent_type_metadata_falls_through() {
475        let (inner, executed) = RecordingExecutor::new();
476        let registry = registry_with(profile(
477            "researcher",
478            ToolPolicy::Allowlist {
479                allow: vec!["Read".to_string()],
480            },
481        ));
482        // Session is cached but has no subagent_type metadata.
483        let sessions = sessions_with("s1", None).await;
484        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
485
486        let ctx = ToolExecutionContext {
487            session_id: Some("s1"),
488            tool_call_id: "call_1",
489            event_tx: None,
490            available_tool_schemas: None,
491        };
492        exec.execute_with_context(&make_call("Edit"), ctx)
493            .await
494            .unwrap();
495        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
496    }
497
498    #[tokio::test]
499    async fn execute_without_context_forwards() {
500        // The plain `execute` path has no session context. We document this
501        // by forwarding unconditionally.
502        let (inner, executed) = RecordingExecutor::new();
503        let registry = registry_with(profile(
504            "researcher",
505            ToolPolicy::Allowlist {
506                allow: vec!["Read".to_string()],
507            },
508        ));
509        let sessions = sessions_with("s1", Some("researcher")).await;
510        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
511
512        exec.execute(&make_call("Edit")).await.unwrap();
513        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
514    }
515}