1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
2use crate::registry::ToolDef;
3
4#[derive(Debug)]
9pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
10 first: A,
11 second: B,
12}
13
14impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
15 #[must_use]
16 pub fn new(first: A, second: B) -> Self {
17 Self { first, second }
18 }
19}
20
21impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
22 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
23 if let Some(output) = self.first.execute(response).await? {
24 return Ok(Some(output));
25 }
26 self.second.execute(response).await
27 }
28
29 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
30 if let Some(output) = self.first.execute_confirmed(response).await? {
31 return Ok(Some(output));
32 }
33 self.second.execute_confirmed(response).await
34 }
35
36 fn tool_definitions(&self) -> Vec<ToolDef> {
37 let mut defs = self.first.tool_definitions();
38 defs.extend(self.second.tool_definitions());
39 defs
40 }
41
42 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
43 if let Some(output) = self.first.execute_tool_call(call).await? {
44 return Ok(Some(output));
45 }
46 self.second.execute_tool_call(call).await
47 }
48}
49
50#[cfg(test)]
51mod tests {
52 use super::*;
53
54 #[derive(Debug)]
55 struct MatchingExecutor;
56 impl ToolExecutor for MatchingExecutor {
57 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
58 Ok(Some(ToolOutput {
59 tool_name: "test".to_owned(),
60 summary: "matched".to_owned(),
61 blocks_executed: 1,
62 filter_stats: None,
63 diff: None,
64 streamed: false,
65 }))
66 }
67 }
68
69 #[derive(Debug)]
70 struct NoMatchExecutor;
71 impl ToolExecutor for NoMatchExecutor {
72 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
73 Ok(None)
74 }
75 }
76
77 #[derive(Debug)]
78 struct ErrorExecutor;
79 impl ToolExecutor for ErrorExecutor {
80 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
81 Err(ToolError::Blocked {
82 command: "test".to_owned(),
83 })
84 }
85 }
86
87 #[derive(Debug)]
88 struct SecondExecutor;
89 impl ToolExecutor for SecondExecutor {
90 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
91 Ok(Some(ToolOutput {
92 tool_name: "test".to_owned(),
93 summary: "second".to_owned(),
94 blocks_executed: 1,
95 filter_stats: None,
96 diff: None,
97 streamed: false,
98 }))
99 }
100 }
101
102 #[tokio::test]
103 async fn first_matches_returns_first() {
104 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
105 let result = composite.execute("anything").await.unwrap();
106 assert_eq!(result.unwrap().summary, "matched");
107 }
108
109 #[tokio::test]
110 async fn first_none_falls_through_to_second() {
111 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
112 let result = composite.execute("anything").await.unwrap();
113 assert_eq!(result.unwrap().summary, "second");
114 }
115
116 #[tokio::test]
117 async fn both_none_returns_none() {
118 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
119 let result = composite.execute("anything").await.unwrap();
120 assert!(result.is_none());
121 }
122
123 #[tokio::test]
124 async fn first_error_propagates_without_trying_second() {
125 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
126 let result = composite.execute("anything").await;
127 assert!(matches!(result, Err(ToolError::Blocked { .. })));
128 }
129
130 #[tokio::test]
131 async fn second_error_propagates_when_first_none() {
132 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
133 let result = composite.execute("anything").await;
134 assert!(matches!(result, Err(ToolError::Blocked { .. })));
135 }
136
137 #[tokio::test]
138 async fn execute_confirmed_first_matches() {
139 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
140 let result = composite.execute_confirmed("anything").await.unwrap();
141 assert_eq!(result.unwrap().summary, "matched");
142 }
143
144 #[tokio::test]
145 async fn execute_confirmed_falls_through() {
146 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
147 let result = composite.execute_confirmed("anything").await.unwrap();
148 assert_eq!(result.unwrap().summary, "second");
149 }
150
151 #[test]
152 fn composite_debug() {
153 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
154 let debug = format!("{composite:?}");
155 assert!(debug.contains("CompositeExecutor"));
156 }
157
158 #[derive(Debug)]
159 struct FileToolExecutor;
160 impl ToolExecutor for FileToolExecutor {
161 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
162 Ok(None)
163 }
164 async fn execute_tool_call(
165 &self,
166 call: &ToolCall,
167 ) -> Result<Option<ToolOutput>, ToolError> {
168 if call.tool_id == "read" || call.tool_id == "write" {
169 Ok(Some(ToolOutput {
170 tool_name: call.tool_id.clone(),
171 summary: "file_handler".to_owned(),
172 blocks_executed: 1,
173 filter_stats: None,
174 diff: None,
175 streamed: false,
176 }))
177 } else {
178 Ok(None)
179 }
180 }
181 }
182
183 #[derive(Debug)]
184 struct ShellToolExecutor;
185 impl ToolExecutor for ShellToolExecutor {
186 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
187 Ok(None)
188 }
189 async fn execute_tool_call(
190 &self,
191 call: &ToolCall,
192 ) -> Result<Option<ToolOutput>, ToolError> {
193 if call.tool_id == "bash" {
194 Ok(Some(ToolOutput {
195 tool_name: "bash".to_owned(),
196 summary: "shell_handler".to_owned(),
197 blocks_executed: 1,
198 filter_stats: None,
199 diff: None,
200 streamed: false,
201 }))
202 } else {
203 Ok(None)
204 }
205 }
206 }
207
208 #[tokio::test]
209 async fn tool_call_routes_to_file_executor() {
210 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
211 let call = ToolCall {
212 tool_id: "read".to_owned(),
213 params: serde_json::Map::new(),
214 };
215 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
216 assert_eq!(result.summary, "file_handler");
217 }
218
219 #[tokio::test]
220 async fn tool_call_routes_to_shell_executor() {
221 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
222 let call = ToolCall {
223 tool_id: "bash".to_owned(),
224 params: serde_json::Map::new(),
225 };
226 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
227 assert_eq!(result.summary, "shell_handler");
228 }
229
230 #[tokio::test]
231 async fn tool_call_unhandled_returns_none() {
232 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
233 let call = ToolCall {
234 tool_id: "unknown".to_owned(),
235 params: serde_json::Map::new(),
236 };
237 let result = composite.execute_tool_call(&call).await.unwrap();
238 assert!(result.is_none());
239 }
240}