Skip to main content

zeph_tools/
composite.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput};
5use crate::registry::ToolDef;
6
7/// Chains two `ToolExecutor` implementations with first-match-wins dispatch.
8///
9/// Tries `first`, falls through to `second` if it returns `Ok(None)`.
10/// Errors from `first` propagate immediately without trying `second`.
11#[derive(Debug)]
12pub struct CompositeExecutor<A: ToolExecutor, B: ToolExecutor> {
13    first: A,
14    second: B,
15}
16
17impl<A: ToolExecutor, B: ToolExecutor> CompositeExecutor<A, B> {
18    #[must_use]
19    pub fn new(first: A, second: B) -> Self {
20        Self { first, second }
21    }
22}
23
24impl<A: ToolExecutor, B: ToolExecutor> ToolExecutor for CompositeExecutor<A, B> {
25    async fn execute(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
26        if let Some(output) = self.first.execute(response).await? {
27            return Ok(Some(output));
28        }
29        self.second.execute(response).await
30    }
31
32    async fn execute_confirmed(&self, response: &str) -> Result<Option<ToolOutput>, ToolError> {
33        if let Some(output) = self.first.execute_confirmed(response).await? {
34            return Ok(Some(output));
35        }
36        self.second.execute_confirmed(response).await
37    }
38
39    fn tool_definitions(&self) -> Vec<ToolDef> {
40        let mut defs = self.first.tool_definitions();
41        let seen: std::collections::HashSet<String> =
42            defs.iter().map(|d| d.id.to_string()).collect();
43        for def in self.second.tool_definitions() {
44            if !seen.contains(def.id.as_ref()) {
45                defs.push(def);
46            }
47        }
48        defs
49    }
50
51    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
52        if let Some(output) = self.first.execute_tool_call(call).await? {
53            return Ok(Some(output));
54        }
55        self.second.execute_tool_call(call).await
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62
63    #[derive(Debug)]
64    struct MatchingExecutor;
65    impl ToolExecutor for MatchingExecutor {
66        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
67            Ok(Some(ToolOutput {
68                tool_name: "test".to_owned(),
69                summary: "matched".to_owned(),
70                blocks_executed: 1,
71                filter_stats: None,
72                diff: None,
73                streamed: false,
74                terminal_id: None,
75                locations: None,
76                raw_response: None,
77            }))
78        }
79    }
80
81    #[derive(Debug)]
82    struct NoMatchExecutor;
83    impl ToolExecutor for NoMatchExecutor {
84        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
85            Ok(None)
86        }
87    }
88
89    #[derive(Debug)]
90    struct ErrorExecutor;
91    impl ToolExecutor for ErrorExecutor {
92        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
93            Err(ToolError::Blocked {
94                command: "test".to_owned(),
95            })
96        }
97    }
98
99    #[derive(Debug)]
100    struct SecondExecutor;
101    impl ToolExecutor for SecondExecutor {
102        async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
103            Ok(Some(ToolOutput {
104                tool_name: "test".to_owned(),
105                summary: "second".to_owned(),
106                blocks_executed: 1,
107                filter_stats: None,
108                diff: None,
109                streamed: false,
110                terminal_id: None,
111                locations: None,
112                raw_response: None,
113            }))
114        }
115    }
116
117    #[tokio::test]
118    async fn first_matches_returns_first() {
119        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
120        let result = composite.execute("anything").await.unwrap();
121        assert_eq!(result.unwrap().summary, "matched");
122    }
123
124    #[tokio::test]
125    async fn first_none_falls_through_to_second() {
126        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
127        let result = composite.execute("anything").await.unwrap();
128        assert_eq!(result.unwrap().summary, "second");
129    }
130
131    #[tokio::test]
132    async fn both_none_returns_none() {
133        let composite = CompositeExecutor::new(NoMatchExecutor, NoMatchExecutor);
134        let result = composite.execute("anything").await.unwrap();
135        assert!(result.is_none());
136    }
137
138    #[tokio::test]
139    async fn first_error_propagates_without_trying_second() {
140        let composite = CompositeExecutor::new(ErrorExecutor, SecondExecutor);
141        let result = composite.execute("anything").await;
142        assert!(matches!(result, Err(ToolError::Blocked { .. })));
143    }
144
145    #[tokio::test]
146    async fn second_error_propagates_when_first_none() {
147        let composite = CompositeExecutor::new(NoMatchExecutor, ErrorExecutor);
148        let result = composite.execute("anything").await;
149        assert!(matches!(result, Err(ToolError::Blocked { .. })));
150    }
151
152    #[tokio::test]
153    async fn execute_confirmed_first_matches() {
154        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
155        let result = composite.execute_confirmed("anything").await.unwrap();
156        assert_eq!(result.unwrap().summary, "matched");
157    }
158
159    #[tokio::test]
160    async fn execute_confirmed_falls_through() {
161        let composite = CompositeExecutor::new(NoMatchExecutor, SecondExecutor);
162        let result = composite.execute_confirmed("anything").await.unwrap();
163        assert_eq!(result.unwrap().summary, "second");
164    }
165
166    #[test]
167    fn composite_debug() {
168        let composite = CompositeExecutor::new(MatchingExecutor, SecondExecutor);
169        let debug = format!("{composite:?}");
170        assert!(debug.contains("CompositeExecutor"));
171    }
172
173    #[derive(Debug)]
174    struct FileToolExecutor;
175    impl ToolExecutor for FileToolExecutor {
176        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
177            Ok(None)
178        }
179        async fn execute_tool_call(
180            &self,
181            call: &ToolCall,
182        ) -> Result<Option<ToolOutput>, ToolError> {
183            if call.tool_id == "read" || call.tool_id == "write" {
184                Ok(Some(ToolOutput {
185                    tool_name: call.tool_id.clone(),
186                    summary: "file_handler".to_owned(),
187                    blocks_executed: 1,
188                    filter_stats: None,
189                    diff: None,
190                    streamed: false,
191                    terminal_id: None,
192                    locations: None,
193                    raw_response: None,
194                }))
195            } else {
196                Ok(None)
197            }
198        }
199    }
200
201    #[derive(Debug)]
202    struct ShellToolExecutor;
203    impl ToolExecutor for ShellToolExecutor {
204        async fn execute(&self, _: &str) -> Result<Option<ToolOutput>, ToolError> {
205            Ok(None)
206        }
207        async fn execute_tool_call(
208            &self,
209            call: &ToolCall,
210        ) -> Result<Option<ToolOutput>, ToolError> {
211            if call.tool_id == "bash" {
212                Ok(Some(ToolOutput {
213                    tool_name: "bash".to_owned(),
214                    summary: "shell_handler".to_owned(),
215                    blocks_executed: 1,
216                    filter_stats: None,
217                    diff: None,
218                    streamed: false,
219                    terminal_id: None,
220                    locations: None,
221                    raw_response: None,
222                }))
223            } else {
224                Ok(None)
225            }
226        }
227    }
228
229    #[tokio::test]
230    async fn tool_call_routes_to_file_executor() {
231        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
232        let call = ToolCall {
233            tool_id: "read".to_owned(),
234            params: serde_json::Map::new(),
235        };
236        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
237        assert_eq!(result.summary, "file_handler");
238    }
239
240    #[tokio::test]
241    async fn tool_call_routes_to_shell_executor() {
242        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
243        let call = ToolCall {
244            tool_id: "bash".to_owned(),
245            params: serde_json::Map::new(),
246        };
247        let result = composite.execute_tool_call(&call).await.unwrap().unwrap();
248        assert_eq!(result.summary, "shell_handler");
249    }
250
251    #[tokio::test]
252    async fn tool_call_unhandled_returns_none() {
253        let composite = CompositeExecutor::new(FileToolExecutor, ShellToolExecutor);
254        let call = ToolCall {
255            tool_id: "unknown".to_owned(),
256            params: serde_json::Map::new(),
257        };
258        let result = composite.execute_tool_call(&call).await.unwrap();
259        assert!(result.is_none());
260    }
261}