mcp_utils/
server.rs

1use std::{sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use rust_mcp_sdk::{
5    McpServer, StdioTransport, TransportOptions,
6    error::McpSdkError,
7    mcp_server::{HyperServerOptions, ServerHandler, hyper_server, server_runtime::create_server},
8    schema::{
9        CallToolRequest, CallToolRequestParams, CallToolResult, Implementation, InitializeResult,
10        LATEST_PROTOCOL_VERSION, ListToolsRequest, ListToolsResult, RpcError, ServerCapabilities,
11        ServerCapabilitiesTools, schema_utils::CallToolError,
12    },
13};
14
15use crate::{server_config::ServerConfig, tool_box::ToolBox};
16
17#[derive(Debug, Clone, Default)]
18pub struct ServerBuilder {
19    config: ServerConfig,
20}
21
22impl ServerBuilder {
23    pub fn new() -> Self {
24        Self::default()
25    }
26
27    pub fn with_name(mut self, name: impl Into<String>) -> Self {
28        self.config.name = name.into();
29        self
30    }
31
32    pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
33        self.config.instructions = instructions.into();
34        self
35    }
36
37    pub fn with_version(mut self, version: impl Into<String>) -> Self {
38        self.config.version = version.into();
39        self
40    }
41
42    pub fn with_title(mut self, title: impl Into<String>) -> Self {
43        self.config.title = title.into();
44        self
45    }
46
47    pub fn with_timeout(mut self, timeout: Duration) -> Self {
48        self.config.timeout = timeout;
49        self
50    }
51
52    pub fn set_name(&mut self, name: impl Into<String>) {
53        self.config.name = name.into();
54    }
55
56    pub fn set_instructions(&mut self, instructions: impl Into<String>) {
57        self.config.instructions = instructions.into();
58    }
59
60    pub fn set_version(&mut self, version: impl Into<String>) {
61        self.config.version = version.into();
62    }
63
64    pub fn set_title(&mut self, title: impl Into<String>) {
65        self.config.title = title.into();
66    }
67
68    pub fn set_timeout(&mut self, timeout: Duration) {
69        self.config.timeout = timeout;
70    }
71
72    pub fn name(&self) -> &str {
73        &self.config.name
74    }
75
76    pub fn title(&self) -> &str {
77        &self.config.title
78    }
79
80    pub fn version(&self) -> &str {
81        &self.config.version
82    }
83
84    pub fn instructions(&self) -> &str {
85        &self.config.instructions
86    }
87
88    pub async fn start_stdio<T>(self) -> Result<(), McpSdkError>
89    where
90        T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
91    {
92        let transport_options = TransportOptions {
93            timeout: self.config.timeout,
94        };
95
96        create_server(
97            self.get_server_details::<T>(),
98            StdioTransport::new(transport_options)?,
99            Handler::<T>::new(),
100        )
101        .start()
102        .await
103    }
104
105    pub async fn start_server<T>(
106        self,
107        host: impl Into<String>,
108        port: u16,
109    ) -> Result<(), McpSdkError>
110    where
111        T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
112    {
113        let transport_options = TransportOptions {
114            timeout: self.config.timeout,
115        };
116
117        hyper_server::create_server(
118            self.get_server_details::<T>(),
119            Handler::<T>::new(),
120            HyperServerOptions {
121                host: Some(host.into())
122                    .filter(|host| !host.is_empty())
123                    .unwrap_or_else(|| "127.0.0.1".to_string()),
124                port,
125                transport_options: Arc::new(transport_options),
126                ..Default::default()
127            },
128        )
129        .start()
130        .await
131    }
132
133    fn get_server_details<T>(self) -> InitializeResult
134    where
135        T: ToolBox,
136    {
137        InitializeResult {
138            server_info: Implementation {
139                name: self.config.name,
140                version: self.config.version,
141                title: Some(self.config.title).filter(|title| !title.is_empty()),
142            },
143            capabilities: ServerCapabilities {
144                tools: if T::get_tools().is_empty() {
145                    None
146                } else {
147                    Some(ServerCapabilitiesTools { list_changed: None })
148                },
149                ..Default::default()
150            },
151            meta: None,
152            instructions: Some(self.config.instructions),
153            protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
154        }
155    }
156}
157
158struct Handler<T> {
159    _phantom: std::marker::PhantomData<T>,
160}
161
162impl<T> Handler<T> {
163    pub fn new() -> Self {
164        Self {
165            _phantom: std::marker::PhantomData,
166        }
167    }
168}
169
170#[async_trait]
171#[allow(unused)]
172impl<T> ServerHandler for Handler<T>
173where
174    T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
175{
176    async fn handle_list_tools_request(
177        &self,
178        request: ListToolsRequest,
179        runtime: &dyn McpServer,
180    ) -> Result<ListToolsResult, RpcError> {
181        runtime.assert_server_request_capabilities(request.method())?;
182
183        Ok(ListToolsResult {
184            meta: None,
185            next_cursor: None,
186            tools: T::get_tools(),
187        })
188    }
189
190    async fn handle_call_tool_request(
191        &self,
192        request: CallToolRequest,
193        runtime: &dyn McpServer,
194    ) -> Result<CallToolResult, CallToolError> {
195        runtime
196            .assert_server_request_capabilities(request.method())
197            .map_err(CallToolError::new)?;
198
199        let custom_tool = T::try_from(request.params).map_err(CallToolError::new)?;
200
201        custom_tool.get_tool().call().await
202    }
203}