pulseengine_mcp_server/
handler.rs

1//! Generic request handler for MCP protocol
2
3use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_protocol::*;
6
7use std::sync::Arc;
8use thiserror::Error;
9use tracing::{debug, error, instrument};
10
11/// Error type for handler operations
12#[derive(Debug, Error)]
13pub enum HandlerError {
14    #[error("Authentication failed: {0}")]
15    Authentication(String),
16
17    #[error("Authorization failed: {0}")]
18    Authorization(String),
19
20    #[error("Backend error: {0}")]
21    Backend(String),
22
23    #[error("Protocol error: {0}")]
24    Protocol(#[from] Error),
25}
26
27/// Generic server handler that implements the MCP protocol
28#[derive(Clone)]
29pub struct GenericServerHandler<B: McpBackend> {
30    backend: Arc<B>,
31    #[allow(dead_code)]
32    auth_manager: Arc<AuthenticationManager>,
33    middleware: MiddlewareStack,
34}
35
36impl<B: McpBackend> GenericServerHandler<B> {
37    /// Create a new handler
38    pub fn new(
39        backend: Arc<B>,
40        auth_manager: Arc<AuthenticationManager>,
41        middleware: MiddlewareStack,
42    ) -> Self {
43        Self {
44            backend,
45            auth_manager,
46            middleware,
47        }
48    }
49
50    /// Handle an MCP request
51    #[instrument(skip(self, request))]
52    pub async fn handle_request(
53        &self,
54        request: Request,
55    ) -> std::result::Result<Response, HandlerError> {
56        debug!("Handling request: {}", request.method);
57
58        // Store request ID before moving request
59        let request_id = request.id.clone();
60
61        // Create request context
62        let context = RequestContext::new();
63
64        // Apply middleware
65        let request = self.middleware.process_request(request, &context).await?;
66
67        // Route to appropriate handler
68        let result = match request.method.as_str() {
69            "initialize" => self.handle_initialize(request).await,
70            "tools/list" => self.handle_list_tools(request).await,
71            "tools/call" => self.handle_call_tool(request).await,
72            "resources/list" => self.handle_list_resources(request).await,
73            "resources/read" => self.handle_read_resource(request).await,
74            "resources/templates/list" => self.handle_list_resource_templates(request).await,
75            "prompts/list" => self.handle_list_prompts(request).await,
76            "prompts/get" => self.handle_get_prompt(request).await,
77            "resources/subscribe" => self.handle_subscribe(request).await,
78            "resources/unsubscribe" => self.handle_unsubscribe(request).await,
79            "completion/complete" => self.handle_complete(request).await,
80            "logging/setLevel" => self.handle_set_level(request).await,
81            "ping" => self.handle_ping(request).await,
82            _ => self.handle_custom_method(request).await,
83        };
84
85        match result {
86            Ok(response) => {
87                // Apply response middleware
88                let response = self.middleware.process_response(response, &context).await?;
89                Ok(response)
90            }
91            Err(error) => {
92                error!("Request failed: {}", error);
93                Ok(Response {
94                    jsonrpc: "2.0".to_string(),
95                    id: request_id,
96                    result: None,
97                    error: Some(error),
98                })
99            }
100        }
101    }
102
103    async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
104        let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
105
106        let result = InitializeResult {
107            protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
108            capabilities: self.backend.get_server_info().capabilities,
109            server_info: self.backend.get_server_info().server_info.clone(),
110            instructions: Some(String::new()), // MCP Inspector expects a string, not null
111        };
112
113        Ok(Response {
114            jsonrpc: "2.0".to_string(),
115            id: request.id,
116            result: Some(serde_json::to_value(result)?),
117            error: None,
118        })
119    }
120
121    async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
122        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
123
124        let result = self
125            .backend
126            .list_tools(params)
127            .await
128            .map_err(|e| e.into())?;
129
130        Ok(Response {
131            jsonrpc: "2.0".to_string(),
132            id: request.id,
133            result: Some(serde_json::to_value(result)?),
134            error: None,
135        })
136    }
137
138    async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
139        let params: CallToolRequestParam = serde_json::from_value(request.params)?;
140
141        let result = self.backend.call_tool(params).await.map_err(|e| e.into())?;
142
143        Ok(Response {
144            jsonrpc: "2.0".to_string(),
145            id: request.id,
146            result: Some(serde_json::to_value(result)?),
147            error: None,
148        })
149    }
150
151    async fn handle_list_resources(
152        &self,
153        request: Request,
154    ) -> std::result::Result<Response, Error> {
155        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
156
157        let result = self
158            .backend
159            .list_resources(params)
160            .await
161            .map_err(|e| e.into())?;
162
163        Ok(Response {
164            jsonrpc: "2.0".to_string(),
165            id: request.id,
166            result: Some(serde_json::to_value(result)?),
167            error: None,
168        })
169    }
170
171    async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
172        let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
173
174        let result = self
175            .backend
176            .read_resource(params)
177            .await
178            .map_err(|e| e.into())?;
179
180        Ok(Response {
181            jsonrpc: "2.0".to_string(),
182            id: request.id,
183            result: Some(serde_json::to_value(result)?),
184            error: None,
185        })
186    }
187
188    async fn handle_list_resource_templates(
189        &self,
190        request: Request,
191    ) -> std::result::Result<Response, Error> {
192        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
193
194        let result = self
195            .backend
196            .list_resource_templates(params)
197            .await
198            .map_err(|e| e.into())?;
199
200        Ok(Response {
201            jsonrpc: "2.0".to_string(),
202            id: request.id,
203            result: Some(serde_json::to_value(result)?),
204            error: None,
205        })
206    }
207
208    async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
209        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
210
211        let result = self
212            .backend
213            .list_prompts(params)
214            .await
215            .map_err(|e| e.into())?;
216
217        Ok(Response {
218            jsonrpc: "2.0".to_string(),
219            id: request.id,
220            result: Some(serde_json::to_value(result)?),
221            error: None,
222        })
223    }
224
225    async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
226        let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
227
228        let result = self
229            .backend
230            .get_prompt(params)
231            .await
232            .map_err(|e| e.into())?;
233
234        Ok(Response {
235            jsonrpc: "2.0".to_string(),
236            id: request.id,
237            result: Some(serde_json::to_value(result)?),
238            error: None,
239        })
240    }
241
242    async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
243        let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
244
245        self.backend.subscribe(params).await.map_err(|e| e.into())?;
246
247        Ok(Response {
248            jsonrpc: "2.0".to_string(),
249            id: request.id,
250            result: Some(serde_json::Value::Object(Default::default())),
251            error: None,
252        })
253    }
254
255    async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
256        let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
257
258        self.backend
259            .unsubscribe(params)
260            .await
261            .map_err(|e| e.into())?;
262
263        Ok(Response {
264            jsonrpc: "2.0".to_string(),
265            id: request.id,
266            result: Some(serde_json::Value::Object(Default::default())),
267            error: None,
268        })
269    }
270
271    async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
272        let params: CompleteRequestParam = serde_json::from_value(request.params)?;
273
274        let result = self.backend.complete(params).await.map_err(|e| e.into())?;
275
276        Ok(Response {
277            jsonrpc: "2.0".to_string(),
278            id: request.id,
279            result: Some(serde_json::to_value(result)?),
280            error: None,
281        })
282    }
283
284    async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
285        let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
286
287        self.backend.set_level(params).await.map_err(|e| e.into())?;
288
289        Ok(Response {
290            jsonrpc: "2.0".to_string(),
291            id: request.id,
292            result: Some(serde_json::Value::Object(Default::default())),
293            error: None,
294        })
295    }
296
297    async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
298        Ok(Response {
299            jsonrpc: "2.0".to_string(),
300            id: _request.id,
301            result: Some(serde_json::Value::Object(Default::default())),
302            error: None,
303        })
304    }
305
306    async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
307        let result = self
308            .backend
309            .handle_custom_method(&request.method, request.params)
310            .await
311            .map_err(|e| e.into())?;
312
313        Ok(Response {
314            jsonrpc: "2.0".to_string(),
315            id: request.id,
316            result: Some(result),
317            error: None,
318        })
319    }
320}
321
322// Convert HandlerError to protocol Error
323impl From<HandlerError> for Error {
324    fn from(err: HandlerError) -> Self {
325        match err {
326            HandlerError::Authentication(msg) => Error::unauthorized(msg),
327            HandlerError::Authorization(msg) => Error::forbidden(msg),
328            HandlerError::Backend(msg) => Error::internal_error(msg),
329            HandlerError::Protocol(e) => e,
330        }
331    }
332}