pulseengine_mcp_server/
handler.rs

1//! Generic request handler for MCP protocol
2
3use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_logging::{get_metrics, spans};
6use pulseengine_mcp_protocol::*;
7
8use std::sync::Arc;
9use std::time::Instant;
10use thiserror::Error;
11use tracing::{debug, error, info, instrument};
12
13/// Error type for handler operations
14#[derive(Debug, Error)]
15pub enum HandlerError {
16    #[error("Authentication failed: {0}")]
17    Authentication(String),
18
19    #[error("Authorization failed: {0}")]
20    Authorization(String),
21
22    #[error("Backend error: {0}")]
23    Backend(String),
24
25    #[error("Protocol error: {0}")]
26    Protocol(#[from] Error),
27}
28
29// Implement ErrorClassification for HandlerError
30impl pulseengine_mcp_logging::ErrorClassification for HandlerError {
31    fn error_type(&self) -> &str {
32        match self {
33            HandlerError::Authentication(_) => "authentication",
34            HandlerError::Authorization(_) => "authorization",
35            HandlerError::Backend(_) => "backend",
36            HandlerError::Protocol(_) => "protocol",
37        }
38    }
39
40    fn is_retryable(&self) -> bool {
41        match self {
42            HandlerError::Backend(_) => true, // Backend errors might be temporary
43            _ => false,
44        }
45    }
46
47    fn is_timeout(&self) -> bool {
48        false // HandlerError doesn't represent timeouts directly
49    }
50
51    fn is_auth_error(&self) -> bool {
52        matches!(
53            self,
54            HandlerError::Authentication(_) | HandlerError::Authorization(_)
55        )
56    }
57
58    fn is_connection_error(&self) -> bool {
59        false // HandlerError doesn't represent connection errors directly
60    }
61}
62
63/// Generic server handler that implements the MCP protocol
64#[derive(Clone)]
65pub struct GenericServerHandler<B: McpBackend> {
66    backend: Arc<B>,
67    #[allow(dead_code)]
68    auth_manager: Arc<AuthenticationManager>,
69    middleware: MiddlewareStack,
70}
71
72impl<B: McpBackend> GenericServerHandler<B> {
73    /// Create a new handler
74    pub fn new(
75        backend: Arc<B>,
76        auth_manager: Arc<AuthenticationManager>,
77        middleware: MiddlewareStack,
78    ) -> Self {
79        Self {
80            backend,
81            auth_manager,
82            middleware,
83        }
84    }
85
86    /// Handle an MCP request
87    #[instrument(skip(self, request), fields(mcp.method = %request.method, mcp.request_id = ?request.id))]
88    pub async fn handle_request(
89        &self,
90        request: Request,
91    ) -> std::result::Result<Response, HandlerError> {
92        let start_time = Instant::now();
93        let method = request.method.clone();
94        debug!("Handling request: {}", method);
95
96        // Store request ID before moving request
97        let request_id = request.id.clone();
98
99        // Create request context
100        let context = RequestContext::new();
101
102        // Get metrics collector
103        let metrics = get_metrics();
104
105        // Record request start
106        metrics.record_request_start(&method).await;
107
108        // Apply middleware
109        let request = self.middleware.process_request(request, &context).await?;
110
111        // Route to appropriate handler with tracing
112        let result = {
113            let request_id_str = request_id
114                .as_ref()
115                .map(|id| id.to_string())
116                .unwrap_or_else(|| "none".to_string());
117            let span = spans::mcp_request_span(&method, &request_id_str);
118            let _guard = span.enter();
119
120            match request.method.as_str() {
121                "initialize" => self.handle_initialize(request).await,
122                "tools/list" => self.handle_list_tools(request).await,
123                "tools/call" => self.handle_call_tool(request).await,
124                "resources/list" => self.handle_list_resources(request).await,
125                "resources/read" => self.handle_read_resource(request).await,
126                "resources/templates/list" => self.handle_list_resource_templates(request).await,
127                "prompts/list" => self.handle_list_prompts(request).await,
128                "prompts/get" => self.handle_get_prompt(request).await,
129                "resources/subscribe" => self.handle_subscribe(request).await,
130                "resources/unsubscribe" => self.handle_unsubscribe(request).await,
131                "completion/complete" => self.handle_complete(request).await,
132                "elicitation/create" => self.handle_elicit(request).await,
133                "logging/setLevel" => self.handle_set_level(request).await,
134                "ping" => self.handle_ping(request).await,
135                _ => self.handle_custom_method(request).await,
136            }
137        };
138
139        // Calculate request duration
140        let duration = start_time.elapsed();
141
142        match result {
143            Ok(response) => {
144                // Record successful request
145                metrics.record_request_end(&method, duration, true).await;
146
147                // Apply response middleware
148                let response = self.middleware.process_response(response, &context).await?;
149
150                info!(
151                    method = %method,
152                    duration_ms = %duration.as_millis(),
153                    request_id = ?request_id,
154                    "Request completed successfully"
155                );
156
157                Ok(response)
158            }
159            Err(error) => {
160                // Record failed request
161                metrics.record_request_end(&method, duration, false).await;
162
163                // Record error details
164                metrics
165                    .record_error(&method, &context.request_id.to_string(), &error, duration)
166                    .await;
167
168                error!(
169                    method = %method,
170                    duration_ms = %duration.as_millis(),
171                    request_id = ?request_id,
172                    error = %error,
173                    "Request failed"
174                );
175
176                Ok(Response {
177                    jsonrpc: "2.0".to_string(),
178                    id: request_id,
179                    result: None,
180                    error: Some(error),
181                })
182            }
183        }
184    }
185
186    #[instrument(skip(self, request), fields(mcp.method = "initialize"))]
187    async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
188        let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
189
190        let server_info = self.backend.get_server_info();
191        let result = InitializeResult {
192            protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
193            capabilities: server_info.capabilities,
194            server_info: server_info.server_info.clone(),
195            instructions: server_info.instructions,
196        };
197
198        Ok(Response {
199            jsonrpc: "2.0".to_string(),
200            id: request.id,
201            result: Some(serde_json::to_value(result)?),
202            error: None,
203        })
204    }
205
206    #[instrument(skip(self, request), fields(mcp.method = "tools/list"))]
207    async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
208        let params: PaginatedRequestParam = if request.params.is_null() {
209            PaginatedRequestParam { cursor: None }
210        } else {
211            serde_json::from_value(request.params)?
212        };
213
214        let result = self
215            .backend
216            .list_tools(params)
217            .await
218            .map_err(|e| e.into())?;
219
220        Ok(Response {
221            jsonrpc: "2.0".to_string(),
222            id: request.id,
223            result: Some(serde_json::to_value(result)?),
224            error: None,
225        })
226    }
227
228    #[instrument(skip(self, request), fields(mcp.method = "tools/call"))]
229    async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
230        let params: CallToolRequestParam = serde_json::from_value(request.params)?;
231        let tool_name = params.name.clone();
232        let start_time = Instant::now();
233
234        // Get metrics collector for tool-specific tracking
235        let metrics = get_metrics();
236        metrics.record_request_start(&tool_name).await;
237
238        let result = {
239            let span = spans::backend_operation_span("call_tool", Some(&tool_name));
240            let _guard = span.enter();
241            match self.backend.call_tool(params).await {
242                Ok(result) => {
243                    let duration = start_time.elapsed();
244                    metrics.record_request_end(&tool_name, duration, true).await;
245                    info!(
246                        tool = %tool_name,
247                        duration_ms = %duration.as_millis(),
248                        "Tool call completed successfully"
249                    );
250                    result
251                }
252                Err(err) => {
253                    let duration = start_time.elapsed();
254                    metrics
255                        .record_request_end(&tool_name, duration, false)
256                        .await;
257                    error!(
258                        tool = %tool_name,
259                        duration_ms = %duration.as_millis(),
260                        error = %err,
261                        "Tool call failed"
262                    );
263                    return Err(err.into());
264                }
265            }
266        };
267
268        Ok(Response {
269            jsonrpc: "2.0".to_string(),
270            id: request.id,
271            result: Some(serde_json::to_value(result)?),
272            error: None,
273        })
274    }
275
276    async fn handle_list_resources(
277        &self,
278        request: Request,
279    ) -> std::result::Result<Response, Error> {
280        let params: PaginatedRequestParam = if request.params.is_null() {
281            PaginatedRequestParam { cursor: None }
282        } else {
283            serde_json::from_value(request.params)?
284        };
285
286        let result = self
287            .backend
288            .list_resources(params)
289            .await
290            .map_err(|e| e.into())?;
291
292        Ok(Response {
293            jsonrpc: "2.0".to_string(),
294            id: request.id,
295            result: Some(serde_json::to_value(result)?),
296            error: None,
297        })
298    }
299
300    async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
301        let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
302
303        let result = self
304            .backend
305            .read_resource(params)
306            .await
307            .map_err(|e| e.into())?;
308
309        Ok(Response {
310            jsonrpc: "2.0".to_string(),
311            id: request.id,
312            result: Some(serde_json::to_value(result)?),
313            error: None,
314        })
315    }
316
317    async fn handle_list_resource_templates(
318        &self,
319        request: Request,
320    ) -> std::result::Result<Response, Error> {
321        let params: PaginatedRequestParam = if request.params.is_null() {
322            PaginatedRequestParam { cursor: None }
323        } else {
324            serde_json::from_value(request.params)?
325        };
326
327        let result = self
328            .backend
329            .list_resource_templates(params)
330            .await
331            .map_err(|e| e.into())?;
332
333        Ok(Response {
334            jsonrpc: "2.0".to_string(),
335            id: request.id,
336            result: Some(serde_json::to_value(result)?),
337            error: None,
338        })
339    }
340
341    async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
342        let params: PaginatedRequestParam = if request.params.is_null() {
343            PaginatedRequestParam { cursor: None }
344        } else {
345            serde_json::from_value(request.params)?
346        };
347
348        let result = self
349            .backend
350            .list_prompts(params)
351            .await
352            .map_err(|e| e.into())?;
353
354        Ok(Response {
355            jsonrpc: "2.0".to_string(),
356            id: request.id,
357            result: Some(serde_json::to_value(result)?),
358            error: None,
359        })
360    }
361
362    async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
363        let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
364
365        let result = self
366            .backend
367            .get_prompt(params)
368            .await
369            .map_err(|e| e.into())?;
370
371        Ok(Response {
372            jsonrpc: "2.0".to_string(),
373            id: request.id,
374            result: Some(serde_json::to_value(result)?),
375            error: None,
376        })
377    }
378
379    async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
380        let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
381
382        self.backend.subscribe(params).await.map_err(|e| e.into())?;
383
384        Ok(Response {
385            jsonrpc: "2.0".to_string(),
386            id: request.id,
387            result: Some(serde_json::Value::Object(Default::default())),
388            error: None,
389        })
390    }
391
392    async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
393        let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
394
395        self.backend
396            .unsubscribe(params)
397            .await
398            .map_err(|e| e.into())?;
399
400        Ok(Response {
401            jsonrpc: "2.0".to_string(),
402            id: request.id,
403            result: Some(serde_json::Value::Object(Default::default())),
404            error: None,
405        })
406    }
407
408    async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
409        let params: CompleteRequestParam = serde_json::from_value(request.params)?;
410
411        let result = self.backend.complete(params).await.map_err(|e| e.into())?;
412
413        Ok(Response {
414            jsonrpc: "2.0".to_string(),
415            id: request.id,
416            result: Some(serde_json::to_value(result)?),
417            error: None,
418        })
419    }
420
421    async fn handle_elicit(&self, request: Request) -> std::result::Result<Response, Error> {
422        let params: ElicitationRequestParam = serde_json::from_value(request.params)?;
423
424        let result = self.backend.elicit(params).await.map_err(|e| e.into())?;
425
426        Ok(Response {
427            jsonrpc: "2.0".to_string(),
428            id: request.id,
429            result: Some(serde_json::to_value(result)?),
430            error: None,
431        })
432    }
433
434    async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
435        let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
436
437        self.backend.set_level(params).await.map_err(|e| e.into())?;
438
439        Ok(Response {
440            jsonrpc: "2.0".to_string(),
441            id: request.id,
442            result: Some(serde_json::Value::Object(Default::default())),
443            error: None,
444        })
445    }
446
447    async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
448        Ok(Response {
449            jsonrpc: "2.0".to_string(),
450            id: _request.id,
451            result: Some(serde_json::Value::Object(Default::default())),
452            error: None,
453        })
454    }
455
456    async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
457        let result = self
458            .backend
459            .handle_custom_method(&request.method, request.params)
460            .await
461            .map_err(|e| e.into())?;
462
463        Ok(Response {
464            jsonrpc: "2.0".to_string(),
465            id: request.id,
466            result: Some(result),
467            error: None,
468        })
469    }
470}
471
472// Convert HandlerError to protocol Error
473impl From<HandlerError> for Error {
474    fn from(err: HandlerError) -> Self {
475        match err {
476            HandlerError::Authentication(msg) => Error::unauthorized(msg),
477            HandlerError::Authorization(msg) => Error::forbidden(msg),
478            HandlerError::Backend(msg) => Error::internal_error(msg),
479            HandlerError::Protocol(e) => e,
480        }
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use crate::backend::McpBackend;
488    use crate::middleware::MiddlewareStack;
489    use async_trait::async_trait;
490    use pulseengine_mcp_auth::AuthenticationManager;
491    use pulseengine_mcp_auth::config::AuthConfig;
492    use pulseengine_mcp_logging::ErrorClassification;
493    use pulseengine_mcp_protocol::{
494        CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult, CompletionInfo,
495        Content, Error, GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult,
496        ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, ListToolsResult,
497        LoggingCapability, PaginatedRequestParam, Prompt, PromptMessage, PromptMessageContent,
498        PromptMessageRole, PromptsCapability, ProtocolVersion, ReadResourceRequestParam,
499        ReadResourceResult, Request, Resource, ResourceContents, ResourcesCapability,
500        ServerCapabilities, ServerInfo, SetLevelRequestParam, SubscribeRequestParam, Tool,
501        ToolsCapability, UnsubscribeRequestParam, error::ErrorCode,
502    };
503    use serde_json::json;
504    use std::sync::Arc;
505
506    // Mock backend for testing
507    #[derive(Clone)]
508    struct MockBackend {
509        server_info: ServerInfo,
510        tools: Vec<Tool>,
511        resources: Vec<Resource>,
512        prompts: Vec<Prompt>,
513        should_error: bool,
514    }
515
516    impl MockBackend {
517        fn new() -> Self {
518            Self {
519                server_info: ServerInfo {
520                    protocol_version: ProtocolVersion::default(),
521                    capabilities: ServerCapabilities {
522                        tools: Some(ToolsCapability { list_changed: None }),
523                        resources: Some(ResourcesCapability {
524                            subscribe: Some(true),
525                            list_changed: None,
526                        }),
527                        prompts: Some(PromptsCapability { list_changed: None }),
528                        logging: Some(LoggingCapability { level: None }),
529                        sampling: None,
530                        elicitation: Some(ElicitationCapability {}),
531                    },
532                    server_info: Implementation {
533                        name: "test-server".to_string(),
534                        version: "1.0.0".to_string(),
535                    },
536                    instructions: None,
537                },
538                tools: vec![Tool {
539                    name: "test_tool".to_string(),
540                    description: "A test tool".to_string(),
541                    input_schema: json!({
542                        "type": "object",
543                        "properties": {
544                            "input": {"type": "string"}
545                        }
546                    }),
547                    output_schema: None,
548                    title: None,
549                    annotations: None,
550                    icons: None,
551                }],
552                resources: vec![Resource {
553                    uri: "test://resource1".to_string(),
554                    name: "Test Resource".to_string(),
555                    description: Some("A test resource".to_string()),
556                    mime_type: Some("text/plain".to_string()),
557                    annotations: None,
558                    raw: None,
559                    title: None,
560                    icons: None,
561                }],
562                prompts: vec![Prompt {
563                    name: "test_prompt".to_string(),
564                    description: Some("A test prompt".to_string()),
565                    arguments: None,
566                    title: None,
567                    icons: None,
568                }],
569                should_error: false,
570            }
571        }
572
573        fn with_error() -> Self {
574            Self {
575                should_error: true,
576                ..Self::new()
577            }
578        }
579    }
580
581    #[async_trait]
582    impl McpBackend for MockBackend {
583        type Error = MockBackendError;
584        type Config = ();
585
586        async fn initialize(_config: Self::Config) -> std::result::Result<Self, Self::Error> {
587            Ok(MockBackend::new())
588        }
589
590        fn get_server_info(&self) -> ServerInfo {
591            self.server_info.clone()
592        }
593
594        async fn health_check(&self) -> std::result::Result<(), Self::Error> {
595            if self.should_error {
596                return Err(MockBackendError::TestError(
597                    "Health check failed".to_string(),
598                ));
599            }
600            Ok(())
601        }
602
603        async fn list_tools(
604            &self,
605            _params: PaginatedRequestParam,
606        ) -> std::result::Result<ListToolsResult, Self::Error> {
607            if self.should_error {
608                return Err(MockBackendError::TestError("Simulated error".to_string()));
609            }
610
611            Ok(ListToolsResult {
612                tools: self.tools.clone(),
613                next_cursor: None,
614            })
615        }
616
617        async fn call_tool(
618            &self,
619            params: CallToolRequestParam,
620        ) -> std::result::Result<CallToolResult, Self::Error> {
621            if self.should_error {
622                return Err(MockBackendError::TestError("Tool call failed".to_string()));
623            }
624
625            if params.name == "test_tool" {
626                Ok(CallToolResult {
627                    content: vec![Content::Text {
628                        text: "Tool executed successfully".to_string(),
629                        _meta: None,
630                    }],
631                    is_error: Some(false),
632                    structured_content: None,
633                    _meta: None,
634                })
635            } else {
636                Err(MockBackendError::TestError("Tool not found".to_string()))
637            }
638        }
639
640        async fn list_resources(
641            &self,
642            _params: PaginatedRequestParam,
643        ) -> std::result::Result<ListResourcesResult, Self::Error> {
644            if self.should_error {
645                return Err(MockBackendError::TestError("Simulated error".to_string()));
646            }
647
648            Ok(ListResourcesResult {
649                resources: self.resources.clone(),
650                next_cursor: None,
651            })
652        }
653
654        async fn read_resource(
655            &self,
656            params: ReadResourceRequestParam,
657        ) -> std::result::Result<ReadResourceResult, Self::Error> {
658            if self.should_error {
659                return Err(MockBackendError::TestError("Simulated error".to_string()));
660            }
661
662            if params.uri == "test://resource1" {
663                Ok(ReadResourceResult {
664                    contents: vec![ResourceContents {
665                        uri: params.uri,
666                        mime_type: Some("text/plain".to_string()),
667                        text: Some("Resource content".to_string()),
668                        blob: None,
669                        _meta: None,
670                    }],
671                })
672            } else {
673                Err(MockBackendError::TestError(
674                    "Resource not found".to_string(),
675                ))
676            }
677        }
678
679        async fn list_resource_templates(
680            &self,
681            _params: PaginatedRequestParam,
682        ) -> std::result::Result<ListResourceTemplatesResult, Self::Error> {
683            Ok(ListResourceTemplatesResult {
684                resource_templates: vec![],
685                next_cursor: None,
686            })
687        }
688
689        async fn list_prompts(
690            &self,
691            _params: PaginatedRequestParam,
692        ) -> std::result::Result<ListPromptsResult, Self::Error> {
693            if self.should_error {
694                return Err(MockBackendError::TestError("Simulated error".to_string()));
695            }
696
697            Ok(ListPromptsResult {
698                prompts: self.prompts.clone(),
699                next_cursor: None,
700            })
701        }
702
703        async fn get_prompt(
704            &self,
705            params: GetPromptRequestParam,
706        ) -> std::result::Result<GetPromptResult, Self::Error> {
707            if self.should_error {
708                return Err(MockBackendError::TestError("Simulated error".to_string()));
709            }
710
711            if params.name == "test_prompt" {
712                Ok(GetPromptResult {
713                    description: Some("A test prompt".to_string()),
714                    messages: vec![PromptMessage {
715                        role: PromptMessageRole::User,
716                        content: PromptMessageContent::Text {
717                            text: "Test prompt message".to_string(),
718                        },
719                    }],
720                })
721            } else {
722                Err(MockBackendError::TestError("Prompt not found".to_string()))
723            }
724        }
725
726        async fn subscribe(
727            &self,
728            _params: SubscribeRequestParam,
729        ) -> std::result::Result<(), Self::Error> {
730            if self.should_error {
731                return Err(MockBackendError::TestError("Subscribe failed".to_string()));
732            }
733            Ok(())
734        }
735
736        async fn unsubscribe(
737            &self,
738            _params: UnsubscribeRequestParam,
739        ) -> std::result::Result<(), Self::Error> {
740            if self.should_error {
741                return Err(MockBackendError::TestError(
742                    "Unsubscribe failed".to_string(),
743                ));
744            }
745            Ok(())
746        }
747
748        async fn complete(
749            &self,
750            _params: CompleteRequestParam,
751        ) -> std::result::Result<CompleteResult, Self::Error> {
752            if self.should_error {
753                return Err(MockBackendError::TestError("Complete failed".to_string()));
754            }
755
756            Ok(CompleteResult {
757                completion: vec![
758                    CompletionInfo {
759                        completion: "completion1".to_string(),
760                        has_more: Some(false),
761                    },
762                    CompletionInfo {
763                        completion: "completion2".to_string(),
764                        has_more: Some(false),
765                    },
766                ],
767            })
768        }
769
770        async fn elicit(
771            &self,
772            _params: ElicitationRequestParam,
773        ) -> std::result::Result<ElicitationResult, Self::Error> {
774            if self.should_error {
775                return Err(MockBackendError::TestError(
776                    "Elicitation failed".to_string(),
777                ));
778            }
779
780            // Simulate user accepting with sample data
781            Ok(ElicitationResult::accept(serde_json::json!({
782                "name": "Test User",
783                "email": "test@example.com"
784            })))
785        }
786
787        async fn set_level(
788            &self,
789            _params: SetLevelRequestParam,
790        ) -> std::result::Result<(), Self::Error> {
791            if self.should_error {
792                return Err(MockBackendError::TestError("Set level failed".to_string()));
793            }
794            Ok(())
795        }
796
797        async fn handle_custom_method(
798            &self,
799            method: &str,
800            _params: serde_json::Value,
801        ) -> std::result::Result<serde_json::Value, Self::Error> {
802            if self.should_error {
803                return Err(MockBackendError::TestError(
804                    "Custom method failed".to_string(),
805                ));
806            }
807
808            Ok(json!({
809                "method": method,
810                "result": "custom method executed"
811            }))
812        }
813    }
814
815    #[derive(Debug, thiserror::Error)]
816    enum MockBackendError {
817        #[error("Test error: {0}")]
818        TestError(String),
819    }
820
821    impl From<MockBackendError> for Error {
822        fn from(err: MockBackendError) -> Self {
823            Error::internal_error(err.to_string())
824        }
825    }
826
827    impl From<crate::backend::BackendError> for MockBackendError {
828        fn from(error: crate::backend::BackendError) -> Self {
829            MockBackendError::TestError(error.to_string())
830        }
831    }
832
833    async fn create_test_handler() -> GenericServerHandler<MockBackend> {
834        let backend = Arc::new(MockBackend::new());
835        let auth_config = AuthConfig::memory();
836        let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
837        let middleware = MiddlewareStack::new();
838
839        GenericServerHandler::new(backend, auth_manager, middleware)
840    }
841
842    async fn create_error_handler() -> GenericServerHandler<MockBackend> {
843        let backend = Arc::new(MockBackend::with_error());
844        let auth_config = AuthConfig::memory();
845        let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
846        let middleware = MiddlewareStack::new();
847
848        GenericServerHandler::new(backend, auth_manager, middleware)
849    }
850
851    #[tokio::test]
852    async fn test_handler_creation() {
853        let handler = create_test_handler().await;
854        // Just verify the handler can be created
855        assert!(!handler.backend.tools.is_empty());
856    }
857
858    #[tokio::test]
859    async fn test_handle_initialize() {
860        let handler = create_test_handler().await;
861        let request = Request {
862            jsonrpc: "2.0".to_string(),
863            method: "initialize".to_string(),
864            params: json!({
865                "protocolVersion": "2024-11-05",
866                "capabilities": {},
867                "clientInfo": {
868                    "name": "test-client",
869                    "version": "1.0.0"
870                }
871            }),
872            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(1)),
873        };
874
875        let response = handler.handle_request(request).await.unwrap();
876
877        assert_eq!(response.jsonrpc, "2.0");
878        assert_eq!(
879            response.id,
880            Some(pulseengine_mcp_protocol::NumberOrString::Number(1))
881        );
882        assert!(response.result.is_some());
883        assert!(response.error.is_none());
884
885        let result: InitializeResult = serde_json::from_value(response.result.unwrap()).unwrap();
886        assert_eq!(
887            result.protocol_version,
888            pulseengine_mcp_protocol::MCP_VERSION
889        );
890        assert_eq!(result.server_info.name, "test-server");
891    }
892
893    #[tokio::test]
894    async fn test_handle_list_tools() {
895        let handler = create_test_handler().await;
896        let request = Request {
897            jsonrpc: "2.0".to_string(),
898            method: "tools/list".to_string(),
899            params: json!({}),
900            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(2)),
901        };
902
903        let response = handler.handle_request(request).await.unwrap();
904
905        assert_eq!(response.jsonrpc, "2.0");
906        assert_eq!(
907            response.id,
908            Some(pulseengine_mcp_protocol::NumberOrString::Number(2))
909        );
910        assert!(response.result.is_some());
911        assert!(response.error.is_none());
912
913        let result: ListToolsResult = serde_json::from_value(response.result.unwrap()).unwrap();
914        assert_eq!(result.tools.len(), 1);
915        assert_eq!(result.tools[0].name, "test_tool");
916    }
917
918    #[tokio::test]
919    async fn test_handle_call_tool_success() {
920        let handler = create_test_handler().await;
921        let request = Request {
922            jsonrpc: "2.0".to_string(),
923            method: "tools/call".to_string(),
924            params: json!({
925                "name": "test_tool",
926                "arguments": {
927                    "input": "test input"
928                }
929            }),
930            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(3)),
931        };
932
933        let response = handler.handle_request(request).await.unwrap();
934
935        assert_eq!(response.jsonrpc, "2.0");
936        assert_eq!(
937            response.id,
938            Some(pulseengine_mcp_protocol::NumberOrString::Number(3))
939        );
940        assert!(response.result.is_some());
941        assert!(response.error.is_none());
942
943        let result: CallToolResult = serde_json::from_value(response.result.unwrap()).unwrap();
944        assert_eq!(result.content.len(), 1);
945        assert!(!result.is_error.unwrap_or(true));
946    }
947
948    #[tokio::test]
949    async fn test_handle_call_tool_not_found() {
950        let handler = create_test_handler().await;
951        let request = Request {
952            jsonrpc: "2.0".to_string(),
953            method: "tools/call".to_string(),
954            params: json!({
955                "name": "nonexistent_tool",
956                "arguments": {}
957            }),
958            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(4)),
959        };
960
961        let response = handler.handle_request(request).await.unwrap();
962
963        assert_eq!(response.jsonrpc, "2.0");
964        assert_eq!(
965            response.id,
966            Some(pulseengine_mcp_protocol::NumberOrString::Number(4))
967        );
968        assert!(response.result.is_none());
969        assert!(response.error.is_some());
970    }
971
972    #[tokio::test]
973    async fn test_handle_list_resources() {
974        let handler = create_test_handler().await;
975        let request = Request {
976            jsonrpc: "2.0".to_string(),
977            method: "resources/list".to_string(),
978            params: json!({}),
979            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(5)),
980        };
981
982        let response = handler.handle_request(request).await.unwrap();
983
984        assert_eq!(response.jsonrpc, "2.0");
985        assert_eq!(
986            response.id,
987            Some(pulseengine_mcp_protocol::NumberOrString::Number(5))
988        );
989        assert!(response.result.is_some());
990        assert!(response.error.is_none());
991
992        let result: ListResourcesResult = serde_json::from_value(response.result.unwrap()).unwrap();
993        assert_eq!(result.resources.len(), 1);
994        assert_eq!(result.resources[0].uri, "test://resource1");
995    }
996
997    #[tokio::test]
998    async fn test_handle_read_resource() {
999        let handler = create_test_handler().await;
1000        let request = Request {
1001            jsonrpc: "2.0".to_string(),
1002            method: "resources/read".to_string(),
1003            params: json!({
1004                "uri": "test://resource1"
1005            }),
1006            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(6)),
1007        };
1008
1009        let response = handler.handle_request(request).await.unwrap();
1010
1011        assert_eq!(response.jsonrpc, "2.0");
1012        assert_eq!(
1013            response.id,
1014            Some(pulseengine_mcp_protocol::NumberOrString::Number(6))
1015        );
1016        assert!(response.result.is_some());
1017        assert!(response.error.is_none());
1018
1019        let result: ReadResourceResult = serde_json::from_value(response.result.unwrap()).unwrap();
1020        assert_eq!(result.contents.len(), 1);
1021    }
1022
1023    #[tokio::test]
1024    async fn test_handle_list_prompts() {
1025        let handler = create_test_handler().await;
1026        let request = Request {
1027            jsonrpc: "2.0".to_string(),
1028            method: "prompts/list".to_string(),
1029            params: json!({}),
1030            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(7)),
1031        };
1032
1033        let response = handler.handle_request(request).await.unwrap();
1034
1035        assert_eq!(response.jsonrpc, "2.0");
1036        assert_eq!(
1037            response.id,
1038            Some(pulseengine_mcp_protocol::NumberOrString::Number(7))
1039        );
1040        assert!(response.result.is_some());
1041        assert!(response.error.is_none());
1042
1043        let result: ListPromptsResult = serde_json::from_value(response.result.unwrap()).unwrap();
1044        assert_eq!(result.prompts.len(), 1);
1045        assert_eq!(result.prompts[0].name, "test_prompt");
1046    }
1047
1048    #[tokio::test]
1049    async fn test_handle_get_prompt() {
1050        let handler = create_test_handler().await;
1051        let request = Request {
1052            jsonrpc: "2.0".to_string(),
1053            method: "prompts/get".to_string(),
1054            params: json!({
1055                "name": "test_prompt",
1056                "arguments": {}
1057            }),
1058            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(8)),
1059        };
1060
1061        let response = handler.handle_request(request).await.unwrap();
1062
1063        assert_eq!(response.jsonrpc, "2.0");
1064        assert_eq!(
1065            response.id,
1066            Some(pulseengine_mcp_protocol::NumberOrString::Number(8))
1067        );
1068        assert!(response.result.is_some());
1069        assert!(response.error.is_none());
1070
1071        let result: GetPromptResult = serde_json::from_value(response.result.unwrap()).unwrap();
1072        assert_eq!(result.messages.len(), 1);
1073    }
1074
1075    #[tokio::test]
1076    async fn test_handle_subscribe() {
1077        let handler = create_test_handler().await;
1078        let request = Request {
1079            jsonrpc: "2.0".to_string(),
1080            method: "resources/subscribe".to_string(),
1081            params: json!({
1082                "uri": "test://resource1"
1083            }),
1084            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(9)),
1085        };
1086
1087        let response = handler.handle_request(request).await.unwrap();
1088
1089        assert_eq!(response.jsonrpc, "2.0");
1090        assert_eq!(
1091            response.id,
1092            Some(pulseengine_mcp_protocol::NumberOrString::Number(9))
1093        );
1094        assert!(response.result.is_some());
1095        assert!(response.error.is_none());
1096    }
1097
1098    #[tokio::test]
1099    async fn test_handle_unsubscribe() {
1100        let handler = create_test_handler().await;
1101        let request = Request {
1102            jsonrpc: "2.0".to_string(),
1103            method: "resources/unsubscribe".to_string(),
1104            params: json!({
1105                "uri": "test://resource1"
1106            }),
1107            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(10)),
1108        };
1109
1110        let response = handler.handle_request(request).await.unwrap();
1111
1112        assert_eq!(response.jsonrpc, "2.0");
1113        assert_eq!(
1114            response.id,
1115            Some(pulseengine_mcp_protocol::NumberOrString::Number(10))
1116        );
1117        assert!(response.result.is_some());
1118        assert!(response.error.is_none());
1119    }
1120
1121    #[tokio::test]
1122    async fn test_handle_complete() {
1123        let handler = create_test_handler().await;
1124        let request = Request {
1125            jsonrpc: "2.0".to_string(),
1126            method: "completion/complete".to_string(),
1127            params: json!({
1128                "ref_": "test_prompt",
1129                "argument": {
1130                    "name": "query",
1131                    "value": "test"
1132                }
1133            }),
1134            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(11)),
1135        };
1136
1137        let response = handler.handle_request(request).await.unwrap();
1138
1139        assert_eq!(response.jsonrpc, "2.0");
1140        assert_eq!(
1141            response.id,
1142            Some(pulseengine_mcp_protocol::NumberOrString::Number(11))
1143        );
1144        assert!(response.result.is_some());
1145        assert!(response.error.is_none());
1146
1147        let result: CompleteResult = serde_json::from_value(response.result.unwrap()).unwrap();
1148        assert_eq!(result.completion.len(), 2);
1149    }
1150
1151    #[tokio::test]
1152    async fn test_handle_elicit() {
1153        let handler = create_test_handler().await;
1154        let request = Request {
1155            jsonrpc: "2.0".to_string(),
1156            method: "elicitation/create".to_string(),
1157            params: json!({
1158                "message": "Please provide your contact information",
1159                "requestedSchema": {
1160                    "type": "object",
1161                    "properties": {
1162                        "name": {"type": "string", "description": "Your full name"},
1163                        "email": {"type": "string", "format": "email"}
1164                    },
1165                    "required": ["name", "email"]
1166                }
1167            }),
1168            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(12)),
1169        };
1170
1171        let response = handler.handle_request(request).await.unwrap();
1172
1173        assert_eq!(response.jsonrpc, "2.0");
1174        assert_eq!(
1175            response.id,
1176            Some(pulseengine_mcp_protocol::NumberOrString::Number(12))
1177        );
1178        assert!(response.result.is_some());
1179        assert!(response.error.is_none());
1180
1181        let result: ElicitationResult = serde_json::from_value(response.result.unwrap()).unwrap();
1182        assert!(matches!(result.response.action, ElicitationAction::Accept));
1183        assert!(result.response.data.is_some());
1184    }
1185
1186    #[tokio::test]
1187    async fn test_handle_ping() {
1188        let handler = create_test_handler().await;
1189        let request = Request {
1190            jsonrpc: "2.0".to_string(),
1191            method: "ping".to_string(),
1192            params: json!({}),
1193            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(13)),
1194        };
1195
1196        let response = handler.handle_request(request).await.unwrap();
1197
1198        assert_eq!(response.jsonrpc, "2.0");
1199        assert_eq!(
1200            response.id,
1201            Some(pulseengine_mcp_protocol::NumberOrString::Number(13))
1202        );
1203        assert!(response.result.is_some());
1204        assert!(response.error.is_none());
1205    }
1206
1207    #[tokio::test]
1208    async fn test_handle_custom_method() {
1209        let handler = create_test_handler().await;
1210        let request = Request {
1211            jsonrpc: "2.0".to_string(),
1212            method: "custom/method".to_string(),
1213            params: json!({"test": "data"}),
1214            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(14)),
1215        };
1216
1217        let response = handler.handle_request(request).await.unwrap();
1218
1219        assert_eq!(response.jsonrpc, "2.0");
1220        assert_eq!(
1221            response.id,
1222            Some(pulseengine_mcp_protocol::NumberOrString::Number(14))
1223        );
1224        assert!(response.result.is_some());
1225        assert!(response.error.is_none());
1226
1227        let result = response.result.unwrap();
1228        assert_eq!(result["method"], "custom/method");
1229    }
1230
1231    #[tokio::test]
1232    async fn test_backend_error_handling() {
1233        let handler = create_error_handler().await;
1234        let request = Request {
1235            jsonrpc: "2.0".to_string(),
1236            method: "tools/list".to_string(),
1237            params: json!({}),
1238            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(15)),
1239        };
1240
1241        let response = handler.handle_request(request).await.unwrap();
1242
1243        assert_eq!(response.jsonrpc, "2.0");
1244        assert_eq!(
1245            response.id,
1246            Some(pulseengine_mcp_protocol::NumberOrString::Number(15))
1247        );
1248        assert!(response.result.is_none());
1249        assert!(response.error.is_some());
1250
1251        let error = response.error.unwrap();
1252        assert!(error.message.contains("Simulated error"));
1253    }
1254
1255    #[tokio::test]
1256    async fn test_invalid_params() {
1257        let handler = create_test_handler().await;
1258        let request = Request {
1259            jsonrpc: "2.0".to_string(),
1260            method: "tools/call".to_string(),
1261            params: json!("invalid"), // Should be an object
1262            id: Some(pulseengine_mcp_protocol::NumberOrString::Number(16)),
1263        };
1264
1265        let response = handler.handle_request(request).await.unwrap();
1266
1267        assert_eq!(response.jsonrpc, "2.0");
1268        assert_eq!(
1269            response.id,
1270            Some(pulseengine_mcp_protocol::NumberOrString::Number(16))
1271        );
1272        assert!(response.result.is_none());
1273        assert!(response.error.is_some());
1274    }
1275
1276    #[test]
1277    fn test_handler_error_classification() {
1278        let auth_error = HandlerError::Authentication("Invalid token".to_string());
1279        assert_eq!(auth_error.error_type(), "authentication");
1280        assert!(!auth_error.is_retryable());
1281        assert!(!auth_error.is_timeout());
1282        assert!(auth_error.is_auth_error());
1283        assert!(!auth_error.is_connection_error());
1284
1285        let backend_error = HandlerError::Backend("Database error".to_string());
1286        assert_eq!(backend_error.error_type(), "backend");
1287        assert!(backend_error.is_retryable());
1288        assert!(!backend_error.is_timeout());
1289        assert!(!backend_error.is_auth_error());
1290        assert!(!backend_error.is_connection_error());
1291
1292        let protocol_error =
1293            HandlerError::Protocol(Error::invalid_request("Bad request".to_string()));
1294        assert_eq!(protocol_error.error_type(), "protocol");
1295        assert!(!protocol_error.is_retryable());
1296        assert!(!protocol_error.is_timeout());
1297        assert!(!protocol_error.is_auth_error());
1298        assert!(!protocol_error.is_connection_error());
1299    }
1300
1301    #[test]
1302    fn test_handler_error_conversion() {
1303        let auth_error = HandlerError::Authentication("Auth failed".to_string());
1304        let protocol_error: Error = auth_error.into();
1305        assert_eq!(protocol_error.code, ErrorCode::Unauthorized);
1306
1307        let backend_error = HandlerError::Backend("Backend failed".to_string());
1308        let protocol_error: Error = backend_error.into();
1309        assert_eq!(protocol_error.code, ErrorCode::InternalError);
1310    }
1311
1312    #[test]
1313    fn test_handler_error_display() {
1314        let error = HandlerError::Authentication("Test auth error".to_string());
1315        assert_eq!(error.to_string(), "Authentication failed: Test auth error");
1316
1317        let error = HandlerError::Authorization("Test auth error".to_string());
1318        assert_eq!(error.to_string(), "Authorization failed: Test auth error");
1319
1320        let error = HandlerError::Backend("Test backend error".to_string());
1321        assert_eq!(error.to_string(), "Backend error: Test backend error");
1322    }
1323}