1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
7use crate::registry::ToolDef;
8
9#[derive(Debug)]
34pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
35 first: A,
36 second: B,
37}
38
39impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
40 #[must_use]
42 pub fn new(first: A, second: B) -> Self {
43 Self { first, second }
44 }
45}
46
47impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
48 async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
49 if let Some(output) = self.first.execute(response).await? {
50 return Ok(Some(output));
51 }
52 self.second.execute(response).await
53 }
54
55 async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
56 if let Some(output) = self.first.execute_confirmed(response).await? {
57 return Ok(Some(output));
58 }
59 self.second.execute_confirmed(response).await
60 }
61
62 fn tool_definitions(&self) -> Vec<ToolDef> {
63 let mut defs = self.first.tool_definitions();
64 let seen: std::collections::HashSet<String> =
65 defs.iter().map(|d| d.id.to_string()).collect();
66 for def in self.second.tool_definitions() {
67 if !seen.contains(def.id.as_ref()) {
68 defs.push(def);
69 }
70 }
71 defs
72 }
73
74 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
75 if let Some(output) = self.first.execute_tool_call(call).await? {
76 return Ok(Some(output));
77 }
78 self.second.execute_tool_call(call).await
79 }
80
81 fn is_tool_retryable(&self, tool_id: &str) -> bool {
82 self.first.is_tool_retryable(tool_id) || self.second.is_tool_retryable(tool_id)
83 }
84
85 fn is_tool_speculatable(&self, tool_id: &str) -> bool {
86 self.first.is_tool_speculatable(tool_id) || self.second.is_tool_speculatable(tool_id)
87 }
88
89 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
97 self.first.set_skill_env(env.clone());
98 self.second.set_skill_env(env);
99 }
100
101 fn set_effective_trust(&self, level: crate::SkillTrustLevel) {
107 self.first.set_effective_trust(level);
108 self.second.set_effective_trust(level);
109 }
110}
111
112#[cfg(test)]
113mod tests {
114 use super::*;
115 use crate::ToolName;
116
117 #[derive(Debug)]
118 struct MatchingExecutor;
119 impl ToolExecutor for MatchingExecutor {
120 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
121 Ok(Some(ToolOutput {
122 tool_name: ToolName::new("test"),
123 summary: "matched".to_owned(),
124 blocks_executed: 1,
125 filter_stats: None,
126 diff: None,
127 streamed: false,
128 terminal_id: None,
129 locations: None,
130 raw_response: None,
131 claim_source: None,
132 }))
133 }
134 }
135
136 #[derive(Debug)]
137 struct NoMatchExecutor;
138 impl ToolExecutor for NoMatchExecutor {
139 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
140 Ok(None)
141 }
142 }
143
144 #[derive(Debug)]
145 struct ErrorExecutor;
146 impl ToolExecutor for ErrorExecutor {
147 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
148 Err(ToolError::Blocked {
149 command: "test".to_owned(),
150 })
151 }
152 }
153
154 #[derive(Debug)]
155 struct SecondExecutor;
156 impl ToolExecutor for SecondExecutor {
157 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
158 Ok(Some(ToolOutput {
159 tool_name: ToolName::new("test"),
160 summary: "second".to_owned(),
161 blocks_executed: 1,
162 filter_stats: None,
163 diff: None,
164 streamed: false,
165 terminal_id: None,
166 locations: None,
167 raw_response: None,
168 claim_source: None,
169 }))
170 }
171 }
172
173 #[tokio::test]
174 async fn first_matches_returns_first() {
175 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
176 let result = composite.execute("anything").await.unwrap();
177 assert_eq!(result.unwrap().summary, "matched");
178 }
179
180 #[tokio::test]
181 async fn first_none_falls_through_to_second() {
182 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
183 let result = composite.execute("anything").await.unwrap();
184 assert_eq!(result.unwrap().summary, "second");
185 }
186
187 #[tokio::test]
188 async fn both_none_returns_none() {
189 let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
190 let result = composite.execute("anything").await.unwrap();
191 assert!(result.is_none());
192 }
193
194 #[tokio::test]
195 async fn first_error_propagates_without_trying_second() {
196 let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
197 let result = composite.execute("anything").await;
198 assert!(matches!(result, Err(ToolError::Blocked { .. })));
199 }
200
201 #[tokio::test]
202 async fn second_error_propagates_when_first_none() {
203 let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
204 let result = composite.execute("anything").await;
205 assert!(matches!(result, Err(ToolError::Blocked { .. })));
206 }
207
208 #[tokio::test]
209 async fn execute_confirmed_first_matches() {
210 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
211 let result = composite.execute_confirmed("anything").await.unwrap();
212 assert_eq!(result.unwrap().summary, "matched");
213 }
214
215 #[tokio::test]
216 async fn execute_confirmed_falls_through() {
217 let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
218 let result = composite.execute_confirmed("anything").await.unwrap();
219 assert_eq!(result.unwrap().summary, "second");
220 }
221
222 #[test]
223 fn composite_debug() {
224 let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
225 let debug = format!("{composite:?}");
226 assert!(debug.contains("CompositeExecutor"));
227 }
228
229 #[derive(Debug)]
230 struct FileToolExecutor;
231 impl ToolExecutor for FileToolExecutor {
232 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
233 Ok(None)
234 }
235 async fn execute_tool_call(
236 &self,
237 call: &ToolCall,
238 ) -> Result<Option<ToolOutput>, ToolError> {
239 if call.tool_id == "read" || call.tool_id == "write" {
240 Ok(Some(ToolOutput {
241 tool_name: call.tool_id.clone(),
242 summary: "file_handler".to_owned(),
243 blocks_executed: 1,
244 filter_stats: None,
245 diff: None,
246 streamed: false,
247 terminal_id: None,
248 locations: None,
249 raw_response: None,
250 claim_source: None,
251 }))
252 } else {
253 Ok(None)
254 }
255 }
256 }
257
258 #[derive(Debug)]
259 struct ShellToolExecutor;
260 impl ToolExecutor for ShellToolExecutor {
261 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
262 Ok(None)
263 }
264 async fn execute_tool_call(
265 &self,
266 call: &ToolCall,
267 ) -> Result<Option<ToolOutput>, ToolError> {
268 if call.tool_id == "bash" {
269 Ok(Some(ToolOutput {
270 tool_name: ToolName::new("bash"),
271 summary: "shell_handler".to_owned(),
272 blocks_executed: 1,
273 filter_stats: None,
274 diff: None,
275 streamed: false,
276 terminal_id: None,
277 locations: None,
278 raw_response: None,
279 claim_source: None,
280 }))
281 } else {
282 Ok(None)
283 }
284 }
285 }
286
287 #[tokio::test]
288 async fn tool_call_routes_to_file_executor() {
289 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
290 let call = ToolCall {
291 tool_id: ToolName::new("read"),
292 params: serde_json::Map::new(),
293 caller_id: None,
294 context: None,
295
296 tool_call_id: String::new(),
297 skill_name: None,
298 };
299 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
300 assert_eq!(result.summary, "file_handler");
301 }
302
303 #[tokio::test]
304 async fn tool_call_routes_to_shell_executor() {
305 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
306 let call = ToolCall {
307 tool_id: ToolName::new("bash"),
308 params: serde_json::Map::new(),
309 caller_id: None,
310 context: None,
311
312 tool_call_id: String::new(),
313 skill_name: None,
314 };
315 let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
316 assert_eq!(result.summary, "shell_handler");
317 }
318
319 #[tokio::test]
320 async fn tool_call_unhandled_returns_none() {
321 let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
322 let call = ToolCall {
323 tool_id: ToolName::new("unknown"),
324 params: serde_json::Map::new(),
325 caller_id: None,
326 context: None,
327
328 tool_call_id: String::new(),
329 skill_name: None,
330 };
331 let result = composite.execute_tool_call(&call).await.unwrap();
332 assert!(result.is_none());
333 }
334
335 mod state_forwarding {
341 use super::*;
342 use crate::SkillTrustLevel;
343 use std::sync::Mutex;
344
345 #[derive(Debug, Default)]
346 struct SpyExecutor {
347 last_env: Mutex<Option<std::collections::HashMap<String, String>>>,
348 last_trust: Mutex<Option<SkillTrustLevel>>,
349 }
350 impl ToolExecutor for SpyExecutor {
351 async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
352 Ok(None)
353 }
354 fn set_skill_env(&self, env: Option<std::collections::HashMap<String, String>>) {
355 *self.last_env.lock().unwrap() = env;
356 }
357 fn set_effective_trust(&self, level: SkillTrustLevel) {
358 *self.last_trust.lock().unwrap() = Some(level);
359 }
360 }
361
362 #[test]
363 fn set_skill_env_reaches_both_inner_executors_in_nested_composition() {
364 let leaf_a = SpyExecutor::default();
367 let leaf_b = SpyExecutor::default();
368 let leaf_c = SpyExecutor::default();
369 let nested = CompositeExecutor::new(leaf_a, leaf_b);
370 let outer = CompositeExecutor::new(nested, leaf_c);
371
372 let mut env = std::collections::HashMap::new();
373 env.insert("GITHUB_TOKEN".to_owned(), "tok".to_owned());
374 outer.set_skill_env(Some(env.clone()));
375
376 assert_eq!(
378 outer.first.first.last_env.lock().unwrap().as_ref(),
379 Some(&env)
380 );
381 assert_eq!(
383 outer.first.second.last_env.lock().unwrap().as_ref(),
384 Some(&env)
385 );
386 assert_eq!(outer.second.last_env.lock().unwrap().as_ref(), Some(&env));
388 }
389
390 #[test]
391 fn set_effective_trust_reaches_both_inner_executors_in_nested_composition() {
392 let leaf_a = SpyExecutor::default();
393 let leaf_b = SpyExecutor::default();
394 let outer = CompositeExecutor::new(leaf_a, leaf_b);
395
396 outer.set_effective_trust(SkillTrustLevel::Quarantined);
397
398 assert_eq!(
399 *outer.first.last_trust.lock().unwrap(),
400 Some(SkillTrustLevel::Quarantined)
401 );
402 assert_eq!(
403 *outer.second.last_trust.lock().unwrap(),
404 Some(SkillTrustLevel::Quarantined)
405 );
406 }
407 }
408}