1use std::collections::HashMap;
43use std::sync::Arc;
44
45use async_trait::async_trait;
46use bamboo_agent_core::tools::{
47 ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
48};
49use bamboo_agent_core::Session;
50use bamboo_domain::subagent::{SubagentProfileRegistry, ToolPolicy};
51use tokio::sync::RwLock;
52
53pub struct PolicyAwareToolExecutor {
56 inner: Arc<dyn ToolExecutor>,
57 profiles: Arc<SubagentProfileRegistry>,
58 sessions: Arc<RwLock<HashMap<String, Session>>>,
59}
60
61impl PolicyAwareToolExecutor {
62 pub fn new(
63 inner: Arc<dyn ToolExecutor>,
64 profiles: Arc<SubagentProfileRegistry>,
65 sessions: Arc<RwLock<HashMap<String, Session>>>,
66 ) -> Self {
67 Self {
68 inner,
69 profiles,
70 sessions,
71 }
72 }
73
74 async fn subagent_type_for_session(&self, session_id: &str) -> Option<String> {
78 let sessions = self.sessions.read().await;
79 let value = sessions
80 .get(session_id)?
81 .metadata
82 .get("subagent_type")?
83 .trim();
84 if value.is_empty() {
85 None
86 } else {
87 Some(value.to_string())
88 }
89 }
90
91 fn check_policy(
94 policy: &ToolPolicy,
95 tool_name: &str,
96 subagent_type: &str,
97 ) -> Result<(), String> {
98 match policy {
99 ToolPolicy::Inherit => Ok(()),
100 ToolPolicy::Allowlist { allow } => {
101 if allow.iter().any(|t| t == tool_name) {
102 Ok(())
103 } else {
104 Err(format!(
105 "tool '{tool_name}' is not permitted for subagent_type \
106 '{subagent_type}' (allowlist policy: {allow:?})"
107 ))
108 }
109 }
110 ToolPolicy::Denylist { deny } => {
111 if deny.iter().any(|t| t == tool_name) {
112 Err(format!(
113 "tool '{tool_name}' is denied for subagent_type \
114 '{subagent_type}' (denylist policy: {deny:?})"
115 ))
116 } else {
117 Ok(())
118 }
119 }
120 }
121 }
122
123 async fn evaluate(&self, call: &ToolCall, session_id: Option<&str>) -> Result<(), String> {
128 let Some(session_id) = session_id else {
129 return Ok(());
130 };
131 let Some(subagent_type) = self.subagent_type_for_session(session_id).await else {
132 return Ok(());
133 };
134 let profile = self.profiles.resolve(&subagent_type);
135 Self::check_policy(&profile.tools, call.function.name.trim(), &subagent_type)
136 }
137}
138
139#[async_trait]
140impl ToolExecutor for PolicyAwareToolExecutor {
141 async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
142 self.inner.execute(call).await
147 }
148
149 async fn execute_with_context(
150 &self,
151 call: &ToolCall,
152 ctx: ToolExecutionContext<'_>,
153 ) -> std::result::Result<ToolResult, ToolError> {
154 if let Err(reason) = self.evaluate(call, ctx.session_id).await {
155 return Err(ToolError::Execution(reason));
156 }
157 self.inner.execute_with_context(call, ctx).await
158 }
159
160 fn list_tools(&self) -> Vec<ToolSchema> {
161 self.inner.list_tools()
163 }
164
165 fn tool_mutability(&self, tool_name: &str) -> bamboo_agent_core::tools::ToolMutability {
166 self.inner.tool_mutability(tool_name)
167 }
168
169 fn call_mutability(&self, call: &ToolCall) -> bamboo_agent_core::tools::ToolMutability {
170 self.inner.call_mutability(call)
171 }
172
173 fn tool_concurrency_safe(&self, tool_name: &str) -> bool {
174 self.inner.tool_concurrency_safe(tool_name)
175 }
176
177 fn call_concurrency_safe(&self, call: &ToolCall) -> bool {
178 self.inner.call_concurrency_safe(call)
179 }
180}
181
182#[cfg(test)]
183mod tests {
184 use super::*;
185 use bamboo_agent_core::tools::{FunctionCall, ToolMutability};
186 use bamboo_domain::subagent::SubagentProfile;
187
188 struct RecordingExecutor {
193 executed: Arc<RwLock<Vec<String>>>,
194 }
195
196 impl RecordingExecutor {
197 fn new() -> (Arc<Self>, Arc<RwLock<Vec<String>>>) {
198 let executed = Arc::new(RwLock::new(Vec::new()));
199 let exec = Arc::new(Self {
200 executed: executed.clone(),
201 });
202 (exec, executed)
203 }
204 }
205
206 #[async_trait]
207 impl ToolExecutor for RecordingExecutor {
208 async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
209 self.executed.write().await.push(call.function.name.clone());
210 Ok(ToolResult {
211 success: true,
212 result: "ok".to_string(),
213 display_preference: None,
214 })
215 }
216
217 fn list_tools(&self) -> Vec<ToolSchema> {
218 Vec::new()
219 }
220
221 fn tool_mutability(&self, _tool_name: &str) -> ToolMutability {
222 ToolMutability::ReadOnly
223 }
224 }
225
226 fn make_call(name: &str) -> ToolCall {
227 ToolCall {
228 id: "call_1".to_string(),
229 tool_type: "function".to_string(),
230 function: FunctionCall {
231 name: name.to_string(),
232 arguments: "{}".to_string(),
233 },
234 }
235 }
236
237 fn registry_with(profile: SubagentProfile) -> Arc<SubagentProfileRegistry> {
238 let id = profile.id.clone();
242 Arc::new(
243 SubagentProfileRegistry::builder()
244 .extend(vec![profile])
245 .fallback_id(id)
246 .build()
247 .expect("registry build"),
248 )
249 }
250
251 fn profile(id: &str, tools: ToolPolicy) -> SubagentProfile {
252 SubagentProfile {
253 id: id.to_string(),
254 display_name: id.to_string(),
255 description: String::new(),
256 system_prompt: "p".to_string(),
257 tools,
258 model_hint: None,
259 default_responsibility: None,
260 ui: Default::default(),
261 }
262 }
263
264 async fn sessions_with(
265 session_id: &str,
266 subagent_type: Option<&str>,
267 ) -> Arc<RwLock<HashMap<String, Session>>> {
268 let mut map = HashMap::new();
269 let mut session = Session::new_child(session_id, "root", "test-model", "Child");
270 if let Some(t) = subagent_type {
271 session
272 .metadata
273 .insert("subagent_type".to_string(), t.to_string());
274 }
275 map.insert(session_id.to_string(), session);
276 Arc::new(RwLock::new(map))
277 }
278
279 #[tokio::test]
280 async fn inherit_policy_forwards_all_calls() {
281 let (inner, executed) = RecordingExecutor::new();
282 let registry = registry_with(profile("test", ToolPolicy::Inherit));
283 let sessions = sessions_with("s1", Some("test")).await;
284 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
285
286 let call = make_call("Read");
287 let ctx = ToolExecutionContext {
288 session_id: Some("s1"),
289 tool_call_id: "call_1",
290 event_tx: None,
291 available_tool_schemas: None,
292 };
293 exec.execute_with_context(&call, ctx).await.unwrap();
294 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
295 }
296
297 #[tokio::test]
298 async fn allowlist_permits_listed_tool() {
299 let (inner, executed) = RecordingExecutor::new();
300 let registry = registry_with(profile(
301 "researcher",
302 ToolPolicy::Allowlist {
303 allow: vec!["Read".to_string(), "Grep".to_string()],
304 },
305 ));
306 let sessions = sessions_with("s1", Some("researcher")).await;
307 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
308
309 let ctx = ToolExecutionContext {
310 session_id: Some("s1"),
311 tool_call_id: "call_1",
312 event_tx: None,
313 available_tool_schemas: None,
314 };
315 exec.execute_with_context(&make_call("Read"), ctx)
316 .await
317 .unwrap();
318 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
319 }
320
321 #[tokio::test]
322 async fn allowlist_blocks_unlisted_tool() {
323 let (inner, executed) = RecordingExecutor::new();
324 let registry = registry_with(profile(
325 "researcher",
326 ToolPolicy::Allowlist {
327 allow: vec!["Read".to_string()],
328 },
329 ));
330 let sessions = sessions_with("s1", Some("researcher")).await;
331 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
332
333 let ctx = ToolExecutionContext {
334 session_id: Some("s1"),
335 tool_call_id: "call_1",
336 event_tx: None,
337 available_tool_schemas: None,
338 };
339 let err = exec
340 .execute_with_context(&make_call("Edit"), ctx)
341 .await
342 .unwrap_err();
343 match err {
344 ToolError::Execution(msg) => {
345 assert!(msg.contains("Edit"), "msg should name tool: {msg}");
346 assert!(
347 msg.contains("researcher"),
348 "msg should name subagent_type: {msg}"
349 );
350 assert!(msg.contains("allowlist"), "msg should name mode: {msg}");
351 }
352 other => panic!("expected ToolError::Execution, got {other:?}"),
353 }
354 assert!(executed.read().await.is_empty());
355 }
356
357 #[tokio::test]
358 async fn denylist_blocks_listed_tool() {
359 let (inner, executed) = RecordingExecutor::new();
360 let registry = registry_with(profile(
361 "coder",
362 ToolPolicy::Denylist {
363 deny: vec!["SubAgent".to_string()],
364 },
365 ));
366 let sessions = sessions_with("s1", Some("coder")).await;
367 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
368
369 let ctx = ToolExecutionContext {
370 session_id: Some("s1"),
371 tool_call_id: "call_1",
372 event_tx: None,
373 available_tool_schemas: None,
374 };
375 let err = exec
376 .execute_with_context(&make_call("SubAgent"), ctx)
377 .await
378 .unwrap_err();
379 match err {
380 ToolError::Execution(msg) => {
381 assert!(msg.contains("SubAgent"));
382 assert!(msg.contains("denylist"));
383 }
384 other => panic!("expected ToolError::Execution, got {other:?}"),
385 }
386 assert!(executed.read().await.is_empty());
387 }
388
389 #[tokio::test]
390 async fn denylist_permits_unlisted_tool() {
391 let (inner, executed) = RecordingExecutor::new();
392 let registry = registry_with(profile(
393 "coder",
394 ToolPolicy::Denylist {
395 deny: vec!["SubAgent".to_string()],
396 },
397 ));
398 let sessions = sessions_with("s1", Some("coder")).await;
399 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
400
401 let ctx = ToolExecutionContext {
402 session_id: Some("s1"),
403 tool_call_id: "call_1",
404 event_tx: None,
405 available_tool_schemas: None,
406 };
407 exec.execute_with_context(&make_call("Read"), ctx)
408 .await
409 .unwrap();
410 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
411 }
412
413 #[tokio::test]
414 async fn missing_session_id_falls_through() {
415 let (inner, executed) = RecordingExecutor::new();
418 let registry = registry_with(profile(
419 "researcher",
420 ToolPolicy::Allowlist {
421 allow: vec!["Read".to_string()],
422 },
423 ));
424 let sessions = sessions_with("s1", Some("researcher")).await;
425 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
426
427 let ctx = ToolExecutionContext::none("call_1");
428 exec.execute_with_context(&make_call("Edit"), ctx)
429 .await
430 .unwrap();
431 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
432 }
433
434 #[tokio::test]
435 async fn unknown_session_falls_through() {
436 let (inner, executed) = RecordingExecutor::new();
437 let registry = registry_with(profile(
438 "researcher",
439 ToolPolicy::Allowlist {
440 allow: vec!["Read".to_string()],
441 },
442 ));
443 let sessions = sessions_with("other", Some("researcher")).await;
445 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
446
447 let ctx = ToolExecutionContext {
448 session_id: Some("missing"),
449 tool_call_id: "call_1",
450 event_tx: None,
451 available_tool_schemas: None,
452 };
453 exec.execute_with_context(&make_call("Edit"), ctx)
454 .await
455 .unwrap();
456 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
457 }
458
459 #[tokio::test]
460 async fn missing_subagent_type_metadata_falls_through() {
461 let (inner, executed) = RecordingExecutor::new();
462 let registry = registry_with(profile(
463 "researcher",
464 ToolPolicy::Allowlist {
465 allow: vec!["Read".to_string()],
466 },
467 ));
468 let sessions = sessions_with("s1", None).await;
470 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
471
472 let ctx = ToolExecutionContext {
473 session_id: Some("s1"),
474 tool_call_id: "call_1",
475 event_tx: None,
476 available_tool_schemas: None,
477 };
478 exec.execute_with_context(&make_call("Edit"), ctx)
479 .await
480 .unwrap();
481 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
482 }
483
484 #[tokio::test]
485 async fn execute_without_context_forwards() {
486 let (inner, executed) = RecordingExecutor::new();
489 let registry = registry_with(profile(
490 "researcher",
491 ToolPolicy::Allowlist {
492 allow: vec!["Read".to_string()],
493 },
494 ));
495 let sessions = sessions_with("s1", Some("researcher")).await;
496 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
497
498 exec.execute(&make_call("Edit")).await.unwrap();
499 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
500 }
501}