1use std::collections::{HashMap, HashSet};
4use std::sync::Arc;
5
6use forge_error::DispatchError;
7use forge_sandbox::{ResourceDispatcher, ToolDispatcher};
8use serde_json::Value;
9
10pub struct RouterDispatcher {
17 clients: HashMap<String, Arc<dyn ToolDispatcher>>,
18 known_tools: HashMap<String, HashSet<String>>,
20}
21
22impl RouterDispatcher {
23 pub fn new() -> Self {
25 Self {
26 clients: HashMap::new(),
27 known_tools: HashMap::new(),
28 }
29 }
30
31 pub fn add_client(&mut self, name: impl Into<String>, client: Arc<dyn ToolDispatcher>) {
33 let name = name.into();
34 self.clients.insert(name.clone(), client);
35 self.known_tools.entry(name).or_default();
37 }
38
39 pub fn set_known_tools(
41 &mut self,
42 server: impl Into<String>,
43 tools: impl IntoIterator<Item = String>,
44 ) {
45 self.known_tools
46 .insert(server.into(), tools.into_iter().collect());
47 }
48
49 pub fn server_names(&self) -> Vec<&str> {
51 let mut names: Vec<&str> = self.clients.keys().map(|s| s.as_str()).collect();
52 names.sort();
53 names
54 }
55
56 pub fn server_count(&self) -> usize {
58 self.clients.len()
59 }
60}
61
62impl Default for RouterDispatcher {
63 fn default() -> Self {
64 Self::new()
65 }
66}
67
68#[async_trait::async_trait]
69impl ToolDispatcher for RouterDispatcher {
70 #[tracing::instrument(skip(self, args))]
71 async fn call_tool(
72 &self,
73 server: &str,
74 tool: &str,
75 args: Value,
76 ) -> Result<Value, DispatchError> {
77 let client = self
78 .clients
79 .get(server)
80 .ok_or_else(|| DispatchError::ServerNotFound(server.into()))?;
81
82 if let Some(tools) = self.known_tools.get(server) {
85 if !tools.is_empty() && !tools.contains(tool) {
86 return Err(DispatchError::ToolNotFound {
87 server: server.into(),
88 tool: tool.into(),
89 });
90 }
91 }
92
93 client.call_tool(server, tool, args).await
94 }
95}
96
97pub struct RouterResourceDispatcher {
100 clients: HashMap<String, Arc<dyn ResourceDispatcher>>,
101}
102
103impl RouterResourceDispatcher {
104 pub fn new() -> Self {
106 Self {
107 clients: HashMap::new(),
108 }
109 }
110
111 pub fn add_client(&mut self, name: impl Into<String>, client: Arc<dyn ResourceDispatcher>) {
113 self.clients.insert(name.into(), client);
114 }
115
116 pub fn server_names(&self) -> Vec<&str> {
118 let mut names: Vec<&str> = self.clients.keys().map(|s| s.as_str()).collect();
119 names.sort();
120 names
121 }
122}
123
124impl Default for RouterResourceDispatcher {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130#[async_trait::async_trait]
131impl ResourceDispatcher for RouterResourceDispatcher {
132 #[tracing::instrument(skip(self), fields(server, uri))]
133 async fn read_resource(&self, server: &str, uri: &str) -> Result<Value, DispatchError> {
134 let client = self
135 .clients
136 .get(server)
137 .ok_or_else(|| DispatchError::ServerNotFound(server.into()))?;
138 client.read_resource(server, uri).await
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use std::sync::Mutex;
146
147 struct MockDispatcher {
149 name: String,
150 calls: Mutex<Vec<(String, String, Value)>>,
151 }
152
153 impl MockDispatcher {
154 fn new(name: &str) -> Self {
155 Self {
156 name: name.to_string(),
157 calls: Mutex::new(Vec::new()),
158 }
159 }
160
161 fn call_count(&self) -> usize {
162 self.calls.lock().unwrap().len()
163 }
164 }
165
166 #[async_trait::async_trait]
167 impl ToolDispatcher for MockDispatcher {
168 async fn call_tool(
169 &self,
170 server: &str,
171 tool: &str,
172 args: Value,
173 ) -> Result<Value, DispatchError> {
174 self.calls
175 .lock()
176 .unwrap()
177 .push((server.to_string(), tool.to_string(), args.clone()));
178 Ok(serde_json::json!({
179 "dispatcher": self.name,
180 "server": server,
181 "tool": tool,
182 "status": "ok"
183 }))
184 }
185 }
186
187 struct FailingDispatcher;
189
190 #[async_trait::async_trait]
191 impl ToolDispatcher for FailingDispatcher {
192 async fn call_tool(
193 &self,
194 _server: &str,
195 _tool: &str,
196 _args: Value,
197 ) -> Result<Value, DispatchError> {
198 Err(DispatchError::Internal(anyhow::anyhow!(
199 "downstream connection failed"
200 )))
201 }
202 }
203
204 #[tokio::test]
205 async fn router_dispatches_to_correct_server() {
206 let client_a = Arc::new(MockDispatcher::new("client-a"));
207 let client_b = Arc::new(MockDispatcher::new("client-b"));
208
209 let mut router = RouterDispatcher::new();
210 router.add_client("server-a", client_a.clone());
211 router.add_client("server-b", client_b.clone());
212
213 let result = router
215 .call_tool("server-a", "tool1", serde_json::json!({}))
216 .await
217 .unwrap();
218 assert_eq!(result["dispatcher"], "client-a");
219 assert_eq!(result["tool"], "tool1");
220
221 let result = router
223 .call_tool("server-b", "tool2", serde_json::json!({}))
224 .await
225 .unwrap();
226 assert_eq!(result["dispatcher"], "client-b");
227 assert_eq!(result["tool"], "tool2");
228
229 assert_eq!(client_a.call_count(), 1);
231 assert_eq!(client_b.call_count(), 1);
232 }
233
234 #[tokio::test]
235 async fn router_returns_error_for_unknown_server() {
236 let mut router = RouterDispatcher::new();
237 router.add_client("known", Arc::new(MockDispatcher::new("known")));
238
239 let result = router
240 .call_tool("nonexistent", "tool", serde_json::json!({}))
241 .await;
242
243 assert!(result.is_err());
244 let err = result.unwrap_err();
245 assert!(
246 matches!(err, DispatchError::ServerNotFound(ref s) if s == "nonexistent"),
247 "expected ServerNotFound, got: {err}"
248 );
249 }
250
251 #[tokio::test]
252 async fn router_handles_concurrent_calls_to_same_server() {
253 let client = Arc::new(MockDispatcher::new("shared"));
254 let mut router = RouterDispatcher::new();
255 router.add_client("server", client.clone());
256
257 let router = Arc::new(router);
258 let mut handles = Vec::new();
259
260 for i in 0..10 {
261 let router = router.clone();
262 handles.push(tokio::spawn(async move {
263 router
264 .call_tool("server", &format!("tool-{i}"), serde_json::json!({"i": i}))
265 .await
266 }));
267 }
268
269 for handle in handles {
270 let result = handle.await.unwrap();
271 assert!(result.is_ok(), "concurrent call should succeed");
272 }
273
274 assert_eq!(client.call_count(), 10, "all 10 calls should be recorded");
275 }
276
277 #[tokio::test]
278 async fn router_handles_client_failure_gracefully() {
279 let healthy = Arc::new(MockDispatcher::new("healthy"));
280 let failing: Arc<dyn ToolDispatcher> = Arc::new(FailingDispatcher);
281
282 let mut router = RouterDispatcher::new();
283 router.add_client("healthy-server", healthy.clone());
284 router.add_client("failing-server", failing);
285
286 let result = router
288 .call_tool("failing-server", "tool", serde_json::json!({}))
289 .await;
290 assert!(result.is_err());
291 assert!(result
292 .unwrap_err()
293 .to_string()
294 .contains("downstream connection failed"));
295
296 let result = router
298 .call_tool("healthy-server", "tool", serde_json::json!({}))
299 .await;
300 assert!(result.is_ok());
301 assert_eq!(result.unwrap()["dispatcher"], "healthy");
302 }
303
304 #[test]
305 fn router_server_names_is_sorted() {
306 let mut router = RouterDispatcher::new();
307 router.add_client("zebra", Arc::new(MockDispatcher::new("z")));
308 router.add_client("alpha", Arc::new(MockDispatcher::new("a")));
309 router.add_client("middle", Arc::new(MockDispatcher::new("m")));
310
311 assert_eq!(router.server_names(), vec!["alpha", "middle", "zebra"]);
312 }
313
314 #[test]
315 fn router_server_count() {
316 let mut router = RouterDispatcher::new();
317 assert_eq!(router.server_count(), 0);
318
319 router.add_client("a", Arc::new(MockDispatcher::new("a")));
320 router.add_client("b", Arc::new(MockDispatcher::new("b")));
321 assert_eq!(router.server_count(), 2);
322 }
323
324 #[tokio::test]
325 async fn router_empty_returns_error() {
326 let router = RouterDispatcher::new();
327 let result = router.call_tool("any", "tool", serde_json::json!({})).await;
328 assert!(matches!(result, Err(DispatchError::ServerNotFound(_))));
329 }
330
331 #[tokio::test]
332 async fn router_returns_tool_not_found_for_unknown_tool() {
333 let mut router = RouterDispatcher::new();
334 router.set_known_tools("server", vec!["tool_a".into(), "tool_b".into()]);
335 router.add_client("server", Arc::new(MockDispatcher::new("server")));
336
337 let result = router
339 .call_tool("server", "tool_a", serde_json::json!({}))
340 .await;
341 assert!(result.is_ok(), "known tool should succeed");
342
343 let result = router
345 .call_tool("server", "tool_x", serde_json::json!({}))
346 .await;
347 assert!(result.is_err());
348 let err = result.unwrap_err();
349 assert!(
350 matches!(err, DispatchError::ToolNotFound { ref server, ref tool }
351 if server == "server" && tool == "tool_x"),
352 "expected ToolNotFound, got: {err}"
353 );
354 }
355
356 #[tokio::test]
357 async fn router_skips_tool_validation_when_no_tools_registered() {
358 let mut router = RouterDispatcher::new();
359 router.add_client("server", Arc::new(MockDispatcher::new("server")));
361
362 let result = router
364 .call_tool("server", "anything", serde_json::json!({}))
365 .await;
366 assert!(
367 result.is_ok(),
368 "should pass through when no tools registered"
369 );
370 }
371
372 struct MockResourceDispatcher {
375 name: String,
376 }
377
378 #[async_trait::async_trait]
379 impl ResourceDispatcher for MockResourceDispatcher {
380 async fn read_resource(&self, server: &str, uri: &str) -> Result<Value, DispatchError> {
381 Ok(serde_json::json!({
382 "dispatcher": self.name,
383 "server": server,
384 "uri": uri,
385 "content": "mock data"
386 }))
387 }
388 }
389
390 #[tokio::test]
391 async fn rs_c05_resource_router_dispatches_to_correct_client() {
392 let client_a = Arc::new(MockResourceDispatcher {
393 name: "client-a".into(),
394 });
395 let client_b = Arc::new(MockResourceDispatcher {
396 name: "client-b".into(),
397 });
398
399 let mut router = RouterResourceDispatcher::new();
400 router.add_client("server-a", client_a);
401 router.add_client("server-b", client_b);
402
403 let result = router
404 .read_resource("server-a", "file:///log")
405 .await
406 .unwrap();
407 assert_eq!(result["dispatcher"], "client-a");
408
409 let result = router
410 .read_resource("server-b", "db://table")
411 .await
412 .unwrap();
413 assert_eq!(result["dispatcher"], "client-b");
414 }
415
416 #[tokio::test]
417 async fn rs_c06_resource_router_returns_error_for_unknown_server() {
418 let mut router = RouterResourceDispatcher::new();
419 router.add_client(
420 "known",
421 Arc::new(MockResourceDispatcher {
422 name: "known".into(),
423 }),
424 );
425
426 let result = router.read_resource("nonexistent", "uri").await;
427 assert!(matches!(result, Err(DispatchError::ServerNotFound(ref s)) if s == "nonexistent"));
428 }
429}