mcp_core/
server.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use crate::{
7    tools::{ToolHandler, Tools},
8    types::{CallToolRequest, CallToolResponse, ListRequest, Tool, ToolsListResponse},
9};
10
11use super::{
12    protocol::{Protocol, ProtocolBuilder},
13    transport::Transport,
14    types::{
15        ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
16        ServerCapabilities, LATEST_PROTOCOL_VERSION,
17    },
18};
19use anyhow::Result;
20use serde::{de::DeserializeOwned, Serialize};
21use std::future::Future;
22use std::pin::Pin;
23
24#[derive(Clone)]
25pub struct ServerState {
26    client_capabilities: Option<ClientCapabilities>,
27    client_info: Option<Implementation>,
28    initialized: bool,
29}
30
31#[derive(Clone)]
32pub struct Server<T: Transport> {
33    protocol: Protocol<T>,
34    state: Arc<RwLock<ServerState>>,
35}
36
37pub struct ServerBuilder<T: Transport> {
38    protocol: ProtocolBuilder<T>,
39    server_info: Implementation,
40    capabilities: ServerCapabilities,
41    tools: HashMap<String, ToolHandler>,
42}
43
44impl<T: Transport> ServerBuilder<T> {
45    pub fn name<S: Into<String>>(mut self, name: S) -> Self {
46        self.server_info.name = name.into();
47        self
48    }
49
50    pub fn version<S: Into<String>>(mut self, version: S) -> Self {
51        self.server_info.version = version.into();
52        self
53    }
54
55    pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
56        self.capabilities = capabilities;
57        self
58    }
59
60    /// Register a typed request handler
61    /// for higher-level api use add tool
62    pub fn request_handler<Req, Resp>(
63        mut self,
64        method: &str,
65        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
66            + Send
67            + Sync
68            + 'static,
69    ) -> Self
70    where
71        Req: DeserializeOwned + Send + Sync + 'static,
72        Resp: Serialize + Send + Sync + 'static,
73    {
74        self.protocol = self.protocol.request_handler(method, handler);
75        self
76    }
77
78    pub fn notification_handler<N>(
79        mut self,
80        method: &str,
81        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
82            + Send
83            + Sync
84            + 'static,
85    ) -> Self
86    where
87        N: DeserializeOwned + Send + Sync + 'static,
88    {
89        self.protocol = self.protocol.notification_handler(method, handler);
90        self
91    }
92
93    pub fn register_tool(
94        &mut self,
95        tool: Tool,
96        f: impl Fn(CallToolRequest) -> Pin<Box<dyn Future<Output = CallToolResponse> + Send>>
97            + Send
98            + Sync
99            + 'static,
100    ) {
101        self.tools.insert(
102            tool.name.clone(),
103            ToolHandler {
104                tool,
105                f: Box::new(f),
106            },
107        );
108    }
109
110    pub fn build(self) -> Server<T> {
111        Server::new(self)
112    }
113}
114
115impl<T: Transport> Server<T> {
116    pub fn builder(transport: T) -> ServerBuilder<T> {
117        ServerBuilder {
118            protocol: Protocol::builder(transport),
119            server_info: Implementation {
120                name: env!("CARGO_PKG_NAME").to_string(),
121                version: env!("CARGO_PKG_VERSION").to_string(),
122            },
123            capabilities: Default::default(),
124            tools: HashMap::new(),
125        }
126    }
127
128    fn new(builder: ServerBuilder<T>) -> Self {
129        let state = Arc::new(RwLock::new(ServerState {
130            client_capabilities: None,
131            client_info: None,
132            initialized: false,
133        }));
134
135        // Initialize protocol with handlers
136        let mut protocol = builder
137            .protocol
138            .request_handler(
139                "initialize",
140                Self::handle_init(state.clone(), builder.server_info, builder.capabilities),
141            )
142            .notification_handler(
143                "notifications/initialized",
144                Self::handle_initialized(state.clone()),
145            );
146
147        // Add tools handlers if not already present
148        if !protocol.has_request_handler("tools/list") {
149            let tools = Arc::new(Tools::new(builder.tools));
150            let tools_clone = tools.clone();
151            let tools_list = tools.clone();
152            let tools_call = tools_clone.clone();
153
154            protocol = protocol
155                .request_handler("tools/list", move |_req: ListRequest| {
156                    let tools = tools_list.clone();
157                    Box::pin(async move {
158                        Ok(ToolsListResponse {
159                            tools: tools.list_tools(),
160                            next_cursor: None,
161                            meta: None,
162                        })
163                    })
164                })
165                .request_handler("tools/call", move |req: CallToolRequest| {
166                    let tools = tools_call.clone();
167                    Box::pin(async move { tools.call_tool(req).await })
168                });
169        }
170
171        Server {
172            protocol: protocol.build(),
173            state,
174        }
175    }
176
177    // Helper function for initialize handler
178    fn handle_init(
179        state: Arc<RwLock<ServerState>>,
180        server_info: Implementation,
181        capabilities: ServerCapabilities,
182    ) -> impl Fn(
183        InitializeRequest,
184    )
185        -> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
186        move |req| {
187            let state = state.clone();
188            let server_info = server_info.clone();
189            let capabilities = capabilities.clone();
190
191            Box::pin(async move {
192                let mut state = state
193                    .write()
194                    .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
195                state.client_capabilities = Some(req.capabilities);
196                state.client_info = Some(req.client_info);
197
198                Ok(InitializeResponse {
199                    protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
200                    capabilities,
201                    server_info,
202                })
203            })
204        }
205    }
206
207    // Helper function for initialized handler
208    fn handle_initialized(
209        state: Arc<RwLock<ServerState>>,
210    ) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
211        move |_| {
212            let state = state.clone();
213            Box::pin(async move {
214                let mut state = state
215                    .write()
216                    .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
217                state.initialized = true;
218                Ok(())
219            })
220        }
221    }
222
223    pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
224        self.state.read().ok()?.client_capabilities.clone()
225    }
226
227    pub fn get_client_info(&self) -> Option<Implementation> {
228        self.state.read().ok()?.client_info.clone()
229    }
230
231    pub fn is_initialized(&self) -> bool {
232        self.state
233            .read()
234            .ok()
235            .map(|state| state.initialized)
236            .unwrap_or(false)
237    }
238
239    pub async fn listen(&self) -> Result<()> {
240        self.protocol.listen().await
241    }
242}