Skip to main content

bamboo_server/tools/
policy_aware.rs

1//! `PolicyAwareToolExecutor`: enforces a subagent profile's [`ToolPolicy`]
2//! at tool-call time.
3//!
4//! This is the runtime half of the SubagentProfile feature. Profile loading
5//! and prompt injection live in `subagent_profiles` and
6//! `session_app::child_session`; here we wrap the child-session
7//! [`ToolExecutor`] so that tool calls are filtered against the calling
8//! child's `subagent_type` metadata.
9//!
10//! ## Why a wrapper instead of changing `SpawnContext`
11//!
12//! `bamboo-engine`'s `SpawnContext` carries a single `Arc<dyn ToolExecutor>`
13//! shared by every child session. Different child sessions can have
14//! different `subagent_type`s and therefore different
15//! [`ToolPolicy`]s. Rather than changing the engine to dispatch executors
16//! per session (which would ripple into `SpawnContext`, `SpawnScheduler`
17//! and `ScheduleManager`), this wrapper sits in front of the existing
18//! executor and resolves the policy on the fly using
19//! [`ToolExecutionContext::session_id`] and the in-memory sessions cache
20//! that the server already maintains.
21//!
22//! ## Behaviour summary
23//!
24//! - Tool _execution_ is filtered:
25//!   - [`ToolPolicy::Inherit`] → forwarded unchanged.
26//!   - [`ToolPolicy::Allowlist`] → forwarded only if the tool name is on
27//!     the allow list; otherwise rejected with a clear
28//!     `ToolError::Execution` message.
29//!   - [`ToolPolicy::Denylist`] → rejected if the tool name is on the
30//!     deny list; otherwise forwarded.
31//! - Tool _discovery_ (`list_tools`) is **not** filtered here. The
32//!   advertised tool surface is per-session and is filtered upstream via
33//!   the engine's `disabled_tools` mechanism (a later PR will wire the
34//!   profile policy into that list as well). Filtering at the wrapper
35//!   level would over-restrict callers that share the wrapper but are not
36//!   actually subject to a child policy.
37//! - When the wrapper cannot determine a policy (no `session_id`, the
38//!   session is not in cache, no `subagent_type` metadata, or the
39//!   profile is unknown), it forwards unchanged. This keeps the change
40//!   strictly additive: any existing call path that has not yet adopted
41//!   subagent profiles continues to behave exactly as before.
42
43use std::collections::HashMap;
44use std::sync::Arc;
45
46use async_trait::async_trait;
47use bamboo_agent_core::tools::{
48    ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
49};
50use bamboo_agent_core::Session;
51use bamboo_domain::subagent::{SubagentProfileRegistry, ToolPolicy};
52use tokio::sync::RwLock;
53
54/// Tool executor that enforces a subagent profile's [`ToolPolicy`] when
55/// executing tool calls from a child session.
56pub struct PolicyAwareToolExecutor {
57    inner: Arc<dyn ToolExecutor>,
58    profiles: Arc<SubagentProfileRegistry>,
59    sessions: Arc<RwLock<HashMap<String, Session>>>,
60}
61
62impl PolicyAwareToolExecutor {
63    pub fn new(
64        inner: Arc<dyn ToolExecutor>,
65        profiles: Arc<SubagentProfileRegistry>,
66        sessions: Arc<RwLock<HashMap<String, Session>>>,
67    ) -> Self {
68        Self {
69            inner,
70            profiles,
71            sessions,
72        }
73    }
74
75    /// Look up the `subagent_type` metadata for a session id from the
76    /// in-memory cache. Returns `None` when the session is not cached or
77    /// the metadata key is missing / blank.
78    async fn subagent_type_for_session(&self, session_id: &str) -> Option<String> {
79        let sessions = self.sessions.read().await;
80        let value = sessions
81            .get(session_id)?
82            .metadata
83            .get("subagent_type")?
84            .trim();
85        if value.is_empty() {
86            None
87        } else {
88            Some(value.to_string())
89        }
90    }
91
92    /// Check whether a tool call is permitted under the given policy.
93    /// Returns `Ok(())` when allowed, `Err(reason)` when blocked.
94    fn check_policy(
95        policy: &ToolPolicy,
96        tool_name: &str,
97        subagent_type: &str,
98    ) -> Result<(), String> {
99        match policy {
100            ToolPolicy::Inherit => Ok(()),
101            ToolPolicy::Allowlist { allow } => {
102                if allow.iter().any(|t| t == tool_name) {
103                    Ok(())
104                } else {
105                    Err(format!(
106                        "tool '{tool_name}' is not permitted for subagent_type \
107                         '{subagent_type}' (allowlist policy: {allow:?})"
108                    ))
109                }
110            }
111            ToolPolicy::Denylist { deny } => {
112                if deny.iter().any(|t| t == tool_name) {
113                    Err(format!(
114                        "tool '{tool_name}' is denied for subagent_type \
115                         '{subagent_type}' (denylist policy: {deny:?})"
116                    ))
117                } else {
118                    Ok(())
119                }
120            }
121        }
122    }
123
124    /// Resolve the policy for a tool call: returns `Ok(())` when the call
125    /// should proceed, or an `Err` payload to be surfaced as a
126    /// `ToolError::Execution`. When no policy can be resolved the call is
127    /// allowed (legacy / inherit behaviour).
128    async fn evaluate(&self, call: &ToolCall, session_id: Option<&str>) -> Result<(), String> {
129        let Some(session_id) = session_id else {
130            return Ok(());
131        };
132        let Some(subagent_type) = self.subagent_type_for_session(session_id).await else {
133            return Ok(());
134        };
135        let profile = self.profiles.resolve(&subagent_type);
136        Self::check_policy(&profile.tools, call.function.name.trim(), &subagent_type)
137    }
138}
139
140#[async_trait]
141impl ToolExecutor for PolicyAwareToolExecutor {
142    async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
143        // No session context available via plain `execute` — fall through.
144        // The agent loop calls `execute_with_context`; this branch is only
145        // hit by direct callers that don't carry session metadata, in
146        // which case we mirror legacy behaviour exactly.
147        self.inner.execute(call).await
148    }
149
150    async fn execute_with_context(
151        &self,
152        call: &ToolCall,
153        ctx: ToolExecutionContext<'_>,
154    ) -> std::result::Result<ToolResult, ToolError> {
155        if let Err(reason) = self.evaluate(call, ctx.session_id).await {
156            return Err(ToolError::Execution(reason));
157        }
158        self.inner.execute_with_context(call, ctx).await
159    }
160
161    fn list_tools(&self) -> Vec<ToolSchema> {
162        // Discovery is intentionally not filtered here; see module docs.
163        self.inner.list_tools()
164    }
165
166    fn tool_mutability(&self, tool_name: &str) -> bamboo_agent_core::tools::ToolMutability {
167        self.inner.tool_mutability(tool_name)
168    }
169
170    fn call_mutability(&self, call: &ToolCall) -> bamboo_agent_core::tools::ToolMutability {
171        self.inner.call_mutability(call)
172    }
173
174    fn tool_concurrency_safe(&self, tool_name: &str) -> bool {
175        self.inner.tool_concurrency_safe(tool_name)
176    }
177
178    fn call_concurrency_safe(&self, call: &ToolCall) -> bool {
179        self.inner.call_concurrency_safe(call)
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use bamboo_agent_core::tools::{FunctionCall, ToolMutability};
187    use bamboo_domain::subagent::SubagentProfile;
188
189    /// Tiny stand-in executor that records every name it was asked to
190    /// execute and always succeeds with a fixed payload. We use it to
191    /// assert "forwarded vs blocked" without spinning up the full builtin
192    /// tool surface.
193    struct RecordingExecutor {
194        executed: Arc<RwLock<Vec<String>>>,
195    }
196
197    impl RecordingExecutor {
198        fn new() -> (Arc<Self>, Arc<RwLock<Vec<String>>>) {
199            let executed = Arc::new(RwLock::new(Vec::new()));
200            let exec = Arc::new(Self {
201                executed: executed.clone(),
202            });
203            (exec, executed)
204        }
205    }
206
207    #[async_trait]
208    impl ToolExecutor for RecordingExecutor {
209        async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
210            self.executed.write().await.push(call.function.name.clone());
211            Ok(ToolResult {
212                success: true,
213                result: "ok".to_string(),
214                display_preference: None,
215            })
216        }
217
218        fn list_tools(&self) -> Vec<ToolSchema> {
219            Vec::new()
220        }
221
222        fn tool_mutability(&self, _tool_name: &str) -> ToolMutability {
223            ToolMutability::ReadOnly
224        }
225    }
226
227    fn make_call(name: &str) -> ToolCall {
228        ToolCall {
229            id: "call_1".to_string(),
230            tool_type: "function".to_string(),
231            function: FunctionCall {
232                name: name.to_string(),
233                arguments: "{}".to_string(),
234            },
235        }
236    }
237
238    fn registry_with(profile: SubagentProfile) -> Arc<SubagentProfileRegistry> {
239        // The registry requires the fallback id to exist in the profile
240        // set. Use the profile's own id so each test can declare just one
241        // profile and still build a valid registry.
242        let id = profile.id.clone();
243        Arc::new(
244            SubagentProfileRegistry::builder()
245                .extend(vec![profile])
246                .fallback_id(id)
247                .build()
248                .expect("registry build"),
249        )
250    }
251
252    fn profile(id: &str, tools: ToolPolicy) -> SubagentProfile {
253        SubagentProfile {
254            id: id.to_string(),
255            display_name: id.to_string(),
256            description: String::new(),
257            system_prompt: "p".to_string(),
258            tools,
259            model_hint: None,
260            default_responsibility: None,
261            ui: Default::default(),
262        }
263    }
264
265    async fn sessions_with(
266        session_id: &str,
267        subagent_type: Option<&str>,
268    ) -> Arc<RwLock<HashMap<String, Session>>> {
269        let mut map = HashMap::new();
270        let mut session = Session::new_child(session_id, "root", "test-model", "Child");
271        if let Some(t) = subagent_type {
272            session
273                .metadata
274                .insert("subagent_type".to_string(), t.to_string());
275        }
276        map.insert(session_id.to_string(), session);
277        Arc::new(RwLock::new(map))
278    }
279
280    #[tokio::test]
281    async fn inherit_policy_forwards_all_calls() {
282        let (inner, executed) = RecordingExecutor::new();
283        let registry = registry_with(profile("test", ToolPolicy::Inherit));
284        let sessions = sessions_with("s1", Some("test")).await;
285        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
286
287        let call = make_call("Read");
288        let ctx = ToolExecutionContext {
289            session_id: Some("s1"),
290            tool_call_id: "call_1",
291            event_tx: None,
292            available_tool_schemas: None,
293        };
294        exec.execute_with_context(&call, ctx).await.unwrap();
295        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
296    }
297
298    #[tokio::test]
299    async fn allowlist_permits_listed_tool() {
300        let (inner, executed) = RecordingExecutor::new();
301        let registry = registry_with(profile(
302            "researcher",
303            ToolPolicy::Allowlist {
304                allow: vec!["Read".to_string(), "Grep".to_string()],
305            },
306        ));
307        let sessions = sessions_with("s1", Some("researcher")).await;
308        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
309
310        let ctx = ToolExecutionContext {
311            session_id: Some("s1"),
312            tool_call_id: "call_1",
313            event_tx: None,
314            available_tool_schemas: None,
315        };
316        exec.execute_with_context(&make_call("Read"), ctx)
317            .await
318            .unwrap();
319        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
320    }
321
322    #[tokio::test]
323    async fn allowlist_blocks_unlisted_tool() {
324        let (inner, executed) = RecordingExecutor::new();
325        let registry = registry_with(profile(
326            "researcher",
327            ToolPolicy::Allowlist {
328                allow: vec!["Read".to_string()],
329            },
330        ));
331        let sessions = sessions_with("s1", Some("researcher")).await;
332        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
333
334        let ctx = ToolExecutionContext {
335            session_id: Some("s1"),
336            tool_call_id: "call_1",
337            event_tx: None,
338            available_tool_schemas: None,
339        };
340        let err = exec
341            .execute_with_context(&make_call("Edit"), ctx)
342            .await
343            .unwrap_err();
344        match err {
345            ToolError::Execution(msg) => {
346                assert!(msg.contains("Edit"), "msg should name tool: {msg}");
347                assert!(
348                    msg.contains("researcher"),
349                    "msg should name subagent_type: {msg}"
350                );
351                assert!(msg.contains("allowlist"), "msg should name mode: {msg}");
352            }
353            other => panic!("expected ToolError::Execution, got {other:?}"),
354        }
355        assert!(executed.read().await.is_empty());
356    }
357
358    #[tokio::test]
359    async fn denylist_blocks_listed_tool() {
360        let (inner, executed) = RecordingExecutor::new();
361        let registry = registry_with(profile(
362            "coder",
363            ToolPolicy::Denylist {
364                deny: vec!["SubSession".to_string()],
365            },
366        ));
367        let sessions = sessions_with("s1", Some("coder")).await;
368        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
369
370        let ctx = ToolExecutionContext {
371            session_id: Some("s1"),
372            tool_call_id: "call_1",
373            event_tx: None,
374            available_tool_schemas: None,
375        };
376        let err = exec
377            .execute_with_context(&make_call("SubSession"), ctx)
378            .await
379            .unwrap_err();
380        match err {
381            ToolError::Execution(msg) => {
382                assert!(msg.contains("SubSession"));
383                assert!(msg.contains("denylist"));
384            }
385            other => panic!("expected ToolError::Execution, got {other:?}"),
386        }
387        assert!(executed.read().await.is_empty());
388    }
389
390    #[tokio::test]
391    async fn denylist_permits_unlisted_tool() {
392        let (inner, executed) = RecordingExecutor::new();
393        let registry = registry_with(profile(
394            "coder",
395            ToolPolicy::Denylist {
396                deny: vec!["SubSession".to_string()],
397            },
398        ));
399        let sessions = sessions_with("s1", Some("coder")).await;
400        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
401
402        let ctx = ToolExecutionContext {
403            session_id: Some("s1"),
404            tool_call_id: "call_1",
405            event_tx: None,
406            available_tool_schemas: None,
407        };
408        exec.execute_with_context(&make_call("Read"), ctx)
409            .await
410            .unwrap();
411        assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
412    }
413
414    #[tokio::test]
415    async fn missing_session_id_falls_through() {
416        // No session_id in context → wrapper must not reject. This is the
417        // "legacy / direct caller" path and must keep working.
418        let (inner, executed) = RecordingExecutor::new();
419        let registry = registry_with(profile(
420            "researcher",
421            ToolPolicy::Allowlist {
422                allow: vec!["Read".to_string()],
423            },
424        ));
425        let sessions = sessions_with("s1", Some("researcher")).await;
426        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
427
428        let ctx = ToolExecutionContext::none("call_1");
429        exec.execute_with_context(&make_call("Edit"), ctx)
430            .await
431            .unwrap();
432        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
433    }
434
435    #[tokio::test]
436    async fn unknown_session_falls_through() {
437        let (inner, executed) = RecordingExecutor::new();
438        let registry = registry_with(profile(
439            "researcher",
440            ToolPolicy::Allowlist {
441                allow: vec!["Read".to_string()],
442            },
443        ));
444        // Cache contains a different session id than the one referenced.
445        let sessions = sessions_with("other", Some("researcher")).await;
446        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
447
448        let ctx = ToolExecutionContext {
449            session_id: Some("missing"),
450            tool_call_id: "call_1",
451            event_tx: None,
452            available_tool_schemas: None,
453        };
454        exec.execute_with_context(&make_call("Edit"), ctx)
455            .await
456            .unwrap();
457        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
458    }
459
460    #[tokio::test]
461    async fn missing_subagent_type_metadata_falls_through() {
462        let (inner, executed) = RecordingExecutor::new();
463        let registry = registry_with(profile(
464            "researcher",
465            ToolPolicy::Allowlist {
466                allow: vec!["Read".to_string()],
467            },
468        ));
469        // Session is cached but has no subagent_type metadata.
470        let sessions = sessions_with("s1", None).await;
471        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
472
473        let ctx = ToolExecutionContext {
474            session_id: Some("s1"),
475            tool_call_id: "call_1",
476            event_tx: None,
477            available_tool_schemas: None,
478        };
479        exec.execute_with_context(&make_call("Edit"), ctx)
480            .await
481            .unwrap();
482        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
483    }
484
485    #[tokio::test]
486    async fn execute_without_context_forwards() {
487        // The plain `execute` path has no session context. We document this
488        // by forwarding unconditionally.
489        let (inner, executed) = RecordingExecutor::new();
490        let registry = registry_with(profile(
491            "researcher",
492            ToolPolicy::Allowlist {
493                allow: vec!["Read".to_string()],
494            },
495        ));
496        let sessions = sessions_with("s1", Some("researcher")).await;
497        let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
498
499        exec.execute(&make_call("Edit")).await.unwrap();
500        assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
501    }
502}