mcp_guard_core/router/
mod.rs

1// Copyright (c) 2025 Austin Green
2// SPDX-License-Identifier: AGPL-3.0
3//
4// This file is part of MCP-Guard.
5//
6// MCP-Guard is free software: you can redistribute it and/or modify
7// it under the terms of the GNU Affero General Public License as published by
8// the Free Software Foundation, either version 3 of the License, or
9// (at your option) any later version.
10//
11// MCP-Guard is distributed in the hope that it will be useful,
12// but WITHOUT ANY WARRANTY; without even the implied warranty of
13// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14// GNU Affero General Public License for more details.
15//
16// You should have received a copy of the GNU Affero General Public License
17// along with MCP-Guard. If not, see <https://www.gnu.org/licenses/>.
18//! Multi-server routing for mcp-guard
19//!
20//! Routes requests to different upstream MCP servers based on path prefix.
21//! This enables organizations to run multiple MCP servers behind a single gateway.
22
23use std::collections::HashMap;
24use std::sync::Arc;
25
26use crate::config::{ServerRouteConfig, TransportType};
27use crate::transport::{
28    HttpTransport, Message, SseTransport, StdioTransport, Transport, TransportError,
29};
30
31/// Router error types
32#[derive(Debug, thiserror::Error)]
33pub enum RouterError {
34    #[error("No route found for path: {0}")]
35    NoRoute(String),
36
37    #[error("Failed to initialize transport for server '{0}': {1}")]
38    TransportInit(String, String),
39
40    #[error("Transport error: {0}")]
41    Transport(#[from] TransportError),
42}
43
44/// Server route with initialized transport
45pub struct ServerRoute {
46    /// Route configuration
47    pub config: ServerRouteConfig,
48    /// Initialized transport
49    pub transport: Arc<dyn Transport>,
50}
51
52/// Multi-server router that routes requests to different upstreams based on path
53pub struct ServerRouter {
54    /// Routes indexed by path prefix (sorted by specificity)
55    routes: Vec<ServerRoute>,
56    /// Default route (optional, used when no path prefix matches)
57    default_route: Option<ServerRoute>,
58}
59
60impl std::fmt::Debug for ServerRouter {
61    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62        f.debug_struct("ServerRouter")
63            .field("route_count", &self.routes.len())
64            .field("has_default", &self.default_route.is_some())
65            .finish()
66    }
67}
68
69impl ServerRouter {
70    /// Create a new server router from configuration
71    ///
72    /// This performs SSRF validation on HTTP/SSE URLs. Use `new_unchecked` to bypass
73    /// SSRF validation for trusted configurations (e.g., in tests).
74    pub async fn new(configs: Vec<ServerRouteConfig>) -> Result<Self, RouterError> {
75        Self::new_internal(configs, true).await
76    }
77
78    /// Create a new server router without SSRF validation
79    ///
80    /// # Safety
81    /// This bypasses SSRF protection. Only use when URLs are from a trusted source
82    /// (e.g., hardcoded in the application) or when connecting to localhost for testing.
83    pub async fn new_unchecked(configs: Vec<ServerRouteConfig>) -> Result<Self, RouterError> {
84        Self::new_internal(configs, false).await
85    }
86
87    /// Internal constructor with configurable SSRF validation
88    async fn new_internal(
89        configs: Vec<ServerRouteConfig>,
90        validate_ssrf: bool,
91    ) -> Result<Self, RouterError> {
92        let mut routes = Vec::new();
93
94        for config in configs {
95            let transport = Self::create_transport(&config, validate_ssrf).await?;
96            routes.push(ServerRoute { config, transport });
97        }
98
99        // Sort routes by path prefix length (longer = more specific = higher priority)
100        routes.sort_by(|a, b| b.config.path_prefix.len().cmp(&a.config.path_prefix.len()));
101
102        Ok(Self {
103            routes,
104            default_route: None,
105        })
106    }
107
108    /// Create a transport from server route configuration
109    async fn create_transport(
110        config: &ServerRouteConfig,
111        validate_ssrf: bool,
112    ) -> Result<Arc<dyn Transport>, RouterError> {
113        match config.transport {
114            TransportType::Stdio => {
115                let command = config.command.as_ref().ok_or_else(|| {
116                    RouterError::TransportInit(
117                        config.name.clone(),
118                        "stdio transport requires 'command'".to_string(),
119                    )
120                })?;
121                let transport = StdioTransport::spawn(command, &config.args)
122                    .await
123                    .map_err(|e| RouterError::TransportInit(config.name.clone(), e.to_string()))?;
124                Ok(Arc::new(transport))
125            }
126            TransportType::Http => {
127                let url = config.url.as_ref().ok_or_else(|| {
128                    RouterError::TransportInit(
129                        config.name.clone(),
130                        "http transport requires 'url'".to_string(),
131                    )
132                })?;
133                let transport = if validate_ssrf {
134                    HttpTransport::new(url.clone()).await.map_err(|e| {
135                        RouterError::TransportInit(config.name.clone(), e.to_string())
136                    })?
137                } else {
138                    HttpTransport::new_unchecked(url.clone())
139                };
140                Ok(Arc::new(transport))
141            }
142            TransportType::Sse => {
143                let url = config.url.as_ref().ok_or_else(|| {
144                    RouterError::TransportInit(
145                        config.name.clone(),
146                        "sse transport requires 'url'".to_string(),
147                    )
148                })?;
149                let transport = if validate_ssrf {
150                    SseTransport::connect(url.clone()).await.map_err(|e| {
151                        RouterError::TransportInit(config.name.clone(), e.to_string())
152                    })?
153                } else {
154                    SseTransport::connect_unchecked(url.clone())
155                        .await
156                        .map_err(|e| {
157                            RouterError::TransportInit(config.name.clone(), e.to_string())
158                        })?
159                };
160                Ok(Arc::new(transport))
161            }
162        }
163    }
164
165    /// Set a default route for unmatched requests
166    pub fn with_default(mut self, route: ServerRoute) -> Self {
167        self.default_route = Some(route);
168        self
169    }
170
171    /// Find the route for a given path
172    pub fn find_route(&self, path: &str) -> Option<&ServerRoute> {
173        // Try to match a specific route first
174        for route in &self.routes {
175            if path.starts_with(&route.config.path_prefix) {
176                return Some(route);
177            }
178        }
179
180        // Fall back to default route
181        self.default_route.as_ref()
182    }
183
184    /// Get the transport for a given path
185    pub fn get_transport(&self, path: &str) -> Option<Arc<dyn Transport>> {
186        self.find_route(path).map(|r| r.transport.clone())
187    }
188
189    /// Get the route name for a given path (for logging/metrics)
190    pub fn get_route_name(&self, path: &str) -> Option<&str> {
191        self.find_route(path).map(|r| r.config.name.as_str())
192    }
193
194    /// Transform the path if strip_prefix is enabled for the route
195    pub fn transform_path(&self, path: &str) -> String {
196        if let Some(route) = self.find_route(path) {
197            if route.config.strip_prefix {
198                return path
199                    .strip_prefix(&route.config.path_prefix)
200                    .unwrap_or(path)
201                    .to_string();
202            }
203        }
204        path.to_string()
205    }
206
207    /// Send a message to the appropriate server based on path
208    pub async fn send(&self, path: &str, message: Message) -> Result<(), RouterError> {
209        let route = self
210            .find_route(path)
211            .ok_or_else(|| RouterError::NoRoute(path.to_string()))?;
212
213        route
214            .transport
215            .send(message)
216            .await
217            .map_err(RouterError::from)
218    }
219
220    /// Receive a message from the appropriate server based on path
221    pub async fn receive(&self, path: &str) -> Result<Message, RouterError> {
222        let route = self
223            .find_route(path)
224            .ok_or_else(|| RouterError::NoRoute(path.to_string()))?;
225
226        route.transport.receive().await.map_err(RouterError::from)
227    }
228
229    /// Get all route names for metrics/debugging
230    pub fn route_names(&self) -> Vec<&str> {
231        self.routes.iter().map(|r| r.config.name.as_str()).collect()
232    }
233
234    /// Check if any routes are configured
235    pub fn has_routes(&self) -> bool {
236        !self.routes.is_empty() || self.default_route.is_some()
237    }
238
239    /// Get the number of configured routes
240    pub fn route_count(&self) -> usize {
241        self.routes.len()
242    }
243}
244
245/// Route matcher for extracting server name from path
246pub struct RouteMatcher {
247    /// Map of path prefixes to server names
248    prefixes: HashMap<String, String>,
249}
250
251impl RouteMatcher {
252    /// Create a new route matcher from server routes
253    pub fn new(routes: &[ServerRouteConfig]) -> Self {
254        let mut prefixes = HashMap::new();
255        for route in routes {
256            prefixes.insert(route.path_prefix.clone(), route.name.clone());
257        }
258        Self { prefixes }
259    }
260
261    /// Match a path to a server name
262    pub fn match_path(&self, path: &str) -> Option<&str> {
263        // Find the longest matching prefix
264        let mut best_match: Option<(&str, &String)> = None;
265        for (prefix, name) in &self.prefixes {
266            if path.starts_with(prefix) {
267                let dominated = match &best_match {
268                    Some((best_prefix, _)) => prefix.len() > best_prefix.len(),
269                    None => true,
270                };
271                if dominated {
272                    best_match = Some((prefix, name));
273                }
274            }
275        }
276        best_match.map(|(_, name)| name.as_str())
277    }
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283    use crate::config::TransportType;
284
285    fn create_test_route(name: &str, path_prefix: &str, strip: bool) -> ServerRouteConfig {
286        ServerRouteConfig {
287            name: name.to_string(),
288            path_prefix: path_prefix.to_string(),
289            transport: TransportType::Http,
290            command: None,
291            args: vec![],
292            url: Some("http://localhost:8080".to_string()),
293            strip_prefix: strip,
294        }
295    }
296
297    #[test]
298    fn test_route_matcher_exact() {
299        let routes = vec![
300            create_test_route("github", "/github", false),
301            create_test_route("filesystem", "/filesystem", false),
302        ];
303        let matcher = RouteMatcher::new(&routes);
304
305        assert_eq!(matcher.match_path("/github/repos"), Some("github"));
306        assert_eq!(matcher.match_path("/filesystem/read"), Some("filesystem"));
307        assert_eq!(matcher.match_path("/unknown/path"), None);
308    }
309
310    #[test]
311    fn test_route_matcher_longest_prefix() {
312        let routes = vec![
313            create_test_route("api", "/api", false),
314            create_test_route("api-v2", "/api/v2", false),
315        ];
316        let matcher = RouteMatcher::new(&routes);
317
318        // Longer prefix should win
319        assert_eq!(matcher.match_path("/api/v2/users"), Some("api-v2"));
320        assert_eq!(matcher.match_path("/api/v1/users"), Some("api"));
321    }
322
323    #[test]
324    fn test_config_validation() {
325        let valid = create_test_route("test", "/test", false);
326        assert!(valid.validate().is_ok());
327
328        let mut invalid = create_test_route("test", "no-slash", false);
329        assert!(invalid.validate().is_err());
330
331        invalid.path_prefix = "/test".to_string();
332        invalid.name = "".to_string();
333        assert!(invalid.validate().is_err());
334    }
335
336    // ------------------------------------------------------------------------
337    // Additional RouteMatcher Tests
338    // ------------------------------------------------------------------------
339
340    #[test]
341    fn test_route_matcher_empty() {
342        let routes: Vec<ServerRouteConfig> = vec![];
343        let matcher = RouteMatcher::new(&routes);
344        assert_eq!(matcher.match_path("/any/path"), None);
345    }
346
347    #[test]
348    fn test_route_matcher_root_path() {
349        let routes = vec![
350            create_test_route("root", "/", false),
351            create_test_route("api", "/api", false),
352        ];
353        let matcher = RouteMatcher::new(&routes);
354
355        // More specific should win
356        assert_eq!(matcher.match_path("/api/users"), Some("api"));
357        // Root should match everything else
358        assert_eq!(matcher.match_path("/other"), Some("root"));
359    }
360
361    #[test]
362    fn test_route_matcher_exact_match() {
363        let routes = vec![create_test_route("exact", "/exact", false)];
364        let matcher = RouteMatcher::new(&routes);
365
366        assert_eq!(matcher.match_path("/exact"), Some("exact"));
367        assert_eq!(matcher.match_path("/exact/sub"), Some("exact"));
368        // Note: /exactnot starts with /exact, so it matches (prefix-based routing)
369        assert_eq!(matcher.match_path("/exactnot"), Some("exact"));
370        // This one doesn't match
371        assert_eq!(matcher.match_path("/other"), None);
372    }
373
374    // ------------------------------------------------------------------------
375    // RouterError Tests
376    // ------------------------------------------------------------------------
377
378    #[test]
379    fn test_router_error_no_route() {
380        let err = RouterError::NoRoute("/unknown".to_string());
381        let msg = format!("{}", err);
382        assert!(msg.contains("/unknown"));
383    }
384
385    #[test]
386    fn test_router_error_transport_init() {
387        let err =
388            RouterError::TransportInit("server1".to_string(), "connection failed".to_string());
389        let msg = format!("{}", err);
390        assert!(msg.contains("server1"));
391        assert!(msg.contains("connection failed"));
392    }
393
394    #[test]
395    fn test_router_error_from_transport() {
396        let transport_err = TransportError::Timeout;
397        let router_err: RouterError = transport_err.into();
398        assert!(matches!(router_err, RouterError::Transport(_)));
399    }
400
401    // ------------------------------------------------------------------------
402    // ServerRouteConfig Transport Type Tests
403    // ------------------------------------------------------------------------
404
405    #[test]
406    fn test_config_validation_stdio_missing_command() {
407        let mut config = ServerRouteConfig {
408            name: "stdio".to_string(),
409            path_prefix: "/stdio".to_string(),
410            transport: TransportType::Stdio,
411            command: None,
412            args: vec![],
413            url: None,
414            strip_prefix: false,
415        };
416        assert!(config.validate().is_err());
417
418        config.command = Some("node".to_string());
419        assert!(config.validate().is_ok());
420    }
421
422    #[test]
423    fn test_config_validation_http_missing_url() {
424        let config = ServerRouteConfig {
425            name: "http".to_string(),
426            path_prefix: "/http".to_string(),
427            transport: TransportType::Http,
428            command: None,
429            args: vec![],
430            url: None,
431            strip_prefix: false,
432        };
433        assert!(config.validate().is_err());
434    }
435
436    #[test]
437    fn test_config_validation_sse_missing_url() {
438        let config = ServerRouteConfig {
439            name: "sse".to_string(),
440            path_prefix: "/sse".to_string(),
441            transport: TransportType::Sse,
442            command: None,
443            args: vec![],
444            url: None,
445            strip_prefix: false,
446        };
447        assert!(config.validate().is_err());
448    }
449
450    // ------------------------------------------------------------------------
451    // Additional Coverage Tests
452    // ------------------------------------------------------------------------
453
454    #[test]
455    fn test_router_new_validation() {
456        // Test with invalid URL scheme to ensure validation runs
457        let invalid_config = ServerRouteConfig {
458            name: "invalid".to_string(),
459            path_prefix: "/invalid".to_string(),
460            transport: TransportType::Http,
461            command: None,
462            args: vec![],
463            url: Some("not-a-url".to_string()),
464            strip_prefix: false,
465        };
466
467        let result = tokio::runtime::Runtime::new()
468            .unwrap()
469            .block_on(ServerRouter::new(vec![invalid_config]));
470        assert!(result.is_err());
471        assert!(matches!(
472            result.unwrap_err(),
473            RouterError::TransportInit(_, _)
474        ));
475    }
476
477    #[test]
478    fn test_router_send_no_route() {
479        let router = ServerRouter {
480            routes: vec![],
481            default_route: None,
482        };
483
484        let test_message = Message::request(1, "ping", None);
485        let result = tokio::runtime::Runtime::new()
486            .unwrap()
487            .block_on(router.send("/unknown", test_message));
488        assert!(matches!(result, Err(RouterError::NoRoute(_))));
489    }
490
491    #[test]
492    fn test_router_receive_no_route() {
493        let router = ServerRouter {
494            routes: vec![],
495            default_route: None,
496        };
497
498        let result = tokio::runtime::Runtime::new()
499            .unwrap()
500            .block_on(router.receive("/unknown"));
501        assert!(matches!(result, Err(RouterError::NoRoute(_))));
502    }
503
504    #[test]
505    fn test_router_transform_path() {
506        use crate::mocks::MockTransport;
507        let mut config = create_test_route("strip", "/strip", true);
508        config.strip_prefix = true;
509
510        let router = ServerRouter {
511            routes: vec![ServerRoute {
512                config: config.clone(),
513                transport: Arc::new(MockTransport::new()),
514            }],
515            default_route: None,
516        };
517
518        // Should strip prefix
519        assert_eq!(router.transform_path("/strip/foo"), "/foo");
520
521        // Should return original if no match
522        assert_eq!(router.transform_path("/other/foo"), "/other/foo");
523
524        // Should return original if strip_prefix is false
525        let config_no_strip = create_test_route("no-strip", "/no-strip", false);
526        let router_no_strip = ServerRouter {
527            routes: vec![ServerRoute {
528                config: config_no_strip,
529                transport: Arc::new(MockTransport::new()),
530            }],
531            default_route: None,
532        };
533        assert_eq!(
534            router_no_strip.transform_path("/no-strip/foo"),
535            "/no-strip/foo"
536        );
537    }
538
539    #[test]
540    fn test_router_route_count() {
541        use crate::mocks::MockTransport;
542        let router = ServerRouter {
543            routes: vec![
544                ServerRoute {
545                    config: create_test_route("s1", "/s1", false),
546                    transport: Arc::new(MockTransport::new()),
547                },
548                ServerRoute {
549                    config: create_test_route("s2", "/s2", false),
550                    transport: Arc::new(MockTransport::new()),
551                },
552            ],
553            default_route: None,
554        };
555
556        assert_eq!(router.route_count(), 2);
557        assert!(router.has_routes());
558        assert_eq!(router.route_names(), vec!["s1", "s2"]);
559    }
560
561    // -------------------------------------------------------------------------
562    // Additional Router Tests
563    // -------------------------------------------------------------------------
564
565    #[test]
566    fn test_router_with_default_route() {
567        use crate::mocks::MockTransport;
568
569        let default_config = create_test_route("default", "/", false);
570        let default_route = ServerRoute {
571            config: default_config,
572            transport: Arc::new(MockTransport::new()),
573        };
574
575        let router = ServerRouter {
576            routes: vec![ServerRoute {
577                config: create_test_route("api", "/api", false),
578                transport: Arc::new(MockTransport::new()),
579            }],
580            default_route: None,
581        }
582        .with_default(default_route);
583
584        // Verify default is set
585        assert!(router.has_routes());
586
587        // Should find /api route
588        let route = router.find_route("/api/users");
589        assert!(route.is_some());
590        assert_eq!(route.unwrap().config.name, "api");
591
592        // Unknown path should find default
593        let route = router.find_route("/unknown");
594        assert!(route.is_some());
595        assert_eq!(route.unwrap().config.name, "default");
596    }
597
598    #[test]
599    fn test_router_get_route_name() {
600        use crate::mocks::MockTransport;
601
602        let router = ServerRouter {
603            routes: vec![ServerRoute {
604                config: create_test_route("github", "/github", false),
605                transport: Arc::new(MockTransport::new()),
606            }],
607            default_route: None,
608        };
609
610        assert_eq!(router.get_route_name("/github/repos"), Some("github"));
611        assert_eq!(router.get_route_name("/unknown"), None);
612    }
613
614    #[test]
615    fn test_router_get_transport() {
616        use crate::mocks::MockTransport;
617
618        let router = ServerRouter {
619            routes: vec![ServerRoute {
620                config: create_test_route("test", "/test", false),
621                transport: Arc::new(MockTransport::new()),
622            }],
623            default_route: None,
624        };
625
626        // Should return transport for matching route
627        assert!(router.get_transport("/test/path").is_some());
628        // Should return None for non-matching route
629        assert!(router.get_transport("/other/path").is_none());
630    }
631
632    #[test]
633    fn test_router_debug_formatting() {
634        use crate::mocks::MockTransport;
635
636        let router = ServerRouter {
637            routes: vec![ServerRoute {
638                config: create_test_route("s1", "/s1", false),
639                transport: Arc::new(MockTransport::new()),
640            }],
641            default_route: None,
642        };
643
644        // Format should include route count and has_default
645        let debug_str = format!("{:?}", router);
646        assert!(debug_str.contains("route_count: 1"));
647        assert!(debug_str.contains("has_default: false"));
648    }
649
650    #[test]
651    fn test_router_empty_has_no_routes() {
652        let router = ServerRouter {
653            routes: vec![],
654            default_route: None,
655        };
656
657        assert!(!router.has_routes());
658        assert_eq!(router.route_count(), 0);
659        assert!(router.route_names().is_empty());
660    }
661
662    #[test]
663    fn test_router_empty_with_default_has_routes() {
664        use crate::mocks::MockTransport;
665
666        let default_config = create_test_route("default", "/", false);
667        let default_route = ServerRoute {
668            config: default_config,
669            transport: Arc::new(MockTransport::new()),
670        };
671
672        let router = ServerRouter {
673            routes: vec![],
674            default_route: Some(default_route),
675        };
676
677        // Empty routes but has default means has_routes is true
678        assert!(router.has_routes());
679        assert_eq!(router.route_count(), 0); // route_count only counts routes, not default
680    }
681}