Skip to main content

zeph_tools/
composite.rs

1use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
2use crate::registry::ToolDef;
3
4/// Chains two `ToolExecutor` implementations with first-match-wins dispatch.
5///
6/// Tries `first`, falls through to `second` if it returns `Ok(None)`.
7/// Errors from `first` propagate immediately without trying `second`.
8#[derive(Debug)]
9pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
10    first: A,
11    second: B,
12}
13
14impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
15    #[must_use]
16    pub fn new(first: A, second: B) -> Self {
17        Self { first, second }
18    }
19}
20
21impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
22    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
23        if let Some(output) = self.first.execute(response).await? {
24            return Ok(Some(output));
25        }
26        self.second.execute(response).await
27    }
28
29    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
30        if let Some(output) = self.first.execute_confirmed(response).await? {
31            return Ok(Some(output));
32        }
33        self.second.execute_confirmed(response).await
34    }
35
36    fn tool_definitions(&self) -> Vec<ToolDef> {
37        let mut defs = self.first.tool_definitions();
38        defs.extend(self.second.tool_definitions());
39        defs
40    }
41
42    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
43        if let Some(output) = self.first.execute_tool_call(call).await? {
44            return Ok(Some(output));
45        }
46        self.second.execute_tool_call(call).await
47    }
48}
49
50#[cfg(test)]
51mod tests {
52    use super::*;
53
54    #[derive(Debug)]
55    struct MatchingExecutor;
56    impl ToolExecutor for MatchingExecutor {
57        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
58            Ok(Some(ToolOutput {
59                tool_name: "test".to_owned(),
60                summary: "matched".to_owned(),
61                blocks_executed: 1,
62                filter_stats: None,
63                diff: None,
64                streamed: false,
65            }))
66        }
67    }
68
69    #[derive(Debug)]
70    struct NoMatchExecutor;
71    impl ToolExecutor for NoMatchExecutor {
72        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
73            Ok(None)
74        }
75    }
76
77    #[derive(Debug)]
78    struct ErrorExecutor;
79    impl ToolExecutor for ErrorExecutor {
80        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
81            Err(ToolError::Blocked {
82                command: "test".to_owned(),
83            })
84        }
85    }
86
87    #[derive(Debug)]
88    struct SecondExecutor;
89    impl ToolExecutor for SecondExecutor {
90        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
91            Ok(Some(ToolOutput {
92                tool_name: "test".to_owned(),
93                summary: "second".to_owned(),
94                blocks_executed: 1,
95                filter_stats: None,
96                diff: None,
97                streamed: false,
98            }))
99        }
100    }
101
102    #[tokio::test]
103    async fn first_matches_returns_first() {
104        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
105        let result = composite.execute("anything").await.unwrap();
106        assert_eq!(result.unwrap().summary, "matched");
107    }
108
109    #[tokio::test]
110    async fn first_none_falls_through_to_second() {
111        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
112        let result = composite.execute("anything").await.unwrap();
113        assert_eq!(result.unwrap().summary, "second");
114    }
115
116    #[tokio::test]
117    async fn both_none_returns_none() {
118        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
119        let result = composite.execute("anything").await.unwrap();
120        assert!(result.is_none());
121    }
122
123    #[tokio::test]
124    async fn first_error_propagates_without_trying_second() {
125        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
126        let result = composite.execute("anything").await;
127        assert!(matches!(result, Err(ToolError::Blocked { .. })));
128    }
129
130    #[tokio::test]
131    async fn second_error_propagates_when_first_none() {
132        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
133        let result = composite.execute("anything").await;
134        assert!(matches!(result, Err(ToolError::Blocked { .. })));
135    }
136
137    #[tokio::test]
138    async fn execute_confirmed_first_matches() {
139        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
140        let result = composite.execute_confirmed("anything").await.unwrap();
141        assert_eq!(result.unwrap().summary, "matched");
142    }
143
144    #[tokio::test]
145    async fn execute_confirmed_falls_through() {
146        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
147        let result = composite.execute_confirmed("anything").await.unwrap();
148        assert_eq!(result.unwrap().summary, "second");
149    }
150
151    #[test]
152    fn composite_debug() {
153        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
154        let debug = format!("{composite:?}");
155        assert!(debug.contains("CompositeExecutor"));
156    }
157
158    #[derive(Debug)]
159    struct FileToolExecutor;
160    impl ToolExecutor for FileToolExecutor {
161        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
162            Ok(None)
163        }
164        async fn execute_tool_call(
165            &self,
166            call: &ToolCall,
167        ) -> Result<Option<ToolOutput>, ToolError> {
168            if call.tool_id == "read" || call.tool_id == "write" {
169                Ok(Some(ToolOutput {
170                    tool_name: call.tool_id.clone(),
171                    summary: "file_handler".to_owned(),
172                    blocks_executed: 1,
173                    filter_stats: None,
174                    diff: None,
175                    streamed: false,
176                }))
177            } else {
178                Ok(None)
179            }
180        }
181    }
182
183    #[derive(Debug)]
184    struct ShellToolExecutor;
185    impl ToolExecutor for ShellToolExecutor {
186        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
187            Ok(None)
188        }
189        async fn execute_tool_call(
190            &self,
191            call: &ToolCall,
192        ) -> Result<Option<ToolOutput>, ToolError> {
193            if call.tool_id == "bash" {
194                Ok(Some(ToolOutput {
195                    tool_name: "bash".to_owned(),
196                    summary: "shell_handler".to_owned(),
197                    blocks_executed: 1,
198                    filter_stats: None,
199                    diff: None,
200                    streamed: false,
201                }))
202            } else {
203                Ok(None)
204            }
205        }
206    }
207
208    #[tokio::test]
209    async fn tool_call_routes_to_file_executor() {
210        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
211        let call = ToolCall {
212            tool_id: "read".to_owned(),
213            params: serde_json::Map::new(),
214        };
215        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
216        assert_eq!(result.summary, "file_handler");
217    }
218
219    #[tokio::test]
220    async fn tool_call_routes_to_shell_executor() {
221        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
222        let call = ToolCall {
223            tool_id: "bash".to_owned(),
224            params: serde_json::Map::new(),
225        };
226        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
227        assert_eq!(result.summary, "shell_handler");
228    }
229
230    #[tokio::test]
231    async fn tool_call_unhandled_returns_none() {
232        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
233        let call = ToolCall {
234            tool_id: "unknown".to_owned(),
235            params: serde_json::Map::new(),
236        };
237        let result = composite.execute_tool_call(&call).await.unwrap();
238        assert!(result.is_none());
239    }
240}