1use std::{sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use rust_mcp_sdk::{
5 McpServer, StdioTransport, ToMcpServerHandler, TransportOptions,
6 error::McpSdkError,
7 mcp_server::{
8 HyperServerOptions, McpServerOptions, ServerHandler, hyper_server,
9 server_runtime::create_server,
10 },
11 schema::{
12 CallToolRequestParams, CallToolResult, Implementation, InitializeResult,
13 LATEST_PROTOCOL_VERSION, ListToolsResult, PaginatedRequestParams, RpcError,
14 ServerCapabilities, ServerCapabilitiesTools, schema_utils::CallToolError,
15 },
16};
17
18use crate::{server_config::ServerConfig, tool_box::ToolBox};
19
20#[derive(Debug, Clone, Default)]
21pub struct ServerBuilder {
22 config: ServerConfig,
23}
24
25impl ServerBuilder {
26 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn with_name(mut self, name: impl Into<String>) -> Self {
31 self.config.name = name.into();
32 self
33 }
34
35 pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
36 self.config.instructions = instructions.into();
37 self
38 }
39
40 pub fn with_version(mut self, version: impl Into<String>) -> Self {
41 self.config.version = version.into();
42 self
43 }
44
45 pub fn with_title(mut self, title: impl Into<String>) -> Self {
46 self.config.title = title.into();
47 self
48 }
49
50 pub fn with_timeout(mut self, timeout: Duration) -> Self {
51 self.config.timeout = timeout;
52 self
53 }
54
55 pub fn set_name(&mut self, name: impl Into<String>) {
56 self.config.name = name.into();
57 }
58
59 pub fn set_instructions(&mut self, instructions: impl Into<String>) {
60 self.config.instructions = instructions.into();
61 }
62
63 pub fn set_version(&mut self, version: impl Into<String>) {
64 self.config.version = version.into();
65 }
66
67 pub fn set_title(&mut self, title: impl Into<String>) {
68 self.config.title = title.into();
69 }
70
71 pub fn set_timeout(&mut self, timeout: Duration) {
72 self.config.timeout = timeout;
73 }
74
75 pub fn name(&self) -> &str {
76 &self.config.name
77 }
78
79 pub fn title(&self) -> &str {
80 &self.config.title
81 }
82
83 pub fn version(&self) -> &str {
84 &self.config.version
85 }
86
87 pub fn instructions(&self) -> &str {
88 &self.config.instructions
89 }
90
91 pub async fn start_stdio<T>(self) -> Result<(), McpSdkError>
92 where
93 T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
94 {
95 let transport_options = TransportOptions {
96 timeout: self.config.timeout,
97 };
98
99 create_server(McpServerOptions {
100 server_details: self.get_server_details::<T>(),
101 transport: StdioTransport::new(transport_options)?,
102 handler: Handler::<T>::new().to_mcp_server_handler(),
103 task_store: None,
104 client_task_store: None,
105 message_observer: None,
106 })
107 .start()
108 .await
109 }
110
111 pub async fn start_server<T>(
112 self,
113 host: impl Into<String>,
114 port: u16,
115 ) -> Result<(), McpSdkError>
116 where
117 T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
118 {
119 let transport_options = TransportOptions {
120 timeout: self.config.timeout,
121 };
122
123 hyper_server::create_server(
124 self.get_server_details::<T>(),
125 Handler::<T>::new().to_mcp_server_handler(),
126 HyperServerOptions {
127 host: Some(host.into())
128 .filter(|host| !host.is_empty())
129 .unwrap_or_else(|| "127.0.0.1".to_string()),
130 port,
131 transport_options: Arc::new(transport_options),
132 ..Default::default()
133 },
134 )
135 .start()
136 .await
137 }
138
139 fn get_server_details<T>(self) -> InitializeResult
140 where
141 T: ToolBox,
142 {
143 InitializeResult {
144 server_info: Implementation {
145 name: self.config.name,
146 version: self.config.version,
147 title: Some(self.config.title).filter(|title| !title.is_empty()),
148 description: Some(self.config.description)
149 .filter(|description| !description.is_empty()),
150 website_url: None,
151 icons: Default::default(),
152 },
153 capabilities: ServerCapabilities {
154 tools: if T::get_tools().is_empty() {
155 None
156 } else {
157 Some(ServerCapabilitiesTools { list_changed: None })
158 },
159 ..Default::default()
160 },
161 meta: None,
162 instructions: Some(self.config.instructions),
163 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
164 }
165 }
166}
167
168struct Handler<T> {
169 _phantom: std::marker::PhantomData<T>,
170}
171
172impl<T> Handler<T> {
173 pub fn new() -> Self {
174 Self {
175 _phantom: std::marker::PhantomData,
176 }
177 }
178}
179
180#[async_trait]
181#[allow(unused)]
182impl<T> ServerHandler for Handler<T>
183where
184 T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
185{
186 async fn handle_list_tools_request(
187 &self,
188 params: Option<PaginatedRequestParams>,
189 runtime: Arc<dyn McpServer>,
190 ) -> Result<ListToolsResult, RpcError> {
191 Ok(ListToolsResult {
192 meta: None,
193 next_cursor: None,
194 tools: T::get_tools(),
195 })
196 }
197
198 async fn handle_call_tool_request(
199 &self,
200 params: CallToolRequestParams,
201 runtime: Arc<dyn McpServer>,
202 ) -> Result<CallToolResult, CallToolError> {
203 let custom_tool = T::try_from(params).map_err(CallToolError::new)?;
204
205 custom_tool.get_tool().call().await
206 }
207}