use crate::handler::{handle_mcp_post, handle_sse};
use crate::state::{HasServerInfo, McpState};
use axum::routing::{get, post};
use axum::Router;
use mcpkit_server::ServerHandler;
use tower_http::cors::{Any, CorsLayer};
use tower_http::trace::TraceLayer;
pub struct McpRouter<H> {
state: McpState<H>,
enable_cors: bool,
enable_tracing: bool,
post_path: String,
sse_path: String,
}
impl<H> McpRouter<H>
where
H: ServerHandler + HasServerInfo + Send + Sync + Clone + 'static,
{
pub fn new(handler: H) -> Self {
Self {
state: McpState::new(handler),
enable_cors: false,
enable_tracing: false,
post_path: "/".to_string(),
sse_path: "/sse".to_string(),
}
}
#[must_use]
pub const fn with_cors(mut self) -> Self {
self.enable_cors = true;
self
}
#[must_use]
pub const fn with_tracing(mut self) -> Self {
self.enable_tracing = true;
self
}
#[must_use]
pub fn post_path(mut self, path: impl Into<String>) -> Self {
self.post_path = path.into();
self
}
#[must_use]
pub fn sse_path(mut self, path: impl Into<String>) -> Self {
self.sse_path = path.into();
self
}
#[must_use]
pub fn into_router(self) -> Router {
let mut router = Router::new()
.route(&self.post_path, post(handle_mcp_post::<H>))
.route(&self.sse_path, get(handle_sse::<H>))
.with_state(self.state);
if self.enable_cors {
router = router.layer(
CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any),
);
}
if self.enable_tracing {
router = router.layer(TraceLayer::new_for_http());
}
router
}
}
#[cfg(test)]
mod tests {
use super::*;
use mcpkit_core::capability::{ServerCapabilities, ServerInfo};
use mcpkit_server::ServerHandler;
#[derive(Clone)]
struct TestHandler;
impl ServerHandler for TestHandler {
fn server_info(&self) -> ServerInfo {
ServerInfo {
name: "test-server".to_string(),
version: "1.0.0".to_string(),
protocol_version: None,
}
}
fn capabilities(&self) -> ServerCapabilities {
ServerCapabilities::default()
}
}
#[test]
fn test_router_builder() {
let router = McpRouter::new(TestHandler)
.with_cors()
.with_tracing()
.post_path("/api/mcp")
.sse_path("/api/sse")
.into_router();
let _ = router;
}
}