pulseengine_mcp_server/
handler.rs

1//! Generic request handler for MCP protocol
2
3use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_protocol::*;
6
7use std::sync::Arc;
8use thiserror::Error;
9use tracing::{debug, error, instrument};
10
11/// Error type for handler operations
12#[derive(Debug, Error)]
13pub enum HandlerError {
14    #[error("Authentication failed: {0}")]
15    Authentication(String),
16
17    #[error("Authorization failed: {0}")]
18    Authorization(String),
19
20    #[error("Backend error: {0}")]
21    Backend(String),
22
23    #[error("Protocol error: {0}")]
24    Protocol(#[from] Error),
25}
26
27/// Generic server handler that implements the MCP protocol
28#[derive(Clone)]
29pub struct GenericServerHandler<B: McpBackend> {
30    backend: Arc<B>,
31    #[allow(dead_code)]
32    auth_manager: Arc<AuthenticationManager>,
33    middleware: MiddlewareStack,
34}
35
36impl<B: McpBackend> GenericServerHandler<B> {
37    /// Create a new handler
38    pub fn new(
39        backend: Arc<B>,
40        auth_manager: Arc<AuthenticationManager>,
41        middleware: MiddlewareStack,
42    ) -> Self {
43        Self {
44            backend,
45            auth_manager,
46            middleware,
47        }
48    }
49
50    /// Handle an MCP request
51    #[instrument(skip(self, request))]
52    pub async fn handle_request(
53        &self,
54        request: Request,
55    ) -> std::result::Result<Response, HandlerError> {
56        debug!("Handling request: {}", request.method);
57
58        // Store request ID before moving request
59        let request_id = request.id.clone();
60
61        // Create request context
62        let context = RequestContext::new();
63
64        // Apply middleware
65        let request = self.middleware.process_request(request, &context).await?;
66
67        // Route to appropriate handler
68        let result = match request.method.as_str() {
69            "initialize" => self.handle_initialize(request).await,
70            "tools/list" => self.handle_list_tools(request).await,
71            "tools/call" => self.handle_call_tool(request).await,
72            "resources/list" => self.handle_list_resources(request).await,
73            "resources/read" => self.handle_read_resource(request).await,
74            "resources/templates/list" => self.handle_list_resource_templates(request).await,
75            "prompts/list" => self.handle_list_prompts(request).await,
76            "prompts/get" => self.handle_get_prompt(request).await,
77            "resources/subscribe" => self.handle_subscribe(request).await,
78            "resources/unsubscribe" => self.handle_unsubscribe(request).await,
79            "completion/complete" => self.handle_complete(request).await,
80            "logging/setLevel" => self.handle_set_level(request).await,
81            "ping" => self.handle_ping(request).await,
82            _ => self.handle_custom_method(request).await,
83        };
84
85        match result {
86            Ok(response) => {
87                // Apply response middleware
88                let response = self.middleware.process_response(response, &context).await?;
89                Ok(response)
90            }
91            Err(error) => {
92                error!("Request failed: {}", error);
93                Ok(Response {
94                    jsonrpc: "2.0".to_string(),
95                    id: request_id,
96                    result: None,
97                    error: Some(error),
98                })
99            }
100        }
101    }
102
103    async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
104        let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
105
106        let server_info = self.backend.get_server_info();
107        let result = InitializeResult {
108            protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
109            capabilities: server_info.capabilities,
110            server_info: server_info.server_info.clone(),
111            instructions: server_info.instructions,
112        };
113
114        Ok(Response {
115            jsonrpc: "2.0".to_string(),
116            id: request.id,
117            result: Some(serde_json::to_value(result)?),
118            error: None,
119        })
120    }
121
122    async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
123        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
124
125        let result = self
126            .backend
127            .list_tools(params)
128            .await
129            .map_err(|e| e.into())?;
130
131        Ok(Response {
132            jsonrpc: "2.0".to_string(),
133            id: request.id,
134            result: Some(serde_json::to_value(result)?),
135            error: None,
136        })
137    }
138
139    async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
140        let params: CallToolRequestParam = serde_json::from_value(request.params)?;
141
142        let result = self.backend.call_tool(params).await.map_err(|e| e.into())?;
143
144        Ok(Response {
145            jsonrpc: "2.0".to_string(),
146            id: request.id,
147            result: Some(serde_json::to_value(result)?),
148            error: None,
149        })
150    }
151
152    async fn handle_list_resources(
153        &self,
154        request: Request,
155    ) -> std::result::Result<Response, Error> {
156        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
157
158        let result = self
159            .backend
160            .list_resources(params)
161            .await
162            .map_err(|e| e.into())?;
163
164        Ok(Response {
165            jsonrpc: "2.0".to_string(),
166            id: request.id,
167            result: Some(serde_json::to_value(result)?),
168            error: None,
169        })
170    }
171
172    async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
173        let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
174
175        let result = self
176            .backend
177            .read_resource(params)
178            .await
179            .map_err(|e| e.into())?;
180
181        Ok(Response {
182            jsonrpc: "2.0".to_string(),
183            id: request.id,
184            result: Some(serde_json::to_value(result)?),
185            error: None,
186        })
187    }
188
189    async fn handle_list_resource_templates(
190        &self,
191        request: Request,
192    ) -> std::result::Result<Response, Error> {
193        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
194
195        let result = self
196            .backend
197            .list_resource_templates(params)
198            .await
199            .map_err(|e| e.into())?;
200
201        Ok(Response {
202            jsonrpc: "2.0".to_string(),
203            id: request.id,
204            result: Some(serde_json::to_value(result)?),
205            error: None,
206        })
207    }
208
209    async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
210        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
211
212        let result = self
213            .backend
214            .list_prompts(params)
215            .await
216            .map_err(|e| e.into())?;
217
218        Ok(Response {
219            jsonrpc: "2.0".to_string(),
220            id: request.id,
221            result: Some(serde_json::to_value(result)?),
222            error: None,
223        })
224    }
225
226    async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
227        let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
228
229        let result = self
230            .backend
231            .get_prompt(params)
232            .await
233            .map_err(|e| e.into())?;
234
235        Ok(Response {
236            jsonrpc: "2.0".to_string(),
237            id: request.id,
238            result: Some(serde_json::to_value(result)?),
239            error: None,
240        })
241    }
242
243    async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
244        let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
245
246        self.backend.subscribe(params).await.map_err(|e| e.into())?;
247
248        Ok(Response {
249            jsonrpc: "2.0".to_string(),
250            id: request.id,
251            result: Some(serde_json::Value::Object(Default::default())),
252            error: None,
253        })
254    }
255
256    async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
257        let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
258
259        self.backend
260            .unsubscribe(params)
261            .await
262            .map_err(|e| e.into())?;
263
264        Ok(Response {
265            jsonrpc: "2.0".to_string(),
266            id: request.id,
267            result: Some(serde_json::Value::Object(Default::default())),
268            error: None,
269        })
270    }
271
272    async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
273        let params: CompleteRequestParam = serde_json::from_value(request.params)?;
274
275        let result = self.backend.complete(params).await.map_err(|e| e.into())?;
276
277        Ok(Response {
278            jsonrpc: "2.0".to_string(),
279            id: request.id,
280            result: Some(serde_json::to_value(result)?),
281            error: None,
282        })
283    }
284
285    async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
286        let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
287
288        self.backend.set_level(params).await.map_err(|e| e.into())?;
289
290        Ok(Response {
291            jsonrpc: "2.0".to_string(),
292            id: request.id,
293            result: Some(serde_json::Value::Object(Default::default())),
294            error: None,
295        })
296    }
297
298    async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
299        Ok(Response {
300            jsonrpc: "2.0".to_string(),
301            id: _request.id,
302            result: Some(serde_json::Value::Object(Default::default())),
303            error: None,
304        })
305    }
306
307    async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
308        let result = self
309            .backend
310            .handle_custom_method(&request.method, request.params)
311            .await
312            .map_err(|e| e.into())?;
313
314        Ok(Response {
315            jsonrpc: "2.0".to_string(),
316            id: request.id,
317            result: Some(result),
318            error: None,
319        })
320    }
321}
322
323// Convert HandlerError to protocol Error
324impl From<HandlerError> for Error {
325    fn from(err: HandlerError) -> Self {
326        match err {
327            HandlerError::Authentication(msg) => Error::unauthorized(msg),
328            HandlerError::Authorization(msg) => Error::forbidden(msg),
329            HandlerError::Backend(msg) => Error::internal_error(msg),
330            HandlerError::Protocol(e) => e,
331        }
332    }
333}