use config::ServerConfig;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{sync::Arc, time::Duration};
use tokio::sync::watch;
use tokio::sync::RwLock;
use tracing::info;
use crate::prompts::{GetPromptRequest, ListPromptsRequest, PromptCapabilities, PromptManager};
use crate::tools::{ToolCapabilities, ToolManager};
use crate::{
client::ServerCapabilities,
error::McpError,
logging::{LoggingCapabilities, LoggingManager, SetLevelRequest},
protocol::types::*,
protocol::{
BasicRequestHandler, JsonRpcNotification, Protocol, ProtocolBuilder, ProtocolOptions,
RequestHandler,
},
resource::{ListResourcesRequest, ReadResourceRequest, ResourceCapabilities, ResourceManager},
tools::{CallToolRequest, ListToolsRequest},
transport::{sse::SseTransport, stdio::StdioTransport, Transport},
};
use tokio::sync::mpsc;
pub mod config;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeParams {
pub protocol_version: String,
pub capabilities: ClientCapabilities,
pub client_info: ClientInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeResult {
pub protocol_version: String,
pub capabilities: ServerCapabilities,
pub server_info: InitializeServerInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InitializeServerInfo {
pub name: String,
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientCapabilities {
pub roots: Option<RootsCapabilities>,
pub sampling: Option<SamplingCapabilities>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RootsCapabilities {
pub list_changed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SamplingCapabilities {}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientInfo {
pub name: String,
pub version: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerInfo {
pub name: String,
pub version: String,
}
#[derive(Debug, Clone, Copy, PartialEq)]
enum ServerState {
Created,
Initializing,
Running,
ShuttingDown,
}
pub struct McpServer<H>
where
H: RequestHandler + Send + Sync + 'static,
{
pub handler: Arc<H>,
pub config: ServerConfig,
pub resource_manager: Arc<ResourceManager>,
pub tool_manager: Arc<ToolManager>,
pub prompt_manager: Arc<PromptManager>,
pub logging_manager: Arc<tokio::sync::Mutex<LoggingManager>>,
notification_tx: mpsc::Sender<JsonRpcNotification>,
notification_rx: Option<mpsc::Receiver<JsonRpcNotification>>, state: Arc<(watch::Sender<ServerState>, watch::Receiver<ServerState>)>,
supported_versions: Vec<String>,
client_capabilities: Arc<RwLock<Option<ClientCapabilities>>>,
}
impl<H> McpServer<H>
where
H: RequestHandler + Send + Sync + 'static,
{
pub fn new(config: ServerConfig, handler: H) -> Self {
let (notification_tx, notification_rx) = mpsc::channel(100);
let (state_tx, state_rx) = watch::channel(ServerState::Created);
Self {
handler: Arc::new(handler),
config: config.clone(),
resource_manager: Arc::new(ResourceManager::new(ResourceCapabilities {
subscribe: false,
list_changed: false,
})),
tool_manager: Arc::new(ToolManager::new(ToolCapabilities {
list_changed: false,
})),
prompt_manager: Arc::new(PromptManager::new(PromptCapabilities {
list_changed: false,
})),
logging_manager: Arc::new(tokio::sync::Mutex::new(LoggingManager::new())),
notification_tx,
notification_rx: Some(notification_rx),
state: Arc::new((state_tx, state_rx)),
supported_versions: vec!["1.0".to_string()],
client_capabilities: Arc::new(RwLock::new(None)),
}
}
pub async fn process_request(
&self,
method: &str,
params: Option<Value>,
) -> Result<Value, McpError> {
self.handler.handle_request(method, params).await
}
pub async fn run_stdio_transport(&mut self) -> Result<(), McpError> {
let transport = StdioTransport::new(Some(1024));
self.run_transport(transport).await
}
pub async fn run_sse_transport(&mut self) -> Result<(), McpError> {
let transport = SseTransport::new_server(
self.config.server.host.clone(),
self.config.server.port,
1024, );
self.run_transport(transport).await
}
pub async fn run_unix_transport(&mut self) -> Result<(), McpError> {
tracing::info!("Starting Unix transport");
Ok(())
}
async fn run_transport<T: Transport>(&mut self, transport: T) -> Result<(), McpError> {
let notification_rx = self.notification_rx.take().ok_or_else(|| {
McpError::InternalError("Notification receiver already taken".to_string())
})?;
let (shutdown_tx, mut shutdown_rx) = tokio::sync::mpsc::channel::<()>(1);
let state = Arc::clone(&self.state);
tokio::spawn(async move {
loop {
if *state.1.borrow() == ServerState::ShuttingDown {
let _ = shutdown_tx.send(()).await;
break;
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
});
let resource_manager = Arc::clone(&self.resource_manager);
let resource_manager2 = Arc::clone(&self.resource_manager);
let resource_manager3 = Arc::clone(&self.resource_manager);
let tool_manager = Arc::clone(&self.tool_manager);
let tool_manager2 = Arc::clone(&self.tool_manager);
let prompt_manager = Arc::clone(&self.prompt_manager);
let prompt_manager2 = Arc::clone(&self.prompt_manager);
let mut protocol = Protocol::builder(Some(ProtocolOptions {
enforce_strict_capabilities: false,
}))
.with_request_handler(
"initialize",
Box::new(|req, _extra| {
Box::pin(async move {
let params: InitializeParams =
serde_json::from_value(req.params.unwrap_or_default())
.map_err(|_| McpError::InvalidParams)?;
let result = InitializeResult {
protocol_version: "2024-11-05".to_string(),
capabilities: ServerCapabilities {
logging: Some(LoggingCapabilities {}),
prompts: Some(PromptCapabilities {
list_changed: false,
}),
resources: Some(ResourceCapabilities {
subscribe: false,
list_changed: false,
}),
tools: Some(ToolCapabilities {
list_changed: false,
}),
},
server_info: InitializeServerInfo {
name: "test-server".to_string(),
version: "1.0.0".to_string(),
},
};
Ok(serde_json::to_value(result).unwrap())
})
}),
)
.with_request_handler(
"resources/list",
Box::new(move |req, _extra| {
let resource_manager = Arc::clone(&resource_manager);
Box::pin(async move {
let params: ListResourcesRequest = req
.params
.map(|params| serde_json::from_value(params))
.transpose()
.map_err(|_| McpError::InvalidParams)?
.unwrap_or_default();
let resources_list = resource_manager.list_resources(params.cursor).await?;
Ok(serde_json::to_value(resources_list).unwrap())
})
}),
)
.with_request_handler(
"resources/read",
Box::new(move |req, _extra| {
let resource_manager = Arc::clone(&resource_manager2);
Box::pin(async move {
let params: ReadResourceRequest =
serde_json::from_value(req.params.unwrap_or_default())
.map_err(|_| McpError::InvalidParams)?;
let resource = resource_manager.read_resource(¶ms.uri).await?;
Ok(serde_json::to_value(resource).unwrap())
})
}),
)
.with_request_handler(
"resources/templates/list",
Box::new(move |_req, _extra| {
let resource_manager = Arc::clone(&resource_manager3);
Box::pin(async move {
let templates_list = resource_manager.list_templates().await?;
Ok(serde_json::to_value(templates_list).unwrap())
})
}),
)
.with_request_handler(
"tools/list",
Box::new(move |req, _extra| {
let tool_manager = Arc::clone(&tool_manager);
Box::pin(async move {
let params: ListToolsRequest = req
.params
.map(|params| serde_json::from_value(params))
.transpose()
.map_err(|_| McpError::InvalidParams)?
.unwrap_or_default();
let tools_list = tool_manager.list_tools(params.cursor).await?;
Ok(serde_json::to_value(tools_list).unwrap())
})
}),
)
.with_request_handler(
"tools/call",
Box::new(move |req, _extra| {
let tool_manager = Arc::clone(&tool_manager2);
Box::pin(async move {
let params: CallToolRequest =
serde_json::from_value(req.params.unwrap_or_default())
.map_err(|_| McpError::InvalidParams)?;
let result = tool_manager
.call_tool(¶ms.name, params.arguments)
.await?;
Ok(serde_json::to_value(result).unwrap())
})
}),
)
.with_request_handler(
"prompts/list",
Box::new(move |req, _extra| {
let prompt_manager = Arc::clone(&prompt_manager);
Box::pin(async move {
let params: ListPromptsRequest = req
.params
.map(|params| serde_json::from_value(params))
.transpose()
.map_err(|_| McpError::InvalidParams)?
.unwrap_or_default();
let prompts_list = prompt_manager.list_prompts(params.cursor).await?;
Ok(serde_json::to_value(prompts_list).unwrap())
})
}),
)
.with_request_handler(
"prompts/get",
Box::new(move |req, _extra| {
let prompt_manager = Arc::clone(&prompt_manager2);
Box::pin(async move {
let params: GetPromptRequest =
serde_json::from_value(req.params.unwrap_or_default())
.map_err(|_| McpError::InvalidParams)?;
let prompt = prompt_manager
.get_prompt(¶ms.name, params.arguments)
.await?;
Ok(serde_json::to_value(prompt).unwrap())
})
}),
)
.build();
let protocol_handle = protocol.connect(transport).await?;
info!("Server running and ready to handle requests");
shutdown_rx.recv().await;
protocol_handle.close().await?;
info!("Server shutting down");
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::transport::{TransportChannels, TransportCommand, TransportEvent};
use super::*;
use async_trait::async_trait;
use serde_json::json;
use std::time::Duration;
use tokio::time::sleep;
struct MockTransport {
_channels: Option<TransportChannels>,
}
impl MockTransport {
fn new() -> Self {
Self { _channels: None }
}
}
#[async_trait]
impl Transport for MockTransport {
async fn start(&mut self) -> Result<TransportChannels, McpError> {
let (command_tx, mut command_rx) = mpsc::channel(32);
let (event_tx, event_rx) = mpsc::channel(32);
let channels = TransportChannels {
cmd_tx: command_tx,
event_rx: Arc::new(tokio::sync::Mutex::new(event_rx)),
};
tokio::spawn(async move {
while let Some(cmd) = command_rx.recv().await {
match cmd {
TransportCommand::SendMessage(JsonRpcMessage::Request(req)) => {
if req.method == "initialize" {
let response = JsonRpcMessage::Response(JsonRpcResponse {
jsonrpc: "2.0".to_string(),
id: req.id,
result: Some(json!({
"protocol_version": "1.0",
"capabilities": {
"roots": { "listChanged": true }
},
"client_info": {
"name": "test-client",
"version": "1.0.0"
}
})),
error: None,
});
event_tx
.send(TransportEvent::Message(response))
.await
.unwrap();
}
}
TransportCommand::Close => break,
_ => {}
}
}
event_tx.send(TransportEvent::Closed).await.unwrap();
});
self._channels = Some(channels.clone());
Ok(channels)
}
}
struct MockHandler;
#[async_trait]
impl RequestHandler for MockHandler {
async fn handle_request(
&self,
method: &str,
params: Option<Value>,
) -> Result<Value, McpError> {
match method {
"test.echo" => Ok(params.unwrap_or(Value::Null)),
_ => Ok(Value::Null),
}
}
async fn handle_notification(
&self,
_method: &str,
_params: Option<Value>,
) -> Result<(), McpError> {
Ok(())
}
fn get_capabilities(&self) -> crate::protocol::ServerCapabilities {
crate::protocol::ServerCapabilities {
name: "test-server".to_string(),
version: "1.0.0".to_string(),
protocol_version: "1.0".to_string(),
capabilities: vec!["test.echo".to_string()],
}
}
}
#[tokio::test]
async fn test_run_transport() {
let mut config = ServerConfig::default();
config.server.host = "localhost".to_string();
config.server.port = 8080;
let mut server = McpServer::new(config, MockHandler);
let notification_tx = server.notification_tx.clone();
let state = Arc::clone(&server.state);
let server_handle = tokio::spawn(async move {
let transport = MockTransport::new();
server.run_transport(transport).await
});
sleep(Duration::from_millis(100)).await;
let test_notification = JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "test.notification".to_string(),
params: Some(json!({"message": "test"})),
};
notification_tx.send(test_notification).await.unwrap();
sleep(Duration::from_millis(100)).await;
let (state_tx, _): &(watch::Sender<ServerState>, watch::Receiver<ServerState>) = &*state;
state_tx.send(ServerState::ShuttingDown).unwrap();
match tokio::time::timeout(Duration::from_secs(1), server_handle).await {
Ok(result) => {
assert!(result.unwrap().is_ok(), "Server should shut down cleanly");
}
Err(_) => panic!("Server did not shut down within timeout period"),
}
}
#[tokio::test]
async fn test_protocol_messages() {
let mut config = ServerConfig::default();
config.server.host = "localhost".to_string();
config.server.port = 8080;
let mut server = McpServer::new(config, MockHandler);
let notification_tx = server.notification_tx.clone();
let state = Arc::clone(&server.state);
let server_handle = tokio::spawn(async move {
let transport = MockTransport::new();
server.run_transport(transport).await
});
sleep(Duration::from_millis(100)).await;
let notifications = vec![
JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "resource.changed".to_string(),
params: Some(json!({
"path": "/test/resource",
"type": "modified"
})),
},
JsonRpcNotification {
jsonrpc: "2.0".to_string(),
method: "tool.executed".to_string(),
params: Some(json!({
"tool": "test-tool",
"status": "success"
})),
},
];
for notification in notifications {
notification_tx.send(notification).await.unwrap();
sleep(Duration::from_millis(50)).await;
}
let (state_tx, _): &(watch::Sender<ServerState>, watch::Receiver<ServerState>) = &*state;
state_tx.send(ServerState::ShuttingDown).unwrap();
match tokio::time::timeout(Duration::from_secs(1), server_handle).await {
Ok(result) => {
assert!(result.unwrap().is_ok(), "Server should shut down cleanly");
}
Err(_) => panic!("Server did not shut down within timeout period"),
}
}
}