1use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_protocol::*;
6
7use std::sync::Arc;
8use thiserror::Error;
9use tracing::{debug, error, instrument};
10
11#[derive(Debug, Error)]
13pub enum HandlerError {
14 #[error("Authentication failed: {0}")]
15 Authentication(String),
16
17 #[error("Authorization failed: {0}")]
18 Authorization(String),
19
20 #[error("Backend error: {0}")]
21 Backend(String),
22
23 #[error("Protocol error: {0}")]
24 Protocol(#[from] Error),
25}
26
27#[derive(Clone)]
29pub struct GenericServerHandler<B: McpBackend> {
30 backend: Arc<B>,
31 #[allow(dead_code)]
32 auth_manager: Arc<AuthenticationManager>,
33 middleware: MiddlewareStack,
34}
35
36impl<B: McpBackend> GenericServerHandler<B> {
37 pub fn new(
39 backend: Arc<B>,
40 auth_manager: Arc<AuthenticationManager>,
41 middleware: MiddlewareStack,
42 ) -> Self {
43 Self {
44 backend,
45 auth_manager,
46 middleware,
47 }
48 }
49
50 #[instrument(skip(self, request))]
52 pub async fn handle_request(
53 &self,
54 request: Request,
55 ) -> std::result::Result<Response, HandlerError> {
56 debug!("Handling request: {}", request.method);
57
58 let request_id = request.id.clone();
60
61 let context = RequestContext::new();
63
64 let request = self.middleware.process_request(request, &context).await?;
66
67 let result = match request.method.as_str() {
69 "initialize" => self.handle_initialize(request).await,
70 "tools/list" => self.handle_list_tools(request).await,
71 "tools/call" => self.handle_call_tool(request).await,
72 "resources/list" => self.handle_list_resources(request).await,
73 "resources/read" => self.handle_read_resource(request).await,
74 "resources/templates/list" => self.handle_list_resource_templates(request).await,
75 "prompts/list" => self.handle_list_prompts(request).await,
76 "prompts/get" => self.handle_get_prompt(request).await,
77 "resources/subscribe" => self.handle_subscribe(request).await,
78 "resources/unsubscribe" => self.handle_unsubscribe(request).await,
79 "completion/complete" => self.handle_complete(request).await,
80 "logging/setLevel" => self.handle_set_level(request).await,
81 "ping" => self.handle_ping(request).await,
82 _ => self.handle_custom_method(request).await,
83 };
84
85 match result {
86 Ok(response) => {
87 let response = self.middleware.process_response(response, &context).await?;
89 Ok(response)
90 }
91 Err(error) => {
92 error!("Request failed: {}", error);
93 Ok(Response {
94 jsonrpc: "2.0".to_string(),
95 id: request_id,
96 result: None,
97 error: Some(error),
98 })
99 }
100 }
101 }
102
103 async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
104 let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
105
106 let result = InitializeResult {
107 protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
108 capabilities: self.backend.get_server_info().capabilities,
109 server_info: self.backend.get_server_info().server_info.clone(),
110 instructions: Some(String::new()), };
112
113 Ok(Response {
114 jsonrpc: "2.0".to_string(),
115 id: request.id,
116 result: Some(serde_json::to_value(result)?),
117 error: None,
118 })
119 }
120
121 async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
122 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
123
124 let result = self
125 .backend
126 .list_tools(params)
127 .await
128 .map_err(|e| e.into())?;
129
130 Ok(Response {
131 jsonrpc: "2.0".to_string(),
132 id: request.id,
133 result: Some(serde_json::to_value(result)?),
134 error: None,
135 })
136 }
137
138 async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
139 let params: CallToolRequestParam = serde_json::from_value(request.params)?;
140
141 let result = self.backend.call_tool(params).await.map_err(|e| e.into())?;
142
143 Ok(Response {
144 jsonrpc: "2.0".to_string(),
145 id: request.id,
146 result: Some(serde_json::to_value(result)?),
147 error: None,
148 })
149 }
150
151 async fn handle_list_resources(
152 &self,
153 request: Request,
154 ) -> std::result::Result<Response, Error> {
155 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
156
157 let result = self
158 .backend
159 .list_resources(params)
160 .await
161 .map_err(|e| e.into())?;
162
163 Ok(Response {
164 jsonrpc: "2.0".to_string(),
165 id: request.id,
166 result: Some(serde_json::to_value(result)?),
167 error: None,
168 })
169 }
170
171 async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
172 let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
173
174 let result = self
175 .backend
176 .read_resource(params)
177 .await
178 .map_err(|e| e.into())?;
179
180 Ok(Response {
181 jsonrpc: "2.0".to_string(),
182 id: request.id,
183 result: Some(serde_json::to_value(result)?),
184 error: None,
185 })
186 }
187
188 async fn handle_list_resource_templates(
189 &self,
190 request: Request,
191 ) -> std::result::Result<Response, Error> {
192 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
193
194 let result = self
195 .backend
196 .list_resource_templates(params)
197 .await
198 .map_err(|e| e.into())?;
199
200 Ok(Response {
201 jsonrpc: "2.0".to_string(),
202 id: request.id,
203 result: Some(serde_json::to_value(result)?),
204 error: None,
205 })
206 }
207
208 async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
209 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
210
211 let result = self
212 .backend
213 .list_prompts(params)
214 .await
215 .map_err(|e| e.into())?;
216
217 Ok(Response {
218 jsonrpc: "2.0".to_string(),
219 id: request.id,
220 result: Some(serde_json::to_value(result)?),
221 error: None,
222 })
223 }
224
225 async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
226 let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
227
228 let result = self
229 .backend
230 .get_prompt(params)
231 .await
232 .map_err(|e| e.into())?;
233
234 Ok(Response {
235 jsonrpc: "2.0".to_string(),
236 id: request.id,
237 result: Some(serde_json::to_value(result)?),
238 error: None,
239 })
240 }
241
242 async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
243 let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
244
245 self.backend.subscribe(params).await.map_err(|e| e.into())?;
246
247 Ok(Response {
248 jsonrpc: "2.0".to_string(),
249 id: request.id,
250 result: Some(serde_json::Value::Object(Default::default())),
251 error: None,
252 })
253 }
254
255 async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
256 let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
257
258 self.backend
259 .unsubscribe(params)
260 .await
261 .map_err(|e| e.into())?;
262
263 Ok(Response {
264 jsonrpc: "2.0".to_string(),
265 id: request.id,
266 result: Some(serde_json::Value::Object(Default::default())),
267 error: None,
268 })
269 }
270
271 async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
272 let params: CompleteRequestParam = serde_json::from_value(request.params)?;
273
274 let result = self.backend.complete(params).await.map_err(|e| e.into())?;
275
276 Ok(Response {
277 jsonrpc: "2.0".to_string(),
278 id: request.id,
279 result: Some(serde_json::to_value(result)?),
280 error: None,
281 })
282 }
283
284 async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
285 let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
286
287 self.backend.set_level(params).await.map_err(|e| e.into())?;
288
289 Ok(Response {
290 jsonrpc: "2.0".to_string(),
291 id: request.id,
292 result: Some(serde_json::Value::Object(Default::default())),
293 error: None,
294 })
295 }
296
297 async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
298 Ok(Response {
299 jsonrpc: "2.0".to_string(),
300 id: _request.id,
301 result: Some(serde_json::Value::Object(Default::default())),
302 error: None,
303 })
304 }
305
306 async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
307 let result = self
308 .backend
309 .handle_custom_method(&request.method, request.params)
310 .await
311 .map_err(|e| e.into())?;
312
313 Ok(Response {
314 jsonrpc: "2.0".to_string(),
315 id: request.id,
316 result: Some(result),
317 error: None,
318 })
319 }
320}
321
322impl From<HandlerError> for Error {
324 fn from(err: HandlerError) -> Self {
325 match err {
326 HandlerError::Authentication(msg) => Error::unauthorized(msg),
327 HandlerError::Authorization(msg) => Error::forbidden(msg),
328 HandlerError::Backend(msg) => Error::internal_error(msg),
329 HandlerError::Protocol(e) => e,
330 }
331 }
332}