1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4};
5
6use crate::{
7 protocol::Protocol,
8 tools::{ToolHandler, Tools},
9 types::{CallToolRequest, CallToolResponse, ListRequest, Tool, ToolsListResponse},
10};
11
12use super::{
13 protocol::ProtocolBuilder,
14 transport::Transport,
15 types::{
16 ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
17 ServerCapabilities, LATEST_PROTOCOL_VERSION,
18 },
19};
20use anyhow::Result;
21use std::future::Future;
22use std::pin::Pin;
23
24#[derive(Clone)]
25pub struct ClientConnection {
26 client_capabilities: Option<ClientCapabilities>,
27 client_info: Option<Implementation>,
28 initialized: bool,
29}
30
31#[derive(Clone)]
32pub struct Server;
33
34impl Server {
35 pub fn builder(name: String, version: String) -> ServerProtocolBuilder {
36 ServerProtocolBuilder::new(name, version)
37 }
38
39 pub async fn start<T: Transport>(transport: T) -> Result<()> {
40 transport.open().await
41 }
42}
43
44pub struct ServerProtocolBuilder {
45 protocol_builder: ProtocolBuilder,
46 server_info: Implementation,
47 capabilities: ServerCapabilities,
48 tools: HashMap<String, ToolHandler>,
49 client_connection: Arc<RwLock<ClientConnection>>,
50}
51
52impl ServerProtocolBuilder {
53 pub fn new(name: String, version: String) -> Self {
54 ServerProtocolBuilder {
55 protocol_builder: ProtocolBuilder::new(),
56 server_info: Implementation { name, version },
57 capabilities: ServerCapabilities::default(),
58 tools: HashMap::new(),
59 client_connection: Arc::new(RwLock::new(ClientConnection {
60 client_capabilities: None,
61 client_info: None,
62 initialized: false,
63 })),
64 }
65 }
66
67 pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
68 self.capabilities = capabilities;
69 self
70 }
71
72 pub fn register_tool(
73 mut self,
74 tool: Tool,
75 f: impl Fn(CallToolRequest) -> Pin<Box<dyn Future<Output = CallToolResponse> + Send>>
76 + Send
77 + Sync
78 + 'static,
79 ) -> Self {
80 self.tools.insert(
81 tool.name.clone(),
82 ToolHandler {
83 tool,
84 f: Box::new(f),
85 },
86 );
87 self
88 }
89
90 fn handle_init(
92 state: Arc<RwLock<ClientConnection>>,
93 server_info: Implementation,
94 capabilities: ServerCapabilities,
95 ) -> impl Fn(
96 InitializeRequest,
97 )
98 -> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
99 move |req| {
100 let state = state.clone();
101 let server_info = server_info.clone();
102 let capabilities = capabilities.clone();
103
104 Box::pin(async move {
105 let mut state = state
106 .write()
107 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
108 state.client_capabilities = Some(req.capabilities);
109 state.client_info = Some(req.client_info);
110
111 Ok(InitializeResponse {
112 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
113 capabilities,
114 server_info,
115 })
116 })
117 }
118 }
119
120 fn handle_initialized(
122 state: Arc<RwLock<ClientConnection>>,
123 ) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
124 move |_| {
125 let state = state.clone();
126 Box::pin(async move {
127 let mut state = state
128 .write()
129 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
130 state.initialized = true;
131 Ok(())
132 })
133 }
134 }
135
136 pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
137 self.client_connection
138 .read()
139 .ok()?
140 .client_capabilities
141 .clone()
142 }
143
144 pub fn get_client_info(&self) -> Option<Implementation> {
145 self.client_connection.read().ok()?.client_info.clone()
146 }
147
148 pub fn is_initialized(&self) -> bool {
149 self.client_connection
150 .read()
151 .ok()
152 .map(|client_connection| client_connection.initialized)
153 .unwrap_or(false)
154 }
155
156 pub fn build(self) -> Protocol {
157 let tools = Arc::new(Tools::new(self.tools));
158 let tools_clone = tools.clone();
159 let tools_list = tools.clone();
160 let tools_call = tools_clone.clone();
161
162 let conn_for_list = self.client_connection.clone();
163 let conn_for_call = self.client_connection.clone();
164
165 self.protocol_builder
166 .request_handler(
167 "initialize",
168 Self::handle_init(
169 self.client_connection.clone(),
170 self.server_info,
171 self.capabilities,
172 ),
173 )
174 .notification_handler(
175 "notifications/initialized",
176 Self::handle_initialized(self.client_connection.clone()),
177 )
178 .request_handler("tools/list", move |_req: ListRequest| {
179 let tools = tools_list.clone();
180 let conn = conn_for_list.clone();
181
182 Box::pin(async move {
183 let client_state = conn.read().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
184
185 if !client_state.initialized {
186 return Err(anyhow::anyhow!(
187 "Client must be initialized before using tools/list"
188 ));
189 }
190
191 Ok(ToolsListResponse {
192 tools: tools.list_tools(),
193 next_cursor: None,
194 meta: None,
195 })
196 })
197 })
198 .request_handler("tools/call", move |req: CallToolRequest| {
199 let tools = tools_call.clone();
200 let conn = conn_for_call.clone();
201
202 Box::pin(async move {
203 {
204 let client_state =
206 conn.read().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
207
208 if !client_state.initialized {
209 return Err(anyhow::anyhow!(
210 "Client must be initialized before using tools/call"
211 ));
212 }
213 }
214
215 tools.call_tool(req).await
216 })
217 })
218 .build()
219 }
220}