use crate::server::streamable_http_server::{
build_mcp_router, make_server_state, StreamableHttpServerConfig,
};
use crate::server::tower_layers::{DnsRebindingLayer, SecurityHeadersLayer};
use crate::server::Server;
use axum::Router;
use std::sync::Arc;
pub use crate::server::tower_layers::AllowedOrigins;
#[derive(Debug, Default)]
pub struct RouterConfig {
pub allowed_origins: Option<AllowedOrigins>,
pub security_headers: SecurityHeadersLayer,
pub server_config: StreamableHttpServerConfig,
}
pub fn router(server: Arc<tokio::sync::Mutex<Server>>) -> Router {
router_with_config(server, RouterConfig::default())
}
pub fn router_with_config(server: Arc<tokio::sync::Mutex<Server>>, config: RouterConfig) -> Router {
let allowed = config
.allowed_origins
.unwrap_or_else(AllowedOrigins::localhost);
let mut server_config = config.server_config;
server_config.allowed_origins = Some(allowed.clone());
let state = make_server_state(server, server_config);
let base_router = build_mcp_router(state);
let cors = crate::server::tower_layers::build_mcp_cors_layer(&allowed);
base_router
.layer(config.security_headers)
.layer(DnsRebindingLayer::new(allowed))
.layer(cors)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_router_returns_router() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let server = Server::builder()
.name("test")
.version("0.1.0")
.build()
.unwrap();
let server = Arc::new(tokio::sync::Mutex::new(server));
let _app = router(server);
});
}
#[test]
fn test_router_with_explicit_origins() {
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let server = Server::builder()
.name("test")
.version("0.1.0")
.build()
.unwrap();
let server = Arc::new(tokio::sync::Mutex::new(server));
let _app = router_with_config(
server,
RouterConfig {
allowed_origins: Some(AllowedOrigins::explicit(vec![
"https://example.com".to_string()
])),
..Default::default()
},
);
});
}
#[test]
fn test_router_config_default() {
let config = RouterConfig::default();
assert!(config.allowed_origins.is_none());
}
}