mcp_core/
server.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use crate::{
7    protocol::Protocol,
8    tools::{ToolHandler, ToolHandlerFn, Tools},
9    types::{CallToolRequest, ListRequest, Tool, ToolsListResponse},
10};
11
12use super::{
13    protocol::ProtocolBuilder,
14    transport::Transport,
15    types::{
16        ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
17        ServerCapabilities, LATEST_PROTOCOL_VERSION,
18    },
19};
20use anyhow::Result;
21use std::pin::Pin;
22
23#[derive(Clone)]
24pub struct ClientConnection {
25    client_capabilities: Option<ClientCapabilities>,
26    client_info: Option<Implementation>,
27    initialized: bool,
28}
29
30#[derive(Clone)]
31pub struct Server;
32
33impl Server {
34    pub fn builder(name: String, version: String) -> ServerProtocolBuilder {
35        ServerProtocolBuilder::new(name, version)
36    }
37
38    pub async fn start<T: Transport>(transport: T) -> Result<()> {
39        transport.open().await
40    }
41}
42
43pub struct ServerProtocolBuilder {
44    protocol_builder: ProtocolBuilder,
45    server_info: Implementation,
46    capabilities: ServerCapabilities,
47    tools: HashMap<String, ToolHandler>,
48    client_connection: Arc<RwLock<ClientConnection>>,
49}
50
51impl ServerProtocolBuilder {
52    pub fn new(name: String, version: String) -> Self {
53        ServerProtocolBuilder {
54            protocol_builder: ProtocolBuilder::new(),
55            server_info: Implementation { name, version },
56            capabilities: ServerCapabilities::default(),
57            tools: HashMap::new(),
58            client_connection: Arc::new(RwLock::new(ClientConnection {
59                client_capabilities: None,
60                client_info: None,
61                initialized: false,
62            })),
63        }
64    }
65
66    pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
67        self.capabilities = capabilities;
68        self
69    }
70
71    pub fn register_tool(mut self, tool: Tool, f: ToolHandlerFn) -> Self {
72        self.tools.insert(
73            tool.name.clone(),
74            ToolHandler {
75                tool,
76                f: Box::new(f),
77            },
78        );
79        self
80    }
81
82    // Helper function for initialize handler
83    fn handle_init(
84        state: Arc<RwLock<ClientConnection>>,
85        server_info: Implementation,
86        capabilities: ServerCapabilities,
87    ) -> impl Fn(
88        InitializeRequest,
89    )
90        -> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
91        move |req| {
92            let state = state.clone();
93            let server_info = server_info.clone();
94            let capabilities = capabilities.clone();
95
96            Box::pin(async move {
97                let mut state = state
98                    .write()
99                    .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
100                state.client_capabilities = Some(req.capabilities);
101                state.client_info = Some(req.client_info);
102
103                Ok(InitializeResponse {
104                    protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
105                    capabilities,
106                    server_info,
107                })
108            })
109        }
110    }
111
112    // Helper function for initialized handler
113    fn handle_initialized(
114        state: Arc<RwLock<ClientConnection>>,
115    ) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
116        move |_| {
117            let state = state.clone();
118            Box::pin(async move {
119                let mut state = state
120                    .write()
121                    .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
122                state.initialized = true;
123                Ok(())
124            })
125        }
126    }
127
128    pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
129        self.client_connection
130            .read()
131            .ok()?
132            .client_capabilities
133            .clone()
134    }
135
136    pub fn get_client_info(&self) -> Option<Implementation> {
137        self.client_connection.read().ok()?.client_info.clone()
138    }
139
140    pub fn is_initialized(&self) -> bool {
141        self.client_connection
142            .read()
143            .ok()
144            .map(|client_connection| client_connection.initialized)
145            .unwrap_or(false)
146    }
147
148    pub fn build(self) -> Protocol {
149        let tools = Arc::new(Tools::new(self.tools));
150        let tools_clone = tools.clone();
151        let tools_list = tools.clone();
152        let tools_call = tools_clone.clone();
153
154        let conn_for_list = self.client_connection.clone();
155        let conn_for_call = self.client_connection.clone();
156
157        self.protocol_builder
158            .request_handler(
159                "initialize",
160                Self::handle_init(
161                    self.client_connection.clone(),
162                    self.server_info,
163                    self.capabilities,
164                ),
165            )
166            .notification_handler(
167                "notifications/initialized",
168                Self::handle_initialized(self.client_connection.clone()),
169            )
170            .request_handler("tools/list", move |_req: ListRequest| {
171                let tools = tools_list.clone();
172                let conn = conn_for_list.clone();
173
174                Box::pin(async move {
175                    let client_state = conn.read().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
176
177                    if !client_state.initialized {
178                        return Err(anyhow::anyhow!(
179                            "Client must be initialized before using tools/list"
180                        ));
181                    }
182
183                    Ok(ToolsListResponse {
184                        tools: tools.list_tools(),
185                        next_cursor: None,
186                        meta: None,
187                    })
188                })
189            })
190            .request_handler("tools/call", move |req: CallToolRequest| {
191                let tools = tools_call.clone();
192                let conn = conn_for_call.clone();
193
194                Box::pin(async move {
195                    {
196                        // Check if client is initialized
197                        let client_state =
198                            conn.read().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
199
200                        if !client_state.initialized {
201                            return Err(anyhow::anyhow!(
202                                "Client must be initialized before using tools/call"
203                            ));
204                        }
205                    }
206
207                    tools.call_tool(req).await
208                })
209            })
210            .build()
211    }
212}