1use std::sync::{Arc, RwLock};
2
3use crate::{
4 tools::Tools,
5 types::{CallToolRequest, ListRequest, ToolsListResponse},
6};
7
8use super::{
9 protocol::{Protocol, ProtocolBuilder},
10 transport::Transport,
11 types::{
12 ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
13 ServerCapabilities, LATEST_PROTOCOL_VERSION,
14 },
15};
16use anyhow::Result;
17use serde::{de::DeserializeOwned, Serialize};
18
19#[derive(Clone)]
20pub struct ServerState {
21 client_capabilities: Option<ClientCapabilities>,
22 client_info: Option<Implementation>,
23 initialized: bool,
24}
25
26#[derive(Clone)]
27pub struct Server<T: Transport> {
28 protocol: Protocol<T>,
29 state: Arc<RwLock<ServerState>>,
30}
31
32pub struct ServerBuilder<T: Transport> {
33 protocol: ProtocolBuilder<T>,
34 server_info: Implementation,
35 capabilities: ServerCapabilities,
36 tools: Option<Tools>,
37}
38
39impl<T: Transport> ServerBuilder<T> {
40 pub fn name<S: Into<String>>(mut self, name: S) -> Self {
41 self.server_info.name = name.into();
42 self
43 }
44
45 pub fn version<S: Into<String>>(mut self, version: S) -> Self {
46 self.server_info.version = version.into();
47 self
48 }
49
50 pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
51 self.capabilities = capabilities;
52 self
53 }
54
55 pub fn request_handler<Req, Resp>(
58 mut self,
59 method: &str,
60 handler: impl Fn(Req) -> Result<Resp> + Send + Sync + 'static,
61 ) -> Self
62 where
63 Req: DeserializeOwned + Send + Sync + 'static,
64 Resp: Serialize + Send + Sync + 'static,
65 {
66 self.protocol = self.protocol.request_handler(method, handler);
67 self
68 }
69
70 pub fn notification_handler<N>(
71 mut self,
72 method: &str,
73 handler: impl Fn(N) -> Result<()> + Send + Sync + 'static,
74 ) -> Self
75 where
76 N: DeserializeOwned + Send + Sync + 'static,
77 {
78 self.protocol = self.protocol.notification_handler(method, handler);
79 self
80 }
81
82 pub fn tools(mut self, tools: Tools) -> Self {
83 self.tools = Some(tools);
84 self
85 }
86
87 pub fn build(self) -> Server<T> {
88 Server::new(self)
89 }
90}
91
92impl<T: Transport> Server<T> {
93 pub fn builder(transport: T) -> ServerBuilder<T> {
94 ServerBuilder {
95 protocol: Protocol::builder(transport),
96 server_info: Implementation {
97 name: env!("CARGO_PKG_NAME").to_string(),
98 version: env!("CARGO_PKG_VERSION").to_string(),
99 },
100 capabilities: Default::default(),
101 tools: None,
102 }
103 }
104
105 fn new(builder: ServerBuilder<T>) -> Self {
106 let state = Arc::new(RwLock::new(ServerState {
107 client_capabilities: None,
108 client_info: None,
109 initialized: false,
110 }));
111
112 let mut protocol = builder
114 .protocol
115 .request_handler(
116 "initialize",
117 Self::handle_init(state.clone(), builder.server_info, builder.capabilities),
118 )
119 .notification_handler(
120 "notifications/initialized",
121 Self::handle_initialized(state.clone()),
122 );
123 if let Some(tools) = builder.tools {
124 let tools = Arc::new(tools);
126 let tools_clone = tools.clone();
127 protocol = protocol
128 .request_handler("tools/list", move |_req: ListRequest| {
129 Ok(ToolsListResponse {
130 tools: tools.list_tools(),
131 next_cursor: None,
132 meta: None,
133 })
134 })
135 .request_handler("tools/call", move |req: CallToolRequest| {
136 let response = tools_clone.call_tool(req);
137 Ok(response)
138 });
139 }
140
141 Server {
142 protocol: protocol.build(),
143 state,
144 }
145 }
146
147 fn handle_init(
149 state: Arc<RwLock<ServerState>>,
150 server_info: Implementation,
151 capabilities: ServerCapabilities,
152 ) -> impl Fn(InitializeRequest) -> Result<InitializeResponse> {
153 move |req| {
154 let mut state = state
155 .write()
156 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
157 state.client_capabilities = Some(req.capabilities);
158 state.client_info = Some(req.client_info);
159
160 Ok(InitializeResponse {
161 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
162 capabilities: capabilities.clone(),
163 server_info: server_info.clone(),
164 })
165 }
166 }
167
168 fn handle_initialized(state: Arc<RwLock<ServerState>>) -> impl Fn(()) -> Result<()> {
170 move |_| {
171 let mut state = state
172 .write()
173 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
174 state.initialized = true;
175 Ok(())
176 }
177 }
178
179 pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
180 self.state.read().ok()?.client_capabilities.clone()
181 }
182
183 pub fn get_client_info(&self) -> Option<Implementation> {
184 self.state.read().ok()?.client_info.clone()
185 }
186
187 pub fn is_initialized(&self) -> bool {
188 self.state
189 .read()
190 .ok()
191 .map(|state| state.initialized)
192 .unwrap_or(false)
193 }
194
195 pub async fn listen(&self) -> Result<()> {
196 self.protocol.listen().await
197 }
198}