Skip to main content

bamboo_tools/
policy_aware.rs

1//! `PolicyAwareToolExecutor`: enforces a subagent profile's [`ToolPolicy`]
2//! at tool-call time.
3//!
4//! This module lives in `bamboo-tools` (it depends only on `bamboo-agent-core`
5//! and `bamboo-domain`); `bamboo-server` re-exports it through a thin shim
6//! (`crate::tools::PolicyAwareToolExecutor`) for back-compat.
7//!
8//! This is the runtime half of the SubagentProfile feature. Profile loading
9//! and prompt injection live in `bamboo_engine::profiles` and
10//! `bamboo_engine::session_app::child_session`; here we wrap the child-session
11//! [`ToolExecutor`] so that tool calls are filtered against the calling
12//! child's `subagent_type` metadata.
13//!
14//! ## Why a wrapper instead of changing `SpawnContext`
15//!
16//! `bamboo-engine`'s `SpawnContext` carries a single `Arc<dyn ToolExecutor>`
17//! shared by every child session. Different child sessions can have
18//! different `subagent_type`s and therefore different
19//! [`ToolPolicy`]s. Rather than changing the engine to dispatch executors
20//! per session (which would ripple into `SpawnContext`, `SpawnScheduler`
21//! and `ScheduleManager`), this wrapper sits in front of the existing
22//! executor and resolves the policy on the fly using
23//! [`ToolExecutionContext::session_id`] and the in-memory sessions cache
24//! that the server already maintains.
25//!
26//! ## Behaviour summary
27//!
28//! - Tool _execution_ is filtered:
29//!   - [`ToolPolicy::Inherit`] → forwarded unchanged.
30//!   - [`ToolPolicy::Allowlist`] → forwarded only if the tool name is on
31//!     the allow list; otherwise rejected with a clear
32//!     `ToolError::Execution` message.
33//!   - [`ToolPolicy::Denylist`] → rejected if the tool name is on the
34//!     deny list; otherwise forwarded.
35//! - Tool _discovery_ (`list_tools`) is **not** filtered here. The
36//!   advertised tool surface is filtered upstream via the engine's
37//!   `disabled_tools` mechanism, which includes the subagent profile
38//!   policy (see `ChildSessionAdapter::enqueue_child_run`). This wrapper
39//!   remains as a safety net for execution-time enforcement.
40//! - When the wrapper cannot associate the call with a subagent profile (no
41//!   `session_id`, the session is not in cache, or no `subagent_type`
42//!   metadata), it forwards unchanged. This keeps the change strictly
43//!   additive: any existing call path that has not yet adopted subagent
44//!   profiles continues to behave exactly as before.
45//! - When a `subagent_type` *is* present but unrecognized,
46//!   [`SubagentProfileRegistry::resolve`] returns the registry's fallback
47//!   profile and that profile's policy is enforced. With the default
48//!   `general-purpose` (Inherit) fallback this still forwards unchanged, but a
49//!   restrictively-configured fallback profile would apply to unknown types.
50
51use std::collections::HashMap;
52use std::sync::Arc;
53
54use async_trait::async_trait;
55use bamboo_agent_core::tools::{
56    ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
57};
58use bamboo_agent_core::Session;
59use bamboo_domain::subagent::{SubagentProfileRegistry, ToolPolicy};
60use tokio::sync::RwLock;
61
62/// Tool executor that enforces a subagent profile's [`ToolPolicy`] when
63/// executing tool calls from a child session.
64pub struct PolicyAwareToolExecutor {
65    inner: Arc<dyn ToolExecutor>,
66    profiles: Arc<SubagentProfileRegistry>,
67    sessions: Arc<RwLock<HashMap<String, Session>>>,
68}
69
70impl PolicyAwareToolExecutor {
71    pub fn new(
72        inner: Arc<dyn ToolExecutor>,
73        profiles: Arc<SubagentProfileRegistry>,
74        sessions: Arc<RwLock<HashMap<String, Session>>>,
75    ) -> Self {
76        Self {
77            inner,
78            profiles,
79            sessions,
80        }
81    }
82
83    /// Look up the `subagent_type` metadata for a session id from the
84    /// in-memory cache. Returns `None` when the session is not cached or
85    /// the metadata key is missing / blank.
86    async fn subagent_type_for_session(&self, session_id: &str) -> Option<String> {
87        let sessions = self.sessions.read().await;
88        let value = sessions
89            .get(session_id)?
90            .metadata
91            .get("subagent_type")?
92            .trim();
93        if value.is_empty() {
94            None
95        } else {
96            Some(value.to_string())
97        }
98    }
99
100    /// Check whether a tool call is permitted under the given policy.
101    /// Returns `Ok(())` when allowed, `Err(reason)` when blocked.
102    fn check_policy(
103        policy: &ToolPolicy,
104        tool_name: &str,
105        subagent_type: &str,
106    ) -> Result<(), String> {
107        match policy {
108            ToolPolicy::Inherit => Ok(()),
109            ToolPolicy::Allowlist { allow } => {
110                if allow.iter().any(|t| t == tool_name) {
111                    Ok(())
112                } else {
113                    Err(format!(
114                        "tool '{tool_name}' is not permitted for subagent_type \
115                         '{subagent_type}' (allowlist policy: {allow:?})"
116                    ))
117                }
118            }
119            ToolPolicy::Denylist { deny } => {
120                if deny.iter().any(|t| t == tool_name) {
121                    Err(format!(
122                        "tool '{tool_name}' is denied for subagent_type \
123                         '{subagent_type}' (denylist policy: {deny:?})"
124                    ))
125                } else {
126                    Ok(())
127                }
128            }
129        }
130    }
131
132    /// Resolve the policy for a tool call: returns `Ok(())` when the call
133    /// should proceed, or an `Err` payload to be surfaced as a
134    /// `ToolError::Execution`. When no policy can be resolved the call is
135    /// allowed (legacy / inherit behaviour).
136    async fn evaluate(&self, call: &ToolCall, session_id: Option<&str>) -> Result<(), String> {
137        let Some(session_id) = session_id else {
138            return Ok(());
139        };
140        let Some(subagent_type) = self.subagent_type_for_session(session_id).await else {
141            return Ok(());
142        };
143        let profile = self.profiles.resolve(&subagent_type);
144        Self::check_policy(&profile.tools, call.function.name.trim(), &subagent_type)
145    }
146}
147
148#[async_trait]
149impl ToolExecutor for PolicyAwareToolExecutor {
150    async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
151        // No session context available via plain `execute` — fall through.
152        // The agent loop calls `execute_with_context`; this branch is only
153        // hit by direct callers that don't carry session metadata, in
154        // which case we mirror legacy behaviour exactly.
155        self.inner.execute(call).await
156    }
157
158    async fn execute_with_context(
159        &self,
160        call: &ToolCall,
161        ctx: ToolExecutionContext<'_>,
162    ) -> std::result::Result<ToolResult, ToolError> {
163        if let Err(reason) = self.evaluate(call, ctx.session_id).await {
164            return Err(ToolError::Execution(reason));
165        }
166        self.inner.execute_with_context(call, ctx).await
167    }
168
169    fn list_tools(&self) -> Vec<ToolSchema> {
170        // Discovery is intentionally not filtered here; see module docs.
171        self.inner.list_tools()
172    }
173
174    fn tool_mutability(&self, tool_name: &str) -> bamboo_agent_core::tools::ToolMutability {
175        self.inner.tool_mutability(tool_name)
176    }
177
178    fn call_mutability(&self, call: &ToolCall) -> bamboo_agent_core::tools::ToolMutability {
179        self.inner.call_mutability(call)
180    }
181
182    fn tool_concurrency_safe(&self, tool_name: &str) -> bool {
183        self.inner.tool_concurrency_safe(tool_name)
184    }
185
186    fn call_concurrency_safe(&self, call: &ToolCall) -> bool {
187        self.inner.call_concurrency_safe(call)
188    }
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use bamboo_agent_core::tools::{FunctionCall, ToolMutability};
195    use bamboo_domain::subagent::SubagentProfile;
196
197    /// Tiny stand-in executor that records every name it was asked to
198    /// execute and always succeeds with a fixed payload. We use it to
199    /// assert "forwarded vs blocked" without spinning up the full builtin
200    /// tool surface.
201    struct RecordingExecutor {
202        executed: Arc<RwLock<Vec<String>>>,
203    }
204
205    impl RecordingExecutor {
206        fn new() -> (Arc<Self>, Arc<RwLock<Vec<String>>>) {
207            let executed = Arc::new(RwLock::new(Vec::new()));
208            let exec = Arc::new(Self {
209                executed: executed.clone(),
210            });
211            (exec, executed)
212        }
213    }
214
215    #[async_trait]
216    impl ToolExecutor for RecordingExecutor {
217        async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
218            self.executed.write().await.push(call.function.name.clone());
219            Ok(ToolResult {
220                success: true,
221                result: "ok".to_string(),
222                display_preference: None,
223            })
224        }
225
226        fn list_tools(&self) -> Vec<ToolSchema> {
227            Vec::new()
228        }
229
230        fn tool_mutability(&self, _tool_name: &str) -> ToolMutability {
231            ToolMutability::ReadOnly
232        }
233    }
234
235    fn make_call(name: &str) -> ToolCall {
236        ToolCall {
237            id: "call_1".to_string(),
238            tool_type: "function".to_string(),
239            function: FunctionCall {
240                name: name.to_string(),
241                arguments: "{}".to_string(),
242            },
243        }
244    }
245
246    fn registry_with(profile: SubagentProfile) -> Arc<SubagentProfileRegistry> {
247        // The registry requires the fallback id to exist in the profile
248        // set. Use the profile's own id so each test can declare just one
249        // profile and still build a valid registry.
250        let id = profile.id.clone();
251        Arc::new(
252            SubagentProfileRegistry::builder()
253                .extend(vec![profile])
254                .fallback_id(id)
255                .build()
256                .expect("registry build"),
257        )
258    }
259
260    fn profile(id: &str, tools: ToolPolicy) -> SubagentProfile {
261        SubagentProfile {
262            id: id.to_string(),
263            display_name: id.to_string(),
264            description: String::new(),
265            system_prompt: "p".to_string(),
266            tools,
267            model_hint: None,
268            default_responsibility: None,
269            ui: Default::default(),
270        }
271    }
272
273    async fn sessions_with(
274        session_id: &str,
275        subagent_type: Option<&str>,
276    ) -> Arc<RwLock<HashMap<String, Session>>> {
277        let mut map = HashMap::new();
278        let mut session = Session::new_child(session_id, "root", "test-model", "Child");
279        if let Some(t) = subagent_type {
280            session
281                .metadata
282                .insert("subagent_type".to_string(), t.to_string());
283        }
284        map.insert(session_id.to_string(), session);
285        Arc::new(RwLock::new(map))
286    }
287
288    #[tokio::test]
289    async fn inherit_policy_forwards_all_calls() {
290        let (inner, executed) = RecordingExecutor::new();
291        let registry = registry_with(profile("test", ToolPolicy::Inherit));
292        let sessions = sessions_with("s1", Some("test")).await;
293        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
294
295        let call = make_call("Read");
296        let ctx = ToolExecutionContext {
297            session_id: Some("s1"),
298            tool_call_id: "call_1",
299            event_tx: None,
300            available_tool_schemas: None,
301        };
302        exec.execute_with_context(&call, ctx).await.unwrap();
303        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
304    }
305
306    #[tokio::test]
307    async fn allowlist_permits_listed_tool() {
308        let (inner, executed) = RecordingExecutor::new();
309        let registry = registry_with(profile(
310            "researcher",
311            ToolPolicy::Allowlist {
312                allow: vec!["Read".to_string(), "Grep".to_string()],
313            },
314        ));
315        let sessions = sessions_with("s1", Some("researcher")).await;
316        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
317
318        let ctx = ToolExecutionContext {
319            session_id: Some("s1"),
320            tool_call_id: "call_1",
321            event_tx: None,
322            available_tool_schemas: None,
323        };
324        exec.execute_with_context(&make_call("Read"), ctx)
325            .await
326            .unwrap();
327        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
328    }
329
330    #[tokio::test]
331    async fn allowlist_blocks_unlisted_tool() {
332        let (inner, executed) = RecordingExecutor::new();
333        let registry = registry_with(profile(
334            "researcher",
335            ToolPolicy::Allowlist {
336                allow: vec!["Read".to_string()],
337            },
338        ));
339        let sessions = sessions_with("s1", Some("researcher")).await;
340        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
341
342        let ctx = ToolExecutionContext {
343            session_id: Some("s1"),
344            tool_call_id: "call_1",
345            event_tx: None,
346            available_tool_schemas: None,
347        };
348        let err = exec
349            .execute_with_context(&make_call("Edit"), ctx)
350            .await
351            .unwrap_err();
352        match err {
353            ToolError::Execution(msg) => {
354                assert!(msg.contains("Edit"), "msg should name tool: {msg}");
355                assert!(
356                    msg.contains("researcher"),
357                    "msg should name subagent_type: {msg}"
358                );
359                assert!(msg.contains("allowlist"), "msg should name mode: {msg}");
360            }
361            other => panic!("expected ToolError::Execution, got {other:?}"),
362        }
363        assert!(executed.read().await.is_empty());
364    }
365
366    #[tokio::test]
367    async fn denylist_blocks_listed_tool() {
368        let (inner, executed) = RecordingExecutor::new();
369        let registry = registry_with(profile(
370            "coder",
371            ToolPolicy::Denylist {
372                deny: vec!["SubAgent".to_string()],
373            },
374        ));
375        let sessions = sessions_with("s1", Some("coder")).await;
376        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
377
378        let ctx = ToolExecutionContext {
379            session_id: Some("s1"),
380            tool_call_id: "call_1",
381            event_tx: None,
382            available_tool_schemas: None,
383        };
384        let err = exec
385            .execute_with_context(&make_call("SubAgent"), ctx)
386            .await
387            .unwrap_err();
388        match err {
389            ToolError::Execution(msg) => {
390                assert!(msg.contains("SubAgent"));
391                assert!(msg.contains("denylist"));
392            }
393            other => panic!("expected ToolError::Execution, got {other:?}"),
394        }
395        assert!(executed.read().await.is_empty());
396    }
397
398    #[tokio::test]
399    async fn denylist_permits_unlisted_tool() {
400        let (inner, executed) = RecordingExecutor::new();
401        let registry = registry_with(profile(
402            "coder",
403            ToolPolicy::Denylist {
404                deny: vec!["SubAgent".to_string()],
405            },
406        ));
407        let sessions = sessions_with("s1", Some("coder")).await;
408        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
409
410        let ctx = ToolExecutionContext {
411            session_id: Some("s1"),
412            tool_call_id: "call_1",
413            event_tx: None,
414            available_tool_schemas: None,
415        };
416        exec.execute_with_context(&make_call("Read"), ctx)
417            .await
418            .unwrap();
419        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
420    }
421
422    #[tokio::test]
423    async fn missing_session_id_falls_through() {
424        // No session_id in context → wrapper must not reject. This is the
425        // "legacy / direct caller" path and must keep working.
426        let (inner, executed) = RecordingExecutor::new();
427        let registry = registry_with(profile(
428            "researcher",
429            ToolPolicy::Allowlist {
430                allow: vec!["Read".to_string()],
431            },
432        ));
433        let sessions = sessions_with("s1", Some("researcher")).await;
434        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
435
436        let ctx = ToolExecutionContext::none("call_1");
437        exec.execute_with_context(&make_call("Edit"), ctx)
438            .await
439            .unwrap();
440        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
441    }
442
443    #[tokio::test]
444    async fn unknown_session_falls_through() {
445        let (inner, executed) = RecordingExecutor::new();
446        let registry = registry_with(profile(
447            "researcher",
448            ToolPolicy::Allowlist {
449                allow: vec!["Read".to_string()],
450            },
451        ));
452        // Cache contains a different session id than the one referenced.
453        let sessions = sessions_with("other", Some("researcher")).await;
454        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
455
456        let ctx = ToolExecutionContext {
457            session_id: Some("missing"),
458            tool_call_id: "call_1",
459            event_tx: None,
460            available_tool_schemas: None,
461        };
462        exec.execute_with_context(&make_call("Edit"), ctx)
463            .await
464            .unwrap();
465        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
466    }
467
468    #[tokio::test]
469    async fn missing_subagent_type_metadata_falls_through() {
470        let (inner, executed) = RecordingExecutor::new();
471        let registry = registry_with(profile(
472            "researcher",
473            ToolPolicy::Allowlist {
474                allow: vec!["Read".to_string()],
475            },
476        ));
477        // Session is cached but has no subagent_type metadata.
478        let sessions = sessions_with("s1", None).await;
479        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
480
481        let ctx = ToolExecutionContext {
482            session_id: Some("s1"),
483            tool_call_id: "call_1",
484            event_tx: None,
485            available_tool_schemas: None,
486        };
487        exec.execute_with_context(&make_call("Edit"), ctx)
488            .await
489            .unwrap();
490        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
491    }
492
493    #[tokio::test]
494    async fn execute_without_context_forwards() {
495        // The plain `execute` path has no session context. We document this
496        // by forwarding unconditionally.
497        let (inner, executed) = RecordingExecutor::new();
498        let registry = registry_with(profile(
499            "researcher",
500            ToolPolicy::Allowlist {
501                allow: vec!["Read".to_string()],
502            },
503        ));
504        let sessions = sessions_with("s1", Some("researcher")).await;
505        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
506
507        exec.execute(&make_call("Edit")).await.unwrap();
508        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
509    }
510}