1use std::collections::HashMap;
52use std::sync::Arc;
53
54use async_trait::async_trait;
55use bamboo_agent_core::tools::{
56 ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
57};
58use bamboo_agent_core::Session;
59use bamboo_domain::subagent::{SubagentProfileRegistry, ToolPolicy};
60use tokio::sync::RwLock;
61
62pub struct PolicyAwareToolExecutor {
65 inner: Arc<dyn ToolExecutor>,
66 profiles: Arc<SubagentProfileRegistry>,
67 sessions: Arc<RwLock<HashMap<String, Session>>>,
68}
69
70impl PolicyAwareToolExecutor {
71 pub fn new(
72 inner: Arc<dyn ToolExecutor>,
73 profiles: Arc<SubagentProfileRegistry>,
74 sessions: Arc<RwLock<HashMap<String, Session>>>,
75 ) -> Self {
76 Self {
77 inner,
78 profiles,
79 sessions,
80 }
81 }
82
83 async fn subagent_type_for_session(&self, session_id: &str) -> Option<String> {
87 let sessions = self.sessions.read().await;
88 let value = sessions
89 .get(session_id)?
90 .metadata
91 .get("subagent_type")?
92 .trim();
93 if value.is_empty() {
94 None
95 } else {
96 Some(value.to_string())
97 }
98 }
99
100 fn check_policy(
103 policy: &ToolPolicy,
104 tool_name: &str,
105 subagent_type: &str,
106 ) -> Result<(), String> {
107 match policy {
108 ToolPolicy::Inherit => Ok(()),
109 ToolPolicy::Allowlist { allow } => {
110 if allow.iter().any(|t| t == tool_name) {
111 Ok(())
112 } else {
113 Err(format!(
114 "tool '{tool_name}' is not permitted for subagent_type \
115 '{subagent_type}' (allowlist policy: {allow:?})"
116 ))
117 }
118 }
119 ToolPolicy::Denylist { deny } => {
120 if deny.iter().any(|t| t == tool_name) {
121 Err(format!(
122 "tool '{tool_name}' is denied for subagent_type \
123 '{subagent_type}' (denylist policy: {deny:?})"
124 ))
125 } else {
126 Ok(())
127 }
128 }
129 }
130 }
131
132 async fn evaluate(&self, call: &ToolCall, session_id: Option<&str>) -> Result<(), String> {
137 let Some(session_id) = session_id else {
138 return Ok(());
139 };
140 let Some(subagent_type) = self.subagent_type_for_session(session_id).await else {
141 return Ok(());
142 };
143 let profile = self.profiles.resolve(&subagent_type);
144 Self::check_policy(&profile.tools, call.function.name.trim(), &subagent_type)
145 }
146}
147
148#[async_trait]
149impl ToolExecutor for PolicyAwareToolExecutor {
150 async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
151 self.inner.execute(call).await
156 }
157
158 async fn execute_with_context(
159 &self,
160 call: &ToolCall,
161 ctx: ToolExecutionContext<'_>,
162 ) -> std::result::Result<ToolResult, ToolError> {
163 if let Err(reason) = self.evaluate(call, ctx.session_id).await {
164 return Err(ToolError::Execution(reason));
165 }
166 self.inner.execute_with_context(call, ctx).await
167 }
168
169 fn list_tools(&self) -> Vec<ToolSchema> {
170 self.inner.list_tools()
172 }
173
174 fn tool_mutability(&self, tool_name: &str) -> bamboo_agent_core::tools::ToolMutability {
175 self.inner.tool_mutability(tool_name)
176 }
177
178 fn call_mutability(&self, call: &ToolCall) -> bamboo_agent_core::tools::ToolMutability {
179 self.inner.call_mutability(call)
180 }
181
182 fn tool_concurrency_safe(&self, tool_name: &str) -> bool {
183 self.inner.tool_concurrency_safe(tool_name)
184 }
185
186 fn call_concurrency_safe(&self, call: &ToolCall) -> bool {
187 self.inner.call_concurrency_safe(call)
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use bamboo_agent_core::tools::{FunctionCall, ToolMutability};
195 use bamboo_domain::subagent::SubagentProfile;
196
197 struct RecordingExecutor {
202 executed: Arc<RwLock<Vec<String>>>,
203 }
204
205 impl RecordingExecutor {
206 fn new() -> (Arc<Self>, Arc<RwLock<Vec<String>>>) {
207 let executed = Arc::new(RwLock::new(Vec::new()));
208 let exec = Arc::new(Self {
209 executed: executed.clone(),
210 });
211 (exec, executed)
212 }
213 }
214
215 #[async_trait]
216 impl ToolExecutor for RecordingExecutor {
217 async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
218 self.executed.write().await.push(call.function.name.clone());
219 Ok(ToolResult {
220 success: true,
221 result: "ok".to_string(),
222 display_preference: None,
223 })
224 }
225
226 fn list_tools(&self) -> Vec<ToolSchema> {
227 Vec::new()
228 }
229
230 fn tool_mutability(&self, _tool_name: &str) -> ToolMutability {
231 ToolMutability::ReadOnly
232 }
233 }
234
235 fn make_call(name: &str) -> ToolCall {
236 ToolCall {
237 id: "call_1".to_string(),
238 tool_type: "function".to_string(),
239 function: FunctionCall {
240 name: name.to_string(),
241 arguments: "{}".to_string(),
242 },
243 }
244 }
245
246 fn registry_with(profile: SubagentProfile) -> Arc<SubagentProfileRegistry> {
247 let id = profile.id.clone();
251 Arc::new(
252 SubagentProfileRegistry::builder()
253 .extend(vec![profile])
254 .fallback_id(id)
255 .build()
256 .expect("registry build"),
257 )
258 }
259
260 fn profile(id: &str, tools: ToolPolicy) -> SubagentProfile {
261 SubagentProfile {
262 id: id.to_string(),
263 display_name: id.to_string(),
264 description: String::new(),
265 system_prompt: "p".to_string(),
266 tools,
267 model_hint: None,
268 default_responsibility: None,
269 ui: Default::default(),
270 }
271 }
272
273 async fn sessions_with(
274 session_id: &str,
275 subagent_type: Option<&str>,
276 ) -> Arc<RwLock<HashMap<String, Session>>> {
277 let mut map = HashMap::new();
278 let mut session = Session::new_child(session_id, "root", "test-model", "Child");
279 if let Some(t) = subagent_type {
280 session
281 .metadata
282 .insert("subagent_type".to_string(), t.to_string());
283 }
284 map.insert(session_id.to_string(), session);
285 Arc::new(RwLock::new(map))
286 }
287
288 #[tokio::test]
289 async fn inherit_policy_forwards_all_calls() {
290 let (inner, executed) = RecordingExecutor::new();
291 let registry = registry_with(profile("test", ToolPolicy::Inherit));
292 let sessions = sessions_with("s1", Some("test")).await;
293 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
294
295 let call = make_call("Read");
296 let ctx = ToolExecutionContext {
297 session_id: Some("s1"),
298 tool_call_id: "call_1",
299 event_tx: None,
300 available_tool_schemas: None,
301 };
302 exec.execute_with_context(&call, ctx).await.unwrap();
303 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
304 }
305
306 #[tokio::test]
307 async fn allowlist_permits_listed_tool() {
308 let (inner, executed) = RecordingExecutor::new();
309 let registry = registry_with(profile(
310 "researcher",
311 ToolPolicy::Allowlist {
312 allow: vec!["Read".to_string(), "Grep".to_string()],
313 },
314 ));
315 let sessions = sessions_with("s1", Some("researcher")).await;
316 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
317
318 let ctx = ToolExecutionContext {
319 session_id: Some("s1"),
320 tool_call_id: "call_1",
321 event_tx: None,
322 available_tool_schemas: None,
323 };
324 exec.execute_with_context(&make_call("Read"), ctx)
325 .await
326 .unwrap();
327 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
328 }
329
330 #[tokio::test]
331 async fn allowlist_blocks_unlisted_tool() {
332 let (inner, executed) = RecordingExecutor::new();
333 let registry = registry_with(profile(
334 "researcher",
335 ToolPolicy::Allowlist {
336 allow: vec!["Read".to_string()],
337 },
338 ));
339 let sessions = sessions_with("s1", Some("researcher")).await;
340 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
341
342 let ctx = ToolExecutionContext {
343 session_id: Some("s1"),
344 tool_call_id: "call_1",
345 event_tx: None,
346 available_tool_schemas: None,
347 };
348 let err = exec
349 .execute_with_context(&make_call("Edit"), ctx)
350 .await
351 .unwrap_err();
352 match err {
353 ToolError::Execution(msg) => {
354 assert!(msg.contains("Edit"), "msg should name tool: {msg}");
355 assert!(
356 msg.contains("researcher"),
357 "msg should name subagent_type: {msg}"
358 );
359 assert!(msg.contains("allowlist"), "msg should name mode: {msg}");
360 }
361 other => panic!("expected ToolError::Execution, got {other:?}"),
362 }
363 assert!(executed.read().await.is_empty());
364 }
365
366 #[tokio::test]
367 async fn denylist_blocks_listed_tool() {
368 let (inner, executed) = RecordingExecutor::new();
369 let registry = registry_with(profile(
370 "coder",
371 ToolPolicy::Denylist {
372 deny: vec!["SubAgent".to_string()],
373 },
374 ));
375 let sessions = sessions_with("s1", Some("coder")).await;
376 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
377
378 let ctx = ToolExecutionContext {
379 session_id: Some("s1"),
380 tool_call_id: "call_1",
381 event_tx: None,
382 available_tool_schemas: None,
383 };
384 let err = exec
385 .execute_with_context(&make_call("SubAgent"), ctx)
386 .await
387 .unwrap_err();
388 match err {
389 ToolError::Execution(msg) => {
390 assert!(msg.contains("SubAgent"));
391 assert!(msg.contains("denylist"));
392 }
393 other => panic!("expected ToolError::Execution, got {other:?}"),
394 }
395 assert!(executed.read().await.is_empty());
396 }
397
398 #[tokio::test]
399 async fn denylist_permits_unlisted_tool() {
400 let (inner, executed) = RecordingExecutor::new();
401 let registry = registry_with(profile(
402 "coder",
403 ToolPolicy::Denylist {
404 deny: vec!["SubAgent".to_string()],
405 },
406 ));
407 let sessions = sessions_with("s1", Some("coder")).await;
408 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
409
410 let ctx = ToolExecutionContext {
411 session_id: Some("s1"),
412 tool_call_id: "call_1",
413 event_tx: None,
414 available_tool_schemas: None,
415 };
416 exec.execute_with_context(&make_call("Read"), ctx)
417 .await
418 .unwrap();
419 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
420 }
421
422 #[tokio::test]
423 async fn missing_session_id_falls_through() {
424 let (inner, executed) = RecordingExecutor::new();
427 let registry = registry_with(profile(
428 "researcher",
429 ToolPolicy::Allowlist {
430 allow: vec!["Read".to_string()],
431 },
432 ));
433 let sessions = sessions_with("s1", Some("researcher")).await;
434 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
435
436 let ctx = ToolExecutionContext::none("call_1");
437 exec.execute_with_context(&make_call("Edit"), ctx)
438 .await
439 .unwrap();
440 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
441 }
442
443 #[tokio::test]
444 async fn unknown_session_falls_through() {
445 let (inner, executed) = RecordingExecutor::new();
446 let registry = registry_with(profile(
447 "researcher",
448 ToolPolicy::Allowlist {
449 allow: vec!["Read".to_string()],
450 },
451 ));
452 let sessions = sessions_with("other", Some("researcher")).await;
454 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
455
456 let ctx = ToolExecutionContext {
457 session_id: Some("missing"),
458 tool_call_id: "call_1",
459 event_tx: None,
460 available_tool_schemas: None,
461 };
462 exec.execute_with_context(&make_call("Edit"), ctx)
463 .await
464 .unwrap();
465 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
466 }
467
468 #[tokio::test]
469 async fn missing_subagent_type_metadata_falls_through() {
470 let (inner, executed) = RecordingExecutor::new();
471 let registry = registry_with(profile(
472 "researcher",
473 ToolPolicy::Allowlist {
474 allow: vec!["Read".to_string()],
475 },
476 ));
477 let sessions = sessions_with("s1", None).await;
479 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
480
481 let ctx = ToolExecutionContext {
482 session_id: Some("s1"),
483 tool_call_id: "call_1",
484 event_tx: None,
485 available_tool_schemas: None,
486 };
487 exec.execute_with_context(&make_call("Edit"), ctx)
488 .await
489 .unwrap();
490 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
491 }
492
493 #[tokio::test]
494 async fn execute_without_context_forwards() {
495 let (inner, executed) = RecordingExecutor::new();
498 let registry = registry_with(profile(
499 "researcher",
500 ToolPolicy::Allowlist {
501 allow: vec!["Read".to_string()],
502 },
503 ));
504 let sessions = sessions_with("s1", Some("researcher")).await;
505 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
506
507 exec.execute(&make_call("Edit")).await.unwrap();
508 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
509 }
510}