1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4};
5
6use crate::{
7 tools::{ToolHandler, Tools},
8 types::{CallToolRequest, CallToolResponse, ListRequest, Tool, ToolsListResponse},
9};
10
11use super::{
12 protocol::{Protocol, ProtocolBuilder},
13 transport::Transport,
14 types::{
15 ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
16 ServerCapabilities, LATEST_PROTOCOL_VERSION,
17 },
18};
19use anyhow::Result;
20use serde::{de::DeserializeOwned, Serialize};
21use std::future::Future;
22use std::pin::Pin;
23
24#[derive(Clone)]
25pub struct ServerState {
26 client_capabilities: Option<ClientCapabilities>,
27 client_info: Option<Implementation>,
28 initialized: bool,
29}
30
31#[derive(Clone)]
32pub struct Server<T: Transport> {
33 protocol: Protocol<T>,
34 state: Arc<RwLock<ServerState>>,
35}
36
37pub struct ServerBuilder<T: Transport> {
38 protocol: ProtocolBuilder<T>,
39 server_info: Implementation,
40 capabilities: ServerCapabilities,
41 tools: HashMap<String, ToolHandler>,
42}
43
44impl<T: Transport> ServerBuilder<T> {
45 pub fn name<S: Into<String>>(mut self, name: S) -> Self {
46 self.server_info.name = name.into();
47 self
48 }
49
50 pub fn version<S: Into<String>>(mut self, version: S) -> Self {
51 self.server_info.version = version.into();
52 self
53 }
54
55 pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
56 self.capabilities = capabilities;
57 self
58 }
59
60 pub fn request_handler<Req, Resp>(
63 mut self,
64 method: &str,
65 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
66 + Send
67 + Sync
68 + 'static,
69 ) -> Self
70 where
71 Req: DeserializeOwned + Send + Sync + 'static,
72 Resp: Serialize + Send + Sync + 'static,
73 {
74 self.protocol = self.protocol.request_handler(method, handler);
75 self
76 }
77
78 pub fn notification_handler<N>(
79 mut self,
80 method: &str,
81 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
82 + Send
83 + Sync
84 + 'static,
85 ) -> Self
86 where
87 N: DeserializeOwned + Send + Sync + 'static,
88 {
89 self.protocol = self.protocol.notification_handler(method, handler);
90 self
91 }
92
93 pub fn register_tool(
94 &mut self,
95 tool: Tool,
96 f: impl Fn(CallToolRequest) -> Pin<Box<dyn Future<Output = CallToolResponse> + Send>>
97 + Send
98 + Sync
99 + 'static,
100 ) {
101 self.tools.insert(
102 tool.name.clone(),
103 ToolHandler {
104 tool,
105 f: Box::new(f),
106 },
107 );
108 }
109
110 pub fn build(self) -> Server<T> {
111 Server::new(self)
112 }
113}
114
115impl<T: Transport> Server<T> {
116 pub fn builder(transport: T) -> ServerBuilder<T> {
117 ServerBuilder {
118 protocol: Protocol::builder(transport),
119 server_info: Implementation {
120 name: env!("CARGO_PKG_NAME").to_string(),
121 version: env!("CARGO_PKG_VERSION").to_string(),
122 },
123 capabilities: Default::default(),
124 tools: HashMap::new(),
125 }
126 }
127
128 fn new(builder: ServerBuilder<T>) -> Self {
129 let state = Arc::new(RwLock::new(ServerState {
130 client_capabilities: None,
131 client_info: None,
132 initialized: false,
133 }));
134
135 let mut protocol = builder
137 .protocol
138 .request_handler(
139 "initialize",
140 Self::handle_init(state.clone(), builder.server_info, builder.capabilities),
141 )
142 .notification_handler(
143 "notifications/initialized",
144 Self::handle_initialized(state.clone()),
145 );
146
147 if !protocol.has_request_handler("tools/list") {
149 let tools = Arc::new(Tools::new(builder.tools));
150 let tools_clone = tools.clone();
151 let tools_list = tools.clone();
152 let tools_call = tools_clone.clone();
153
154 protocol = protocol
155 .request_handler("tools/list", move |_req: ListRequest| {
156 let tools = tools_list.clone();
157 Box::pin(async move {
158 Ok(ToolsListResponse {
159 tools: tools.list_tools(),
160 next_cursor: None,
161 meta: None,
162 })
163 })
164 })
165 .request_handler("tools/call", move |req: CallToolRequest| {
166 let tools = tools_call.clone();
167 Box::pin(async move { tools.call_tool(req).await })
168 });
169 }
170
171 Server {
172 protocol: protocol.build(),
173 state,
174 }
175 }
176
177 fn handle_init(
179 state: Arc<RwLock<ServerState>>,
180 server_info: Implementation,
181 capabilities: ServerCapabilities,
182 ) -> impl Fn(
183 InitializeRequest,
184 )
185 -> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
186 move |req| {
187 let state = state.clone();
188 let server_info = server_info.clone();
189 let capabilities = capabilities.clone();
190
191 Box::pin(async move {
192 let mut state = state
193 .write()
194 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
195 state.client_capabilities = Some(req.capabilities);
196 state.client_info = Some(req.client_info);
197
198 Ok(InitializeResponse {
199 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
200 capabilities,
201 server_info,
202 })
203 })
204 }
205 }
206
207 fn handle_initialized(
209 state: Arc<RwLock<ServerState>>,
210 ) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
211 move |_| {
212 let state = state.clone();
213 Box::pin(async move {
214 let mut state = state
215 .write()
216 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
217 state.initialized = true;
218 Ok(())
219 })
220 }
221 }
222
223 pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
224 self.state.read().ok()?.client_capabilities.clone()
225 }
226
227 pub fn get_client_info(&self) -> Option<Implementation> {
228 self.state.read().ok()?.client_info.clone()
229 }
230
231 pub fn is_initialized(&self) -> bool {
232 self.state
233 .read()
234 .ok()
235 .map(|state| state.initialized)
236 .unwrap_or(false)
237 }
238
239 pub async fn listen(&self) -> Result<()> {
240 self.protocol.listen().await
241 }
242}