Skip to main content

bamboo_server/server_tools/
overlay_executor.rs

1use async_trait::async_trait;
2
3use bamboo_agent_core::tools::{
4    normalize_tool_name, parse_tool_args_best_effort, Tool, ToolCall, ToolError,
5    ToolExecutionContext, ToolExecutor, ToolResult, ToolSchema,
6};
7use bamboo_tools::normalize_tool_ref;
8
9/// Tool executor that overlays a single tool on top of an existing executor.
10///
11/// This is used to add server-only tools (like `SubSession`) without mutating the
12/// underlying built-in/MCP executor.
13pub struct OverlayToolExecutor {
14    base: std::sync::Arc<dyn ToolExecutor>,
15    overlay: std::sync::Arc<dyn Tool>,
16}
17
18impl OverlayToolExecutor {
19    pub fn new(base: std::sync::Arc<dyn ToolExecutor>, overlay: std::sync::Arc<dyn Tool>) -> Self {
20        Self { base, overlay }
21    }
22}
23
24#[async_trait]
25impl ToolExecutor for OverlayToolExecutor {
26    async fn execute(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
27        self.execute_with_context(call, ToolExecutionContext::none(&call.id))
28            .await
29    }
30
31    async fn execute_with_context(
32        &self,
33        call: &ToolCall,
34        ctx: ToolExecutionContext<'_>,
35    ) -> Result<ToolResult, ToolError> {
36        let name = normalize_tool_name(&call.function.name);
37        let is_overlay_call = name == self.overlay.name()
38            || normalize_tool_ref(name)
39                .as_deref()
40                .is_some_and(|normalized| normalized == self.overlay.name());
41        if is_overlay_call {
42            let args_raw = call.function.arguments.trim();
43            let (args, parse_warning) = parse_tool_args_best_effort(&call.function.arguments);
44            if let Some(warning) = parse_warning {
45                tracing::warn!(
46                    "Overlay tool argument parsing fallback applied: tool_call_id={}, tool_name={}, args_len={}, warning={}",
47                    call.id,
48                    call.function.name,
49                    args_raw.len(),
50                    warning
51                );
52            }
53            return self.overlay.execute_with_context(args, ctx).await;
54        }
55        self.base.execute_with_context(call, ctx).await
56    }
57
58    fn list_tools(&self) -> Vec<ToolSchema> {
59        let mut tools = self.base.list_tools();
60
61        // Ensure overlay tool is present exactly once.
62        let overlay_schema = self.overlay.to_schema();
63        let overlay_name = overlay_schema.function.name.clone();
64        tools.retain(|t| t.function.name != overlay_name);
65        tools.push(overlay_schema);
66
67        tools.sort_by_key(|t| t.function.name.clone());
68        tools
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75
76    use serde_json::json;
77
78    use bamboo_agent_core::tools::FunctionCall;
79
80    struct BaseExecutor;
81
82    #[async_trait]
83    impl ToolExecutor for BaseExecutor {
84        async fn execute(&self, call: &ToolCall) -> Result<ToolResult, ToolError> {
85            Err(ToolError::Execution(format!(
86                "base executor called for {}",
87                call.function.name
88            )))
89        }
90
91        async fn execute_with_context(
92            &self,
93            call: &ToolCall,
94            _ctx: ToolExecutionContext<'_>,
95        ) -> Result<ToolResult, ToolError> {
96            self.execute(call).await
97        }
98
99        fn list_tools(&self) -> Vec<ToolSchema> {
100            Vec::new()
101        }
102    }
103
104    struct SubSessionOverlayTool;
105
106    #[async_trait]
107    impl Tool for SubSessionOverlayTool {
108        fn name(&self) -> &str {
109            "SubSession"
110        }
111
112        fn description(&self) -> &str {
113            "overlay sub session"
114        }
115
116        fn parameters_schema(&self) -> serde_json::Value {
117            json!({"type":"object","properties":{}})
118        }
119
120        async fn execute(&self, _args: serde_json::Value) -> Result<ToolResult, ToolError> {
121            Ok(ToolResult {
122                success: true,
123                result: "overlay".to_string(),
124                display_preference: None,
125            })
126        }
127    }
128
129    fn make_call(name: &str) -> ToolCall {
130        ToolCall {
131            id: "call_1".to_string(),
132            tool_type: "function".to_string(),
133            function: FunctionCall {
134                name: name.to_string(),
135                arguments: "{}".to_string(),
136            },
137        }
138    }
139
140    #[tokio::test]
141    async fn overlay_executor_routes_spawn_alias_to_overlay_tool() {
142        let overlay = OverlayToolExecutor::new(
143            std::sync::Arc::new(BaseExecutor),
144            std::sync::Arc::new(SubSessionOverlayTool),
145        );
146
147        let result = overlay
148            .execute(&make_call("sub_task"))
149            .await
150            .expect("spawn alias should route to overlay");
151
152        assert!(result.success);
153        assert_eq!(result.result, "overlay");
154    }
155
156    #[tokio::test]
157    async fn overlay_executor_keeps_non_overlay_calls_on_base_executor() {
158        let overlay = OverlayToolExecutor::new(
159            std::sync::Arc::new(BaseExecutor),
160            std::sync::Arc::new(SubSessionOverlayTool),
161        );
162
163        let err = overlay
164            .execute(&make_call("Read"))
165            .await
166            .expect_err("non-overlay call should stay on base executor");
167
168        assert!(
169            matches!(err, ToolError::Execution(msg) if msg.contains("base executor called for Read"))
170        );
171    }
172}