1use std::collections::HashMap;
24use std::sync::Arc;
25
26use crate::config::{ServerRouteConfig, TransportType};
27use crate::transport::{
28 HttpTransport, Message, SseTransport, StdioTransport, Transport, TransportError,
29};
30
31#[derive(Debug, thiserror::Error)]
33pub enum RouterError {
34 #[error("No route found for path: {0}")]
35 NoRoute(String),
36
37 #[error("Failed to initialize transport for server '{0}': {1}")]
38 TransportInit(String, String),
39
40 #[error("Transport error: {0}")]
41 Transport(#[from] TransportError),
42}
43
44pub struct ServerRoute {
46 pub config: ServerRouteConfig,
48 pub transport: Arc<dyn Transport>,
50}
51
52pub struct ServerRouter {
54 routes: Vec<ServerRoute>,
56 default_route: Option<ServerRoute>,
58}
59
60impl std::fmt::Debug for ServerRouter {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 f.debug_struct("ServerRouter")
63 .field("route_count", &self.routes.len())
64 .field("has_default", &self.default_route.is_some())
65 .finish()
66 }
67}
68
69impl ServerRouter {
70 pub async fn new(configs: Vec<ServerRouteConfig>) -> Result<Self, RouterError> {
75 Self::new_internal(configs, true).await
76 }
77
78 pub async fn new_unchecked(configs: Vec<ServerRouteConfig>) -> Result<Self, RouterError> {
84 Self::new_internal(configs, false).await
85 }
86
87 async fn new_internal(
89 configs: Vec<ServerRouteConfig>,
90 validate_ssrf: bool,
91 ) -> Result<Self, RouterError> {
92 let mut routes = Vec::new();
93
94 for config in configs {
95 let transport = Self::create_transport(&config, validate_ssrf).await?;
96 routes.push(ServerRoute { config, transport });
97 }
98
99 routes.sort_by(|a, b| b.config.path_prefix.len().cmp(&a.config.path_prefix.len()));
101
102 Ok(Self {
103 routes,
104 default_route: None,
105 })
106 }
107
108 async fn create_transport(
110 config: &ServerRouteConfig,
111 validate_ssrf: bool,
112 ) -> Result<Arc<dyn Transport>, RouterError> {
113 match config.transport {
114 TransportType::Stdio => {
115 let command = config.command.as_ref().ok_or_else(|| {
116 RouterError::TransportInit(
117 config.name.clone(),
118 "stdio transport requires 'command'".to_string(),
119 )
120 })?;
121 let transport = StdioTransport::spawn(command, &config.args)
122 .await
123 .map_err(|e| RouterError::TransportInit(config.name.clone(), e.to_string()))?;
124 Ok(Arc::new(transport))
125 }
126 TransportType::Http => {
127 let url = config.url.as_ref().ok_or_else(|| {
128 RouterError::TransportInit(
129 config.name.clone(),
130 "http transport requires 'url'".to_string(),
131 )
132 })?;
133 let transport = if validate_ssrf {
134 HttpTransport::new(url.clone()).await.map_err(|e| {
135 RouterError::TransportInit(config.name.clone(), e.to_string())
136 })?
137 } else {
138 HttpTransport::new_unchecked(url.clone())
139 };
140 Ok(Arc::new(transport))
141 }
142 TransportType::Sse => {
143 let url = config.url.as_ref().ok_or_else(|| {
144 RouterError::TransportInit(
145 config.name.clone(),
146 "sse transport requires 'url'".to_string(),
147 )
148 })?;
149 let transport = if validate_ssrf {
150 SseTransport::connect(url.clone()).await.map_err(|e| {
151 RouterError::TransportInit(config.name.clone(), e.to_string())
152 })?
153 } else {
154 SseTransport::connect_unchecked(url.clone())
155 .await
156 .map_err(|e| {
157 RouterError::TransportInit(config.name.clone(), e.to_string())
158 })?
159 };
160 Ok(Arc::new(transport))
161 }
162 }
163 }
164
165 pub fn with_default(mut self, route: ServerRoute) -> Self {
167 self.default_route = Some(route);
168 self
169 }
170
171 pub fn find_route(&self, path: &str) -> Option<&ServerRoute> {
173 for route in &self.routes {
175 if path.starts_with(&route.config.path_prefix) {
176 return Some(route);
177 }
178 }
179
180 self.default_route.as_ref()
182 }
183
184 pub fn get_transport(&self, path: &str) -> Option<Arc<dyn Transport>> {
186 self.find_route(path).map(|r| r.transport.clone())
187 }
188
189 pub fn get_route_name(&self, path: &str) -> Option<&str> {
191 self.find_route(path).map(|r| r.config.name.as_str())
192 }
193
194 pub fn transform_path(&self, path: &str) -> String {
196 if let Some(route) = self.find_route(path) {
197 if route.config.strip_prefix {
198 return path
199 .strip_prefix(&route.config.path_prefix)
200 .unwrap_or(path)
201 .to_string();
202 }
203 }
204 path.to_string()
205 }
206
207 pub async fn send(&self, path: &str, message: Message) -> Result<(), RouterError> {
209 let route = self
210 .find_route(path)
211 .ok_or_else(|| RouterError::NoRoute(path.to_string()))?;
212
213 route
214 .transport
215 .send(message)
216 .await
217 .map_err(RouterError::from)
218 }
219
220 pub async fn receive(&self, path: &str) -> Result<Message, RouterError> {
222 let route = self
223 .find_route(path)
224 .ok_or_else(|| RouterError::NoRoute(path.to_string()))?;
225
226 route.transport.receive().await.map_err(RouterError::from)
227 }
228
229 pub fn route_names(&self) -> Vec<&str> {
231 self.routes.iter().map(|r| r.config.name.as_str()).collect()
232 }
233
234 pub fn has_routes(&self) -> bool {
236 !self.routes.is_empty() || self.default_route.is_some()
237 }
238
239 pub fn route_count(&self) -> usize {
241 self.routes.len()
242 }
243}
244
245pub struct RouteMatcher {
247 prefixes: HashMap<String, String>,
249}
250
251impl RouteMatcher {
252 pub fn new(routes: &[ServerRouteConfig]) -> Self {
254 let mut prefixes = HashMap::new();
255 for route in routes {
256 prefixes.insert(route.path_prefix.clone(), route.name.clone());
257 }
258 Self { prefixes }
259 }
260
261 pub fn match_path(&self, path: &str) -> Option<&str> {
263 let mut best_match: Option<(&str, &String)> = None;
265 for (prefix, name) in &self.prefixes {
266 if path.starts_with(prefix) {
267 let dominated = match &best_match {
268 Some((best_prefix, _)) => prefix.len() > best_prefix.len(),
269 None => true,
270 };
271 if dominated {
272 best_match = Some((prefix, name));
273 }
274 }
275 }
276 best_match.map(|(_, name)| name.as_str())
277 }
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283 use crate::config::TransportType;
284
285 fn create_test_route(name: &str, path_prefix: &str, strip: bool) -> ServerRouteConfig {
286 ServerRouteConfig {
287 name: name.to_string(),
288 path_prefix: path_prefix.to_string(),
289 transport: TransportType::Http,
290 command: None,
291 args: vec![],
292 url: Some("http://localhost:8080".to_string()),
293 strip_prefix: strip,
294 }
295 }
296
297 #[test]
298 fn test_route_matcher_exact() {
299 let routes = vec![
300 create_test_route("github", "/github", false),
301 create_test_route("filesystem", "/filesystem", false),
302 ];
303 let matcher = RouteMatcher::new(&routes);
304
305 assert_eq!(matcher.match_path("/github/repos"), Some("github"));
306 assert_eq!(matcher.match_path("/filesystem/read"), Some("filesystem"));
307 assert_eq!(matcher.match_path("/unknown/path"), None);
308 }
309
310 #[test]
311 fn test_route_matcher_longest_prefix() {
312 let routes = vec![
313 create_test_route("api", "/api", false),
314 create_test_route("api-v2", "/api/v2", false),
315 ];
316 let matcher = RouteMatcher::new(&routes);
317
318 assert_eq!(matcher.match_path("/api/v2/users"), Some("api-v2"));
320 assert_eq!(matcher.match_path("/api/v1/users"), Some("api"));
321 }
322
323 #[test]
324 fn test_config_validation() {
325 let valid = create_test_route("test", "/test", false);
326 assert!(valid.validate().is_ok());
327
328 let mut invalid = create_test_route("test", "no-slash", false);
329 assert!(invalid.validate().is_err());
330
331 invalid.path_prefix = "/test".to_string();
332 invalid.name = "".to_string();
333 assert!(invalid.validate().is_err());
334 }
335
336 #[test]
341 fn test_route_matcher_empty() {
342 let routes: Vec<ServerRouteConfig> = vec![];
343 let matcher = RouteMatcher::new(&routes);
344 assert_eq!(matcher.match_path("/any/path"), None);
345 }
346
347 #[test]
348 fn test_route_matcher_root_path() {
349 let routes = vec![
350 create_test_route("root", "/", false),
351 create_test_route("api", "/api", false),
352 ];
353 let matcher = RouteMatcher::new(&routes);
354
355 assert_eq!(matcher.match_path("/api/users"), Some("api"));
357 assert_eq!(matcher.match_path("/other"), Some("root"));
359 }
360
361 #[test]
362 fn test_route_matcher_exact_match() {
363 let routes = vec![create_test_route("exact", "/exact", false)];
364 let matcher = RouteMatcher::new(&routes);
365
366 assert_eq!(matcher.match_path("/exact"), Some("exact"));
367 assert_eq!(matcher.match_path("/exact/sub"), Some("exact"));
368 assert_eq!(matcher.match_path("/exactnot"), Some("exact"));
370 assert_eq!(matcher.match_path("/other"), None);
372 }
373
374 #[test]
379 fn test_router_error_no_route() {
380 let err = RouterError::NoRoute("/unknown".to_string());
381 let msg = format!("{}", err);
382 assert!(msg.contains("/unknown"));
383 }
384
385 #[test]
386 fn test_router_error_transport_init() {
387 let err =
388 RouterError::TransportInit("server1".to_string(), "connection failed".to_string());
389 let msg = format!("{}", err);
390 assert!(msg.contains("server1"));
391 assert!(msg.contains("connection failed"));
392 }
393
394 #[test]
395 fn test_router_error_from_transport() {
396 let transport_err = TransportError::Timeout;
397 let router_err: RouterError = transport_err.into();
398 assert!(matches!(router_err, RouterError::Transport(_)));
399 }
400
401 #[test]
406 fn test_config_validation_stdio_missing_command() {
407 let mut config = ServerRouteConfig {
408 name: "stdio".to_string(),
409 path_prefix: "/stdio".to_string(),
410 transport: TransportType::Stdio,
411 command: None,
412 args: vec![],
413 url: None,
414 strip_prefix: false,
415 };
416 assert!(config.validate().is_err());
417
418 config.command = Some("node".to_string());
419 assert!(config.validate().is_ok());
420 }
421
422 #[test]
423 fn test_config_validation_http_missing_url() {
424 let config = ServerRouteConfig {
425 name: "http".to_string(),
426 path_prefix: "/http".to_string(),
427 transport: TransportType::Http,
428 command: None,
429 args: vec![],
430 url: None,
431 strip_prefix: false,
432 };
433 assert!(config.validate().is_err());
434 }
435
436 #[test]
437 fn test_config_validation_sse_missing_url() {
438 let config = ServerRouteConfig {
439 name: "sse".to_string(),
440 path_prefix: "/sse".to_string(),
441 transport: TransportType::Sse,
442 command: None,
443 args: vec![],
444 url: None,
445 strip_prefix: false,
446 };
447 assert!(config.validate().is_err());
448 }
449
450 #[test]
455 fn test_router_new_validation() {
456 let invalid_config = ServerRouteConfig {
458 name: "invalid".to_string(),
459 path_prefix: "/invalid".to_string(),
460 transport: TransportType::Http,
461 command: None,
462 args: vec![],
463 url: Some("not-a-url".to_string()),
464 strip_prefix: false,
465 };
466
467 let result = tokio::runtime::Runtime::new()
468 .unwrap()
469 .block_on(ServerRouter::new(vec![invalid_config]));
470 assert!(result.is_err());
471 assert!(matches!(
472 result.unwrap_err(),
473 RouterError::TransportInit(_, _)
474 ));
475 }
476
477 #[test]
478 fn test_router_send_no_route() {
479 let router = ServerRouter {
480 routes: vec![],
481 default_route: None,
482 };
483
484 let test_message = Message::request(1, "ping", None);
485 let result = tokio::runtime::Runtime::new()
486 .unwrap()
487 .block_on(router.send("/unknown", test_message));
488 assert!(matches!(result, Err(RouterError::NoRoute(_))));
489 }
490
491 #[test]
492 fn test_router_receive_no_route() {
493 let router = ServerRouter {
494 routes: vec![],
495 default_route: None,
496 };
497
498 let result = tokio::runtime::Runtime::new()
499 .unwrap()
500 .block_on(router.receive("/unknown"));
501 assert!(matches!(result, Err(RouterError::NoRoute(_))));
502 }
503
504 #[test]
505 fn test_router_transform_path() {
506 use crate::mocks::MockTransport;
507 let mut config = create_test_route("strip", "/strip", true);
508 config.strip_prefix = true;
509
510 let router = ServerRouter {
511 routes: vec![ServerRoute {
512 config: config.clone(),
513 transport: Arc::new(MockTransport::new()),
514 }],
515 default_route: None,
516 };
517
518 assert_eq!(router.transform_path("/strip/foo"), "/foo");
520
521 assert_eq!(router.transform_path("/other/foo"), "/other/foo");
523
524 let config_no_strip = create_test_route("no-strip", "/no-strip", false);
526 let router_no_strip = ServerRouter {
527 routes: vec![ServerRoute {
528 config: config_no_strip,
529 transport: Arc::new(MockTransport::new()),
530 }],
531 default_route: None,
532 };
533 assert_eq!(
534 router_no_strip.transform_path("/no-strip/foo"),
535 "/no-strip/foo"
536 );
537 }
538
539 #[test]
540 fn test_router_route_count() {
541 use crate::mocks::MockTransport;
542 let router = ServerRouter {
543 routes: vec![
544 ServerRoute {
545 config: create_test_route("s1", "/s1", false),
546 transport: Arc::new(MockTransport::new()),
547 },
548 ServerRoute {
549 config: create_test_route("s2", "/s2", false),
550 transport: Arc::new(MockTransport::new()),
551 },
552 ],
553 default_route: None,
554 };
555
556 assert_eq!(router.route_count(), 2);
557 assert!(router.has_routes());
558 assert_eq!(router.route_names(), vec!["s1", "s2"]);
559 }
560
561 #[test]
566 fn test_router_with_default_route() {
567 use crate::mocks::MockTransport;
568
569 let default_config = create_test_route("default", "/", false);
570 let default_route = ServerRoute {
571 config: default_config,
572 transport: Arc::new(MockTransport::new()),
573 };
574
575 let router = ServerRouter {
576 routes: vec![ServerRoute {
577 config: create_test_route("api", "/api", false),
578 transport: Arc::new(MockTransport::new()),
579 }],
580 default_route: None,
581 }
582 .with_default(default_route);
583
584 assert!(router.has_routes());
586
587 let route = router.find_route("/api/users");
589 assert!(route.is_some());
590 assert_eq!(route.unwrap().config.name, "api");
591
592 let route = router.find_route("/unknown");
594 assert!(route.is_some());
595 assert_eq!(route.unwrap().config.name, "default");
596 }
597
598 #[test]
599 fn test_router_get_route_name() {
600 use crate::mocks::MockTransport;
601
602 let router = ServerRouter {
603 routes: vec![ServerRoute {
604 config: create_test_route("github", "/github", false),
605 transport: Arc::new(MockTransport::new()),
606 }],
607 default_route: None,
608 };
609
610 assert_eq!(router.get_route_name("/github/repos"), Some("github"));
611 assert_eq!(router.get_route_name("/unknown"), None);
612 }
613
614 #[test]
615 fn test_router_get_transport() {
616 use crate::mocks::MockTransport;
617
618 let router = ServerRouter {
619 routes: vec![ServerRoute {
620 config: create_test_route("test", "/test", false),
621 transport: Arc::new(MockTransport::new()),
622 }],
623 default_route: None,
624 };
625
626 assert!(router.get_transport("/test/path").is_some());
628 assert!(router.get_transport("/other/path").is_none());
630 }
631
632 #[test]
633 fn test_router_debug_formatting() {
634 use crate::mocks::MockTransport;
635
636 let router = ServerRouter {
637 routes: vec![ServerRoute {
638 config: create_test_route("s1", "/s1", false),
639 transport: Arc::new(MockTransport::new()),
640 }],
641 default_route: None,
642 };
643
644 let debug_str = format!("{:?}", router);
646 assert!(debug_str.contains("route_count: 1"));
647 assert!(debug_str.contains("has_default: false"));
648 }
649
650 #[test]
651 fn test_router_empty_has_no_routes() {
652 let router = ServerRouter {
653 routes: vec![],
654 default_route: None,
655 };
656
657 assert!(!router.has_routes());
658 assert_eq!(router.route_count(), 0);
659 assert!(router.route_names().is_empty());
660 }
661
662 #[test]
663 fn test_router_empty_with_default_has_routes() {
664 use crate::mocks::MockTransport;
665
666 let default_config = create_test_route("default", "/", false);
667 let default_route = ServerRoute {
668 config: default_config,
669 transport: Arc::new(MockTransport::new()),
670 };
671
672 let router = ServerRouter {
673 routes: vec![],
674 default_route: Some(default_route),
675 };
676
677 assert!(router.has_routes());
679 assert_eq!(router.route_count(), 0); }
681}