1use std::{sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use rust_mcp_sdk::{
5 McpServer, StdioTransport, TransportOptions,
6 error::McpSdkError,
7 mcp_server::{HyperServerOptions, ServerHandler, hyper_server, server_runtime::create_server},
8 schema::{
9 CallToolRequest, CallToolRequestParams, CallToolResult, Implementation, InitializeResult,
10 LATEST_PROTOCOL_VERSION, ListToolsRequest, ListToolsResult, RpcError, ServerCapabilities,
11 ServerCapabilitiesTools, schema_utils::CallToolError,
12 },
13};
14
15use crate::{server_config::ServerConfig, tool_box::ToolBox};
16
17#[derive(Debug, Clone, Default)]
18pub struct ServerBuilder {
19 config: ServerConfig,
20}
21
22impl ServerBuilder {
23 pub fn new() -> Self {
24 Self::default()
25 }
26
27 pub fn with_name(mut self, name: impl Into<String>) -> Self {
28 self.config.name = name.into();
29 self
30 }
31
32 pub fn with_instructions(mut self, instructions: impl Into<String>) -> Self {
33 self.config.instructions = instructions.into();
34 self
35 }
36
37 pub fn with_version(mut self, version: impl Into<String>) -> Self {
38 self.config.version = version.into();
39 self
40 }
41
42 pub fn with_title(mut self, title: impl Into<String>) -> Self {
43 self.config.title = title.into();
44 self
45 }
46
47 pub fn with_timeout(mut self, timeout: Duration) -> Self {
48 self.config.timeout = timeout;
49 self
50 }
51
52 pub fn set_name(&mut self, name: impl Into<String>) {
53 self.config.name = name.into();
54 }
55
56 pub fn set_instructions(&mut self, instructions: impl Into<String>) {
57 self.config.instructions = instructions.into();
58 }
59
60 pub fn set_version(&mut self, version: impl Into<String>) {
61 self.config.version = version.into();
62 }
63
64 pub fn set_title(&mut self, title: impl Into<String>) {
65 self.config.title = title.into();
66 }
67
68 pub fn set_timeout(&mut self, timeout: Duration) {
69 self.config.timeout = timeout;
70 }
71
72 pub fn name(&self) -> &str {
73 &self.config.name
74 }
75
76 pub fn title(&self) -> &str {
77 &self.config.title
78 }
79
80 pub fn version(&self) -> &str {
81 &self.config.version
82 }
83
84 pub fn instructions(&self) -> &str {
85 &self.config.instructions
86 }
87
88 pub async fn start_stdio<T>(self) -> Result<(), McpSdkError>
89 where
90 T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
91 {
92 let transport_options = TransportOptions {
93 timeout: self.config.timeout,
94 };
95
96 create_server(
97 self.get_server_details::<T>(),
98 StdioTransport::new(transport_options)?,
99 Handler::<T>::new(),
100 )
101 .start()
102 .await
103 }
104
105 pub async fn start_server<T>(
106 self,
107 host: impl Into<String>,
108 port: u16,
109 ) -> Result<(), McpSdkError>
110 where
111 T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
112 {
113 let transport_options = TransportOptions {
114 timeout: self.config.timeout,
115 };
116
117 hyper_server::create_server(
118 self.get_server_details::<T>(),
119 Handler::<T>::new(),
120 HyperServerOptions {
121 host: Some(host.into())
122 .filter(|host| !host.is_empty())
123 .unwrap_or_else(|| "127.0.0.1".to_string()),
124 port,
125 transport_options: Arc::new(transport_options),
126 ..Default::default()
127 },
128 )
129 .start()
130 .await
131 }
132
133 fn get_server_details<T>(self) -> InitializeResult
134 where
135 T: ToolBox,
136 {
137 InitializeResult {
138 server_info: Implementation {
139 name: self.config.name,
140 version: self.config.version,
141 title: Some(self.config.title).filter(|title| !title.is_empty()),
142 },
143 capabilities: ServerCapabilities {
144 tools: if T::get_tools().is_empty() {
145 None
146 } else {
147 Some(ServerCapabilitiesTools { list_changed: None })
148 },
149 ..Default::default()
150 },
151 meta: None,
152 instructions: Some(self.config.instructions),
153 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
154 }
155 }
156}
157
158struct Handler<T> {
159 _phantom: std::marker::PhantomData<T>,
160}
161
162impl<T> Handler<T> {
163 pub fn new() -> Self {
164 Self {
165 _phantom: std::marker::PhantomData,
166 }
167 }
168}
169
170#[async_trait]
171#[allow(unused)]
172impl<T> ServerHandler for Handler<T>
173where
174 T: ToolBox + TryFrom<CallToolRequestParams, Error = CallToolError> + Send + Sync + 'static,
175{
176 async fn handle_list_tools_request(
177 &self,
178 request: ListToolsRequest,
179 runtime: &dyn McpServer,
180 ) -> Result<ListToolsResult, RpcError> {
181 runtime.assert_server_request_capabilities(request.method())?;
182
183 Ok(ListToolsResult {
184 meta: None,
185 next_cursor: None,
186 tools: T::get_tools(),
187 })
188 }
189
190 async fn handle_call_tool_request(
191 &self,
192 request: CallToolRequest,
193 runtime: &dyn McpServer,
194 ) -> Result<CallToolResult, CallToolError> {
195 runtime
196 .assert_server_request_capabilities(request.method())
197 .map_err(CallToolError::new)?;
198
199 let custom_tool = T::try_from(request.params).map_err(CallToolError::new)?;
200
201 custom_tool.get_tool().call().await
202 }
203}