1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13 first: A,
14 second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18 #[must_use]
19 pub fn new(first: A, second: B) -> Self {
20 Self { first, second }
21 }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26 if let Some(output) = self.first.execute(response).await? {
27 return Ok(Some(output));
28 }
29 self.second.execute(response).await
30 }
31
32 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33 if let Some(output) = self.first.execute_confirmed(response).await? {
34 return Ok(Some(output));
35 }
36 self.second.execute_confirmed(response).await
37 }
38
39 fn tool_definitions(&self) -> Vec<ToolDef> {
40 let mut defs = self.first.tool_definitions();
41 defs.extend(self.second.tool_definitions());
42 defs
43 }
44
45 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
46 if let Some(output) = self.first.execute_tool_call(call).await? {
47 return Ok(Some(output));
48 }
49 self.second.execute_tool_call(call).await
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56
57 #[derive(Debug)]
58 struct MatchingExecutor;
59 impl ToolExecutor for MatchingExecutor {
60 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
61 Ok(Some(ToolOutput {
62 tool_name: "test".to_owned(),
63 summary: "matched".to_owned(),
64 blocks_executed: 1,
65 filter_stats: None,
66 diff: None,
67 streamed: false,
68 }))
69 }
70 }
71
72 #[derive(Debug)]
73 struct NoMatchExecutor;
74 impl ToolExecutor for NoMatchExecutor {
75 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
76 Ok(None)
77 }
78 }
79
80 #[derive(Debug)]
81 struct ErrorExecutor;
82 impl ToolExecutor for ErrorExecutor {
83 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
84 Err(ToolError::Blocked {
85 command: "test".to_owned(),
86 })
87 }
88 }
89
90 #[derive(Debug)]
91 struct SecondExecutor;
92 impl ToolExecutor for SecondExecutor {
93 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
94 Ok(Some(ToolOutput {
95 tool_name: "test".to_owned(),
96 summary: "second".to_owned(),
97 blocks_executed: 1,
98 filter_stats: None,
99 diff: None,
100 streamed: false,
101 }))
102 }
103 }
104
105 #[tokio::test]
106 async fn first_matches_returns_first() {
107 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
108 let result = composite.execute("anything").await.unwrap();
109 assert_eq!(result.unwrap().summary, "matched");
110 }
111
112 #[tokio::test]
113 async fn first_none_falls_through_to_second() {
114 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
115 let result = composite.execute("anything").await.unwrap();
116 assert_eq!(result.unwrap().summary, "second");
117 }
118
119 #[tokio::test]
120 async fn both_none_returns_none() {
121 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
122 let result = composite.execute("anything").await.unwrap();
123 assert!(result.is_none());
124 }
125
126 #[tokio::test]
127 async fn first_error_propagates_without_trying_second() {
128 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
129 let result = composite.execute("anything").await;
130 assert!(matches!(result, Err(ToolError::Blocked { .. })));
131 }
132
133 #[tokio::test]
134 async fn second_error_propagates_when_first_none() {
135 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
136 let result = composite.execute("anything").await;
137 assert!(matches!(result, Err(ToolError::Blocked { .. })));
138 }
139
140 #[tokio::test]
141 async fn execute_confirmed_first_matches() {
142 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
143 let result = composite.execute_confirmed("anything").await.unwrap();
144 assert_eq!(result.unwrap().summary, "matched");
145 }
146
147 #[tokio::test]
148 async fn execute_confirmed_falls_through() {
149 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
150 let result = composite.execute_confirmed("anything").await.unwrap();
151 assert_eq!(result.unwrap().summary, "second");
152 }
153
154 #[test]
155 fn composite_debug() {
156 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
157 let debug = format!("{composite:?}");
158 assert!(debug.contains("CompositeExecutor"));
159 }
160
161 #[derive(Debug)]
162 struct FileToolExecutor;
163 impl ToolExecutor for FileToolExecutor {
164 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
165 Ok(None)
166 }
167 async fn execute_tool_call(
168 &self,
169 call: &ToolCall,
170 ) -> Result<Option<ToolOutput>, ToolError> {
171 if call.tool_id == "read" || call.tool_id == "write" {
172 Ok(Some(ToolOutput {
173 tool_name: call.tool_id.clone(),
174 summary: "file_handler".to_owned(),
175 blocks_executed: 1,
176 filter_stats: None,
177 diff: None,
178 streamed: false,
179 }))
180 } else {
181 Ok(None)
182 }
183 }
184 }
185
186 #[derive(Debug)]
187 struct ShellToolExecutor;
188 impl ToolExecutor for ShellToolExecutor {
189 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
190 Ok(None)
191 }
192 async fn execute_tool_call(
193 &self,
194 call: &ToolCall,
195 ) -> Result<Option<ToolOutput>, ToolError> {
196 if call.tool_id == "bash" {
197 Ok(Some(ToolOutput {
198 tool_name: "bash".to_owned(),
199 summary: "shell_handler".to_owned(),
200 blocks_executed: 1,
201 filter_stats: None,
202 diff: None,
203 streamed: false,
204 }))
205 } else {
206 Ok(None)
207 }
208 }
209 }
210
211 #[tokio::test]
212 async fn tool_call_routes_to_file_executor() {
213 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
214 let call = ToolCall {
215 tool_id: "read".to_owned(),
216 params: serde_json::Map::new(),
217 };
218 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
219 assert_eq!(result.summary, "file_handler");
220 }
221
222 #[tokio::test]
223 async fn tool_call_routes_to_shell_executor() {
224 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
225 let call = ToolCall {
226 tool_id: "bash".to_owned(),
227 params: serde_json::Map::new(),
228 };
229 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
230 assert_eq!(result.summary, "shell_handler");
231 }
232
233 #[tokio::test]
234 async fn tool_call_unhandled_returns_none() {
235 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
236 let call = ToolCall {
237 tool_id: "unknown".to_owned(),
238 params: serde_json::Map::new(),
239 };
240 let result = composite.execute_tool_call(&call).await.unwrap();
241 assert!(result.is_none());
242 }
243}