1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4};
5
6use crate::{
7 tools::{ToolHandler, Tools},
8 types::{CallToolRequest, CallToolResponse, ListRequest, Tool, ToolsListResponse},
9};
10
11use super::{
12 protocol::{Protocol, ProtocolBuilder},
13 transport::Transport,
14 types::{
15 ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
16 ServerCapabilities, LATEST_PROTOCOL_VERSION,
17 },
18};
19use anyhow::Result;
20use serde::{de::DeserializeOwned, Serialize};
21use std::future::Future;
22use std::pin::Pin;
23use tracing::info;
24
25#[derive(Clone)]
26pub struct ServerState {
27 client_capabilities: Option<ClientCapabilities>,
28 client_info: Option<Implementation>,
29 initialized: bool,
30}
31
32#[derive(Clone)]
33pub struct Server<T: Transport> {
34 protocol: Protocol<T>,
35 state: Arc<RwLock<ServerState>>,
36}
37
38pub struct ServerBuilder<T: Transport> {
39 protocol: ProtocolBuilder<T>,
40 server_info: Implementation,
41 capabilities: ServerCapabilities,
42 tools: HashMap<String, ToolHandler>,
43}
44
45impl<T: Transport> ServerBuilder<T> {
46 pub fn name<S: Into<String>>(mut self, name: S) -> Self {
47 self.server_info.name = name.into();
48 self
49 }
50
51 pub fn version<S: Into<String>>(mut self, version: S) -> Self {
52 self.server_info.version = version.into();
53 self
54 }
55
56 pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
57 self.capabilities = capabilities;
58 self
59 }
60
61 pub fn request_handler<Req, Resp>(
64 mut self,
65 method: &str,
66 handler: impl Fn(Req) -> Pin<Box<dyn std::future::Future<Output = Result<Resp>> + Send>>
67 + Send
68 + Sync
69 + 'static,
70 ) -> Self
71 where
72 Req: DeserializeOwned + Send + Sync + 'static,
73 Resp: Serialize + Send + Sync + 'static,
74 {
75 self.protocol = self.protocol.request_handler(method, handler);
76 self
77 }
78
79 pub fn notification_handler<N>(
80 mut self,
81 method: &str,
82 handler: impl Fn(N) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>>
83 + Send
84 + Sync
85 + 'static,
86 ) -> Self
87 where
88 N: DeserializeOwned + Send + Sync + 'static,
89 {
90 self.protocol = self.protocol.notification_handler(method, handler);
91 self
92 }
93
94 pub fn register_tool(
95 &mut self,
96 tool: Tool,
97 f: impl Fn(CallToolRequest) -> Pin<Box<dyn Future<Output = Result<CallToolResponse>> + Send>>
98 + Send
99 + Sync
100 + 'static,
101 ) {
102 self.tools.insert(
103 tool.name.clone(),
104 ToolHandler {
105 tool,
106 f: Box::new(f),
107 },
108 );
109 }
110
111 pub fn build(self) -> Server<T> {
112 Server::new(self)
113 }
114}
115
116impl<T: Transport> Server<T> {
117 pub fn builder(transport: T) -> ServerBuilder<T> {
118 ServerBuilder {
119 protocol: Protocol::builder(transport),
120 server_info: Implementation {
121 name: env!("CARGO_PKG_NAME").to_string(),
122 version: env!("CARGO_PKG_VERSION").to_string(),
123 },
124 capabilities: Default::default(),
125 tools: HashMap::new(),
126 }
127 }
128
129 fn new(builder: ServerBuilder<T>) -> Self {
130 let state = Arc::new(RwLock::new(ServerState {
131 client_capabilities: None,
132 client_info: None,
133 initialized: false,
134 }));
135
136 let mut protocol = builder
138 .protocol
139 .request_handler(
140 "initialize",
141 Self::handle_init(state.clone(), builder.server_info, builder.capabilities),
142 )
143 .notification_handler(
144 "notifications/initialized",
145 Self::handle_initialized(state.clone()),
146 );
147
148 if !protocol.has_request_handler("tools/list") {
150 let tools = Arc::new(Tools::new(builder.tools));
151 let tools_clone = tools.clone();
152 let tools_list = tools.clone();
153 let tools_call = tools_clone.clone();
154
155 protocol = protocol
156 .request_handler("tools/list", move |_req: ListRequest| {
157 let tools = tools_list.clone();
158 Box::pin(async move {
159 Ok(ToolsListResponse {
160 tools: tools.list_tools(),
161 next_cursor: None,
162 meta: None,
163 })
164 })
165 })
166 .request_handler("tools/call", move |req: CallToolRequest| {
167 let tools = tools_call.clone();
168 Box::pin(async move { tools.call_tool(req).await })
169 });
170 }
171
172 Server {
173 protocol: protocol.build(),
174 state,
175 }
176 }
177
178 fn handle_init(
180 state: Arc<RwLock<ServerState>>,
181 server_info: Implementation,
182 capabilities: ServerCapabilities,
183 ) -> impl Fn(
184 InitializeRequest,
185 )
186 -> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
187 move |req| {
188 let state = state.clone();
189 let server_info = server_info.clone();
190 let capabilities = capabilities.clone();
191
192 Box::pin(async move {
193 let mut state = state
194 .write()
195 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
196 state.client_capabilities = Some(req.capabilities);
197 state.client_info = Some(req.client_info);
198
199 Ok(InitializeResponse {
200 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
201 capabilities,
202 server_info,
203 })
204 })
205 }
206 }
207
208 fn handle_initialized(
210 state: Arc<RwLock<ServerState>>,
211 ) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
212 move |_| {
213 let state = state.clone();
214 Box::pin(async move {
215 let mut state = state
216 .write()
217 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
218 state.initialized = true;
219 Ok(())
220 })
221 }
222 }
223
224 pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
225 self.state.read().ok()?.client_capabilities.clone()
226 }
227
228 pub fn get_client_info(&self) -> Option<Implementation> {
229 self.state.read().ok()?.client_info.clone()
230 }
231
232 pub fn is_initialized(&self) -> bool {
233 self.state
234 .read()
235 .ok()
236 .map(|state| state.initialized)
237 .unwrap_or(false)
238 }
239
240 pub async fn listen(&self) -> Result<()> {
241 self.protocol.listen().await
242 }
243}