1use std::sync::Arc;
52
53use async_trait::async_trait;
54use bamboo_agent_core::tools::{
55 ToolCall, ToolError, ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
56};
57use bamboo_agent_core::Session;
58use bamboo_domain::subagent::{SubagentProfileRegistry, ToolPolicy};
59
60pub type SessionCache = Arc<dashmap::DashMap<String, Arc<parking_lot::RwLock<Session>>>>;
67
68pub struct PolicyAwareToolExecutor {
71 inner: Arc<dyn ToolExecutor>,
72 profiles: Arc<SubagentProfileRegistry>,
73 sessions: SessionCache,
74}
75
76impl PolicyAwareToolExecutor {
77 pub fn new(
78 inner: Arc<dyn ToolExecutor>,
79 profiles: Arc<SubagentProfileRegistry>,
80 sessions: SessionCache,
81 ) -> Self {
82 Self {
83 inner,
84 profiles,
85 sessions,
86 }
87 }
88
89 async fn subagent_type_for_session(&self, session_id: &str) -> Option<String> {
93 let arc = self.sessions.get(session_id).map(|e| e.value().clone())?;
94 let value = arc.read().subagent_type()?;
95 let trimmed = value.trim();
96 if trimmed.is_empty() {
97 None
98 } else {
99 Some(trimmed.to_string())
100 }
101 }
102
103 fn check_policy(
106 policy: &ToolPolicy,
107 tool_name: &str,
108 subagent_type: &str,
109 ) -> Result<(), String> {
110 match policy {
111 ToolPolicy::Inherit => Ok(()),
112 ToolPolicy::Allowlist { allow } => {
113 if allow.iter().any(|t| t == tool_name) {
114 Ok(())
115 } else {
116 Err(format!(
117 "tool '{tool_name}' is not permitted for subagent_type \
118 '{subagent_type}' (allowlist policy: {allow:?})"
119 ))
120 }
121 }
122 ToolPolicy::Denylist { deny } => {
123 if deny.iter().any(|t| t == tool_name) {
124 Err(format!(
125 "tool '{tool_name}' is denied for subagent_type \
126 '{subagent_type}' (denylist policy: {deny:?})"
127 ))
128 } else {
129 Ok(())
130 }
131 }
132 }
133 }
134
135 async fn evaluate(&self, call: &ToolCall, session_id: Option<&str>) -> Result<(), String> {
140 let Some(session_id) = session_id else {
141 return Ok(());
142 };
143 let Some(subagent_type) = self.subagent_type_for_session(session_id).await else {
144 return Ok(());
145 };
146 let profile = self.profiles.resolve(&subagent_type);
147 Self::check_policy(&profile.tools, call.function.name.trim(), &subagent_type)
148 }
149}
150
151#[async_trait]
152impl ToolExecutor for PolicyAwareToolExecutor {
153 async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
154 self.inner.execute(call).await
159 }
160
161 async fn execute_with_context(
162 &self,
163 call: &ToolCall,
164 ctx: ToolExecutionContext<'_>,
165 ) -> std::result::Result<ToolResult, ToolError> {
166 if let Err(reason) = self.evaluate(call, ctx.session_id).await {
167 return Err(ToolError::Execution(reason));
168 }
169 self.inner.execute_with_context(call, ctx).await
170 }
171
172 fn list_tools(&self) -> Vec<ToolSchema> {
173 self.inner.list_tools()
175 }
176
177 fn tool_mutability(&self, tool_name: &str) -> bamboo_agent_core::tools::ToolMutability {
178 self.inner.tool_mutability(tool_name)
179 }
180
181 fn call_mutability(&self, call: &ToolCall) -> bamboo_agent_core::tools::ToolMutability {
182 self.inner.call_mutability(call)
183 }
184
185 fn tool_concurrency_safe(&self, tool_name: &str) -> bool {
186 self.inner.tool_concurrency_safe(tool_name)
187 }
188
189 fn call_concurrency_safe(&self, call: &ToolCall) -> bool {
190 self.inner.call_concurrency_safe(call)
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use bamboo_agent_core::tools::{FunctionCall, ToolMutability};
198 use bamboo_domain::subagent::SubagentProfile;
199 use tokio::sync::RwLock;
200
201 struct RecordingExecutor {
206 executed: Arc<RwLock<Vec<String>>>,
207 }
208
209 impl RecordingExecutor {
210 fn new() -> (Arc<Self>, Arc<RwLock<Vec<String>>>) {
211 let executed = Arc::new(RwLock::new(Vec::new()));
212 let exec = Arc::new(Self {
213 executed: executed.clone(),
214 });
215 (exec, executed)
216 }
217 }
218
219 #[async_trait]
220 impl ToolExecutor for RecordingExecutor {
221 async fn execute(&self, call: &ToolCall) -> std::result::Result<ToolResult, ToolError> {
222 self.executed.write().await.push(call.function.name.clone());
223 Ok(ToolResult {
224 success: true,
225 result: "ok".to_string(),
226 display_preference: None,
227 images: Vec::new(),
228 })
229 }
230
231 fn list_tools(&self) -> Vec<ToolSchema> {
232 Vec::new()
233 }
234
235 fn tool_mutability(&self, _tool_name: &str) -> ToolMutability {
236 ToolMutability::ReadOnly
237 }
238 }
239
240 fn make_call(name: &str) -> ToolCall {
241 ToolCall {
242 id: "call_1".to_string(),
243 tool_type: "function".to_string(),
244 function: FunctionCall {
245 name: name.to_string(),
246 arguments: "{}".to_string(),
247 },
248 }
249 }
250
251 fn registry_with(profile: SubagentProfile) -> Arc<SubagentProfileRegistry> {
252 let id = profile.id.clone();
256 Arc::new(
257 SubagentProfileRegistry::builder()
258 .extend(vec![profile])
259 .fallback_id(id)
260 .build()
261 .expect("registry build"),
262 )
263 }
264
265 fn profile(id: &str, tools: ToolPolicy) -> SubagentProfile {
266 SubagentProfile {
267 id: id.to_string(),
268 display_name: id.to_string(),
269 description: String::new(),
270 system_prompt: "p".to_string(),
271 tools,
272 model_hint: None,
273 default_responsibility: None,
274 ui: Default::default(),
275 }
276 }
277
278 async fn sessions_with(session_id: &str, subagent_type: Option<&str>) -> SessionCache {
279 let map = dashmap::DashMap::new();
280 let mut session = Session::new_child(session_id, "root", "test-model", "Child");
281 if let Some(t) = subagent_type {
282 session
283 .metadata
284 .insert("subagent_type".to_string(), t.to_string());
285 }
286 map.insert(
287 session_id.to_string(),
288 Arc::new(parking_lot::RwLock::new(session)),
289 );
290 Arc::new(map)
291 }
292
293 #[tokio::test]
294 async fn inherit_policy_forwards_all_calls() {
295 let (inner, executed) = RecordingExecutor::new();
296 let registry = registry_with(profile("test", ToolPolicy::Inherit));
297 let sessions = sessions_with("s1", Some("test")).await;
298 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
299
300 let call = make_call("Read");
301 let ctx = ToolExecutionContext {
302 session_id: Some("s1"),
303 tool_call_id: "call_1",
304 event_tx: None,
305 available_tool_schemas: None,
306 };
307 exec.execute_with_context(&call, ctx).await.unwrap();
308 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
309 }
310
311 #[tokio::test]
312 async fn allowlist_permits_listed_tool() {
313 let (inner, executed) = RecordingExecutor::new();
314 let registry = registry_with(profile(
315 "researcher",
316 ToolPolicy::Allowlist {
317 allow: vec!["Read".to_string(), "Grep".to_string()],
318 },
319 ));
320 let sessions = sessions_with("s1", Some("researcher")).await;
321 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
322
323 let ctx = ToolExecutionContext {
324 session_id: Some("s1"),
325 tool_call_id: "call_1",
326 event_tx: None,
327 available_tool_schemas: None,
328 };
329 exec.execute_with_context(&make_call("Read"), ctx)
330 .await
331 .unwrap();
332 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
333 }
334
335 #[tokio::test]
336 async fn allowlist_blocks_unlisted_tool() {
337 let (inner, executed) = RecordingExecutor::new();
338 let registry = registry_with(profile(
339 "researcher",
340 ToolPolicy::Allowlist {
341 allow: vec!["Read".to_string()],
342 },
343 ));
344 let sessions = sessions_with("s1", Some("researcher")).await;
345 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
346
347 let ctx = ToolExecutionContext {
348 session_id: Some("s1"),
349 tool_call_id: "call_1",
350 event_tx: None,
351 available_tool_schemas: None,
352 };
353 let err = exec
354 .execute_with_context(&make_call("Edit"), ctx)
355 .await
356 .unwrap_err();
357 match err {
358 ToolError::Execution(msg) => {
359 assert!(msg.contains("Edit"), "msg should name tool: {msg}");
360 assert!(
361 msg.contains("researcher"),
362 "msg should name subagent_type: {msg}"
363 );
364 assert!(msg.contains("allowlist"), "msg should name mode: {msg}");
365 }
366 other => panic!("expected ToolError::Execution, got {other:?}"),
367 }
368 assert!(executed.read().await.is_empty());
369 }
370
371 #[tokio::test]
372 async fn denylist_blocks_listed_tool() {
373 let (inner, executed) = RecordingExecutor::new();
374 let registry = registry_with(profile(
375 "coder",
376 ToolPolicy::Denylist {
377 deny: vec!["SubAgent".to_string()],
378 },
379 ));
380 let sessions = sessions_with("s1", Some("coder")).await;
381 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
382
383 let ctx = ToolExecutionContext {
384 session_id: Some("s1"),
385 tool_call_id: "call_1",
386 event_tx: None,
387 available_tool_schemas: None,
388 };
389 let err = exec
390 .execute_with_context(&make_call("SubAgent"), ctx)
391 .await
392 .unwrap_err();
393 match err {
394 ToolError::Execution(msg) => {
395 assert!(msg.contains("SubAgent"));
396 assert!(msg.contains("denylist"));
397 }
398 other => panic!("expected ToolError::Execution, got {other:?}"),
399 }
400 assert!(executed.read().await.is_empty());
401 }
402
403 #[tokio::test]
404 async fn denylist_permits_unlisted_tool() {
405 let (inner, executed) = RecordingExecutor::new();
406 let registry = registry_with(profile(
407 "coder",
408 ToolPolicy::Denylist {
409 deny: vec!["SubAgent".to_string()],
410 },
411 ));
412 let sessions = sessions_with("s1", Some("coder")).await;
413 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
414
415 let ctx = ToolExecutionContext {
416 session_id: Some("s1"),
417 tool_call_id: "call_1",
418 event_tx: None,
419 available_tool_schemas: None,
420 };
421 exec.execute_with_context(&make_call("Read"), ctx)
422 .await
423 .unwrap();
424 assert_eq!(executed.read().await.as_slice(), &["Read".to_string()]);
425 }
426
427 #[tokio::test]
428 async fn missing_session_id_falls_through() {
429 let (inner, executed) = RecordingExecutor::new();
432 let registry = registry_with(profile(
433 "researcher",
434 ToolPolicy::Allowlist {
435 allow: vec!["Read".to_string()],
436 },
437 ));
438 let sessions = sessions_with("s1", Some("researcher")).await;
439 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
440
441 let ctx = ToolExecutionContext::none("call_1");
442 exec.execute_with_context(&make_call("Edit"), ctx)
443 .await
444 .unwrap();
445 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
446 }
447
448 #[tokio::test]
449 async fn unknown_session_falls_through() {
450 let (inner, executed) = RecordingExecutor::new();
451 let registry = registry_with(profile(
452 "researcher",
453 ToolPolicy::Allowlist {
454 allow: vec!["Read".to_string()],
455 },
456 ));
457 let sessions = sessions_with("other", Some("researcher")).await;
459 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
460
461 let ctx = ToolExecutionContext {
462 session_id: Some("missing"),
463 tool_call_id: "call_1",
464 event_tx: None,
465 available_tool_schemas: None,
466 };
467 exec.execute_with_context(&make_call("Edit"), ctx)
468 .await
469 .unwrap();
470 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
471 }
472
473 #[tokio::test]
474 async fn missing_subagent_type_metadata_falls_through() {
475 let (inner, executed) = RecordingExecutor::new();
476 let registry = registry_with(profile(
477 "researcher",
478 ToolPolicy::Allowlist {
479 allow: vec!["Read".to_string()],
480 },
481 ));
482 let sessions = sessions_with("s1", None).await;
484 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
485
486 let ctx = ToolExecutionContext {
487 session_id: Some("s1"),
488 tool_call_id: "call_1",
489 event_tx: None,
490 available_tool_schemas: None,
491 };
492 exec.execute_with_context(&make_call("Edit"), ctx)
493 .await
494 .unwrap();
495 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
496 }
497
498 #[tokio::test]
499 async fn execute_without_context_forwards() {
500 let (inner, executed) = RecordingExecutor::new();
503 let registry = registry_with(profile(
504 "researcher",
505 ToolPolicy::Allowlist {
506 allow: vec!["Read".to_string()],
507 },
508 ));
509 let sessions = sessions_with("s1", Some("researcher")).await;
510 let exec = PolicyAwareToolExecutor::new(inner, registry, sessions);
511
512 exec.execute(&make_call("Edit")).await.unwrap();
513 assert_eq!(executed.read().await.as_slice(), &["Edit".to_string()]);
514 }
515}