mcp_core/
server.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use crate::{
7    protocol::Protocol,
8    tools::{ToolHandler, Tools},
9    types::{CallToolRequest, CallToolResponse, ListRequest, Tool, ToolsListResponse},
10};
11
12use super::{
13    protocol::ProtocolBuilder,
14    transport::Transport,
15    types::{
16        ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
17        ServerCapabilities, LATEST_PROTOCOL_VERSION,
18    },
19};
20use anyhow::Result;
21use std::future::Future;
22use std::pin::Pin;
23
24#[derive(Clone)]
25pub struct ClientConnection {
26    client_capabilities: Option<ClientCapabilities>,
27    client_info: Option<Implementation>,
28    initialized: bool,
29}
30
31#[derive(Clone)]
32pub struct Server;
33
34impl Server {
35    pub fn builder(name: String, version: String) -> ServerProtocolBuilder {
36        ServerProtocolBuilder::new(name, version)
37    }
38
39    pub async fn start<T: Transport>(transport: T) -> Result<()> {
40        transport.open().await
41    }
42}
43
44pub struct ServerProtocolBuilder {
45    protocol_builder: ProtocolBuilder,
46    server_info: Implementation,
47    capabilities: ServerCapabilities,
48    tools: HashMap<String, ToolHandler>,
49    client_connection: Arc<RwLock<ClientConnection>>,
50}
51
52impl ServerProtocolBuilder {
53    pub fn new(name: String, version: String) -> Self {
54        ServerProtocolBuilder {
55            protocol_builder: ProtocolBuilder::new(),
56            server_info: Implementation { name, version },
57            capabilities: ServerCapabilities::default(),
58            tools: HashMap::new(),
59            client_connection: Arc::new(RwLock::new(ClientConnection {
60                client_capabilities: None,
61                client_info: None,
62                initialized: false,
63            })),
64        }
65    }
66
67    pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
68        self.capabilities = capabilities;
69        self
70    }
71
72    pub fn register_tool(
73        mut self,
74        tool: Tool,
75        f: impl Fn(CallToolRequest) -> Pin<Box<dyn Future<Output = CallToolResponse> + Send>>
76            + Send
77            + Sync
78            + 'static,
79    ) -> Self {
80        self.tools.insert(
81            tool.name.clone(),
82            ToolHandler {
83                tool,
84                f: Box::new(f),
85            },
86        );
87        self
88    }
89
90    // Helper function for initialize handler
91    fn handle_init(
92        state: Arc<RwLock<ClientConnection>>,
93        server_info: Implementation,
94        capabilities: ServerCapabilities,
95    ) -> impl Fn(
96        InitializeRequest,
97    )
98        -> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
99        move |req| {
100            let state = state.clone();
101            let server_info = server_info.clone();
102            let capabilities = capabilities.clone();
103
104            Box::pin(async move {
105                let mut state = state
106                    .write()
107                    .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
108                state.client_capabilities = Some(req.capabilities);
109                state.client_info = Some(req.client_info);
110
111                Ok(InitializeResponse {
112                    protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
113                    capabilities,
114                    server_info,
115                })
116            })
117        }
118    }
119
120    // Helper function for initialized handler
121    fn handle_initialized(
122        state: Arc<RwLock<ClientConnection>>,
123    ) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
124        move |_| {
125            let state = state.clone();
126            Box::pin(async move {
127                let mut state = state
128                    .write()
129                    .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
130                state.initialized = true;
131                Ok(())
132            })
133        }
134    }
135
136    pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
137        self.client_connection
138            .read()
139            .ok()?
140            .client_capabilities
141            .clone()
142    }
143
144    pub fn get_client_info(&self) -> Option<Implementation> {
145        self.client_connection.read().ok()?.client_info.clone()
146    }
147
148    pub fn is_initialized(&self) -> bool {
149        self.client_connection
150            .read()
151            .ok()
152            .map(|client_connection| client_connection.initialized)
153            .unwrap_or(false)
154    }
155
156    pub fn build(self) -> Protocol {
157        let tools = Arc::new(Tools::new(self.tools));
158        let tools_clone = tools.clone();
159        let tools_list = tools.clone();
160        let tools_call = tools_clone.clone();
161
162        let conn_for_list = self.client_connection.clone();
163        let conn_for_call = self.client_connection.clone();
164
165        self.protocol_builder
166            .request_handler(
167                "initialize",
168                Self::handle_init(
169                    self.client_connection.clone(),
170                    self.server_info,
171                    self.capabilities,
172                ),
173            )
174            .notification_handler(
175                "notifications/initialized",
176                Self::handle_initialized(self.client_connection.clone()),
177            )
178            .request_handler("tools/list", move |_req: ListRequest| {
179                let tools = tools_list.clone();
180                let conn = conn_for_list.clone();
181
182                Box::pin(async move {
183                    let client_state = conn.read().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
184
185                    if !client_state.initialized {
186                        return Err(anyhow::anyhow!(
187                            "Client must be initialized before using tools/list"
188                        ));
189                    }
190
191                    Ok(ToolsListResponse {
192                        tools: tools.list_tools(),
193                        next_cursor: None,
194                        meta: None,
195                    })
196                })
197            })
198            .request_handler("tools/call", move |req: CallToolRequest| {
199                let tools = tools_call.clone();
200                let conn = conn_for_call.clone();
201
202                Box::pin(async move {
203                    {
204                        // Check if client is initialized
205                        let client_state =
206                            conn.read().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
207
208                        if !client_state.initialized {
209                            return Err(anyhow::anyhow!(
210                                "Client must be initialized before using tools/call"
211                            ));
212                        }
213                    }
214
215                    tools.call_tool(req).await
216                })
217            })
218            .build()
219    }
220}