Skip to main content

mcp_utils/
server.rs

1use std::{sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use rust_mcp_sdk::{
5    McpServer, StdioTransport, ToMcpServerHandler, TransportOptions,
6    error::McpSdkError,
7    mcp_server::{
8        HyperServerOptions, McpServerOptions, ServerHandler, hyper_server,
9        server_runtime::create_server,
10    },
11    schema::{
12        CallToolRequestParams, CallToolResult, Implementation, InitializeResult,
13        LATEST_PROTOCOL_VERSION, ListToolsResult, PaginatedRequestParams, RpcError,
14        ServerCapabilities, ServerCapabilitiesTools, schema_utils::CallToolError,
15    },
16};
17
18use crate::{server_config::ServerConfig, tool_box::ToolBox};
19
20#[derive(Debug, Clone, Default)]
21pub struct ServerBuilder {
22    config: ServerConfig,
23}
24
25impl ServerBuilder {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn with_name(mut self, name: impl Into<String>) -> Self {
31        self.config.name = name.into();
32        self
33    }
34
35    pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
36        self.config.instructions = instructions.into();
37        self
38    }
39
40    pub fn with_version(mut self, version: impl Into<String>) -> Self {
41        self.config.version = version.into();
42        self
43    }
44
45    pub fn with_title(mut self, title: impl Into<String>) -> Self {
46        self.config.title = title.into();
47        self
48    }
49
50    pub fn with_timeout(mut self, timeout: Duration) -> Self {
51        self.config.timeout = timeout;
52        self
53    }
54
55    pub fn set_name(&mut self, name: impl Into<String>) {
56        self.config.name = name.into();
57    }
58
59    pub fn set_instructions(&mut self, instructions: impl Into<String>) {
60        self.config.instructions = instructions.into();
61    }
62
63    pub fn set_version(&mut self, version: impl Into<String>) {
64        self.config.version = version.into();
65    }
66
67    pub fn set_title(&mut self, title: impl Into<String>) {
68        self.config.title = title.into();
69    }
70
71    pub fn set_timeout(&mut self, timeout: Duration) {
72        self.config.timeout = timeout;
73    }
74
75    pub fn name(&self) -> &str {
76        &self.config.name
77    }
78
79    pub fn title(&self) -> &str {
80        &self.config.title
81    }
82
83    pub fn version(&self) -> &str {
84        &self.config.version
85    }
86
87    pub fn instructions(&self) -> &str {
88        &self.config.instructions
89    }
90
91    pub async fn start_stdio<T>(self) -> Result<(), McpSdkError>
92    where
93        T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
94    {
95        let transport_options = TransportOptions {
96            timeout: self.config.timeout,
97        };
98
99        create_server(McpServerOptions {
100            server_details: self.get_server_details::<T>(),
101            transport: StdioTransport::new(transport_options)?,
102            handler: Handler::<T>::new().to_mcp_server_handler(),
103            task_store: None,
104            client_task_store: None,
105            message_observer: None,
106        })
107        .start()
108        .await
109    }
110
111    pub async fn start_server<T>(
112        self,
113        host: impl Into<String>,
114        port: u16,
115    ) -> Result<(), McpSdkError>
116    where
117        T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
118    {
119        let transport_options = TransportOptions {
120            timeout: self.config.timeout,
121        };
122
123        hyper_server::create_server(
124            self.get_server_details::<T>(),
125            Handler::<T>::new().to_mcp_server_handler(),
126            HyperServerOptions {
127                host: Some(host.into())
128                    .filter(|host| !host.is_empty())
129                    .unwrap_or_else(|| "127.0.0.1".to_string()),
130                port,
131                transport_options: Arc::new(transport_options),
132                ..Default::default()
133            },
134        )
135        .start()
136        .await
137    }
138
139    fn get_server_details<T>(self) -> InitializeResult
140    where
141        T: ToolBox,
142    {
143        InitializeResult {
144            server_info: Implementation {
145                name: self.config.name,
146                version: self.config.version,
147                title: Some(self.config.title).filter(|title| !title.is_empty()),
148                description: Some(self.config.description)
149                    .filter(|description| !description.is_empty()),
150                website_url: None,
151                icons: Default::default(),
152            },
153            capabilities: ServerCapabilities {
154                tools: if T::get_tools().is_empty() {
155                    None
156                } else {
157                    Some(ServerCapabilitiesTools { list_changed: None })
158                },
159                ..Default::default()
160            },
161            meta: None,
162            instructions: Some(self.config.instructions),
163            protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
164        }
165    }
166}
167
168struct Handler<T> {
169    _phantom: std::marker::PhantomData<T>,
170}
171
172impl<T> Handler<T> {
173    pub fn new() -> Self {
174        Self {
175            _phantom: std::marker::PhantomData,
176        }
177    }
178}
179
180#[async_trait]
181#[allow(unused)]
182impl<T> ServerHandler for Handler<T>
183where
184    T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
185{
186    async fn handle_list_tools_request(
187        &self,
188        params: Option<PaginatedRequestParams>,
189        runtime: Arc<dyn McpServer>,
190    ) -> Result<ListToolsResult, RpcError> {
191        Ok(ListToolsResult {
192            meta: None,
193            next_cursor: None,
194            tools: T::get_tools(),
195        })
196    }
197
198    async fn handle_call_tool_request(
199        &self,
200        params: CallToolRequestParams,
201        runtime: Arc<dyn McpServer>,
202    ) -> Result<CallToolResult, CallToolError> {
203        let custom_tool = T::try_from(params).map_err(CallToolError::new)?;
204
205        custom_tool.get_tool().call().await
206    }
207}