1use std::collections::HashMap;
44use std::sync::Arc;
45
46use async_trait::async_trait;
47use bamboo_agent_core::tools::{
48 ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
49};
50use bamboo_agent_core::Session;
51use bamboo_domain::subagent::{SubagentProfileRegistry, ToolPolicy};
52use tokio::sync::RwLock;
53
54pub struct PolicyAwareToolExecutor {
57 inner: Arc<dyn ToolExecutor>,
58 profiles: Arc<SubagentProfileRegistry>,
59 sessions: Arc<RwLock<HashMap<String, Session>>>,
60}
61
62impl PolicyAwareToolExecutor {
63 pub fn new(
64 inner: Arc<dyn ToolExecutor>,
65 profiles: Arc<SubagentProfileRegistry>,
66 sessions: Arc<RwLock<HashMap<String, Session>>>,
67 ) -> Self {
68 Self {
69 inner,
70 profiles,
71 sessions,
72 }
73 }
74
75 async fn subagent_type_for_session(&self, session_id: &str) -> Option<String> {
79 let sessions = self.sessions.read().await;
80 let value = sessions
81 .get(session_id)?
82 .metadata
83 .get("subagent_type")?
84 .trim();
85 if value.is_empty() {
86 None
87 } else {
88 Some(value.to_string())
89 }
90 }
91
92 fn check_policy(
95 policy: &ToolPolicy,
96 tool_name: &str,
97 subagent_type: &str,
98 ) -> Result<(), String> {
99 match policy {
100 ToolPolicy::Inherit => Ok(()),
101 ToolPolicy::Allowlist { allow } => {
102 if allow.iter().any(|t| t == tool_name) {
103 Ok(())
104 } else {
105 Err(format!(
106 "tool '{tool_name}' is not permitted for subagent_type \
107 '{subagent_type}' (allowlist policy: {allow:?})"
108 ))
109 }
110 }
111 ToolPolicy::Denylist { deny } => {
112 if deny.iter().any(|t| t == tool_name) {
113 Err(format!(
114 "tool '{tool_name}' is denied for subagent_type \
115 '{subagent_type}' (denylist policy: {deny:?})"
116 ))
117 } else {
118 Ok(())
119 }
120 }
121 }
122 }
123
124 async fn evaluate(&self, call: &ToolCall, session_id: Option<&str>) -> Result<(), String> {
129 let Some(session_id) = session_id else {
130 return Ok(());
131 };
132 let Some(subagent_type) = self.subagent_type_for_session(session_id).await else {
133 return Ok(());
134 };
135 let profile = self.profiles.resolve(&subagent_type);
136 Self::check_policy(&profile.tools, call.function.name.trim(), &subagent_type)
137 }
138}
139
140#[async_trait]
141impl ToolExecutor for PolicyAwareToolExecutor {
142 async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
143 self.inner.execute(call).await
148 }
149
150 async fn execute_with_context(
151 &self,
152 call: &ToolCall,
153 ctx: ToolExecutionContext<'_>,
154 ) -> std::result::Result<ToolResult, ToolError> {
155 if let Err(reason) = self.evaluate(call, ctx.session_id).await {
156 return Err(ToolError::Execution(reason));
157 }
158 self.inner.execute_with_context(call, ctx).await
159 }
160
161 fn list_tools(&self) -> Vec<ToolSchema> {
162 self.inner.list_tools()
164 }
165
166 fn tool_mutability(&self, tool_name: &str) -> bamboo_agent_core::tools::ToolMutability {
167 self.inner.tool_mutability(tool_name)
168 }
169
170 fn call_mutability(&self, call: &ToolCall) -> bamboo_agent_core::tools::ToolMutability {
171 self.inner.call_mutability(call)
172 }
173
174 fn tool_concurrency_safe(&self, tool_name: &str) -> bool {
175 self.inner.tool_concurrency_safe(tool_name)
176 }
177
178 fn call_concurrency_safe(&self, call: &ToolCall) -> bool {
179 self.inner.call_concurrency_safe(call)
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186 use bamboo_agent_core::tools::{FunctionCall, ToolMutability};
187 use bamboo_domain::subagent::SubagentProfile;
188
189 struct RecordingExecutor {
194 executed: Arc<RwLock<Vec<String>>>,
195 }
196
197 impl RecordingExecutor {
198 fn new() -> (Arc<Self>, Arc<RwLock<Vec<String>>>) {
199 let executed = Arc::new(RwLock::new(Vec::new()));
200 let exec = Arc::new(Self {
201 executed: executed.clone(),
202 });
203 (exec, executed)
204 }
205 }
206
207 #[async_trait]
208 impl ToolExecutor for RecordingExecutor {
209 async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
210 self.executed.write().await.push(call.function.name.clone());
211 Ok(ToolResult {
212 success: true,
213 result: "ok".to_string(),
214 display_preference: None,
215 })
216 }
217
218 fn list_tools(&self) -> Vec<ToolSchema> {
219 Vec::new()
220 }
221
222 fn tool_mutability(&self, _tool_name: &str) -> ToolMutability {
223 ToolMutability::ReadOnly
224 }
225 }
226
227 fn make_call(name: &str) -> ToolCall {
228 ToolCall {
229 id: "call_1".to_string(),
230 tool_type: "function".to_string(),
231 function: FunctionCall {
232 name: name.to_string(),
233 arguments: "{}".to_string(),
234 },
235 }
236 }
237
238 fn registry_with(profile: SubagentProfile) -> Arc<SubagentProfileRegistry> {
239 let id = profile.id.clone();
243 Arc::new(
244 SubagentProfileRegistry::builder()
245 .extend(vec![profile])
246 .fallback_id(id)
247 .build()
248 .expect("registry build"),
249 )
250 }
251
252 fn profile(id: &str, tools: ToolPolicy) -> SubagentProfile {
253 SubagentProfile {
254 id: id.to_string(),
255 display_name: id.to_string(),
256 description: String::new(),
257 system_prompt: "p".to_string(),
258 tools,
259 model_hint: None,
260 default_responsibility: None,
261 ui: Default::default(),
262 }
263 }
264
265 async fn sessions_with(
266 session_id: &str,
267 subagent_type: Option<&str>,
268 ) -> Arc<RwLock<HashMap<String, Session>>> {
269 let mut map = HashMap::new();
270 let mut session = Session::new_child(session_id, "root", "test-model", "Child");
271 if let Some(t) = subagent_type {
272 session
273 .metadata
274 .insert("subagent_type".to_string(), t.to_string());
275 }
276 map.insert(session_id.to_string(), session);
277 Arc::new(RwLock::new(map))
278 }
279
280 #[tokio::test]
281 async fn inherit_policy_forwards_all_calls() {
282 let (inner, executed) = RecordingExecutor::new();
283 let registry = registry_with(profile("test", ToolPolicy::Inherit));
284 let sessions = sessions_with("s1", Some("test")).await;
285 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
286
287 let call = make_call("Read");
288 let ctx = ToolExecutionContext {
289 session_id: Some("s1"),
290 tool_call_id: "call_1",
291 event_tx: None,
292 available_tool_schemas: None,
293 };
294 exec.execute_with_context(&call, ctx).await.unwrap();
295 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
296 }
297
298 #[tokio::test]
299 async fn allowlist_permits_listed_tool() {
300 let (inner, executed) = RecordingExecutor::new();
301 let registry = registry_with(profile(
302 "researcher",
303 ToolPolicy::Allowlist {
304 allow: vec!["Read".to_string(), "Grep".to_string()],
305 },
306 ));
307 let sessions = sessions_with("s1", Some("researcher")).await;
308 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
309
310 let ctx = ToolExecutionContext {
311 session_id: Some("s1"),
312 tool_call_id: "call_1",
313 event_tx: None,
314 available_tool_schemas: None,
315 };
316 exec.execute_with_context(&make_call("Read"), ctx)
317 .await
318 .unwrap();
319 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
320 }
321
322 #[tokio::test]
323 async fn allowlist_blocks_unlisted_tool() {
324 let (inner, executed) = RecordingExecutor::new();
325 let registry = registry_with(profile(
326 "researcher",
327 ToolPolicy::Allowlist {
328 allow: vec!["Read".to_string()],
329 },
330 ));
331 let sessions = sessions_with("s1", Some("researcher")).await;
332 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
333
334 let ctx = ToolExecutionContext {
335 session_id: Some("s1"),
336 tool_call_id: "call_1",
337 event_tx: None,
338 available_tool_schemas: None,
339 };
340 let err = exec
341 .execute_with_context(&make_call("Edit"), ctx)
342 .await
343 .unwrap_err();
344 match err {
345 ToolError::Execution(msg) => {
346 assert!(msg.contains("Edit"), "msg should name tool: {msg}");
347 assert!(
348 msg.contains("researcher"),
349 "msg should name subagent_type: {msg}"
350 );
351 assert!(msg.contains("allowlist"), "msg should name mode: {msg}");
352 }
353 other => panic!("expected ToolError::Execution, got {other:?}"),
354 }
355 assert!(executed.read().await.is_empty());
356 }
357
358 #[tokio::test]
359 async fn denylist_blocks_listed_tool() {
360 let (inner, executed) = RecordingExecutor::new();
361 let registry = registry_with(profile(
362 "coder",
363 ToolPolicy::Denylist {
364 deny: vec!["SubSession".to_string()],
365 },
366 ));
367 let sessions = sessions_with("s1", Some("coder")).await;
368 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
369
370 let ctx = ToolExecutionContext {
371 session_id: Some("s1"),
372 tool_call_id: "call_1",
373 event_tx: None,
374 available_tool_schemas: None,
375 };
376 let err = exec
377 .execute_with_context(&make_call("SubSession"), ctx)
378 .await
379 .unwrap_err();
380 match err {
381 ToolError::Execution(msg) => {
382 assert!(msg.contains("SubSession"));
383 assert!(msg.contains("denylist"));
384 }
385 other => panic!("expected ToolError::Execution, got {other:?}"),
386 }
387 assert!(executed.read().await.is_empty());
388 }
389
390 #[tokio::test]
391 async fn denylist_permits_unlisted_tool() {
392 let (inner, executed) = RecordingExecutor::new();
393 let registry = registry_with(profile(
394 "coder",
395 ToolPolicy::Denylist {
396 deny: vec!["SubSession".to_string()],
397 },
398 ));
399 let sessions = sessions_with("s1", Some("coder")).await;
400 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
401
402 let ctx = ToolExecutionContext {
403 session_id: Some("s1"),
404 tool_call_id: "call_1",
405 event_tx: None,
406 available_tool_schemas: None,
407 };
408 exec.execute_with_context(&make_call("Read"), ctx)
409 .await
410 .unwrap();
411 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
412 }
413
414 #[tokio::test]
415 async fn missing_session_id_falls_through() {
416 let (inner, executed) = RecordingExecutor::new();
419 let registry = registry_with(profile(
420 "researcher",
421 ToolPolicy::Allowlist {
422 allow: vec!["Read".to_string()],
423 },
424 ));
425 let sessions = sessions_with("s1", Some("researcher")).await;
426 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
427
428 let ctx = ToolExecutionContext::none("call_1");
429 exec.execute_with_context(&make_call("Edit"), ctx)
430 .await
431 .unwrap();
432 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
433 }
434
435 #[tokio::test]
436 async fn unknown_session_falls_through() {
437 let (inner, executed) = RecordingExecutor::new();
438 let registry = registry_with(profile(
439 "researcher",
440 ToolPolicy::Allowlist {
441 allow: vec!["Read".to_string()],
442 },
443 ));
444 let sessions = sessions_with("other", Some("researcher")).await;
446 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
447
448 let ctx = ToolExecutionContext {
449 session_id: Some("missing"),
450 tool_call_id: "call_1",
451 event_tx: None,
452 available_tool_schemas: None,
453 };
454 exec.execute_with_context(&make_call("Edit"), ctx)
455 .await
456 .unwrap();
457 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
458 }
459
460 #[tokio::test]
461 async fn missing_subagent_type_metadata_falls_through() {
462 let (inner, executed) = RecordingExecutor::new();
463 let registry = registry_with(profile(
464 "researcher",
465 ToolPolicy::Allowlist {
466 allow: vec!["Read".to_string()],
467 },
468 ));
469 let sessions = sessions_with("s1", None).await;
471 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
472
473 let ctx = ToolExecutionContext {
474 session_id: Some("s1"),
475 tool_call_id: "call_1",
476 event_tx: None,
477 available_tool_schemas: None,
478 };
479 exec.execute_with_context(&make_call("Edit"), ctx)
480 .await
481 .unwrap();
482 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
483 }
484
485 #[tokio::test]
486 async fn execute_without_context_forwards() {
487 let (inner, executed) = RecordingExecutor::new();
490 let registry = registry_with(profile(
491 "researcher",
492 ToolPolicy::Allowlist {
493 allow: vec!["Read".to_string()],
494 },
495 ));
496 let sessions = sessions_with("s1", Some("researcher")).await;
497 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
498
499 exec.execute(&make_call("Edit")).await.unwrap();
500 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
501 }
502}