1use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_logging::{get_metrics, spans};
6use pulseengine_mcp_protocol::*;
7
8use std::sync::Arc;
9use std::time::Instant;
10use thiserror::Error;
11use tracing::{debug, error, info, instrument};
12
13#[derive(Debug, Error)]
15pub enum HandlerError {
16    #[error("Authentication failed: {0}")]
17    Authentication(String),
18
19    #[error("Authorization failed: {0}")]
20    Authorization(String),
21
22    #[error("Backend error: {0}")]
23    Backend(String),
24
25    #[error("Protocol error: {0}")]
26    Protocol(#[from] Error),
27}
28
29impl pulseengine_mcp_logging::ErrorClassification for HandlerError {
31    fn error_type(&self) -> &str {
32        match self {
33            HandlerError::Authentication(_) => "authentication",
34            HandlerError::Authorization(_) => "authorization",
35            HandlerError::Backend(_) => "backend",
36            HandlerError::Protocol(_) => "protocol",
37        }
38    }
39
40    fn is_retryable(&self) -> bool {
41        match self {
42            HandlerError::Backend(_) => true, _ => false,
44        }
45    }
46
47    fn is_timeout(&self) -> bool {
48        false }
50
51    fn is_auth_error(&self) -> bool {
52        matches!(
53            self,
54            HandlerError::Authentication(_) | HandlerError::Authorization(_)
55        )
56    }
57
58    fn is_connection_error(&self) -> bool {
59        false }
61}
62
63#[derive(Clone)]
65pub struct GenericServerHandler<B: McpBackend> {
66    backend: Arc<B>,
67    #[allow(dead_code)]
68    auth_manager: Arc<AuthenticationManager>,
69    middleware: MiddlewareStack,
70}
71
72impl<B: McpBackend> GenericServerHandler<B> {
73    pub fn new(
75        backend: Arc<B>,
76        auth_manager: Arc<AuthenticationManager>,
77        middleware: MiddlewareStack,
78    ) -> Self {
79        Self {
80            backend,
81            auth_manager,
82            middleware,
83        }
84    }
85
86    #[instrument(skip(self, request), fields(mcp.method = %request.method, mcp.request_id = %request.id))]
88    pub async fn handle_request(
89        &self,
90        request: Request,
91    ) -> std::result::Result<Response, HandlerError> {
92        let start_time = Instant::now();
93        let method = request.method.clone();
94        debug!("Handling request: {}", method);
95
96        let request_id = request.id.clone();
98
99        let context = RequestContext::new();
101
102        let metrics = get_metrics();
104
105        metrics.record_request_start(&method).await;
107
108        let request = self.middleware.process_request(request, &context).await?;
110
111        let result = {
113            let span = spans::mcp_request_span(&method, &request_id.to_string());
114            let _guard = span.enter();
115
116            match request.method.as_str() {
117                "initialize" => self.handle_initialize(request).await,
118                "tools/list" => self.handle_list_tools(request).await,
119                "tools/call" => self.handle_call_tool(request).await,
120                "resources/list" => self.handle_list_resources(request).await,
121                "resources/read" => self.handle_read_resource(request).await,
122                "resources/templates/list" => self.handle_list_resource_templates(request).await,
123                "prompts/list" => self.handle_list_prompts(request).await,
124                "prompts/get" => self.handle_get_prompt(request).await,
125                "resources/subscribe" => self.handle_subscribe(request).await,
126                "resources/unsubscribe" => self.handle_unsubscribe(request).await,
127                "completion/complete" => self.handle_complete(request).await,
128                "elicitation/create" => self.handle_elicit(request).await,
129                "logging/setLevel" => self.handle_set_level(request).await,
130                "ping" => self.handle_ping(request).await,
131                _ => self.handle_custom_method(request).await,
132            }
133        };
134
135        let duration = start_time.elapsed();
137
138        match result {
139            Ok(response) => {
140                metrics.record_request_end(&method, duration, true).await;
142
143                let response = self.middleware.process_response(response, &context).await?;
145
146                info!(
147                    method = %method,
148                    duration_ms = %duration.as_millis(),
149                    request_id = ?request_id,
150                    "Request completed successfully"
151                );
152
153                Ok(response)
154            }
155            Err(error) => {
156                metrics.record_request_end(&method, duration, false).await;
158
159                metrics
161                    .record_error(&method, &context.request_id.to_string(), &error, duration)
162                    .await;
163
164                error!(
165                    method = %method,
166                    duration_ms = %duration.as_millis(),
167                    request_id = ?request_id,
168                    error = %error,
169                    "Request failed"
170                );
171
172                Ok(Response {
173                    jsonrpc: "2.0".to_string(),
174                    id: request_id,
175                    result: None,
176                    error: Some(error),
177                })
178            }
179        }
180    }
181
182    #[instrument(skip(self, request), fields(mcp.method = "initialize"))]
183    async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
184        let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
185
186        let server_info = self.backend.get_server_info();
187        let result = InitializeResult {
188            protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
189            capabilities: server_info.capabilities,
190            server_info: server_info.server_info.clone(),
191            instructions: server_info.instructions,
192        };
193
194        Ok(Response {
195            jsonrpc: "2.0".to_string(),
196            id: request.id,
197            result: Some(serde_json::to_value(result)?),
198            error: None,
199        })
200    }
201
202    #[instrument(skip(self, request), fields(mcp.method = "tools/list"))]
203    async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
204        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
205
206        let result = self
207            .backend
208            .list_tools(params)
209            .await
210            .map_err(|e| e.into())?;
211
212        Ok(Response {
213            jsonrpc: "2.0".to_string(),
214            id: request.id,
215            result: Some(serde_json::to_value(result)?),
216            error: None,
217        })
218    }
219
220    #[instrument(skip(self, request), fields(mcp.method = "tools/call"))]
221    async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
222        let params: CallToolRequestParam = serde_json::from_value(request.params)?;
223        let tool_name = params.name.clone();
224        let start_time = Instant::now();
225
226        let metrics = get_metrics();
228        metrics.record_request_start(&tool_name).await;
229
230        let result = {
231            let span = spans::backend_operation_span("call_tool", Some(&tool_name));
232            let _guard = span.enter();
233            match self.backend.call_tool(params).await {
234                Ok(result) => {
235                    let duration = start_time.elapsed();
236                    metrics.record_request_end(&tool_name, duration, true).await;
237                    info!(
238                        tool = %tool_name,
239                        duration_ms = %duration.as_millis(),
240                        "Tool call completed successfully"
241                    );
242                    result
243                }
244                Err(err) => {
245                    let duration = start_time.elapsed();
246                    metrics
247                        .record_request_end(&tool_name, duration, false)
248                        .await;
249                    error!(
250                        tool = %tool_name,
251                        duration_ms = %duration.as_millis(),
252                        error = %err,
253                        "Tool call failed"
254                    );
255                    return Err(err.into());
256                }
257            }
258        };
259
260        Ok(Response {
261            jsonrpc: "2.0".to_string(),
262            id: request.id,
263            result: Some(serde_json::to_value(result)?),
264            error: None,
265        })
266    }
267
268    async fn handle_list_resources(
269        &self,
270        request: Request,
271    ) -> std::result::Result<Response, Error> {
272        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
273
274        let result = self
275            .backend
276            .list_resources(params)
277            .await
278            .map_err(|e| e.into())?;
279
280        Ok(Response {
281            jsonrpc: "2.0".to_string(),
282            id: request.id,
283            result: Some(serde_json::to_value(result)?),
284            error: None,
285        })
286    }
287
288    async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
289        let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
290
291        let result = self
292            .backend
293            .read_resource(params)
294            .await
295            .map_err(|e| e.into())?;
296
297        Ok(Response {
298            jsonrpc: "2.0".to_string(),
299            id: request.id,
300            result: Some(serde_json::to_value(result)?),
301            error: None,
302        })
303    }
304
305    async fn handle_list_resource_templates(
306        &self,
307        request: Request,
308    ) -> std::result::Result<Response, Error> {
309        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
310
311        let result = self
312            .backend
313            .list_resource_templates(params)
314            .await
315            .map_err(|e| e.into())?;
316
317        Ok(Response {
318            jsonrpc: "2.0".to_string(),
319            id: request.id,
320            result: Some(serde_json::to_value(result)?),
321            error: None,
322        })
323    }
324
325    async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
326        let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
327
328        let result = self
329            .backend
330            .list_prompts(params)
331            .await
332            .map_err(|e| e.into())?;
333
334        Ok(Response {
335            jsonrpc: "2.0".to_string(),
336            id: request.id,
337            result: Some(serde_json::to_value(result)?),
338            error: None,
339        })
340    }
341
342    async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
343        let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
344
345        let result = self
346            .backend
347            .get_prompt(params)
348            .await
349            .map_err(|e| e.into())?;
350
351        Ok(Response {
352            jsonrpc: "2.0".to_string(),
353            id: request.id,
354            result: Some(serde_json::to_value(result)?),
355            error: None,
356        })
357    }
358
359    async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
360        let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
361
362        self.backend.subscribe(params).await.map_err(|e| e.into())?;
363
364        Ok(Response {
365            jsonrpc: "2.0".to_string(),
366            id: request.id,
367            result: Some(serde_json::Value::Object(Default::default())),
368            error: None,
369        })
370    }
371
372    async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
373        let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
374
375        self.backend
376            .unsubscribe(params)
377            .await
378            .map_err(|e| e.into())?;
379
380        Ok(Response {
381            jsonrpc: "2.0".to_string(),
382            id: request.id,
383            result: Some(serde_json::Value::Object(Default::default())),
384            error: None,
385        })
386    }
387
388    async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
389        let params: CompleteRequestParam = serde_json::from_value(request.params)?;
390
391        let result = self.backend.complete(params).await.map_err(|e| e.into())?;
392
393        Ok(Response {
394            jsonrpc: "2.0".to_string(),
395            id: request.id,
396            result: Some(serde_json::to_value(result)?),
397            error: None,
398        })
399    }
400
401    async fn handle_elicit(&self, request: Request) -> std::result::Result<Response, Error> {
402        let params: ElicitationRequestParam = serde_json::from_value(request.params)?;
403
404        let result = self.backend.elicit(params).await.map_err(|e| e.into())?;
405
406        Ok(Response {
407            jsonrpc: "2.0".to_string(),
408            id: request.id,
409            result: Some(serde_json::to_value(result)?),
410            error: None,
411        })
412    }
413
414    async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
415        let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
416
417        self.backend.set_level(params).await.map_err(|e| e.into())?;
418
419        Ok(Response {
420            jsonrpc: "2.0".to_string(),
421            id: request.id,
422            result: Some(serde_json::Value::Object(Default::default())),
423            error: None,
424        })
425    }
426
427    async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
428        Ok(Response {
429            jsonrpc: "2.0".to_string(),
430            id: _request.id,
431            result: Some(serde_json::Value::Object(Default::default())),
432            error: None,
433        })
434    }
435
436    async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
437        let result = self
438            .backend
439            .handle_custom_method(&request.method, request.params)
440            .await
441            .map_err(|e| e.into())?;
442
443        Ok(Response {
444            jsonrpc: "2.0".to_string(),
445            id: request.id,
446            result: Some(result),
447            error: None,
448        })
449    }
450}
451
452impl From<HandlerError> for Error {
454    fn from(err: HandlerError) -> Self {
455        match err {
456            HandlerError::Authentication(msg) => Error::unauthorized(msg),
457            HandlerError::Authorization(msg) => Error::forbidden(msg),
458            HandlerError::Backend(msg) => Error::internal_error(msg),
459            HandlerError::Protocol(e) => e,
460        }
461    }
462}
463
464#[cfg(test)]
465mod tests {
466    use super::*;
467    use crate::backend::McpBackend;
468    use crate::middleware::MiddlewareStack;
469    use async_trait::async_trait;
470    use pulseengine_mcp_auth::config::AuthConfig;
471    use pulseengine_mcp_auth::AuthenticationManager;
472    use pulseengine_mcp_logging::ErrorClassification;
473    use pulseengine_mcp_protocol::{
474        error::ErrorCode, CallToolRequestParam, CallToolResult, CompleteRequestParam,
475        CompleteResult, CompletionInfo, Content, Error, GetPromptRequestParam, GetPromptResult,
476        Implementation, InitializeResult, ListPromptsResult, ListResourceTemplatesResult,
477        ListResourcesResult, ListToolsResult, LoggingCapability, PaginatedRequestParam, Prompt,
478        PromptMessage, PromptMessageContent, PromptMessageRole, PromptsCapability, ProtocolVersion,
479        ReadResourceRequestParam, ReadResourceResult, Request, Resource, ResourceContents,
480        ResourcesCapability, ServerCapabilities, ServerInfo, SetLevelRequestParam,
481        SubscribeRequestParam, Tool, ToolsCapability, UnsubscribeRequestParam,
482    };
483    use serde_json::json;
484    use std::sync::Arc;
485
486    #[derive(Clone)]
488    struct MockBackend {
489        server_info: ServerInfo,
490        tools: Vec<Tool>,
491        resources: Vec<Resource>,
492        prompts: Vec<Prompt>,
493        should_error: bool,
494    }
495
496    impl MockBackend {
497        fn new() -> Self {
498            Self {
499                server_info: ServerInfo {
500                    protocol_version: ProtocolVersion::default(),
501                    capabilities: ServerCapabilities {
502                        tools: Some(ToolsCapability { list_changed: None }),
503                        resources: Some(ResourcesCapability {
504                            subscribe: Some(true),
505                            list_changed: None,
506                        }),
507                        prompts: Some(PromptsCapability { list_changed: None }),
508                        logging: Some(LoggingCapability { level: None }),
509                        sampling: None,
510                        elicitation: Some(ElicitationCapability {}),
511                    },
512                    server_info: Implementation {
513                        name: "test-server".to_string(),
514                        version: "1.0.0".to_string(),
515                    },
516                    instructions: None,
517                },
518                tools: vec![Tool {
519                    name: "test_tool".to_string(),
520                    description: "A test tool".to_string(),
521                    input_schema: json!({
522                        "type": "object",
523                        "properties": {
524                            "input": {"type": "string"}
525                        }
526                    }),
527                    output_schema: None,
528                }],
529                resources: vec![Resource {
530                    uri: "test://resource1".to_string(),
531                    name: "Test Resource".to_string(),
532                    description: Some("A test resource".to_string()),
533                    mime_type: Some("text/plain".to_string()),
534                    annotations: None,
535                    raw: None,
536                }],
537                prompts: vec![Prompt {
538                    name: "test_prompt".to_string(),
539                    description: Some("A test prompt".to_string()),
540                    arguments: None,
541                }],
542                should_error: false,
543            }
544        }
545
546        fn with_error() -> Self {
547            Self {
548                should_error: true,
549                ..Self::new()
550            }
551        }
552    }
553
554    #[async_trait]
555    impl McpBackend for MockBackend {
556        type Error = MockBackendError;
557        type Config = ();
558
559        async fn initialize(_config: Self::Config) -> std::result::Result<Self, Self::Error> {
560            Ok(MockBackend::new())
561        }
562
563        fn get_server_info(&self) -> ServerInfo {
564            self.server_info.clone()
565        }
566
567        async fn health_check(&self) -> std::result::Result<(), Self::Error> {
568            if self.should_error {
569                return Err(MockBackendError::TestError(
570                    "Health check failed".to_string(),
571                ));
572            }
573            Ok(())
574        }
575
576        async fn list_tools(
577            &self,
578            _params: PaginatedRequestParam,
579        ) -> std::result::Result<ListToolsResult, Self::Error> {
580            if self.should_error {
581                return Err(MockBackendError::TestError("Simulated error".to_string()));
582            }
583
584            Ok(ListToolsResult {
585                tools: self.tools.clone(),
586                next_cursor: None,
587            })
588        }
589
590        async fn call_tool(
591            &self,
592            params: CallToolRequestParam,
593        ) -> std::result::Result<CallToolResult, Self::Error> {
594            if self.should_error {
595                return Err(MockBackendError::TestError("Tool call failed".to_string()));
596            }
597
598            if params.name == "test_tool" {
599                Ok(CallToolResult {
600                    content: vec![Content::Text {
601                        text: "Tool executed successfully".to_string(),
602                    }],
603                    is_error: Some(false),
604                    structured_content: None,
605                })
606            } else {
607                Err(MockBackendError::TestError("Tool not found".to_string()))
608            }
609        }
610
611        async fn list_resources(
612            &self,
613            _params: PaginatedRequestParam,
614        ) -> std::result::Result<ListResourcesResult, Self::Error> {
615            if self.should_error {
616                return Err(MockBackendError::TestError("Simulated error".to_string()));
617            }
618
619            Ok(ListResourcesResult {
620                resources: self.resources.clone(),
621                next_cursor: None,
622            })
623        }
624
625        async fn read_resource(
626            &self,
627            params: ReadResourceRequestParam,
628        ) -> std::result::Result<ReadResourceResult, Self::Error> {
629            if self.should_error {
630                return Err(MockBackendError::TestError("Simulated error".to_string()));
631            }
632
633            if params.uri == "test://resource1" {
634                Ok(ReadResourceResult {
635                    contents: vec![ResourceContents {
636                        uri: params.uri,
637                        mime_type: Some("text/plain".to_string()),
638                        text: Some("Resource content".to_string()),
639                        blob: None,
640                    }],
641                })
642            } else {
643                Err(MockBackendError::TestError(
644                    "Resource not found".to_string(),
645                ))
646            }
647        }
648
649        async fn list_resource_templates(
650            &self,
651            _params: PaginatedRequestParam,
652        ) -> std::result::Result<ListResourceTemplatesResult, Self::Error> {
653            Ok(ListResourceTemplatesResult {
654                resource_templates: vec![],
655                next_cursor: None,
656            })
657        }
658
659        async fn list_prompts(
660            &self,
661            _params: PaginatedRequestParam,
662        ) -> std::result::Result<ListPromptsResult, Self::Error> {
663            if self.should_error {
664                return Err(MockBackendError::TestError("Simulated error".to_string()));
665            }
666
667            Ok(ListPromptsResult {
668                prompts: self.prompts.clone(),
669                next_cursor: None,
670            })
671        }
672
673        async fn get_prompt(
674            &self,
675            params: GetPromptRequestParam,
676        ) -> std::result::Result<GetPromptResult, Self::Error> {
677            if self.should_error {
678                return Err(MockBackendError::TestError("Simulated error".to_string()));
679            }
680
681            if params.name == "test_prompt" {
682                Ok(GetPromptResult {
683                    description: Some("A test prompt".to_string()),
684                    messages: vec![PromptMessage {
685                        role: PromptMessageRole::User,
686                        content: PromptMessageContent::Text {
687                            text: "Test prompt message".to_string(),
688                        },
689                    }],
690                })
691            } else {
692                Err(MockBackendError::TestError("Prompt not found".to_string()))
693            }
694        }
695
696        async fn subscribe(
697            &self,
698            _params: SubscribeRequestParam,
699        ) -> std::result::Result<(), Self::Error> {
700            if self.should_error {
701                return Err(MockBackendError::TestError("Subscribe failed".to_string()));
702            }
703            Ok(())
704        }
705
706        async fn unsubscribe(
707            &self,
708            _params: UnsubscribeRequestParam,
709        ) -> std::result::Result<(), Self::Error> {
710            if self.should_error {
711                return Err(MockBackendError::TestError(
712                    "Unsubscribe failed".to_string(),
713                ));
714            }
715            Ok(())
716        }
717
718        async fn complete(
719            &self,
720            _params: CompleteRequestParam,
721        ) -> std::result::Result<CompleteResult, Self::Error> {
722            if self.should_error {
723                return Err(MockBackendError::TestError("Complete failed".to_string()));
724            }
725
726            Ok(CompleteResult {
727                completion: vec![
728                    CompletionInfo {
729                        completion: "completion1".to_string(),
730                        has_more: Some(false),
731                    },
732                    CompletionInfo {
733                        completion: "completion2".to_string(),
734                        has_more: Some(false),
735                    },
736                ],
737            })
738        }
739
740        async fn elicit(
741            &self,
742            _params: ElicitationRequestParam,
743        ) -> std::result::Result<ElicitationResult, Self::Error> {
744            if self.should_error {
745                return Err(MockBackendError::TestError(
746                    "Elicitation failed".to_string(),
747                ));
748            }
749
750            Ok(ElicitationResult::accept(serde_json::json!({
752                "name": "Test User",
753                "email": "test@example.com"
754            })))
755        }
756
757        async fn set_level(
758            &self,
759            _params: SetLevelRequestParam,
760        ) -> std::result::Result<(), Self::Error> {
761            if self.should_error {
762                return Err(MockBackendError::TestError("Set level failed".to_string()));
763            }
764            Ok(())
765        }
766
767        async fn handle_custom_method(
768            &self,
769            method: &str,
770            _params: serde_json::Value,
771        ) -> std::result::Result<serde_json::Value, Self::Error> {
772            if self.should_error {
773                return Err(MockBackendError::TestError(
774                    "Custom method failed".to_string(),
775                ));
776            }
777
778            Ok(json!({
779                "method": method,
780                "result": "custom method executed"
781            }))
782        }
783    }
784
785    #[derive(Debug, thiserror::Error)]
786    enum MockBackendError {
787        #[error("Test error: {0}")]
788        TestError(String),
789    }
790
791    impl From<MockBackendError> for Error {
792        fn from(err: MockBackendError) -> Self {
793            Error::internal_error(err.to_string())
794        }
795    }
796
797    impl From<crate::backend::BackendError> for MockBackendError {
798        fn from(error: crate::backend::BackendError) -> Self {
799            MockBackendError::TestError(error.to_string())
800        }
801    }
802
803    async fn create_test_handler() -> GenericServerHandler<MockBackend> {
804        let backend = Arc::new(MockBackend::new());
805        let auth_config = AuthConfig::memory();
806        let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
807        let middleware = MiddlewareStack::new();
808
809        GenericServerHandler::new(backend, auth_manager, middleware)
810    }
811
812    async fn create_error_handler() -> GenericServerHandler<MockBackend> {
813        let backend = Arc::new(MockBackend::with_error());
814        let auth_config = AuthConfig::memory();
815        let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
816        let middleware = MiddlewareStack::new();
817
818        GenericServerHandler::new(backend, auth_manager, middleware)
819    }
820
821    #[tokio::test]
822    async fn test_handler_creation() {
823        let handler = create_test_handler().await;
824        assert!(!handler.backend.tools.is_empty());
826    }
827
828    #[tokio::test]
829    async fn test_handle_initialize() {
830        let handler = create_test_handler().await;
831        let request = Request {
832            jsonrpc: "2.0".to_string(),
833            method: "initialize".to_string(),
834            params: json!({
835                "protocolVersion": "2024-11-05",
836                "capabilities": {},
837                "clientInfo": {
838                    "name": "test-client",
839                    "version": "1.0.0"
840                }
841            }),
842            id: json!(1),
843        };
844
845        let response = handler.handle_request(request).await.unwrap();
846
847        assert_eq!(response.jsonrpc, "2.0");
848        assert_eq!(response.id, json!(1));
849        assert!(response.result.is_some());
850        assert!(response.error.is_none());
851
852        let result: InitializeResult = serde_json::from_value(response.result.unwrap()).unwrap();
853        assert_eq!(
854            result.protocol_version,
855            pulseengine_mcp_protocol::MCP_VERSION
856        );
857        assert_eq!(result.server_info.name, "test-server");
858    }
859
860    #[tokio::test]
861    async fn test_handle_list_tools() {
862        let handler = create_test_handler().await;
863        let request = Request {
864            jsonrpc: "2.0".to_string(),
865            method: "tools/list".to_string(),
866            params: json!({}),
867            id: json!(2),
868        };
869
870        let response = handler.handle_request(request).await.unwrap();
871
872        assert_eq!(response.jsonrpc, "2.0");
873        assert_eq!(response.id, json!(2));
874        assert!(response.result.is_some());
875        assert!(response.error.is_none());
876
877        let result: ListToolsResult = serde_json::from_value(response.result.unwrap()).unwrap();
878        assert_eq!(result.tools.len(), 1);
879        assert_eq!(result.tools[0].name, "test_tool");
880    }
881
882    #[tokio::test]
883    async fn test_handle_call_tool_success() {
884        let handler = create_test_handler().await;
885        let request = Request {
886            jsonrpc: "2.0".to_string(),
887            method: "tools/call".to_string(),
888            params: json!({
889                "name": "test_tool",
890                "arguments": {
891                    "input": "test input"
892                }
893            }),
894            id: json!(3),
895        };
896
897        let response = handler.handle_request(request).await.unwrap();
898
899        assert_eq!(response.jsonrpc, "2.0");
900        assert_eq!(response.id, json!(3));
901        assert!(response.result.is_some());
902        assert!(response.error.is_none());
903
904        let result: CallToolResult = serde_json::from_value(response.result.unwrap()).unwrap();
905        assert_eq!(result.content.len(), 1);
906        assert!(!result.is_error.unwrap_or(true));
907    }
908
909    #[tokio::test]
910    async fn test_handle_call_tool_not_found() {
911        let handler = create_test_handler().await;
912        let request = Request {
913            jsonrpc: "2.0".to_string(),
914            method: "tools/call".to_string(),
915            params: json!({
916                "name": "nonexistent_tool",
917                "arguments": {}
918            }),
919            id: json!(4),
920        };
921
922        let response = handler.handle_request(request).await.unwrap();
923
924        assert_eq!(response.jsonrpc, "2.0");
925        assert_eq!(response.id, json!(4));
926        assert!(response.result.is_none());
927        assert!(response.error.is_some());
928    }
929
930    #[tokio::test]
931    async fn test_handle_list_resources() {
932        let handler = create_test_handler().await;
933        let request = Request {
934            jsonrpc: "2.0".to_string(),
935            method: "resources/list".to_string(),
936            params: json!({}),
937            id: json!(5),
938        };
939
940        let response = handler.handle_request(request).await.unwrap();
941
942        assert_eq!(response.jsonrpc, "2.0");
943        assert_eq!(response.id, json!(5));
944        assert!(response.result.is_some());
945        assert!(response.error.is_none());
946
947        let result: ListResourcesResult = serde_json::from_value(response.result.unwrap()).unwrap();
948        assert_eq!(result.resources.len(), 1);
949        assert_eq!(result.resources[0].uri, "test://resource1");
950    }
951
952    #[tokio::test]
953    async fn test_handle_read_resource() {
954        let handler = create_test_handler().await;
955        let request = Request {
956            jsonrpc: "2.0".to_string(),
957            method: "resources/read".to_string(),
958            params: json!({
959                "uri": "test://resource1"
960            }),
961            id: json!(6),
962        };
963
964        let response = handler.handle_request(request).await.unwrap();
965
966        assert_eq!(response.jsonrpc, "2.0");
967        assert_eq!(response.id, json!(6));
968        assert!(response.result.is_some());
969        assert!(response.error.is_none());
970
971        let result: ReadResourceResult = serde_json::from_value(response.result.unwrap()).unwrap();
972        assert_eq!(result.contents.len(), 1);
973    }
974
975    #[tokio::test]
976    async fn test_handle_list_prompts() {
977        let handler = create_test_handler().await;
978        let request = Request {
979            jsonrpc: "2.0".to_string(),
980            method: "prompts/list".to_string(),
981            params: json!({}),
982            id: json!(7),
983        };
984
985        let response = handler.handle_request(request).await.unwrap();
986
987        assert_eq!(response.jsonrpc, "2.0");
988        assert_eq!(response.id, json!(7));
989        assert!(response.result.is_some());
990        assert!(response.error.is_none());
991
992        let result: ListPromptsResult = serde_json::from_value(response.result.unwrap()).unwrap();
993        assert_eq!(result.prompts.len(), 1);
994        assert_eq!(result.prompts[0].name, "test_prompt");
995    }
996
997    #[tokio::test]
998    async fn test_handle_get_prompt() {
999        let handler = create_test_handler().await;
1000        let request = Request {
1001            jsonrpc: "2.0".to_string(),
1002            method: "prompts/get".to_string(),
1003            params: json!({
1004                "name": "test_prompt",
1005                "arguments": {}
1006            }),
1007            id: json!(8),
1008        };
1009
1010        let response = handler.handle_request(request).await.unwrap();
1011
1012        assert_eq!(response.jsonrpc, "2.0");
1013        assert_eq!(response.id, json!(8));
1014        assert!(response.result.is_some());
1015        assert!(response.error.is_none());
1016
1017        let result: GetPromptResult = serde_json::from_value(response.result.unwrap()).unwrap();
1018        assert_eq!(result.messages.len(), 1);
1019    }
1020
1021    #[tokio::test]
1022    async fn test_handle_subscribe() {
1023        let handler = create_test_handler().await;
1024        let request = Request {
1025            jsonrpc: "2.0".to_string(),
1026            method: "resources/subscribe".to_string(),
1027            params: json!({
1028                "uri": "test://resource1"
1029            }),
1030            id: json!(9),
1031        };
1032
1033        let response = handler.handle_request(request).await.unwrap();
1034
1035        assert_eq!(response.jsonrpc, "2.0");
1036        assert_eq!(response.id, json!(9));
1037        assert!(response.result.is_some());
1038        assert!(response.error.is_none());
1039    }
1040
1041    #[tokio::test]
1042    async fn test_handle_unsubscribe() {
1043        let handler = create_test_handler().await;
1044        let request = Request {
1045            jsonrpc: "2.0".to_string(),
1046            method: "resources/unsubscribe".to_string(),
1047            params: json!({
1048                "uri": "test://resource1"
1049            }),
1050            id: json!(10),
1051        };
1052
1053        let response = handler.handle_request(request).await.unwrap();
1054
1055        assert_eq!(response.jsonrpc, "2.0");
1056        assert_eq!(response.id, json!(10));
1057        assert!(response.result.is_some());
1058        assert!(response.error.is_none());
1059    }
1060
1061    #[tokio::test]
1062    async fn test_handle_complete() {
1063        let handler = create_test_handler().await;
1064        let request = Request {
1065            jsonrpc: "2.0".to_string(),
1066            method: "completion/complete".to_string(),
1067            params: json!({
1068                "ref_": "test_prompt",
1069                "argument": {
1070                    "name": "query",
1071                    "value": "test"
1072                }
1073            }),
1074            id: json!(11),
1075        };
1076
1077        let response = handler.handle_request(request).await.unwrap();
1078
1079        assert_eq!(response.jsonrpc, "2.0");
1080        assert_eq!(response.id, json!(11));
1081        assert!(response.result.is_some());
1082        assert!(response.error.is_none());
1083
1084        let result: CompleteResult = serde_json::from_value(response.result.unwrap()).unwrap();
1085        assert_eq!(result.completion.len(), 2);
1086    }
1087
1088    #[tokio::test]
1089    async fn test_handle_elicit() {
1090        let handler = create_test_handler().await;
1091        let request = Request {
1092            jsonrpc: "2.0".to_string(),
1093            method: "elicitation/create".to_string(),
1094            params: json!({
1095                "message": "Please provide your contact information",
1096                "requestedSchema": {
1097                    "type": "object",
1098                    "properties": {
1099                        "name": {"type": "string", "description": "Your full name"},
1100                        "email": {"type": "string", "format": "email"}
1101                    },
1102                    "required": ["name", "email"]
1103                }
1104            }),
1105            id: json!(12),
1106        };
1107
1108        let response = handler.handle_request(request).await.unwrap();
1109
1110        assert_eq!(response.jsonrpc, "2.0");
1111        assert_eq!(response.id, json!(12));
1112        assert!(response.result.is_some());
1113        assert!(response.error.is_none());
1114
1115        let result: ElicitationResult = serde_json::from_value(response.result.unwrap()).unwrap();
1116        assert!(matches!(result.response.action, ElicitationAction::Accept));
1117        assert!(result.response.data.is_some());
1118    }
1119
1120    #[tokio::test]
1121    async fn test_handle_ping() {
1122        let handler = create_test_handler().await;
1123        let request = Request {
1124            jsonrpc: "2.0".to_string(),
1125            method: "ping".to_string(),
1126            params: json!({}),
1127            id: json!(12),
1128        };
1129
1130        let response = handler.handle_request(request).await.unwrap();
1131
1132        assert_eq!(response.jsonrpc, "2.0");
1133        assert_eq!(response.id, json!(12));
1134        assert!(response.result.is_some());
1135        assert!(response.error.is_none());
1136    }
1137
1138    #[tokio::test]
1139    async fn test_handle_custom_method() {
1140        let handler = create_test_handler().await;
1141        let request = Request {
1142            jsonrpc: "2.0".to_string(),
1143            method: "custom/method".to_string(),
1144            params: json!({"test": "data"}),
1145            id: json!(13),
1146        };
1147
1148        let response = handler.handle_request(request).await.unwrap();
1149
1150        assert_eq!(response.jsonrpc, "2.0");
1151        assert_eq!(response.id, json!(13));
1152        assert!(response.result.is_some());
1153        assert!(response.error.is_none());
1154
1155        let result = response.result.unwrap();
1156        assert_eq!(result["method"], "custom/method");
1157    }
1158
1159    #[tokio::test]
1160    async fn test_backend_error_handling() {
1161        let handler = create_error_handler().await;
1162        let request = Request {
1163            jsonrpc: "2.0".to_string(),
1164            method: "tools/list".to_string(),
1165            params: json!({}),
1166            id: json!(14),
1167        };
1168
1169        let response = handler.handle_request(request).await.unwrap();
1170
1171        assert_eq!(response.jsonrpc, "2.0");
1172        assert_eq!(response.id, json!(14));
1173        assert!(response.result.is_none());
1174        assert!(response.error.is_some());
1175
1176        let error = response.error.unwrap();
1177        assert!(error.message.contains("Simulated error"));
1178    }
1179
1180    #[tokio::test]
1181    async fn test_invalid_params() {
1182        let handler = create_test_handler().await;
1183        let request = Request {
1184            jsonrpc: "2.0".to_string(),
1185            method: "tools/call".to_string(),
1186            params: json!("invalid"), id: json!(15),
1188        };
1189
1190        let response = handler.handle_request(request).await.unwrap();
1191
1192        assert_eq!(response.jsonrpc, "2.0");
1193        assert_eq!(response.id, json!(15));
1194        assert!(response.result.is_none());
1195        assert!(response.error.is_some());
1196    }
1197
1198    #[test]
1199    fn test_handler_error_classification() {
1200        let auth_error = HandlerError::Authentication("Invalid token".to_string());
1201        assert_eq!(auth_error.error_type(), "authentication");
1202        assert!(!auth_error.is_retryable());
1203        assert!(!auth_error.is_timeout());
1204        assert!(auth_error.is_auth_error());
1205        assert!(!auth_error.is_connection_error());
1206
1207        let backend_error = HandlerError::Backend("Database error".to_string());
1208        assert_eq!(backend_error.error_type(), "backend");
1209        assert!(backend_error.is_retryable());
1210        assert!(!backend_error.is_timeout());
1211        assert!(!backend_error.is_auth_error());
1212        assert!(!backend_error.is_connection_error());
1213
1214        let protocol_error =
1215            HandlerError::Protocol(Error::invalid_request("Bad request".to_string()));
1216        assert_eq!(protocol_error.error_type(), "protocol");
1217        assert!(!protocol_error.is_retryable());
1218        assert!(!protocol_error.is_timeout());
1219        assert!(!protocol_error.is_auth_error());
1220        assert!(!protocol_error.is_connection_error());
1221    }
1222
1223    #[test]
1224    fn test_handler_error_conversion() {
1225        let auth_error = HandlerError::Authentication("Auth failed".to_string());
1226        let protocol_error: Error = auth_error.into();
1227        assert_eq!(protocol_error.code, ErrorCode::Unauthorized);
1228
1229        let backend_error = HandlerError::Backend("Backend failed".to_string());
1230        let protocol_error: Error = backend_error.into();
1231        assert_eq!(protocol_error.code, ErrorCode::InternalError);
1232    }
1233
1234    #[test]
1235    fn test_handler_error_display() {
1236        let error = HandlerError::Authentication("Test auth error".to_string());
1237        assert_eq!(error.to_string(), "Authentication failed: Test auth error");
1238
1239        let error = HandlerError::Authorization("Test auth error".to_string());
1240        assert_eq!(error.to_string(), "Authorization failed: Test auth error");
1241
1242        let error = HandlerError::Backend("Test backend error".to_string());
1243        assert_eq!(error.to_string(), "Backend error: Test backend error");
1244    }
1245}