use std::time::Duration;
use turbomcp_core::error::McpResult;
use turbomcp_core::handler::McpHandler;
use super::config::{
ConnectionLimits, ProtocolConfig, RateLimitConfig, ServerConfig, ServerConfigBuilder,
};
#[derive(Debug, Clone, Default)]
pub enum Transport {
#[default]
Stdio,
#[cfg(feature = "http")]
Http {
addr: String,
},
#[cfg(feature = "websocket")]
WebSocket {
addr: String,
},
#[cfg(feature = "tcp")]
Tcp {
addr: String,
},
#[cfg(feature = "unix")]
Unix {
path: String,
},
}
impl Transport {
#[must_use]
pub fn stdio() -> Self {
Self::Stdio
}
#[cfg(feature = "http")]
#[must_use]
pub fn http(addr: impl Into<String>) -> Self {
Self::Http { addr: addr.into() }
}
#[cfg(feature = "websocket")]
#[must_use]
pub fn websocket(addr: impl Into<String>) -> Self {
Self::WebSocket { addr: addr.into() }
}
#[cfg(feature = "tcp")]
#[must_use]
pub fn tcp(addr: impl Into<String>) -> Self {
Self::Tcp { addr: addr.into() }
}
#[cfg(feature = "unix")]
#[must_use]
pub fn unix(path: impl Into<String>) -> Self {
Self::Unix { path: path.into() }
}
}
#[derive(Debug)]
pub struct ServerBuilder<H: McpHandler> {
handler: H,
transport: Transport,
config: ServerConfigBuilder,
graceful_shutdown: Option<Duration>,
}
impl<H: McpHandler> ServerBuilder<H> {
pub fn new(handler: H) -> Self {
Self {
handler,
transport: Transport::default(),
config: ServerConfig::builder(),
graceful_shutdown: None,
}
}
#[must_use]
pub fn transport(mut self, transport: Transport) -> Self {
self.transport = transport;
self
}
#[must_use]
pub fn with_rate_limit(mut self, max_requests: u32, window: Duration) -> Self {
self.config = self.config.rate_limit(RateLimitConfig {
max_requests,
window,
per_client: true,
});
self
}
#[must_use]
pub fn with_connection_limit(mut self, max: usize) -> Self {
self.config = self.config.connection_limits(ConnectionLimits {
max_tcp_connections: max,
max_websocket_connections: max,
max_http_concurrent: max,
max_unix_connections: max,
});
self
}
#[must_use]
pub fn with_graceful_shutdown(mut self, timeout: Duration) -> Self {
self.graceful_shutdown = Some(timeout);
self
}
#[must_use]
pub fn with_protocol(mut self, protocol: ProtocolConfig) -> Self {
self.config = self.config.protocol(protocol);
self
}
#[must_use]
pub fn with_max_message_size(mut self, size: usize) -> Self {
self.config = self.config.max_message_size(size);
self
}
#[must_use]
pub fn with_config(mut self, config: ServerConfig) -> Self {
let mut builder = ServerConfig::builder()
.protocol(config.protocol)
.connection_limits(config.connection_limits)
.required_capabilities(config.required_capabilities)
.max_message_size(config.max_message_size);
if let Some(rate_limit) = config.rate_limit {
builder = builder.rate_limit(rate_limit);
}
self.config = builder;
self
}
#[allow(unused_variables)]
pub async fn serve(self) -> McpResult<()> {
let config = self.config.build();
match self.transport {
Transport::Stdio => {
#[cfg(feature = "stdio")]
{
super::transport::stdio::run_with_config(&self.handler, &config).await
}
#[cfg(not(feature = "stdio"))]
{
Err(turbomcp_core::error::McpError::internal(
"STDIO transport not available. Enable the 'stdio' feature.",
))
}
}
#[cfg(feature = "http")]
Transport::Http { addr } => {
super::transport::http::run_with_config(&self.handler, &addr, &config).await
}
#[cfg(feature = "websocket")]
Transport::WebSocket { addr } => {
super::transport::websocket::run_with_config(&self.handler, &addr, &config).await
}
#[cfg(feature = "tcp")]
Transport::Tcp { addr } => {
super::transport::tcp::run_with_config(&self.handler, &addr, &config).await
}
#[cfg(feature = "unix")]
Transport::Unix { path } => {
super::transport::unix::run_with_config(&self.handler, &path, &config).await
}
}
}
#[must_use]
pub fn handler(&self) -> &H {
&self.handler
}
#[must_use]
pub fn into_handler(self) -> H {
self.handler
}
#[cfg(feature = "http")]
pub fn into_axum_router(self) -> axum::Router {
use axum::{Router, routing::post};
use std::sync::Arc;
let config = self.config.build();
let handler = Arc::new(self.handler);
let rate_limiter = config
.rate_limit
.as_ref()
.map(|cfg| Arc::new(crate::config::RateLimiter::new(cfg.clone())));
let session_manager = crate::transport::http::SessionManager::new();
let session_versions = Arc::new(tokio::sync::RwLock::new(std::collections::HashMap::<
String,
turbomcp_core::types::core::ProtocolVersion,
>::new()));
Router::new()
.route("/", post(handle_json_rpc::<H>))
.route("/mcp", post(handle_json_rpc::<H>))
.with_state(AppState {
handler,
rate_limiter,
config: Some(config),
session_manager,
session_versions,
})
}
#[cfg(feature = "http")]
pub fn into_service(
self,
) -> impl tower::Service<
axum::http::Request<axum::body::Body>,
Response = axum::http::Response<axum::body::Body>,
Error = std::convert::Infallible,
Future = impl Future<
Output = Result<axum::http::Response<axum::body::Body>, std::convert::Infallible>,
> + Send,
> + Clone
+ Send {
use tower::ServiceExt;
self.into_axum_router()
.into_service()
.map_err(|e| match e {})
}
}
#[cfg(feature = "http")]
#[derive(Clone)]
struct AppState<H: McpHandler> {
handler: std::sync::Arc<H>,
rate_limiter: Option<std::sync::Arc<crate::config::RateLimiter>>,
config: Option<crate::config::ServerConfig>,
#[allow(dead_code)]
session_manager: crate::transport::http::SessionManager,
session_versions: std::sync::Arc<
tokio::sync::RwLock<
std::collections::HashMap<String, turbomcp_core::types::core::ProtocolVersion>,
>,
>,
}
#[cfg(feature = "http")]
async fn handle_json_rpc<H: McpHandler>(
axum::extract::State(state): axum::extract::State<AppState<H>>,
headers: axum::http::HeaderMap,
axum::Json(request): axum::Json<serde_json::Value>,
) -> impl axum::response::IntoResponse {
use super::context::RequestContext;
use super::router::{
parse_request, route_request_versioned, route_request_with_config, serialize_response,
};
if let Some(ref limiter) = state.rate_limiter
&& !limiter.check(None)
{
return (
axum::http::StatusCode::TOO_MANY_REQUESTS,
axum::Json(serde_json::json!({
"jsonrpc": "2.0",
"error": {
"code": -32000,
"message": "Rate limit exceeded"
},
"id": null
})),
);
}
let session_id = headers
.get("mcp-session-id")
.and_then(|v| v.to_str().ok())
.map(str::to_owned);
let request_str = match serde_json::to_string(&request) {
Ok(s) => s,
Err(e) => {
return (
axum::http::StatusCode::BAD_REQUEST,
axum::Json(serde_json::json!({
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": format!("Parse error: {}", e)
},
"id": null
})),
);
}
};
let parsed = match parse_request(&request_str) {
Ok(p) => p,
Err(e) => {
return (
axum::http::StatusCode::BAD_REQUEST,
axum::Json(serde_json::json!({
"jsonrpc": "2.0",
"error": {
"code": -32700,
"message": format!("Parse error: {}", e)
},
"id": null
})),
);
}
};
let ctx = RequestContext::http();
let core_ctx = ctx.to_core_context();
let response = if parsed.method == "initialize" {
let resp =
route_request_with_config(&*state.handler, parsed, &core_ctx, state.config.as_ref())
.await;
if resp.result.is_some() {
let negotiated: Option<turbomcp_core::types::core::ProtocolVersion> = resp
.result
.as_ref()
.and_then(|r| r.get("protocolVersion"))
.and_then(|v| v.as_str())
.map(turbomcp_core::types::core::ProtocolVersion::from);
if let (Some(sid), Some(version)) = (session_id.as_deref(), negotiated) {
state
.session_versions
.write()
.await
.insert(sid.to_owned(), version);
tracing::debug!(
session_id = sid,
"Stored negotiated protocol version for BYO Axum session"
);
}
}
resp
} else {
let stored_version = match session_id.as_deref() {
Some(sid) => state.session_versions.read().await.get(sid).cloned(),
None => None,
};
match stored_version {
Some(version) => {
route_request_versioned(&*state.handler, parsed, &core_ctx, &version).await
}
None => {
route_request_with_config(&*state.handler, parsed, &core_ctx, state.config.as_ref())
.await
}
}
};
if !response.should_send() {
return (
axum::http::StatusCode::NO_CONTENT,
axum::Json(serde_json::json!(null)),
);
}
match serialize_response(&response) {
Ok(json_str) => {
let value: serde_json::Value = serde_json::from_str(&json_str).unwrap_or_default();
(axum::http::StatusCode::OK, axum::Json(value))
}
Err(e) => (
axum::http::StatusCode::INTERNAL_SERVER_ERROR,
axum::Json(serde_json::json!({
"jsonrpc": "2.0",
"error": {
"code": -32603,
"message": format!("Internal error: {}", e)
},
"id": null
})),
),
}
}
pub trait McpServerExt: McpHandler + Sized {
fn builder(self) -> ServerBuilder<Self> {
ServerBuilder::new(self)
}
}
impl<T: McpHandler> McpServerExt for T {}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
use turbomcp_core::context::RequestContext as CoreRequestContext;
use turbomcp_core::error::McpError;
use turbomcp_types::{
Prompt, PromptResult, Resource, ResourceResult, ServerInfo, Tool, ToolResult,
};
#[derive(Clone)]
struct TestHandler;
#[allow(clippy::manual_async_fn)]
impl McpHandler for TestHandler {
fn server_info(&self) -> ServerInfo {
ServerInfo::new("test", "1.0.0")
}
fn list_tools(&self) -> Vec<Tool> {
vec![Tool::new("test", "Test tool")]
}
fn list_resources(&self) -> Vec<Resource> {
vec![]
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
fn call_tool<'a>(
&'a self,
_name: &'a str,
_args: Value,
_ctx: &'a CoreRequestContext,
) -> impl std::future::Future<Output = McpResult<ToolResult>> + Send + 'a {
async { Ok(ToolResult::text("ok")) }
}
fn read_resource<'a>(
&'a self,
uri: &'a str,
_ctx: &'a CoreRequestContext,
) -> impl std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a {
let uri = uri.to_string();
async move { Err(McpError::resource_not_found(&uri)) }
}
fn get_prompt<'a>(
&'a self,
name: &'a str,
_args: Option<Value>,
_ctx: &'a CoreRequestContext,
) -> impl std::future::Future<Output = McpResult<PromptResult>> + Send + 'a {
let name = name.to_string();
async move { Err(McpError::prompt_not_found(&name)) }
}
}
#[test]
fn test_transport_default_is_stdio() {
let transport = Transport::default();
assert!(matches!(transport, Transport::Stdio));
}
#[test]
fn test_builder_creation() {
let handler = TestHandler;
let builder = handler.builder();
assert!(matches!(builder.transport, Transport::Stdio));
}
#[test]
fn test_builder_transport_selection() {
let handler = TestHandler;
let builder = handler.clone().builder().transport(Transport::stdio());
assert!(matches!(builder.transport, Transport::Stdio));
}
#[cfg(feature = "http")]
#[test]
fn test_builder_http_transport() {
let handler = TestHandler;
let builder = handler.builder().transport(Transport::http("0.0.0.0:8080"));
assert!(matches!(builder.transport, Transport::Http { .. }));
}
#[test]
fn test_builder_rate_limit() {
let handler = TestHandler;
let builder = handler
.builder()
.with_rate_limit(100, Duration::from_secs(1));
let config = builder.config.build();
assert!(config.rate_limit.is_some());
}
#[test]
fn test_builder_connection_limit() {
let handler = TestHandler;
let builder = handler.builder().with_connection_limit(500);
let config = builder.config.build();
assert_eq!(config.connection_limits.max_tcp_connections, 500);
assert_eq!(config.connection_limits.max_websocket_connections, 500);
assert_eq!(config.connection_limits.max_http_concurrent, 500);
assert_eq!(config.connection_limits.max_unix_connections, 500);
}
#[test]
fn test_builder_graceful_shutdown() {
let handler = TestHandler;
let builder = handler
.builder()
.with_graceful_shutdown(Duration::from_secs(30));
assert_eq!(builder.graceful_shutdown, Some(Duration::from_secs(30)));
}
#[test]
fn test_builder_into_handler() {
let handler = TestHandler;
let builder = handler.builder();
let recovered = builder.into_handler();
assert_eq!(recovered.server_info().name, "test");
}
}