1use std::{
2 collections::HashMap,
3 sync::{Arc, RwLock},
4};
5
6use crate::{
7 protocol::Protocol,
8 tools::{ToolHandler, ToolHandlerFn, Tools},
9 types::{CallToolRequest, ListRequest, Tool, ToolsListResponse},
10};
11
12use super::{
13 protocol::ProtocolBuilder,
14 transport::Transport,
15 types::{
16 ClientCapabilities, Implementation, InitializeRequest, InitializeResponse,
17 ServerCapabilities, LATEST_PROTOCOL_VERSION,
18 },
19};
20use anyhow::Result;
21use std::pin::Pin;
22
23#[derive(Clone)]
24pub struct ClientConnection {
25 client_capabilities: Option<ClientCapabilities>,
26 client_info: Option<Implementation>,
27 initialized: bool,
28}
29
30#[derive(Clone)]
31pub struct Server;
32
33impl Server {
34 pub fn builder(name: String, version: String) -> ServerProtocolBuilder {
35 ServerProtocolBuilder::new(name, version)
36 }
37
38 pub async fn start<T: Transport>(transport: T) -> Result<()> {
39 transport.open().await
40 }
41}
42
43pub struct ServerProtocolBuilder {
44 protocol_builder: ProtocolBuilder,
45 server_info: Implementation,
46 capabilities: ServerCapabilities,
47 tools: HashMap<String, ToolHandler>,
48 client_connection: Arc<RwLock<ClientConnection>>,
49}
50
51impl ServerProtocolBuilder {
52 pub fn new(name: String, version: String) -> Self {
53 ServerProtocolBuilder {
54 protocol_builder: ProtocolBuilder::new(),
55 server_info: Implementation { name, version },
56 capabilities: ServerCapabilities::default(),
57 tools: HashMap::new(),
58 client_connection: Arc::new(RwLock::new(ClientConnection {
59 client_capabilities: None,
60 client_info: None,
61 initialized: false,
62 })),
63 }
64 }
65
66 pub fn capabilities(mut self, capabilities: ServerCapabilities) -> Self {
67 self.capabilities = capabilities;
68 self
69 }
70
71 pub fn register_tool(mut self, tool: Tool, f: ToolHandlerFn) -> Self {
72 self.tools.insert(
73 tool.name.clone(),
74 ToolHandler {
75 tool,
76 f: Box::new(f),
77 },
78 );
79 self
80 }
81
82 fn handle_init(
84 state: Arc<RwLock<ClientConnection>>,
85 server_info: Implementation,
86 capabilities: ServerCapabilities,
87 ) -> impl Fn(
88 InitializeRequest,
89 )
90 -> Pin<Box<dyn std::future::Future<Output = Result<InitializeResponse>> + Send>> {
91 move |req| {
92 let state = state.clone();
93 let server_info = server_info.clone();
94 let capabilities = capabilities.clone();
95
96 Box::pin(async move {
97 let mut state = state
98 .write()
99 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
100 state.client_capabilities = Some(req.capabilities);
101 state.client_info = Some(req.client_info);
102
103 Ok(InitializeResponse {
104 protocol_version: LATEST_PROTOCOL_VERSION.to_string(),
105 capabilities,
106 server_info,
107 })
108 })
109 }
110 }
111
112 fn handle_initialized(
114 state: Arc<RwLock<ClientConnection>>,
115 ) -> impl Fn(()) -> Pin<Box<dyn std::future::Future<Output = Result<()>> + Send>> {
116 move |_| {
117 let state = state.clone();
118 Box::pin(async move {
119 let mut state = state
120 .write()
121 .map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
122 state.initialized = true;
123 Ok(())
124 })
125 }
126 }
127
128 pub fn get_client_capabilities(&self) -> Option<ClientCapabilities> {
129 self.client_connection
130 .read()
131 .ok()?
132 .client_capabilities
133 .clone()
134 }
135
136 pub fn get_client_info(&self) -> Option<Implementation> {
137 self.client_connection.read().ok()?.client_info.clone()
138 }
139
140 pub fn is_initialized(&self) -> bool {
141 self.client_connection
142 .read()
143 .ok()
144 .map(|client_connection| client_connection.initialized)
145 .unwrap_or(false)
146 }
147
148 pub fn build(self) -> Protocol {
149 let tools = Arc::new(Tools::new(self.tools));
150 let tools_clone = tools.clone();
151 let tools_list = tools.clone();
152 let tools_call = tools_clone.clone();
153
154 let conn_for_list = self.client_connection.clone();
155 let conn_for_call = self.client_connection.clone();
156
157 self.protocol_builder
158 .request_handler(
159 "initialize",
160 Self::handle_init(
161 self.client_connection.clone(),
162 self.server_info,
163 self.capabilities,
164 ),
165 )
166 .notification_handler(
167 "notifications/initialized",
168 Self::handle_initialized(self.client_connection.clone()),
169 )
170 .request_handler("tools/list", move |_req: ListRequest| {
171 let tools = tools_list.clone();
172 let conn = conn_for_list.clone();
173
174 Box::pin(async move {
175 let client_state = conn.read().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
176
177 if !client_state.initialized {
178 return Err(anyhow::anyhow!(
179 "Client must be initialized before using tools/list"
180 ));
181 }
182
183 Ok(ToolsListResponse {
184 tools: tools.list_tools(),
185 next_cursor: None,
186 meta: None,
187 })
188 })
189 })
190 .request_handler("tools/call", move |req: CallToolRequest| {
191 let tools = tools_call.clone();
192 let conn = conn_for_call.clone();
193
194 Box::pin(async move {
195 {
196 let client_state =
198 conn.read().map_err(|_| anyhow::anyhow!("Lock poisoned"))?;
199
200 if !client_state.initialized {
201 return Err(anyhow::anyhow!(
202 "Client must be initialized before using tools/call"
203 ));
204 }
205 }
206
207 tools.call_tool(req).await
208 })
209 })
210 .build()
211 }
212}