use std::time::Duration;
use turbomcp_core::error::McpResult;
use turbomcp_core::handler::McpHandler;
use super::config::{
ConnectionLimits, OriginValidationConfig, ProtocolConfig, RateLimitConfig, ServerConfig,
ServerConfigBuilder,
};
#[derive(Debug, Clone, Default)]
pub enum Transport {
#[default]
Stdio,
#[cfg(feature = "http")]
Http {
addr: String,
},
#[cfg(feature = "websocket")]
WebSocket {
addr: String,
},
#[cfg(feature = "tcp")]
Tcp {
addr: String,
},
#[cfg(feature = "unix")]
Unix {
path: String,
},
}
impl Transport {
#[must_use]
pub fn stdio() -> Self {
Self::Stdio
}
#[cfg(feature = "http")]
#[must_use]
pub fn http(addr: impl Into<String>) -> Self {
Self::Http { addr: addr.into() }
}
#[cfg(feature = "websocket")]
#[must_use]
pub fn websocket(addr: impl Into<String>) -> Self {
Self::WebSocket { addr: addr.into() }
}
#[cfg(feature = "tcp")]
#[must_use]
pub fn tcp(addr: impl Into<String>) -> Self {
Self::Tcp { addr: addr.into() }
}
#[cfg(feature = "unix")]
#[must_use]
pub fn unix(path: impl Into<String>) -> Self {
Self::Unix { path: path.into() }
}
}
#[derive(Debug)]
pub struct ServerBuilder<H: McpHandler> {
handler: H,
transport: Transport,
config: ServerConfigBuilder,
graceful_shutdown: Option<Duration>,
}
impl<H: McpHandler> ServerBuilder<H> {
pub fn new(handler: H) -> Self {
Self {
handler,
transport: Transport::default(),
config: ServerConfig::builder(),
graceful_shutdown: None,
}
}
#[must_use]
pub fn transport(mut self, transport: Transport) -> Self {
self.transport = transport;
self
}
#[must_use]
pub fn with_rate_limit(mut self, max_requests: u32, window: Duration) -> Self {
self.config = self.config.rate_limit(RateLimitConfig {
max_requests,
window,
per_client: true,
});
self
}
#[must_use]
pub fn with_allowed_origin(mut self, origin: impl Into<String>) -> Self {
self.config = self.config.allow_origin(origin);
self
}
#[must_use]
pub fn with_origin_validation(mut self, config: OriginValidationConfig) -> Self {
self.config = self.config.origin_validation(config);
self
}
#[must_use]
pub fn allow_localhost_origins(mut self, allow: bool) -> Self {
self.config = self.config.allow_localhost_origins(allow);
self
}
#[must_use]
pub fn allow_any_origin(mut self, allow: bool) -> Self {
self.config = self.config.allow_any_origin(allow);
self
}
#[must_use]
pub fn with_connection_limit(mut self, max: usize) -> Self {
self.config = self.config.connection_limits(ConnectionLimits {
max_tcp_connections: max,
max_websocket_connections: max,
max_http_concurrent: max,
max_unix_connections: max,
});
self
}
#[must_use]
pub fn with_graceful_shutdown(mut self, timeout: Duration) -> Self {
self.graceful_shutdown = Some(timeout);
self
}
#[must_use]
pub fn with_protocol(mut self, protocol: ProtocolConfig) -> Self {
self.config = self.config.protocol(protocol);
self
}
#[must_use]
pub fn with_max_message_size(mut self, size: usize) -> Self {
self.config = self.config.max_message_size(size);
self
}
#[must_use]
pub fn with_config(mut self, config: ServerConfig) -> Self {
let mut builder = ServerConfig::builder()
.protocol(config.protocol)
.connection_limits(config.connection_limits)
.required_capabilities(config.required_capabilities)
.max_message_size(config.max_message_size)
.origin_validation(config.origin_validation);
if let Some(rate_limit) = config.rate_limit {
builder = builder.rate_limit(rate_limit);
}
self.config = builder;
self
}
#[allow(unused_variables)]
pub async fn serve(self) -> McpResult<()> {
let config = self.config.build();
match self.transport {
Transport::Stdio => {
#[cfg(feature = "stdio")]
{
super::transport::stdio::run_with_config(&self.handler, &config).await
}
#[cfg(not(feature = "stdio"))]
{
Err(turbomcp_core::error::McpError::internal(
"STDIO transport not available. Enable the 'stdio' feature.",
))
}
}
#[cfg(feature = "http")]
Transport::Http { addr } => {
super::transport::http::run_with_config(&self.handler, &addr, &config).await
}
#[cfg(feature = "websocket")]
Transport::WebSocket { addr } => {
super::transport::websocket::run_with_config(&self.handler, &addr, &config).await
}
#[cfg(feature = "tcp")]
Transport::Tcp { addr } => {
super::transport::tcp::run_with_config(&self.handler, &addr, &config).await
}
#[cfg(feature = "unix")]
Transport::Unix { path } => {
super::transport::unix::run_with_config(&self.handler, &path, &config).await
}
}
}
#[must_use]
pub fn handler(&self) -> &H {
&self.handler
}
#[must_use]
pub fn into_handler(self) -> H {
self.handler
}
#[cfg(feature = "http")]
pub fn into_axum_router(self) -> axum::Router {
use std::sync::Arc;
let config = self.config.build();
let rate_limiter = config
.rate_limit
.as_ref()
.map(|cfg| Arc::new(crate::config::RateLimiter::new(cfg.clone())));
crate::transport::http::build_router(self.handler, rate_limiter, Some(config))
}
#[cfg(feature = "http")]
pub fn into_service(
self,
) -> impl tower::Service<
axum::http::Request<axum::body::Body>,
Response = axum::http::Response<axum::body::Body>,
Error = std::convert::Infallible,
Future = impl Future<
Output = Result<axum::http::Response<axum::body::Body>, std::convert::Infallible>,
> + Send,
> + Clone
+ Send {
use tower::ServiceExt;
self.into_axum_router()
.into_service()
.map_err(|e| match e {})
}
}
pub trait McpServerExt: McpHandler + Sized {
fn builder(self) -> ServerBuilder<Self> {
ServerBuilder::new(self)
}
}
impl<T: McpHandler> McpServerExt for T {}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
use turbomcp_core::context::RequestContext as CoreRequestContext;
use turbomcp_core::error::McpError;
use turbomcp_types::{
Prompt, PromptResult, Resource, ResourceResult, ServerInfo, Tool, ToolResult,
};
#[derive(Clone)]
struct TestHandler;
#[allow(clippy::manual_async_fn)]
impl McpHandler for TestHandler {
fn server_info(&self) -> ServerInfo {
ServerInfo::new("test", "1.0.0")
}
fn list_tools(&self) -> Vec<Tool> {
vec![Tool::new("test", "Test tool")]
}
fn list_resources(&self) -> Vec<Resource> {
vec![]
}
fn list_prompts(&self) -> Vec<Prompt> {
vec![]
}
fn call_tool<'a>(
&'a self,
_name: &'a str,
_args: Value,
_ctx: &'a CoreRequestContext,
) -> impl std::future::Future<Output = McpResult<ToolResult>> + Send + 'a {
async { Ok(ToolResult::text("ok")) }
}
fn read_resource<'a>(
&'a self,
uri: &'a str,
_ctx: &'a CoreRequestContext,
) -> impl std::future::Future<Output = McpResult<ResourceResult>> + Send + 'a {
let uri = uri.to_string();
async move { Err(McpError::resource_not_found(&uri)) }
}
fn get_prompt<'a>(
&'a self,
name: &'a str,
_args: Option<Value>,
_ctx: &'a CoreRequestContext,
) -> impl std::future::Future<Output = McpResult<PromptResult>> + Send + 'a {
let name = name.to_string();
async move { Err(McpError::prompt_not_found(&name)) }
}
}
#[test]
fn test_transport_default_is_stdio() {
let transport = Transport::default();
assert!(matches!(transport, Transport::Stdio));
}
#[test]
fn test_builder_creation() {
let handler = TestHandler;
let builder = handler.builder();
assert!(matches!(builder.transport, Transport::Stdio));
}
#[test]
fn test_builder_transport_selection() {
let handler = TestHandler;
let builder = handler.clone().builder().transport(Transport::stdio());
assert!(matches!(builder.transport, Transport::Stdio));
}
#[cfg(feature = "http")]
#[test]
fn test_builder_http_transport() {
let handler = TestHandler;
let builder = handler.builder().transport(Transport::http("0.0.0.0:8080"));
assert!(matches!(builder.transport, Transport::Http { .. }));
}
#[test]
fn test_builder_rate_limit() {
let handler = TestHandler;
let builder = handler
.builder()
.with_rate_limit(100, Duration::from_secs(1));
let config = builder.config.build();
assert!(config.rate_limit.is_some());
}
#[test]
fn test_builder_connection_limit() {
let handler = TestHandler;
let builder = handler.builder().with_connection_limit(500);
let config = builder.config.build();
assert_eq!(config.connection_limits.max_tcp_connections, 500);
assert_eq!(config.connection_limits.max_websocket_connections, 500);
assert_eq!(config.connection_limits.max_http_concurrent, 500);
assert_eq!(config.connection_limits.max_unix_connections, 500);
}
#[test]
fn test_builder_graceful_shutdown() {
let handler = TestHandler;
let builder = handler
.builder()
.with_graceful_shutdown(Duration::from_secs(30));
assert_eq!(builder.graceful_shutdown, Some(Duration::from_secs(30)));
}
#[test]
fn test_builder_into_handler() {
let handler = TestHandler;
let builder = handler.builder();
let recovered = builder.into_handler();
assert_eq!(recovered.server_info().name, "test");
}
}