mcp_sdk/
server.rs

1use std::sync::{Arc, RwLock};
2
3use crate::{
4    tools::Tools,
5    types::{CallToolRequest, ListRequest, ToolsListResponse},
6};
7
8use super::{
9    protocol::{Protocol, ProtocolBuilder},
10    transport::Transport,
11    types::{
12        ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
13        ServerCapabilities, LATEST_PROTOCOL_VERSION,
14    },
15};
16use anyhow::Result;
17use serde::{de::DeserializeOwned, Serialize};
18
19#[derive(Clone)]
20pub struct ServerState {
21    client_capabilities: Option<ClientCapabilities>,
22    client_info: Option<Implementation>,
23    initialized: bool,
24}
25
26#[derive(Clone)]
27pub struct Server<T: Transport> {
28    protocol: Protocol<T>,
29    state: Arc<RwLock<ServerState>>,
30}
31
32pub struct ServerBuilder<T: Transport> {
33    protocol: ProtocolBuilder<T>,
34    server_info: Implementation,
35    capabilities: ServerCapabilities,
36    tools: Option<Tools>,
37}
38
39impl<T: Transport> ServerBuilder<T> {
40    pub fn name<S: Into<String>>(mut self, name: S) -> Self {
41        self.server_info.name = name.into();
42        self
43    }
44
45    pub fn version<S: Into<String>>(mut self, version: S) -> Self {
46        self.server_info.version = version.into();
47        self
48    }
49
50    pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
51        self.capabilities = capabilities;
52        self
53    }
54
55    /// Register a typed request handler
56    /// for higher-level api use add tool
57    pub fn request_handler<Req, Resp>(
58        mut self,
59        method: &str,
60        handler: impl Fn(Req) -> Result<Resp> + Send + Sync + 'static,
61    ) -> Self
62    where
63        Req: DeserializeOwned + Send + Sync + 'static,
64        Resp: Serialize + Send + Sync + 'static,
65    {
66        self.protocol = self.protocol.request_handler(method, handler);
67        self
68    }
69
70    pub fn notification_handler<N>(
71        mut self,
72        method: &str,
73        handler: impl Fn(N) -> Result<()> + Send + Sync + 'static,
74    ) -> Self
75    where
76        N: DeserializeOwned + Send + Sync + 'static,
77    {
78        self.protocol = self.protocol.notification_handler(method, handler);
79        self
80    }
81
82    pub fn tools(mut self, tools: Tools) -> Self {
83        self.tools = Some(tools);
84        self
85    }
86
87    pub fn build(self) -> Server<T> {
88        Server::new(self)
89    }
90}
91
92impl<T: Transport> Server<T> {
93    pub fn builder(transport: T) -> ServerBuilder<T> {
94        ServerBuilder {
95            protocol: Protocol::builder(transport),
96            server_info: Implementation {
97                name: env!("CARGO_PKG_NAME").to_string(),
98                version: env!("CARGO_PKG_VERSION").to_string(),
99            },
100            capabilities: Default::default(),
101            tools: None,
102        }
103    }
104
105    fn new(builder: ServerBuilder<T>) -> Self {
106        let state = Arc::new(RwLock::new(ServerState {
107            client_capabilities: None,
108            client_info: None,
109            initialized: false,
110        }));
111
112        // Initialize protocol with handlers
113        let mut protocol = builder
114            .protocol
115            .request_handler(
116                "initialize",
117                Self::handle_init(state.clone(), builder.server_info, builder.capabilities),
118            )
119            .notification_handler(
120                "notifications/initialized",
121                Self::handle_initialized(state.clone()),
122            );
123        if let Some(tools) = builder.tools {
124            // Add tools handlers if not already present
125            let tools = Arc::new(tools);
126            let tools_clone = tools.clone();
127            protocol = protocol
128                .request_handler("tools/list", move |_req: ListRequest| {
129                    Ok(ToolsListResponse {
130                        tools: tools.list_tools(),
131                        next_cursor: None,
132                        meta: None,
133                    })
134                })
135                .request_handler("tools/call", move |req: CallToolRequest| {
136                    let response = tools_clone.call_tool(req);
137                    Ok(response)
138                });
139        }
140
141        Server {
142            protocol: protocol.build(),
143            state,
144        }
145    }
146
147    // Helper function for initialize handler
148    fn handle_init(
149        state: Arc<RwLock<ServerState>>,
150        server_info: Implementation,
151        capabilities: ServerCapabilities,
152    ) -> impl Fn(InitializeRequest) -> Result<InitializeResponse> {
153        move |req| {
154            let mut state = state
155                .write()
156                .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
157            state.client_capabilities = Some(req.capabilities);
158            state.client_info = Some(req.client_info);
159
160            Ok(InitializeResponse {
161                protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
162                capabilities: capabilities.clone(),
163                server_info: server_info.clone(),
164            })
165        }
166    }
167
168    // Helper function for initialized handler
169    fn handle_initialized(state: Arc<RwLock<ServerState>>) -> impl Fn(()) -> Result<()> {
170        move |_| {
171            let mut state = state
172                .write()
173                .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
174            state.initialized = true;
175            Ok(())
176        }
177    }
178
179    pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
180        self.state.read().ok()?.client_capabilities.clone()
181    }
182
183    pub fn get_client_info(&self) -> Option<Implementation> {
184        self.state.read().ok()?.client_info.clone()
185    }
186
187    pub fn is_initialized(&self) -> bool {
188        self.state
189            .read()
190            .ok()
191            .map(|state| state.initialized)
192            .unwrap_or(false)
193    }
194
195    pub async fn listen(&self) -> Result<()> {
196        self.protocol.listen().await
197    }
198}