mcp_core/
server.rs

1use std::{
2    collections::HashMap,
3    sync::{Arc, RwLock},
4};
5
6use crate::{
7    tools::{ToolHandler, Tools},
8    types::{CallToolRequest, CallToolResponse, ListRequest, Tool, ToolsListResponse},
9};
10
11use super::{
12    protocol::{Protocol, ProtocolBuilder},
13    transport::Transport,
14    types::{
15        ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
16        ServerCapabilities, LATEST_PROTOCOL_VERSION,
17    },
18};
19use anyhow::Result;
20use serde::{de::DeserializeOwned, Serialize};
21use std::future::Future;
22use std::pin::Pin;
23use tracing::info;
24
25#[derive(Clone)]
26pub struct ServerState {
27    client_capabilities: Option<ClientCapabilities>,
28    client_info: Option<Implementation>,
29    initialized: bool,
30}
31
32#[derive(Clone)]
33pub struct Server<T: Transport> {
34    protocol: Protocol<T>,
35    state: Arc<RwLock<ServerState>>,
36}
37
38pub struct ServerBuilder<T: Transport> {
39    protocol: ProtocolBuilder<T>,
40    server_info: Implementation,
41    capabilities: ServerCapabilities,
42    tools: HashMap<String, ToolHandler>,
43}
44
45impl<T: Transport> ServerBuilder<T> {
46    pub fn name<S: Into<String>>(mut self, name: S) -> Self {
47        self.server_info.name = name.into();
48        self
49    }
50
51    pub fn version<S: Into<String>>(mut self, version: S) -> Self {
52        self.server_info.version = version.into();
53        self
54    }
55
56    pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
57        self.capabilities = capabilities;
58        self
59    }
60
61    /// Register a typed request handler
62    /// for higher-level api use add tool
63    pub fn request_handler<Req, Resp>(
64        mut self,
65        method: &str,
66        handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
67            + Send
68            + Sync
69            + 'static,
70    ) -> Self
71    where
72        Req: DeserializeOwned + Send + Sync + 'static,
73        Resp: Serialize + Send + Sync + 'static,
74    {
75        self.protocol = self.protocol.request_handler(method, handler);
76        self
77    }
78
79    pub fn notification_handler<N>(
80        mut self,
81        method: &str,
82        handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
83            + Send
84            + Sync
85            + 'static,
86    ) -> Self
87    where
88        N: DeserializeOwned + Send + Sync + 'static,
89    {
90        self.protocol = self.protocol.notification_handler(method, handler);
91        self
92    }
93
94    pub fn register_tool(
95        &mut self,
96        tool: Tool,
97        f: impl Fn(CallToolRequest) -> Pin<Box<dyn Future<Output = Result<CallToolResponse>> + Send>>
98            + Send
99            + Sync
100            + 'static,
101    ) {
102        self.tools.insert(
103            tool.name.clone(),
104            ToolHandler {
105                tool,
106                f: Box::new(f),
107            },
108        );
109    }
110
111    pub fn build(self) -> Server<T> {
112        Server::new(self)
113    }
114}
115
116impl<T: Transport> Server<T> {
117    pub fn builder(transport: T) -> ServerBuilder<T> {
118        ServerBuilder {
119            protocol: Protocol::builder(transport),
120            server_info: Implementation {
121                name: env!("CARGO_PKG_NAME").to_string(),
122                version: env!("CARGO_PKG_VERSION").to_string(),
123            },
124            capabilities: Default::default(),
125            tools: HashMap::new(),
126        }
127    }
128
129    fn new(builder: ServerBuilder<T>) -> Self {
130        let state = Arc::new(RwLock::new(ServerState {
131            client_capabilities: None,
132            client_info: None,
133            initialized: false,
134        }));
135
136        // Initialize protocol with handlers
137        let mut protocol = builder
138            .protocol
139            .request_handler(
140                "initialize",
141                Self::handle_init(state.clone(), builder.server_info, builder.capabilities),
142            )
143            .notification_handler(
144                "notifications/initialized",
145                Self::handle_initialized(state.clone()),
146            );
147
148        // Add tools handlers if not already present
149        if !protocol.has_request_handler("tools/list") {
150            let tools = Arc::new(Tools::new(builder.tools));
151            let tools_clone = tools.clone();
152            let tools_list = tools.clone();
153            let tools_call = tools_clone.clone();
154
155            protocol = protocol
156                .request_handler("tools/list", move |_req: ListRequest| {
157                    let tools = tools_list.clone();
158                    Box::pin(async move {
159                        Ok(ToolsListResponse {
160                            tools: tools.list_tools(),
161                            next_cursor: None,
162                            meta: None,
163                        })
164                    })
165                })
166                .request_handler("tools/call", move |req: CallToolRequest| {
167                    let tools = tools_call.clone();
168                    Box::pin(async move { tools.call_tool(req).await })
169                });
170        }
171
172        Server {
173            protocol: protocol.build(),
174            state,
175        }
176    }
177
178    // Helper function for initialize handler
179    fn handle_init(
180        state: Arc<RwLock<ServerState>>,
181        server_info: Implementation,
182        capabilities: ServerCapabilities,
183    ) -> impl Fn(
184        InitializeRequest,
185    )
186        -> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
187        move |req| {
188            let state = state.clone();
189            let server_info = server_info.clone();
190            let capabilities = capabilities.clone();
191
192            Box::pin(async move {
193                let mut state = state
194                    .write()
195                    .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
196                state.client_capabilities = Some(req.capabilities);
197                state.client_info = Some(req.client_info);
198
199                Ok(InitializeResponse {
200                    protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
201                    capabilities,
202                    server_info,
203                })
204            })
205        }
206    }
207
208    // Helper function for initialized handler
209    fn handle_initialized(
210        state: Arc<RwLock<ServerState>>,
211    ) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
212        move |_| {
213            let state = state.clone();
214            Box::pin(async move {
215                let mut state = state
216                    .write()
217                    .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
218                state.initialized = true;
219                Ok(())
220            })
221        }
222    }
223
224    pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
225        self.state.read().ok()?.client_capabilities.clone()
226    }
227
228    pub fn get_client_info(&self) -> Option<Implementation> {
229        self.state.read().ok()?.client_info.clone()
230    }
231
232    pub fn is_initialized(&self) -> bool {
233        self.state
234            .read()
235            .ok()
236            .map(|state| state.initialized)
237            .unwrap_or(false)
238    }
239
240    pub async fn listen(&self) -> Result<()> {
241        self.protocol.listen().await
242    }
243}