1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13 first: A,
14 second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18 #[must_use]
19 pub fn new(first: A, second: B) -> Self {
20 Self { first, second }
21 }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26 if let Some(output) = self.first.execute(response).await? {
27 return Ok(Some(output));
28 }
29 self.second.execute(response).await
30 }
31
32 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33 if let Some(output) = self.first.execute_confirmed(response).await? {
34 return Ok(Some(output));
35 }
36 self.second.execute_confirmed(response).await
37 }
38
39 fn tool_definitions(&self) -> Vec<ToolDef> {
40 let mut defs = self.first.tool_definitions();
41 let seen: std::collections::HashSet<String> =
42 defs.iter().map(|d| d.id.to_string()).collect();
43 for def in self.second.tool_definitions() {
44 if !seen.contains(def.id.as_ref()) {
45 defs.push(def);
46 }
47 }
48 defs
49 }
50
51 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52 if let Some(output) = self.first.execute_tool_call(call).await? {
53 return Ok(Some(output));
54 }
55 self.second.execute_tool_call(call).await
56 }
57
58 fn is_tool_retryable(&self, tool_id: &str) -> bool {
59 self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
60 }
61}
62
63#[cfg(test)]
64mod tests {
65 use super::*;
66
67 #[derive(Debug)]
68 struct MatchingExecutor;
69 impl ToolExecutor for MatchingExecutor {
70 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
71 Ok(Some(ToolOutput {
72 tool_name: "test".to_owned(),
73 summary: "matched".to_owned(),
74 blocks_executed: 1,
75 filter_stats: None,
76 diff: None,
77 streamed: false,
78 terminal_id: None,
79 locations: None,
80 raw_response: None,
81 }))
82 }
83 }
84
85 #[derive(Debug)]
86 struct NoMatchExecutor;
87 impl ToolExecutor for NoMatchExecutor {
88 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
89 Ok(None)
90 }
91 }
92
93 #[derive(Debug)]
94 struct ErrorExecutor;
95 impl ToolExecutor for ErrorExecutor {
96 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
97 Err(ToolError::Blocked {
98 command: "test".to_owned(),
99 })
100 }
101 }
102
103 #[derive(Debug)]
104 struct SecondExecutor;
105 impl ToolExecutor for SecondExecutor {
106 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
107 Ok(Some(ToolOutput {
108 tool_name: "test".to_owned(),
109 summary: "second".to_owned(),
110 blocks_executed: 1,
111 filter_stats: None,
112 diff: None,
113 streamed: false,
114 terminal_id: None,
115 locations: None,
116 raw_response: None,
117 }))
118 }
119 }
120
121 #[tokio::test]
122 async fn first_matches_returns_first() {
123 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
124 let result = composite.execute("anything").await.unwrap();
125 assert_eq!(result.unwrap().summary, "matched");
126 }
127
128 #[tokio::test]
129 async fn first_none_falls_through_to_second() {
130 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
131 let result = composite.execute("anything").await.unwrap();
132 assert_eq!(result.unwrap().summary, "second");
133 }
134
135 #[tokio::test]
136 async fn both_none_returns_none() {
137 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
138 let result = composite.execute("anything").await.unwrap();
139 assert!(result.is_none());
140 }
141
142 #[tokio::test]
143 async fn first_error_propagates_without_trying_second() {
144 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
145 let result = composite.execute("anything").await;
146 assert!(matches!(result, Err(ToolError::Blocked { .. })));
147 }
148
149 #[tokio::test]
150 async fn second_error_propagates_when_first_none() {
151 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
152 let result = composite.execute("anything").await;
153 assert!(matches!(result, Err(ToolError::Blocked { .. })));
154 }
155
156 #[tokio::test]
157 async fn execute_confirmed_first_matches() {
158 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
159 let result = composite.execute_confirmed("anything").await.unwrap();
160 assert_eq!(result.unwrap().summary, "matched");
161 }
162
163 #[tokio::test]
164 async fn execute_confirmed_falls_through() {
165 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
166 let result = composite.execute_confirmed("anything").await.unwrap();
167 assert_eq!(result.unwrap().summary, "second");
168 }
169
170 #[test]
171 fn composite_debug() {
172 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
173 let debug = format!("{composite:?}");
174 assert!(debug.contains("CompositeExecutor"));
175 }
176
177 #[derive(Debug)]
178 struct FileToolExecutor;
179 impl ToolExecutor for FileToolExecutor {
180 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
181 Ok(None)
182 }
183 async fn execute_tool_call(
184 &self,
185 call: &ToolCall,
186 ) -> Result<Option<ToolOutput>, ToolError> {
187 if call.tool_id == "read" || call.tool_id == "write" {
188 Ok(Some(ToolOutput {
189 tool_name: call.tool_id.clone(),
190 summary: "file_handler".to_owned(),
191 blocks_executed: 1,
192 filter_stats: None,
193 diff: None,
194 streamed: false,
195 terminal_id: None,
196 locations: None,
197 raw_response: None,
198 }))
199 } else {
200 Ok(None)
201 }
202 }
203 }
204
205 #[derive(Debug)]
206 struct ShellToolExecutor;
207 impl ToolExecutor for ShellToolExecutor {
208 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
209 Ok(None)
210 }
211 async fn execute_tool_call(
212 &self,
213 call: &ToolCall,
214 ) -> Result<Option<ToolOutput>, ToolError> {
215 if call.tool_id == "bash" {
216 Ok(Some(ToolOutput {
217 tool_name: "bash".to_owned(),
218 summary: "shell_handler".to_owned(),
219 blocks_executed: 1,
220 filter_stats: None,
221 diff: None,
222 streamed: false,
223 terminal_id: None,
224 locations: None,
225 raw_response: None,
226 }))
227 } else {
228 Ok(None)
229 }
230 }
231 }
232
233 #[tokio::test]
234 async fn tool_call_routes_to_file_executor() {
235 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
236 let call = ToolCall {
237 tool_id: "read".to_owned(),
238 params: serde_json::Map::new(),
239 };
240 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
241 assert_eq!(result.summary, "file_handler");
242 }
243
244 #[tokio::test]
245 async fn tool_call_routes_to_shell_executor() {
246 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
247 let call = ToolCall {
248 tool_id: "bash".to_owned(),
249 params: serde_json::Map::new(),
250 };
251 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
252 assert_eq!(result.summary, "shell_handler");
253 }
254
255 #[tokio::test]
256 async fn tool_call_unhandled_returns_none() {
257 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
258 let call = ToolCall {
259 tool_id: "unknown".to_owned(),
260 params: serde_json::Map::new(),
261 };
262 let result = composite.execute_tool_call(&call).await.unwrap();
263 assert!(result.is_none());
264 }
265}