1use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_protocol::*;
6
7use std::sync::Arc;
8use thiserror::Error;
9use tracing::{debug, error, instrument};
10
11#[derive(Debug, Error)]
13pub enum HandlerError {
14 #[error("Authentication failed: {0}")]
15 Authentication(String),
16
17 #[error("Authorization failed: {0}")]
18 Authorization(String),
19
20 #[error("Backend error: {0}")]
21 Backend(String),
22
23 #[error("Protocol error: {0}")]
24 Protocol(#[from] Error),
25}
26
27#[derive(Clone)]
29pub struct GenericServerHandler<B: McpBackend> {
30 backend: Arc<B>,
31 #[allow(dead_code)]
32 auth_manager: Arc<AuthenticationManager>,
33 middleware: MiddlewareStack,
34}
35
36impl<B: McpBackend> GenericServerHandler<B> {
37 pub fn new(
39 backend: Arc<B>,
40 auth_manager: Arc<AuthenticationManager>,
41 middleware: MiddlewareStack,
42 ) -> Self {
43 Self {
44 backend,
45 auth_manager,
46 middleware,
47 }
48 }
49
50 #[instrument(skip(self, request))]
52 pub async fn handle_request(
53 &self,
54 request: Request,
55 ) -> std::result::Result<Response, HandlerError> {
56 debug!("Handling request: {}", request.method);
57
58 let request_id = request.id.clone();
60
61 let context = RequestContext::new();
63
64 let request = self.middleware.process_request(request, &context).await?;
66
67 let result = match request.method.as_str() {
69 "initialize" => self.handle_initialize(request).await,
70 "tools/list" => self.handle_list_tools(request).await,
71 "tools/call" => self.handle_call_tool(request).await,
72 "resources/list" => self.handle_list_resources(request).await,
73 "resources/read" => self.handle_read_resource(request).await,
74 "resources/templates/list" => self.handle_list_resource_templates(request).await,
75 "prompts/list" => self.handle_list_prompts(request).await,
76 "prompts/get" => self.handle_get_prompt(request).await,
77 "resources/subscribe" => self.handle_subscribe(request).await,
78 "resources/unsubscribe" => self.handle_unsubscribe(request).await,
79 "completion/complete" => self.handle_complete(request).await,
80 "logging/setLevel" => self.handle_set_level(request).await,
81 "ping" => self.handle_ping(request).await,
82 _ => self.handle_custom_method(request).await,
83 };
84
85 match result {
86 Ok(response) => {
87 let response = self.middleware.process_response(response, &context).await?;
89 Ok(response)
90 }
91 Err(error) => {
92 error!("Request failed: {}", error);
93 Ok(Response {
94 jsonrpc: "2.0".to_string(),
95 id: request_id,
96 result: None,
97 error: Some(error),
98 })
99 }
100 }
101 }
102
103 async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
104 let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
105
106 let server_info = self.backend.get_server_info();
107 let result = InitializeResult {
108 protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
109 capabilities: server_info.capabilities,
110 server_info: server_info.server_info.clone(),
111 instructions: server_info.instructions,
112 };
113
114 Ok(Response {
115 jsonrpc: "2.0".to_string(),
116 id: request.id,
117 result: Some(serde_json::to_value(result)?),
118 error: None,
119 })
120 }
121
122 async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
123 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
124
125 let result = self
126 .backend
127 .list_tools(params)
128 .await
129 .map_err(|e| e.into())?;
130
131 Ok(Response {
132 jsonrpc: "2.0".to_string(),
133 id: request.id,
134 result: Some(serde_json::to_value(result)?),
135 error: None,
136 })
137 }
138
139 async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
140 let params: CallToolRequestParam = serde_json::from_value(request.params)?;
141
142 let result = self.backend.call_tool(params).await.map_err(|e| e.into())?;
143
144 Ok(Response {
145 jsonrpc: "2.0".to_string(),
146 id: request.id,
147 result: Some(serde_json::to_value(result)?),
148 error: None,
149 })
150 }
151
152 async fn handle_list_resources(
153 &self,
154 request: Request,
155 ) -> std::result::Result<Response, Error> {
156 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
157
158 let result = self
159 .backend
160 .list_resources(params)
161 .await
162 .map_err(|e| e.into())?;
163
164 Ok(Response {
165 jsonrpc: "2.0".to_string(),
166 id: request.id,
167 result: Some(serde_json::to_value(result)?),
168 error: None,
169 })
170 }
171
172 async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
173 let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
174
175 let result = self
176 .backend
177 .read_resource(params)
178 .await
179 .map_err(|e| e.into())?;
180
181 Ok(Response {
182 jsonrpc: "2.0".to_string(),
183 id: request.id,
184 result: Some(serde_json::to_value(result)?),
185 error: None,
186 })
187 }
188
189 async fn handle_list_resource_templates(
190 &self,
191 request: Request,
192 ) -> std::result::Result<Response, Error> {
193 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
194
195 let result = self
196 .backend
197 .list_resource_templates(params)
198 .await
199 .map_err(|e| e.into())?;
200
201 Ok(Response {
202 jsonrpc: "2.0".to_string(),
203 id: request.id,
204 result: Some(serde_json::to_value(result)?),
205 error: None,
206 })
207 }
208
209 async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
210 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
211
212 let result = self
213 .backend
214 .list_prompts(params)
215 .await
216 .map_err(|e| e.into())?;
217
218 Ok(Response {
219 jsonrpc: "2.0".to_string(),
220 id: request.id,
221 result: Some(serde_json::to_value(result)?),
222 error: None,
223 })
224 }
225
226 async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
227 let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
228
229 let result = self
230 .backend
231 .get_prompt(params)
232 .await
233 .map_err(|e| e.into())?;
234
235 Ok(Response {
236 jsonrpc: "2.0".to_string(),
237 id: request.id,
238 result: Some(serde_json::to_value(result)?),
239 error: None,
240 })
241 }
242
243 async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
244 let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
245
246 self.backend.subscribe(params).await.map_err(|e| e.into())?;
247
248 Ok(Response {
249 jsonrpc: "2.0".to_string(),
250 id: request.id,
251 result: Some(serde_json::Value::Object(Default::default())),
252 error: None,
253 })
254 }
255
256 async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
257 let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
258
259 self.backend
260 .unsubscribe(params)
261 .await
262 .map_err(|e| e.into())?;
263
264 Ok(Response {
265 jsonrpc: "2.0".to_string(),
266 id: request.id,
267 result: Some(serde_json::Value::Object(Default::default())),
268 error: None,
269 })
270 }
271
272 async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
273 let params: CompleteRequestParam = serde_json::from_value(request.params)?;
274
275 let result = self.backend.complete(params).await.map_err(|e| e.into())?;
276
277 Ok(Response {
278 jsonrpc: "2.0".to_string(),
279 id: request.id,
280 result: Some(serde_json::to_value(result)?),
281 error: None,
282 })
283 }
284
285 async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
286 let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
287
288 self.backend.set_level(params).await.map_err(|e| e.into())?;
289
290 Ok(Response {
291 jsonrpc: "2.0".to_string(),
292 id: request.id,
293 result: Some(serde_json::Value::Object(Default::default())),
294 error: None,
295 })
296 }
297
298 async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
299 Ok(Response {
300 jsonrpc: "2.0".to_string(),
301 id: _request.id,
302 result: Some(serde_json::Value::Object(Default::default())),
303 error: None,
304 })
305 }
306
307 async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
308 let result = self
309 .backend
310 .handle_custom_method(&request.method, request.params)
311 .await
312 .map_err(|e| e.into())?;
313
314 Ok(Response {
315 jsonrpc: "2.0".to_string(),
316 id: request.id,
317 result: Some(result),
318 error: None,
319 })
320 }
321}
322
323impl From<HandlerError> for Error {
325 fn from(err: HandlerError) -> Self {
326 match err {
327 HandlerError::Authentication(msg) => Error::unauthorized(msg),
328 HandlerError::Authorization(msg) => Error::forbidden(msg),
329 HandlerError::Backend(msg) => Error::internal_error(msg),
330 HandlerError::Protocol(e) => e,
331 }
332 }
333}