1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
7use crate::registry::ToolDef;
8
9#[derive(Debug)]
34pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
35 first: A,
36 second: B,
37}
38
39impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
40 #[must_use]
42 pub fn new(first: A, second: B) -> Self {
43 Self { first, second }
44 }
45}
46
47impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
48 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
49 if let Some(output) = self.first.execute(response).await? {
50 return Ok(Some(output));
51 }
52 self.second.execute(response).await
53 }
54
55 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
56 if let Some(output) = self.first.execute_confirmed(response).await? {
57 return Ok(Some(output));
58 }
59 self.second.execute_confirmed(response).await
60 }
61
62 fn tool_definitions(&self) -> Vec<ToolDef> {
63 let mut defs = self.first.tool_definitions();
64 let seen: std::collections::HashSet<String> =
65 defs.iter().map(|d| d.id.to_string()).collect();
66 for def in self.second.tool_definitions() {
67 if !seen.contains(def.id.as_ref()) {
68 defs.push(def);
69 }
70 }
71 defs
72 }
73
74 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
75 if let Some(output) = self.first.execute_tool_call(call).await? {
76 return Ok(Some(output));
77 }
78 self.second.execute_tool_call(call).await
79 }
80
81 fn is_tool_retryable(&self, tool_id: &str) -> bool {
82 self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
83 }
84
85 fn is_tool_speculatable(&self, tool_id: &str) -> bool {
86 self.first.is_tool_speculatable(tool_id) || self.second.is_tool_speculatable(tool_id)
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93 use crate::ToolName;
94
95 #[derive(Debug)]
96 struct MatchingExecutor;
97 impl ToolExecutor for MatchingExecutor {
98 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
99 Ok(Some(ToolOutput {
100 tool_name: ToolName::new("test"),
101 summary: "matched".to_owned(),
102 blocks_executed: 1,
103 filter_stats: None,
104 diff: None,
105 streamed: false,
106 terminal_id: None,
107 locations: None,
108 raw_response: None,
109 claim_source: None,
110 }))
111 }
112 }
113
114 #[derive(Debug)]
115 struct NoMatchExecutor;
116 impl ToolExecutor for NoMatchExecutor {
117 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
118 Ok(None)
119 }
120 }
121
122 #[derive(Debug)]
123 struct ErrorExecutor;
124 impl ToolExecutor for ErrorExecutor {
125 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
126 Err(ToolError::Blocked {
127 command: "test".to_owned(),
128 })
129 }
130 }
131
132 #[derive(Debug)]
133 struct SecondExecutor;
134 impl ToolExecutor for SecondExecutor {
135 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
136 Ok(Some(ToolOutput {
137 tool_name: ToolName::new("test"),
138 summary: "second".to_owned(),
139 blocks_executed: 1,
140 filter_stats: None,
141 diff: None,
142 streamed: false,
143 terminal_id: None,
144 locations: None,
145 raw_response: None,
146 claim_source: None,
147 }))
148 }
149 }
150
151 #[tokio::test]
152 async fn first_matches_returns_first() {
153 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
154 let result = composite.execute("anything").await.unwrap();
155 assert_eq!(result.unwrap().summary, "matched");
156 }
157
158 #[tokio::test]
159 async fn first_none_falls_through_to_second() {
160 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
161 let result = composite.execute("anything").await.unwrap();
162 assert_eq!(result.unwrap().summary, "second");
163 }
164
165 #[tokio::test]
166 async fn both_none_returns_none() {
167 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
168 let result = composite.execute("anything").await.unwrap();
169 assert!(result.is_none());
170 }
171
172 #[tokio::test]
173 async fn first_error_propagates_without_trying_second() {
174 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
175 let result = composite.execute("anything").await;
176 assert!(matches!(result, Err(ToolError::Blocked { .. })));
177 }
178
179 #[tokio::test]
180 async fn second_error_propagates_when_first_none() {
181 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
182 let result = composite.execute("anything").await;
183 assert!(matches!(result, Err(ToolError::Blocked { .. })));
184 }
185
186 #[tokio::test]
187 async fn execute_confirmed_first_matches() {
188 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
189 let result = composite.execute_confirmed("anything").await.unwrap();
190 assert_eq!(result.unwrap().summary, "matched");
191 }
192
193 #[tokio::test]
194 async fn execute_confirmed_falls_through() {
195 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
196 let result = composite.execute_confirmed("anything").await.unwrap();
197 assert_eq!(result.unwrap().summary, "second");
198 }
199
200 #[test]
201 fn composite_debug() {
202 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
203 let debug = format!("{composite:?}");
204 assert!(debug.contains("CompositeExecutor"));
205 }
206
207 #[derive(Debug)]
208 struct FileToolExecutor;
209 impl ToolExecutor for FileToolExecutor {
210 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
211 Ok(None)
212 }
213 async fn execute_tool_call(
214 &self,
215 call: &ToolCall,
216 ) -> Result<Option<ToolOutput>, ToolError> {
217 if call.tool_id == "read" || call.tool_id == "write" {
218 Ok(Some(ToolOutput {
219 tool_name: call.tool_id.clone(),
220 summary: "file_handler".to_owned(),
221 blocks_executed: 1,
222 filter_stats: None,
223 diff: None,
224 streamed: false,
225 terminal_id: None,
226 locations: None,
227 raw_response: None,
228 claim_source: None,
229 }))
230 } else {
231 Ok(None)
232 }
233 }
234 }
235
236 #[derive(Debug)]
237 struct ShellToolExecutor;
238 impl ToolExecutor for ShellToolExecutor {
239 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
240 Ok(None)
241 }
242 async fn execute_tool_call(
243 &self,
244 call: &ToolCall,
245 ) -> Result<Option<ToolOutput>, ToolError> {
246 if call.tool_id == "bash" {
247 Ok(Some(ToolOutput {
248 tool_name: ToolName::new("bash"),
249 summary: "shell_handler".to_owned(),
250 blocks_executed: 1,
251 filter_stats: None,
252 diff: None,
253 streamed: false,
254 terminal_id: None,
255 locations: None,
256 raw_response: None,
257 claim_source: None,
258 }))
259 } else {
260 Ok(None)
261 }
262 }
263 }
264
265 #[tokio::test]
266 async fn tool_call_routes_to_file_executor() {
267 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
268 let call = ToolCall {
269 tool_id: ToolName::new("read"),
270 params: serde_json::Map::new(),
271 caller_id: None,
272 context: None,
273 };
274 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
275 assert_eq!(result.summary, "file_handler");
276 }
277
278 #[tokio::test]
279 async fn tool_call_routes_to_shell_executor() {
280 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
281 let call = ToolCall {
282 tool_id: ToolName::new("bash"),
283 params: serde_json::Map::new(),
284 caller_id: None,
285 context: None,
286 };
287 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
288 assert_eq!(result.summary, "shell_handler");
289 }
290
291 #[tokio::test]
292 async fn tool_call_unhandled_returns_none() {
293 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
294 let call = ToolCall {
295 tool_id: ToolName::new("unknown"),
296 params: serde_json::Map::new(),
297 caller_id: None,
298 context: None,
299 };
300 let result = composite.execute_tool_call(&call).await.unwrap();
301 assert!(result.is_none());
302 }
303}