Skip to main content

forge_client/
router.rs

1//! Router dispatcher for routing tool calls to the correct downstream MCP client.
2
3use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use forge_error::DispatchError;
7use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
8use serde_json::Value;
9
10/// A [`ToolDispatcher`] that routes `call_tool(server, tool, args)` to the
11/// correct downstream MCP client based on server name.
12///
13/// Validates both server names and tool names before dispatching, returning
14/// [`DispatchError::ServerNotFound`] or [`DispatchError::ToolNotFound`] with
15/// fuzzy-match suggestions when appropriate.
16pub struct RouterDispatcher {
17    clients: HashMap<String, Arc<dyn ToolDispatcher>>,
18    /// Known tool names per server, for pre-dispatch validation.
19    known_tools: HashMap<String, HashSet<String>>,
20}
21
22impl RouterDispatcher {
23    /// Create a new empty router.
24    pub fn new() -> Self {
25        Self {
26            clients: HashMap::new(),
27            known_tools: HashMap::new(),
28        }
29    }
30
31    /// Register a dispatcher for a named server.
32    pub fn add_client(&mut self, name: impl Into<String>, client: Arc<dyn ToolDispatcher>) {
33        let name = name.into();
34        self.clients.insert(name.clone(), client);
35        // Ensure a tools entry exists even if no tools are registered yet
36        self.known_tools.entry(name).or_default();
37    }
38
39    /// Register known tool names for a server (for pre-dispatch validation).
40    pub fn set_known_tools(
41        &mut self,
42        server: impl Into<String>,
43        tools: impl IntoIterator<Item = String>,
44    ) {
45        self.known_tools
46            .insert(server.into(), tools.into_iter().collect());
47    }
48
49    /// List all registered server names.
50    pub fn server_names(&self) -> Vec<&str> {
51        let mut names: Vec<&str> = self.clients.keys().map(|s| s.as_str()).collect();
52        names.sort();
53        names
54    }
55
56    /// Number of registered servers.
57    pub fn server_count(&self) -> usize {
58        self.clients.len()
59    }
60}
61
62impl Default for RouterDispatcher {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68#[async_trait::async_trait]
69impl ToolDispatcher for RouterDispatcher {
70    #[tracing::instrument(skip(self, args))]
71    async fn call_tool(
72        &self,
73        server: &str,
74        tool: &str,
75        args: Value,
76    ) -> Result<Value, DispatchError> {
77        let client = self
78            .clients
79            .get(server)
80            .ok_or_else(|| DispatchError::ServerNotFound(server.into()))?;
81
82        // Pre-dispatch tool name validation: if we know the server's tools,
83        // check the tool exists before sending to the upstream server.
84        if let Some(tools) = self.known_tools.get(server) {
85            if !tools.is_empty() && !tools.contains(tool) {
86                return Err(DispatchError::ToolNotFound {
87                    server: server.into(),
88                    tool: tool.into(),
89                });
90            }
91        }
92
93        client.call_tool(server, tool, args).await
94    }
95}
96
97/// A [`ResourceDispatcher`] that routes `read_resource(server, uri)` to the
98/// correct downstream MCP client based on server name.
99pub struct RouterResourceDispatcher {
100    clients: HashMap<String, Arc<dyn ResourceDispatcher>>,
101}
102
103impl RouterResourceDispatcher {
104    /// Create a new empty resource router.
105    pub fn new() -> Self {
106        Self {
107            clients: HashMap::new(),
108        }
109    }
110
111    /// Register a resource dispatcher for a named server.
112    pub fn add_client(&mut self, name: impl Into<String>, client: Arc<dyn ResourceDispatcher>) {
113        self.clients.insert(name.into(), client);
114    }
115
116    /// List all registered server names.
117    pub fn server_names(&self) -> Vec<&str> {
118        let mut names: Vec<&str> = self.clients.keys().map(|s| s.as_str()).collect();
119        names.sort();
120        names
121    }
122}
123
124impl Default for RouterResourceDispatcher {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130#[async_trait::async_trait]
131impl ResourceDispatcher for RouterResourceDispatcher {
132    #[tracing::instrument(skip(self), fields(server, uri))]
133    async fn read_resource(&self, server: &str, uri: &str) -> Result<Value, DispatchError> {
134        let client = self
135            .clients
136            .get(server)
137            .ok_or_else(|| DispatchError::ServerNotFound(server.into()))?;
138        client.read_resource(server, uri).await
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use std::sync::Mutex;
146
147    /// A mock dispatcher that records calls and returns a fixed response.
148    struct MockDispatcher {
149        name: String,
150        calls: Mutex<Vec<(String, String, Value)>>,
151    }
152
153    impl MockDispatcher {
154        fn new(name: &str) -> Self {
155            Self {
156                name: name.to_string(),
157                calls: Mutex::new(Vec::new()),
158            }
159        }
160
161        fn call_count(&self) -> usize {
162            self.calls.lock().unwrap().len()
163        }
164    }
165
166    #[async_trait::async_trait]
167    impl ToolDispatcher for MockDispatcher {
168        async fn call_tool(
169            &self,
170            server: &str,
171            tool: &str,
172            args: Value,
173        ) -> Result<Value, DispatchError> {
174            self.calls
175                .lock()
176                .unwrap()
177                .push((server.to_string(), tool.to_string(), args.clone()));
178            Ok(serde_json::json!({
179                "dispatcher": self.name,
180                "server": server,
181                "tool": tool,
182                "status": "ok"
183            }))
184        }
185    }
186
187    /// A dispatcher that always fails.
188    struct FailingDispatcher;
189
190    #[async_trait::async_trait]
191    impl ToolDispatcher for FailingDispatcher {
192        async fn call_tool(
193            &self,
194            _server: &str,
195            _tool: &str,
196            _args: Value,
197        ) -> Result<Value, DispatchError> {
198            Err(DispatchError::Internal(anyhow::anyhow!(
199                "downstream connection failed"
200            )))
201        }
202    }
203
204    #[tokio::test]
205    async fn router_dispatches_to_correct_server() {
206        let client_a = Arc::new(MockDispatcher::new("client-a"));
207        let client_b = Arc::new(MockDispatcher::new("client-b"));
208
209        let mut router = RouterDispatcher::new();
210        router.add_client("server-a", client_a.clone());
211        router.add_client("server-b", client_b.clone());
212
213        // Call server-a
214        let result = router
215            .call_tool("server-a", "tool1", serde_json::json!({}))
216            .await
217            .unwrap();
218        assert_eq!(result["dispatcher"], "client-a");
219        assert_eq!(result["tool"], "tool1");
220
221        // Call server-b
222        let result = router
223            .call_tool("server-b", "tool2", serde_json::json!({}))
224            .await
225            .unwrap();
226        assert_eq!(result["dispatcher"], "client-b");
227        assert_eq!(result["tool"], "tool2");
228
229        // Each client received exactly one call
230        assert_eq!(client_a.call_count(), 1);
231        assert_eq!(client_b.call_count(), 1);
232    }
233
234    #[tokio::test]
235    async fn router_returns_error_for_unknown_server() {
236        let mut router = RouterDispatcher::new();
237        router.add_client("known", Arc::new(MockDispatcher::new("known")));
238
239        let result = router
240            .call_tool("nonexistent", "tool", serde_json::json!({}))
241            .await;
242
243        assert!(result.is_err());
244        let err = result.unwrap_err();
245        assert!(
246            matches!(err, DispatchError::ServerNotFound(ref s) if s == "nonexistent"),
247            "expected ServerNotFound, got: {err}"
248        );
249    }
250
251    #[tokio::test]
252    async fn router_handles_concurrent_calls_to_same_server() {
253        let client = Arc::new(MockDispatcher::new("shared"));
254        let mut router = RouterDispatcher::new();
255        router.add_client("server", client.clone());
256
257        let router = Arc::new(router);
258        let mut handles = Vec::new();
259
260        for i in 0..10 {
261            let router = router.clone();
262            handles.push(tokio::spawn(async move {
263                router
264                    .call_tool("server", &format!("tool-{i}"), serde_json::json!({"i": i}))
265                    .await
266            }));
267        }
268
269        for handle in handles {
270            let result = handle.await.unwrap();
271            assert!(result.is_ok(), "concurrent call should succeed");
272        }
273
274        assert_eq!(client.call_count(), 10, "all 10 calls should be recorded");
275    }
276
277    #[tokio::test]
278    async fn router_handles_client_failure_gracefully() {
279        let healthy = Arc::new(MockDispatcher::new("healthy"));
280        let failing: Arc<dyn ToolDispatcher> = Arc::new(FailingDispatcher);
281
282        let mut router = RouterDispatcher::new();
283        router.add_client("healthy-server", healthy.clone());
284        router.add_client("failing-server", failing);
285
286        // Failing server returns error
287        let result = router
288            .call_tool("failing-server", "tool", serde_json::json!({}))
289            .await;
290        assert!(result.is_err());
291        assert!(result
292            .unwrap_err()
293            .to_string()
294            .contains("downstream connection failed"));
295
296        // Healthy server still works
297        let result = router
298            .call_tool("healthy-server", "tool", serde_json::json!({}))
299            .await;
300        assert!(result.is_ok());
301        assert_eq!(result.unwrap()["dispatcher"], "healthy");
302    }
303
304    #[test]
305    fn router_server_names_is_sorted() {
306        let mut router = RouterDispatcher::new();
307        router.add_client("zebra", Arc::new(MockDispatcher::new("z")));
308        router.add_client("alpha", Arc::new(MockDispatcher::new("a")));
309        router.add_client("middle", Arc::new(MockDispatcher::new("m")));
310
311        assert_eq!(router.server_names(), vec!["alpha", "middle", "zebra"]);
312    }
313
314    #[test]
315    fn router_server_count() {
316        let mut router = RouterDispatcher::new();
317        assert_eq!(router.server_count(), 0);
318
319        router.add_client("a", Arc::new(MockDispatcher::new("a")));
320        router.add_client("b", Arc::new(MockDispatcher::new("b")));
321        assert_eq!(router.server_count(), 2);
322    }
323
324    #[tokio::test]
325    async fn router_empty_returns_error() {
326        let router = RouterDispatcher::new();
327        let result = router.call_tool("any", "tool", serde_json::json!({})).await;
328        assert!(matches!(result, Err(DispatchError::ServerNotFound(_))));
329    }
330
331    #[tokio::test]
332    async fn router_returns_tool_not_found_for_unknown_tool() {
333        let mut router = RouterDispatcher::new();
334        router.set_known_tools("server", vec!["tool_a".into(), "tool_b".into()]);
335        router.add_client("server", Arc::new(MockDispatcher::new("server")));
336
337        // Known tool works
338        let result = router
339            .call_tool("server", "tool_a", serde_json::json!({}))
340            .await;
341        assert!(result.is_ok(), "known tool should succeed");
342
343        // Unknown tool returns ToolNotFound
344        let result = router
345            .call_tool("server", "tool_x", serde_json::json!({}))
346            .await;
347        assert!(result.is_err());
348        let err = result.unwrap_err();
349        assert!(
350            matches!(err, DispatchError::ToolNotFound { ref server, ref tool }
351                if server == "server" && tool == "tool_x"),
352            "expected ToolNotFound, got: {err}"
353        );
354    }
355
356    #[tokio::test]
357    async fn router_skips_tool_validation_when_no_tools_registered() {
358        let mut router = RouterDispatcher::new();
359        // No set_known_tools call — tools list is empty
360        router.add_client("server", Arc::new(MockDispatcher::new("server")));
361
362        // Should pass through to the client even though tool name is unknown
363        let result = router
364            .call_tool("server", "anything", serde_json::json!({}))
365            .await;
366        assert!(
367            result.is_ok(),
368            "should pass through when no tools registered"
369        );
370    }
371
372    // --- v0.2 Resource Router Tests (RS-C05..RS-C06) ---
373
374    struct MockResourceDispatcher {
375        name: String,
376    }
377
378    #[async_trait::async_trait]
379    impl ResourceDispatcher for MockResourceDispatcher {
380        async fn read_resource(&self, server: &str, uri: &str) -> Result<Value, DispatchError> {
381            Ok(serde_json::json!({
382                "dispatcher": self.name,
383                "server": server,
384                "uri": uri,
385                "content": "mock data"
386            }))
387        }
388    }
389
390    #[tokio::test]
391    async fn rs_c05_resource_router_dispatches_to_correct_client() {
392        let client_a = Arc::new(MockResourceDispatcher {
393            name: "client-a".into(),
394        });
395        let client_b = Arc::new(MockResourceDispatcher {
396            name: "client-b".into(),
397        });
398
399        let mut router = RouterResourceDispatcher::new();
400        router.add_client("server-a", client_a);
401        router.add_client("server-b", client_b);
402
403        let result = router
404            .read_resource("server-a", "file:///log")
405            .await
406            .unwrap();
407        assert_eq!(result["dispatcher"], "client-a");
408
409        let result = router
410            .read_resource("server-b", "db://table")
411            .await
412            .unwrap();
413        assert_eq!(result["dispatcher"], "client-b");
414    }
415
416    #[tokio::test]
417    async fn rs_c06_resource_router_returns_error_for_unknown_server() {
418        let mut router = RouterResourceDispatcher::new();
419        router.add_client(
420            "known",
421            Arc::new(MockResourceDispatcher {
422                name: "known".into(),
423            }),
424        );
425
426        let result = router.read_resource("nonexistent", "uri").await;
427        assert!(matches!(result, Err(DispatchError::ServerNotFound(ref s)) if s == "nonexistent"));
428    }
429}