pulseengine_mcp_server/
handler.rs

1//! Generic request handler for MCP protocol
2
3use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_logging::{get_metrics, spans};
6use pulseengine_mcp_protocol::*;
7
8use std::sync::Arc;
9use std::time::Instant;
10use thiserror::Error;
11use tracing::{debug, error, info, instrument};
12
13/// Error type for handler operations
14#[derive(Debug, Error)]
15pub enum HandlerError {
16    #[error("Authentication failed: {0}")]
17    Authentication(String),
18
19    #[error("Authorization failed: {0}")]
20    Authorization(String),
21
22    #[error("Backend error: {0}")]
23    Backend(String),
24
25    #[error("Protocol error: {0}")]
26    Protocol(#[from] Error),
27}
28
29// Implement ErrorClassification for HandlerError
30impl pulseengine_mcp_logging::ErrorClassification for HandlerError {
31    fn error_type(&self) -> &str {
32        match self {
33            HandlerError::Authentication(_) => "authentication",
34            HandlerError::Authorization(_) => "authorization",
35            HandlerError::Backend(_) => "backend",
36            HandlerError::Protocol(_) => "protocol",
37        }
38    }
39
40    fn is_retryable(&self) -> bool {
41        match self {
42            HandlerError::Backend(_) => true, // Backend errors might be temporary
43            _ => false,
44        }
45    }
46
47    fn is_timeout(&self) -> bool {
48        false // HandlerError doesn't represent timeouts directly
49    }
50
51    fn is_auth_error(&self) -> bool {
52        matches!(
53            self,
54            HandlerError::Authentication(_) | HandlerError::Authorization(_)
55        )
56    }
57
58    fn is_connection_error(&self) -> bool {
59        false // HandlerError doesn't represent connection errors directly
60    }
61}
62
63/// Generic server handler that implements the MCP protocol
64#[derive(Clone)]
65pub struct GenericServerHandler<B: McpBackend> {
66    backend: Arc<B>,
67    #[allow(dead_code)]
68    auth_manager: Arc<AuthenticationManager>,
69    middleware: MiddlewareStack,
70}
71
72impl<B: McpBackend> GenericServerHandler<B> {
73    /// Create a new handler
74    pub fn new(
75        backend: Arc<B>,
76        auth_manager: Arc<AuthenticationManager>,
77        middleware: MiddlewareStack,
78    ) -> Self {
79        Self {
80            backend,
81            auth_manager,
82            middleware,
83        }
84    }
85
86    /// Handle an MCP request
87    #[instrument(skip(self, request), fields(mcp.method = %request.method, mcp.request_id = %request.id))]
88    pub async fn handle_request(
89        &self,
90        request: Request,
91    ) -> std::result::Result<Response, HandlerError> {
92        let start_time = Instant::now();
93        let method = request.method.clone();
94        debug!("Handling request: {}", method);
95
96        // Store request ID before moving request
97        let request_id = request.id.clone();
98
99        // Create request context
100        let context = RequestContext::new();
101
102        // Get metrics collector
103        let metrics = get_metrics();
104
105        // Record request start
106        metrics.record_request_start(&method).await;
107
108        // Apply middleware
109        let request = self.middleware.process_request(request, &context).await?;
110
111        // Route to appropriate handler with tracing
112        let result = {
113            let span = spans::mcp_request_span(&method, &request_id.to_string());
114            let _guard = span.enter();
115
116            match request.method.as_str() {
117                "initialize" => self.handle_initialize(request).await,
118                "tools/list" => self.handle_list_tools(request).await,
119                "tools/call" => self.handle_call_tool(request).await,
120                "resources/list" => self.handle_list_resources(request).await,
121                "resources/read" => self.handle_read_resource(request).await,
122                "resources/templates/list" => self.handle_list_resource_templates(request).await,
123                "prompts/list" => self.handle_list_prompts(request).await,
124                "prompts/get" => self.handle_get_prompt(request).await,
125                "resources/subscribe" => self.handle_subscribe(request).await,
126                "resources/unsubscribe" => self.handle_unsubscribe(request).await,
127                "completion/complete" => self.handle_complete(request).await,
128                "elicitation/create" => self.handle_elicit(request).await,
129                "logging/setLevel" => self.handle_set_level(request).await,
130                "ping" => self.handle_ping(request).await,
131                _ => self.handle_custom_method(request).await,
132            }
133        };
134
135        // Calculate request duration
136        let duration = start_time.elapsed();
137
138        match result {
139            Ok(response) => {
140                // Record successful request
141                metrics.record_request_end(&method, duration, true).await;
142
143                // Apply response middleware
144                let response = self.middleware.process_response(response, &context).await?;
145
146                info!(
147                    method = %method,
148                    duration_ms = %duration.as_millis(),
149                    request_id = ?request_id,
150                    "Request completed successfully"
151                );
152
153                Ok(response)
154            }
155            Err(error) => {
156                // Record failed request
157                metrics.record_request_end(&method, duration, false).await;
158
159                // Record error details
160                metrics
161                    .record_error(&method, &context.request_id.to_string(), &error, duration)
162                    .await;
163
164                error!(
165                    method = %method,
166                    duration_ms = %duration.as_millis(),
167                    request_id = ?request_id,
168                    error = %error,
169                    "Request failed"
170                );
171
172                Ok(Response {
173                    jsonrpc: "2.0".to_string(),
174                    id: request_id,
175                    result: None,
176                    error: Some(error),
177                })
178            }
179        }
180    }
181
182    #[instrument(skip(self, request), fields(mcp.method = "initialize"))]
183    async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
184        let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
185
186        let server_info = self.backend.get_server_info();
187        let result = InitializeResult {
188            protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
189            capabilities: server_info.capabilities,
190            server_info: server_info.server_info.clone(),
191            instructions: server_info.instructions,
192        };
193
194        Ok(Response {
195            jsonrpc: "2.0".to_string(),
196            id: request.id,
197            result: Some(serde_json::to_value(result)?),
198            error: None,
199        })
200    }
201
202    #[instrument(skip(self, request), fields(mcp.method = "tools/list"))]
203    async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
204        let params: PaginatedRequestParam = if request.params.is_null() {
205            PaginatedRequestParam { cursor: None }
206        } else {
207            serde_json::from_value(request.params)?
208        };
209
210        let result = self
211            .backend
212            .list_tools(params)
213            .await
214            .map_err(|e| e.into())?;
215
216        Ok(Response {
217            jsonrpc: "2.0".to_string(),
218            id: request.id,
219            result: Some(serde_json::to_value(result)?),
220            error: None,
221        })
222    }
223
224    #[instrument(skip(self, request), fields(mcp.method = "tools/call"))]
225    async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
226        let params: CallToolRequestParam = serde_json::from_value(request.params)?;
227        let tool_name = params.name.clone();
228        let start_time = Instant::now();
229
230        // Get metrics collector for tool-specific tracking
231        let metrics = get_metrics();
232        metrics.record_request_start(&tool_name).await;
233
234        let result = {
235            let span = spans::backend_operation_span("call_tool", Some(&tool_name));
236            let _guard = span.enter();
237            match self.backend.call_tool(params).await {
238                Ok(result) => {
239                    let duration = start_time.elapsed();
240                    metrics.record_request_end(&tool_name, duration, true).await;
241                    info!(
242                        tool = %tool_name,
243                        duration_ms = %duration.as_millis(),
244                        "Tool call completed successfully"
245                    );
246                    result
247                }
248                Err(err) => {
249                    let duration = start_time.elapsed();
250                    metrics
251                        .record_request_end(&tool_name, duration, false)
252                        .await;
253                    error!(
254                        tool = %tool_name,
255                        duration_ms = %duration.as_millis(),
256                        error = %err,
257                        "Tool call failed"
258                    );
259                    return Err(err.into());
260                }
261            }
262        };
263
264        Ok(Response {
265            jsonrpc: "2.0".to_string(),
266            id: request.id,
267            result: Some(serde_json::to_value(result)?),
268            error: None,
269        })
270    }
271
272    async fn handle_list_resources(
273        &self,
274        request: Request,
275    ) -> std::result::Result<Response, Error> {
276        let params: PaginatedRequestParam = if request.params.is_null() {
277            PaginatedRequestParam { cursor: None }
278        } else {
279            serde_json::from_value(request.params)?
280        };
281
282        let result = self
283            .backend
284            .list_resources(params)
285            .await
286            .map_err(|e| e.into())?;
287
288        Ok(Response {
289            jsonrpc: "2.0".to_string(),
290            id: request.id,
291            result: Some(serde_json::to_value(result)?),
292            error: None,
293        })
294    }
295
296    async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
297        let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
298
299        let result = self
300            .backend
301            .read_resource(params)
302            .await
303            .map_err(|e| e.into())?;
304
305        Ok(Response {
306            jsonrpc: "2.0".to_string(),
307            id: request.id,
308            result: Some(serde_json::to_value(result)?),
309            error: None,
310        })
311    }
312
313    async fn handle_list_resource_templates(
314        &self,
315        request: Request,
316    ) -> std::result::Result<Response, Error> {
317        let params: PaginatedRequestParam = if request.params.is_null() {
318            PaginatedRequestParam { cursor: None }
319        } else {
320            serde_json::from_value(request.params)?
321        };
322
323        let result = self
324            .backend
325            .list_resource_templates(params)
326            .await
327            .map_err(|e| e.into())?;
328
329        Ok(Response {
330            jsonrpc: "2.0".to_string(),
331            id: request.id,
332            result: Some(serde_json::to_value(result)?),
333            error: None,
334        })
335    }
336
337    async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
338        let params: PaginatedRequestParam = if request.params.is_null() {
339            PaginatedRequestParam { cursor: None }
340        } else {
341            serde_json::from_value(request.params)?
342        };
343
344        let result = self
345            .backend
346            .list_prompts(params)
347            .await
348            .map_err(|e| e.into())?;
349
350        Ok(Response {
351            jsonrpc: "2.0".to_string(),
352            id: request.id,
353            result: Some(serde_json::to_value(result)?),
354            error: None,
355        })
356    }
357
358    async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
359        let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
360
361        let result = self
362            .backend
363            .get_prompt(params)
364            .await
365            .map_err(|e| e.into())?;
366
367        Ok(Response {
368            jsonrpc: "2.0".to_string(),
369            id: request.id,
370            result: Some(serde_json::to_value(result)?),
371            error: None,
372        })
373    }
374
375    async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
376        let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
377
378        self.backend.subscribe(params).await.map_err(|e| e.into())?;
379
380        Ok(Response {
381            jsonrpc: "2.0".to_string(),
382            id: request.id,
383            result: Some(serde_json::Value::Object(Default::default())),
384            error: None,
385        })
386    }
387
388    async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
389        let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
390
391        self.backend
392            .unsubscribe(params)
393            .await
394            .map_err(|e| e.into())?;
395
396        Ok(Response {
397            jsonrpc: "2.0".to_string(),
398            id: request.id,
399            result: Some(serde_json::Value::Object(Default::default())),
400            error: None,
401        })
402    }
403
404    async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
405        let params: CompleteRequestParam = serde_json::from_value(request.params)?;
406
407        let result = self.backend.complete(params).await.map_err(|e| e.into())?;
408
409        Ok(Response {
410            jsonrpc: "2.0".to_string(),
411            id: request.id,
412            result: Some(serde_json::to_value(result)?),
413            error: None,
414        })
415    }
416
417    async fn handle_elicit(&self, request: Request) -> std::result::Result<Response, Error> {
418        let params: ElicitationRequestParam = serde_json::from_value(request.params)?;
419
420        let result = self.backend.elicit(params).await.map_err(|e| e.into())?;
421
422        Ok(Response {
423            jsonrpc: "2.0".to_string(),
424            id: request.id,
425            result: Some(serde_json::to_value(result)?),
426            error: None,
427        })
428    }
429
430    async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
431        let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
432
433        self.backend.set_level(params).await.map_err(|e| e.into())?;
434
435        Ok(Response {
436            jsonrpc: "2.0".to_string(),
437            id: request.id,
438            result: Some(serde_json::Value::Object(Default::default())),
439            error: None,
440        })
441    }
442
443    async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
444        Ok(Response {
445            jsonrpc: "2.0".to_string(),
446            id: _request.id,
447            result: Some(serde_json::Value::Object(Default::default())),
448            error: None,
449        })
450    }
451
452    async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
453        let result = self
454            .backend
455            .handle_custom_method(&request.method, request.params)
456            .await
457            .map_err(|e| e.into())?;
458
459        Ok(Response {
460            jsonrpc: "2.0".to_string(),
461            id: request.id,
462            result: Some(result),
463            error: None,
464        })
465    }
466}
467
468// Convert HandlerError to protocol Error
469impl From<HandlerError> for Error {
470    fn from(err: HandlerError) -> Self {
471        match err {
472            HandlerError::Authentication(msg) => Error::unauthorized(msg),
473            HandlerError::Authorization(msg) => Error::forbidden(msg),
474            HandlerError::Backend(msg) => Error::internal_error(msg),
475            HandlerError::Protocol(e) => e,
476        }
477    }
478}
479
480#[cfg(test)]
481mod tests {
482    use super::*;
483    use crate::backend::McpBackend;
484    use crate::middleware::MiddlewareStack;
485    use async_trait::async_trait;
486    use pulseengine_mcp_auth::AuthenticationManager;
487    use pulseengine_mcp_auth::config::AuthConfig;
488    use pulseengine_mcp_logging::ErrorClassification;
489    use pulseengine_mcp_protocol::{
490        CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult, CompletionInfo,
491        Content, Error, GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult,
492        ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, ListToolsResult,
493        LoggingCapability, PaginatedRequestParam, Prompt, PromptMessage, PromptMessageContent,
494        PromptMessageRole, PromptsCapability, ProtocolVersion, ReadResourceRequestParam,
495        ReadResourceResult, Request, Resource, ResourceContents, ResourcesCapability,
496        ServerCapabilities, ServerInfo, SetLevelRequestParam, SubscribeRequestParam, Tool,
497        ToolsCapability, UnsubscribeRequestParam, error::ErrorCode,
498    };
499    use serde_json::json;
500    use std::sync::Arc;
501
502    // Mock backend for testing
503    #[derive(Clone)]
504    struct MockBackend {
505        server_info: ServerInfo,
506        tools: Vec<Tool>,
507        resources: Vec<Resource>,
508        prompts: Vec<Prompt>,
509        should_error: bool,
510    }
511
512    impl MockBackend {
513        fn new() -> Self {
514            Self {
515                server_info: ServerInfo {
516                    protocol_version: ProtocolVersion::default(),
517                    capabilities: ServerCapabilities {
518                        tools: Some(ToolsCapability { list_changed: None }),
519                        resources: Some(ResourcesCapability {
520                            subscribe: Some(true),
521                            list_changed: None,
522                        }),
523                        prompts: Some(PromptsCapability { list_changed: None }),
524                        logging: Some(LoggingCapability { level: None }),
525                        sampling: None,
526                        elicitation: Some(ElicitationCapability {}),
527                    },
528                    server_info: Implementation {
529                        name: "test-server".to_string(),
530                        version: "1.0.0".to_string(),
531                    },
532                    instructions: None,
533                },
534                tools: vec![Tool {
535                    name: "test_tool".to_string(),
536                    description: "A test tool".to_string(),
537                    input_schema: json!({
538                        "type": "object",
539                        "properties": {
540                            "input": {"type": "string"}
541                        }
542                    }),
543                    output_schema: None,
544                }],
545                resources: vec![Resource {
546                    uri: "test://resource1".to_string(),
547                    name: "Test Resource".to_string(),
548                    description: Some("A test resource".to_string()),
549                    mime_type: Some("text/plain".to_string()),
550                    annotations: None,
551                    raw: None,
552                }],
553                prompts: vec![Prompt {
554                    name: "test_prompt".to_string(),
555                    description: Some("A test prompt".to_string()),
556                    arguments: None,
557                }],
558                should_error: false,
559            }
560        }
561
562        fn with_error() -> Self {
563            Self {
564                should_error: true,
565                ..Self::new()
566            }
567        }
568    }
569
570    #[async_trait]
571    impl McpBackend for MockBackend {
572        type Error = MockBackendError;
573        type Config = ();
574
575        async fn initialize(_config: Self::Config) -> std::result::Result<Self, Self::Error> {
576            Ok(MockBackend::new())
577        }
578
579        fn get_server_info(&self) -> ServerInfo {
580            self.server_info.clone()
581        }
582
583        async fn health_check(&self) -> std::result::Result<(), Self::Error> {
584            if self.should_error {
585                return Err(MockBackendError::TestError(
586                    "Health check failed".to_string(),
587                ));
588            }
589            Ok(())
590        }
591
592        async fn list_tools(
593            &self,
594            _params: PaginatedRequestParam,
595        ) -> std::result::Result<ListToolsResult, Self::Error> {
596            if self.should_error {
597                return Err(MockBackendError::TestError("Simulated error".to_string()));
598            }
599
600            Ok(ListToolsResult {
601                tools: self.tools.clone(),
602                next_cursor: None,
603            })
604        }
605
606        async fn call_tool(
607            &self,
608            params: CallToolRequestParam,
609        ) -> std::result::Result<CallToolResult, Self::Error> {
610            if self.should_error {
611                return Err(MockBackendError::TestError("Tool call failed".to_string()));
612            }
613
614            if params.name == "test_tool" {
615                Ok(CallToolResult {
616                    content: vec![Content::Text {
617                        text: "Tool executed successfully".to_string(),
618                    }],
619                    is_error: Some(false),
620                    structured_content: None,
621                })
622            } else {
623                Err(MockBackendError::TestError("Tool not found".to_string()))
624            }
625        }
626
627        async fn list_resources(
628            &self,
629            _params: PaginatedRequestParam,
630        ) -> std::result::Result<ListResourcesResult, Self::Error> {
631            if self.should_error {
632                return Err(MockBackendError::TestError("Simulated error".to_string()));
633            }
634
635            Ok(ListResourcesResult {
636                resources: self.resources.clone(),
637                next_cursor: None,
638            })
639        }
640
641        async fn read_resource(
642            &self,
643            params: ReadResourceRequestParam,
644        ) -> std::result::Result<ReadResourceResult, Self::Error> {
645            if self.should_error {
646                return Err(MockBackendError::TestError("Simulated error".to_string()));
647            }
648
649            if params.uri == "test://resource1" {
650                Ok(ReadResourceResult {
651                    contents: vec![ResourceContents {
652                        uri: params.uri,
653                        mime_type: Some("text/plain".to_string()),
654                        text: Some("Resource content".to_string()),
655                        blob: None,
656                    }],
657                })
658            } else {
659                Err(MockBackendError::TestError(
660                    "Resource not found".to_string(),
661                ))
662            }
663        }
664
665        async fn list_resource_templates(
666            &self,
667            _params: PaginatedRequestParam,
668        ) -> std::result::Result<ListResourceTemplatesResult, Self::Error> {
669            Ok(ListResourceTemplatesResult {
670                resource_templates: vec![],
671                next_cursor: None,
672            })
673        }
674
675        async fn list_prompts(
676            &self,
677            _params: PaginatedRequestParam,
678        ) -> std::result::Result<ListPromptsResult, Self::Error> {
679            if self.should_error {
680                return Err(MockBackendError::TestError("Simulated error".to_string()));
681            }
682
683            Ok(ListPromptsResult {
684                prompts: self.prompts.clone(),
685                next_cursor: None,
686            })
687        }
688
689        async fn get_prompt(
690            &self,
691            params: GetPromptRequestParam,
692        ) -> std::result::Result<GetPromptResult, Self::Error> {
693            if self.should_error {
694                return Err(MockBackendError::TestError("Simulated error".to_string()));
695            }
696
697            if params.name == "test_prompt" {
698                Ok(GetPromptResult {
699                    description: Some("A test prompt".to_string()),
700                    messages: vec![PromptMessage {
701                        role: PromptMessageRole::User,
702                        content: PromptMessageContent::Text {
703                            text: "Test prompt message".to_string(),
704                        },
705                    }],
706                })
707            } else {
708                Err(MockBackendError::TestError("Prompt not found".to_string()))
709            }
710        }
711
712        async fn subscribe(
713            &self,
714            _params: SubscribeRequestParam,
715        ) -> std::result::Result<(), Self::Error> {
716            if self.should_error {
717                return Err(MockBackendError::TestError("Subscribe failed".to_string()));
718            }
719            Ok(())
720        }
721
722        async fn unsubscribe(
723            &self,
724            _params: UnsubscribeRequestParam,
725        ) -> std::result::Result<(), Self::Error> {
726            if self.should_error {
727                return Err(MockBackendError::TestError(
728                    "Unsubscribe failed".to_string(),
729                ));
730            }
731            Ok(())
732        }
733
734        async fn complete(
735            &self,
736            _params: CompleteRequestParam,
737        ) -> std::result::Result<CompleteResult, Self::Error> {
738            if self.should_error {
739                return Err(MockBackendError::TestError("Complete failed".to_string()));
740            }
741
742            Ok(CompleteResult {
743                completion: vec![
744                    CompletionInfo {
745                        completion: "completion1".to_string(),
746                        has_more: Some(false),
747                    },
748                    CompletionInfo {
749                        completion: "completion2".to_string(),
750                        has_more: Some(false),
751                    },
752                ],
753            })
754        }
755
756        async fn elicit(
757            &self,
758            _params: ElicitationRequestParam,
759        ) -> std::result::Result<ElicitationResult, Self::Error> {
760            if self.should_error {
761                return Err(MockBackendError::TestError(
762                    "Elicitation failed".to_string(),
763                ));
764            }
765
766            // Simulate user accepting with sample data
767            Ok(ElicitationResult::accept(serde_json::json!({
768                "name": "Test User",
769                "email": "test@example.com"
770            })))
771        }
772
773        async fn set_level(
774            &self,
775            _params: SetLevelRequestParam,
776        ) -> std::result::Result<(), Self::Error> {
777            if self.should_error {
778                return Err(MockBackendError::TestError("Set level failed".to_string()));
779            }
780            Ok(())
781        }
782
783        async fn handle_custom_method(
784            &self,
785            method: &str,
786            _params: serde_json::Value,
787        ) -> std::result::Result<serde_json::Value, Self::Error> {
788            if self.should_error {
789                return Err(MockBackendError::TestError(
790                    "Custom method failed".to_string(),
791                ));
792            }
793
794            Ok(json!({
795                "method": method,
796                "result": "custom method executed"
797            }))
798        }
799    }
800
801    #[derive(Debug, thiserror::Error)]
802    enum MockBackendError {
803        #[error("Test error: {0}")]
804        TestError(String),
805    }
806
807    impl From<MockBackendError> for Error {
808        fn from(err: MockBackendError) -> Self {
809            Error::internal_error(err.to_string())
810        }
811    }
812
813    impl From<crate::backend::BackendError> for MockBackendError {
814        fn from(error: crate::backend::BackendError) -> Self {
815            MockBackendError::TestError(error.to_string())
816        }
817    }
818
819    async fn create_test_handler() -> GenericServerHandler<MockBackend> {
820        let backend = Arc::new(MockBackend::new());
821        let auth_config = AuthConfig::memory();
822        let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
823        let middleware = MiddlewareStack::new();
824
825        GenericServerHandler::new(backend, auth_manager, middleware)
826    }
827
828    async fn create_error_handler() -> GenericServerHandler<MockBackend> {
829        let backend = Arc::new(MockBackend::with_error());
830        let auth_config = AuthConfig::memory();
831        let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
832        let middleware = MiddlewareStack::new();
833
834        GenericServerHandler::new(backend, auth_manager, middleware)
835    }
836
837    #[tokio::test]
838    async fn test_handler_creation() {
839        let handler = create_test_handler().await;
840        // Just verify the handler can be created
841        assert!(!handler.backend.tools.is_empty());
842    }
843
844    #[tokio::test]
845    async fn test_handle_initialize() {
846        let handler = create_test_handler().await;
847        let request = Request {
848            jsonrpc: "2.0".to_string(),
849            method: "initialize".to_string(),
850            params: json!({
851                "protocolVersion": "2024-11-05",
852                "capabilities": {},
853                "clientInfo": {
854                    "name": "test-client",
855                    "version": "1.0.0"
856                }
857            }),
858            id: json!(1),
859        };
860
861        let response = handler.handle_request(request).await.unwrap();
862
863        assert_eq!(response.jsonrpc, "2.0");
864        assert_eq!(response.id, json!(1));
865        assert!(response.result.is_some());
866        assert!(response.error.is_none());
867
868        let result: InitializeResult = serde_json::from_value(response.result.unwrap()).unwrap();
869        assert_eq!(
870            result.protocol_version,
871            pulseengine_mcp_protocol::MCP_VERSION
872        );
873        assert_eq!(result.server_info.name, "test-server");
874    }
875
876    #[tokio::test]
877    async fn test_handle_list_tools() {
878        let handler = create_test_handler().await;
879        let request = Request {
880            jsonrpc: "2.0".to_string(),
881            method: "tools/list".to_string(),
882            params: json!({}),
883            id: json!(2),
884        };
885
886        let response = handler.handle_request(request).await.unwrap();
887
888        assert_eq!(response.jsonrpc, "2.0");
889        assert_eq!(response.id, json!(2));
890        assert!(response.result.is_some());
891        assert!(response.error.is_none());
892
893        let result: ListToolsResult = serde_json::from_value(response.result.unwrap()).unwrap();
894        assert_eq!(result.tools.len(), 1);
895        assert_eq!(result.tools[0].name, "test_tool");
896    }
897
898    #[tokio::test]
899    async fn test_handle_call_tool_success() {
900        let handler = create_test_handler().await;
901        let request = Request {
902            jsonrpc: "2.0".to_string(),
903            method: "tools/call".to_string(),
904            params: json!({
905                "name": "test_tool",
906                "arguments": {
907                    "input": "test input"
908                }
909            }),
910            id: json!(3),
911        };
912
913        let response = handler.handle_request(request).await.unwrap();
914
915        assert_eq!(response.jsonrpc, "2.0");
916        assert_eq!(response.id, json!(3));
917        assert!(response.result.is_some());
918        assert!(response.error.is_none());
919
920        let result: CallToolResult = serde_json::from_value(response.result.unwrap()).unwrap();
921        assert_eq!(result.content.len(), 1);
922        assert!(!result.is_error.unwrap_or(true));
923    }
924
925    #[tokio::test]
926    async fn test_handle_call_tool_not_found() {
927        let handler = create_test_handler().await;
928        let request = Request {
929            jsonrpc: "2.0".to_string(),
930            method: "tools/call".to_string(),
931            params: json!({
932                "name": "nonexistent_tool",
933                "arguments": {}
934            }),
935            id: json!(4),
936        };
937
938        let response = handler.handle_request(request).await.unwrap();
939
940        assert_eq!(response.jsonrpc, "2.0");
941        assert_eq!(response.id, json!(4));
942        assert!(response.result.is_none());
943        assert!(response.error.is_some());
944    }
945
946    #[tokio::test]
947    async fn test_handle_list_resources() {
948        let handler = create_test_handler().await;
949        let request = Request {
950            jsonrpc: "2.0".to_string(),
951            method: "resources/list".to_string(),
952            params: json!({}),
953            id: json!(5),
954        };
955
956        let response = handler.handle_request(request).await.unwrap();
957
958        assert_eq!(response.jsonrpc, "2.0");
959        assert_eq!(response.id, json!(5));
960        assert!(response.result.is_some());
961        assert!(response.error.is_none());
962
963        let result: ListResourcesResult = serde_json::from_value(response.result.unwrap()).unwrap();
964        assert_eq!(result.resources.len(), 1);
965        assert_eq!(result.resources[0].uri, "test://resource1");
966    }
967
968    #[tokio::test]
969    async fn test_handle_read_resource() {
970        let handler = create_test_handler().await;
971        let request = Request {
972            jsonrpc: "2.0".to_string(),
973            method: "resources/read".to_string(),
974            params: json!({
975                "uri": "test://resource1"
976            }),
977            id: json!(6),
978        };
979
980        let response = handler.handle_request(request).await.unwrap();
981
982        assert_eq!(response.jsonrpc, "2.0");
983        assert_eq!(response.id, json!(6));
984        assert!(response.result.is_some());
985        assert!(response.error.is_none());
986
987        let result: ReadResourceResult = serde_json::from_value(response.result.unwrap()).unwrap();
988        assert_eq!(result.contents.len(), 1);
989    }
990
991    #[tokio::test]
992    async fn test_handle_list_prompts() {
993        let handler = create_test_handler().await;
994        let request = Request {
995            jsonrpc: "2.0".to_string(),
996            method: "prompts/list".to_string(),
997            params: json!({}),
998            id: json!(7),
999        };
1000
1001        let response = handler.handle_request(request).await.unwrap();
1002
1003        assert_eq!(response.jsonrpc, "2.0");
1004        assert_eq!(response.id, json!(7));
1005        assert!(response.result.is_some());
1006        assert!(response.error.is_none());
1007
1008        let result: ListPromptsResult = serde_json::from_value(response.result.unwrap()).unwrap();
1009        assert_eq!(result.prompts.len(), 1);
1010        assert_eq!(result.prompts[0].name, "test_prompt");
1011    }
1012
1013    #[tokio::test]
1014    async fn test_handle_get_prompt() {
1015        let handler = create_test_handler().await;
1016        let request = Request {
1017            jsonrpc: "2.0".to_string(),
1018            method: "prompts/get".to_string(),
1019            params: json!({
1020                "name": "test_prompt",
1021                "arguments": {}
1022            }),
1023            id: json!(8),
1024        };
1025
1026        let response = handler.handle_request(request).await.unwrap();
1027
1028        assert_eq!(response.jsonrpc, "2.0");
1029        assert_eq!(response.id, json!(8));
1030        assert!(response.result.is_some());
1031        assert!(response.error.is_none());
1032
1033        let result: GetPromptResult = serde_json::from_value(response.result.unwrap()).unwrap();
1034        assert_eq!(result.messages.len(), 1);
1035    }
1036
1037    #[tokio::test]
1038    async fn test_handle_subscribe() {
1039        let handler = create_test_handler().await;
1040        let request = Request {
1041            jsonrpc: "2.0".to_string(),
1042            method: "resources/subscribe".to_string(),
1043            params: json!({
1044                "uri": "test://resource1"
1045            }),
1046            id: json!(9),
1047        };
1048
1049        let response = handler.handle_request(request).await.unwrap();
1050
1051        assert_eq!(response.jsonrpc, "2.0");
1052        assert_eq!(response.id, json!(9));
1053        assert!(response.result.is_some());
1054        assert!(response.error.is_none());
1055    }
1056
1057    #[tokio::test]
1058    async fn test_handle_unsubscribe() {
1059        let handler = create_test_handler().await;
1060        let request = Request {
1061            jsonrpc: "2.0".to_string(),
1062            method: "resources/unsubscribe".to_string(),
1063            params: json!({
1064                "uri": "test://resource1"
1065            }),
1066            id: json!(10),
1067        };
1068
1069        let response = handler.handle_request(request).await.unwrap();
1070
1071        assert_eq!(response.jsonrpc, "2.0");
1072        assert_eq!(response.id, json!(10));
1073        assert!(response.result.is_some());
1074        assert!(response.error.is_none());
1075    }
1076
1077    #[tokio::test]
1078    async fn test_handle_complete() {
1079        let handler = create_test_handler().await;
1080        let request = Request {
1081            jsonrpc: "2.0".to_string(),
1082            method: "completion/complete".to_string(),
1083            params: json!({
1084                "ref_": "test_prompt",
1085                "argument": {
1086                    "name": "query",
1087                    "value": "test"
1088                }
1089            }),
1090            id: json!(11),
1091        };
1092
1093        let response = handler.handle_request(request).await.unwrap();
1094
1095        assert_eq!(response.jsonrpc, "2.0");
1096        assert_eq!(response.id, json!(11));
1097        assert!(response.result.is_some());
1098        assert!(response.error.is_none());
1099
1100        let result: CompleteResult = serde_json::from_value(response.result.unwrap()).unwrap();
1101        assert_eq!(result.completion.len(), 2);
1102    }
1103
1104    #[tokio::test]
1105    async fn test_handle_elicit() {
1106        let handler = create_test_handler().await;
1107        let request = Request {
1108            jsonrpc: "2.0".to_string(),
1109            method: "elicitation/create".to_string(),
1110            params: json!({
1111                "message": "Please provide your contact information",
1112                "requestedSchema": {
1113                    "type": "object",
1114                    "properties": {
1115                        "name": {"type": "string", "description": "Your full name"},
1116                        "email": {"type": "string", "format": "email"}
1117                    },
1118                    "required": ["name", "email"]
1119                }
1120            }),
1121            id: json!(12),
1122        };
1123
1124        let response = handler.handle_request(request).await.unwrap();
1125
1126        assert_eq!(response.jsonrpc, "2.0");
1127        assert_eq!(response.id, json!(12));
1128        assert!(response.result.is_some());
1129        assert!(response.error.is_none());
1130
1131        let result: ElicitationResult = serde_json::from_value(response.result.unwrap()).unwrap();
1132        assert!(matches!(result.response.action, ElicitationAction::Accept));
1133        assert!(result.response.data.is_some());
1134    }
1135
1136    #[tokio::test]
1137    async fn test_handle_ping() {
1138        let handler = create_test_handler().await;
1139        let request = Request {
1140            jsonrpc: "2.0".to_string(),
1141            method: "ping".to_string(),
1142            params: json!({}),
1143            id: json!(12),
1144        };
1145
1146        let response = handler.handle_request(request).await.unwrap();
1147
1148        assert_eq!(response.jsonrpc, "2.0");
1149        assert_eq!(response.id, json!(12));
1150        assert!(response.result.is_some());
1151        assert!(response.error.is_none());
1152    }
1153
1154    #[tokio::test]
1155    async fn test_handle_custom_method() {
1156        let handler = create_test_handler().await;
1157        let request = Request {
1158            jsonrpc: "2.0".to_string(),
1159            method: "custom/method".to_string(),
1160            params: json!({"test": "data"}),
1161            id: json!(13),
1162        };
1163
1164        let response = handler.handle_request(request).await.unwrap();
1165
1166        assert_eq!(response.jsonrpc, "2.0");
1167        assert_eq!(response.id, json!(13));
1168        assert!(response.result.is_some());
1169        assert!(response.error.is_none());
1170
1171        let result = response.result.unwrap();
1172        assert_eq!(result["method"], "custom/method");
1173    }
1174
1175    #[tokio::test]
1176    async fn test_backend_error_handling() {
1177        let handler = create_error_handler().await;
1178        let request = Request {
1179            jsonrpc: "2.0".to_string(),
1180            method: "tools/list".to_string(),
1181            params: json!({}),
1182            id: json!(14),
1183        };
1184
1185        let response = handler.handle_request(request).await.unwrap();
1186
1187        assert_eq!(response.jsonrpc, "2.0");
1188        assert_eq!(response.id, json!(14));
1189        assert!(response.result.is_none());
1190        assert!(response.error.is_some());
1191
1192        let error = response.error.unwrap();
1193        assert!(error.message.contains("Simulated error"));
1194    }
1195
1196    #[tokio::test]
1197    async fn test_invalid_params() {
1198        let handler = create_test_handler().await;
1199        let request = Request {
1200            jsonrpc: "2.0".to_string(),
1201            method: "tools/call".to_string(),
1202            params: json!("invalid"), // Should be an object
1203            id: json!(15),
1204        };
1205
1206        let response = handler.handle_request(request).await.unwrap();
1207
1208        assert_eq!(response.jsonrpc, "2.0");
1209        assert_eq!(response.id, json!(15));
1210        assert!(response.result.is_none());
1211        assert!(response.error.is_some());
1212    }
1213
1214    #[test]
1215    fn test_handler_error_classification() {
1216        let auth_error = HandlerError::Authentication("Invalid token".to_string());
1217        assert_eq!(auth_error.error_type(), "authentication");
1218        assert!(!auth_error.is_retryable());
1219        assert!(!auth_error.is_timeout());
1220        assert!(auth_error.is_auth_error());
1221        assert!(!auth_error.is_connection_error());
1222
1223        let backend_error = HandlerError::Backend("Database error".to_string());
1224        assert_eq!(backend_error.error_type(), "backend");
1225        assert!(backend_error.is_retryable());
1226        assert!(!backend_error.is_timeout());
1227        assert!(!backend_error.is_auth_error());
1228        assert!(!backend_error.is_connection_error());
1229
1230        let protocol_error =
1231            HandlerError::Protocol(Error::invalid_request("Bad request".to_string()));
1232        assert_eq!(protocol_error.error_type(), "protocol");
1233        assert!(!protocol_error.is_retryable());
1234        assert!(!protocol_error.is_timeout());
1235        assert!(!protocol_error.is_auth_error());
1236        assert!(!protocol_error.is_connection_error());
1237    }
1238
1239    #[test]
1240    fn test_handler_error_conversion() {
1241        let auth_error = HandlerError::Authentication("Auth failed".to_string());
1242        let protocol_error: Error = auth_error.into();
1243        assert_eq!(protocol_error.code, ErrorCode::Unauthorized);
1244
1245        let backend_error = HandlerError::Backend("Backend failed".to_string());
1246        let protocol_error: Error = backend_error.into();
1247        assert_eq!(protocol_error.code, ErrorCode::InternalError);
1248    }
1249
1250    #[test]
1251    fn test_handler_error_display() {
1252        let error = HandlerError::Authentication("Test auth error".to_string());
1253        assert_eq!(error.to_string(), "Authentication failed: Test auth error");
1254
1255        let error = HandlerError::Authorization("Test auth error".to_string());
1256        assert_eq!(error.to_string(), "Authorization failed: Test auth error");
1257
1258        let error = HandlerError::Backend("Test backend error".to_string());
1259        assert_eq!(error.to_string(), "Backend error: Test backend error");
1260    }
1261}