1use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_logging::{get_metrics, spans};
6use pulseengine_mcp_protocol::*;
7
8use std::sync::Arc;
9use std::time::Instant;
10use thiserror::Error;
11use tracing::{debug, error, info, instrument};
12
13#[derive(Debug, Error)]
15pub enum HandlerError {
16 #[error("Authentication failed: {0}")]
17 Authentication(String),
18
19 #[error("Authorization failed: {0}")]
20 Authorization(String),
21
22 #[error("Backend error: {0}")]
23 Backend(String),
24
25 #[error("Protocol error: {0}")]
26 Protocol(#[from] Error),
27}
28
29impl pulseengine_mcp_logging::ErrorClassification for HandlerError {
31 fn error_type(&self) -> &str {
32 match self {
33 HandlerError::Authentication(_) => "authentication",
34 HandlerError::Authorization(_) => "authorization",
35 HandlerError::Backend(_) => "backend",
36 HandlerError::Protocol(_) => "protocol",
37 }
38 }
39
40 fn is_retryable(&self) -> bool {
41 match self {
42 HandlerError::Backend(_) => true, _ => false,
44 }
45 }
46
47 fn is_timeout(&self) -> bool {
48 false }
50
51 fn is_auth_error(&self) -> bool {
52 matches!(
53 self,
54 HandlerError::Authentication(_) | HandlerError::Authorization(_)
55 )
56 }
57
58 fn is_connection_error(&self) -> bool {
59 false }
61}
62
63#[derive(Clone)]
65pub struct GenericServerHandler<B: McpBackend> {
66 backend: Arc<B>,
67 #[allow(dead_code)]
68 auth_manager: Arc<AuthenticationManager>,
69 middleware: MiddlewareStack,
70}
71
72impl<B: McpBackend> GenericServerHandler<B> {
73 pub fn new(
75 backend: Arc<B>,
76 auth_manager: Arc<AuthenticationManager>,
77 middleware: MiddlewareStack,
78 ) -> Self {
79 Self {
80 backend,
81 auth_manager,
82 middleware,
83 }
84 }
85
86 #[instrument(skip(self, request), fields(mcp.method = %request.method, mcp.request_id = %request.id))]
88 pub async fn handle_request(
89 &self,
90 request: Request,
91 ) -> std::result::Result<Response, HandlerError> {
92 let start_time = Instant::now();
93 let method = request.method.clone();
94 debug!("Handling request: {}", method);
95
96 let request_id = request.id.clone();
98
99 let context = RequestContext::new();
101
102 let metrics = get_metrics();
104
105 metrics.record_request_start(&method).await;
107
108 let request = self.middleware.process_request(request, &context).await?;
110
111 let result = {
113 let span = spans::mcp_request_span(&method, &request_id.to_string());
114 let _guard = span.enter();
115
116 match request.method.as_str() {
117 "initialize" => self.handle_initialize(request).await,
118 "tools/list" => self.handle_list_tools(request).await,
119 "tools/call" => self.handle_call_tool(request).await,
120 "resources/list" => self.handle_list_resources(request).await,
121 "resources/read" => self.handle_read_resource(request).await,
122 "resources/templates/list" => self.handle_list_resource_templates(request).await,
123 "prompts/list" => self.handle_list_prompts(request).await,
124 "prompts/get" => self.handle_get_prompt(request).await,
125 "resources/subscribe" => self.handle_subscribe(request).await,
126 "resources/unsubscribe" => self.handle_unsubscribe(request).await,
127 "completion/complete" => self.handle_complete(request).await,
128 "elicitation/create" => self.handle_elicit(request).await,
129 "logging/setLevel" => self.handle_set_level(request).await,
130 "ping" => self.handle_ping(request).await,
131 _ => self.handle_custom_method(request).await,
132 }
133 };
134
135 let duration = start_time.elapsed();
137
138 match result {
139 Ok(response) => {
140 metrics.record_request_end(&method, duration, true).await;
142
143 let response = self.middleware.process_response(response, &context).await?;
145
146 info!(
147 method = %method,
148 duration_ms = %duration.as_millis(),
149 request_id = ?request_id,
150 "Request completed successfully"
151 );
152
153 Ok(response)
154 }
155 Err(error) => {
156 metrics.record_request_end(&method, duration, false).await;
158
159 metrics
161 .record_error(&method, &context.request_id.to_string(), &error, duration)
162 .await;
163
164 error!(
165 method = %method,
166 duration_ms = %duration.as_millis(),
167 request_id = ?request_id,
168 error = %error,
169 "Request failed"
170 );
171
172 Ok(Response {
173 jsonrpc: "2.0".to_string(),
174 id: request_id,
175 result: None,
176 error: Some(error),
177 })
178 }
179 }
180 }
181
182 #[instrument(skip(self, request), fields(mcp.method = "initialize"))]
183 async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
184 let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
185
186 let server_info = self.backend.get_server_info();
187 let result = InitializeResult {
188 protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
189 capabilities: server_info.capabilities,
190 server_info: server_info.server_info.clone(),
191 instructions: server_info.instructions,
192 };
193
194 Ok(Response {
195 jsonrpc: "2.0".to_string(),
196 id: request.id,
197 result: Some(serde_json::to_value(result)?),
198 error: None,
199 })
200 }
201
202 #[instrument(skip(self, request), fields(mcp.method = "tools/list"))]
203 async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
204 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
205
206 let result = self
207 .backend
208 .list_tools(params)
209 .await
210 .map_err(|e| e.into())?;
211
212 Ok(Response {
213 jsonrpc: "2.0".to_string(),
214 id: request.id,
215 result: Some(serde_json::to_value(result)?),
216 error: None,
217 })
218 }
219
220 #[instrument(skip(self, request), fields(mcp.method = "tools/call"))]
221 async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
222 let params: CallToolRequestParam = serde_json::from_value(request.params)?;
223 let tool_name = params.name.clone();
224 let start_time = Instant::now();
225
226 let metrics = get_metrics();
228 metrics.record_request_start(&tool_name).await;
229
230 let result = {
231 let span = spans::backend_operation_span("call_tool", Some(&tool_name));
232 let _guard = span.enter();
233 match self.backend.call_tool(params).await {
234 Ok(result) => {
235 let duration = start_time.elapsed();
236 metrics.record_request_end(&tool_name, duration, true).await;
237 info!(
238 tool = %tool_name,
239 duration_ms = %duration.as_millis(),
240 "Tool call completed successfully"
241 );
242 result
243 }
244 Err(err) => {
245 let duration = start_time.elapsed();
246 metrics
247 .record_request_end(&tool_name, duration, false)
248 .await;
249 error!(
250 tool = %tool_name,
251 duration_ms = %duration.as_millis(),
252 error = %err,
253 "Tool call failed"
254 );
255 return Err(err.into());
256 }
257 }
258 };
259
260 Ok(Response {
261 jsonrpc: "2.0".to_string(),
262 id: request.id,
263 result: Some(serde_json::to_value(result)?),
264 error: None,
265 })
266 }
267
268 async fn handle_list_resources(
269 &self,
270 request: Request,
271 ) -> std::result::Result<Response, Error> {
272 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
273
274 let result = self
275 .backend
276 .list_resources(params)
277 .await
278 .map_err(|e| e.into())?;
279
280 Ok(Response {
281 jsonrpc: "2.0".to_string(),
282 id: request.id,
283 result: Some(serde_json::to_value(result)?),
284 error: None,
285 })
286 }
287
288 async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
289 let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
290
291 let result = self
292 .backend
293 .read_resource(params)
294 .await
295 .map_err(|e| e.into())?;
296
297 Ok(Response {
298 jsonrpc: "2.0".to_string(),
299 id: request.id,
300 result: Some(serde_json::to_value(result)?),
301 error: None,
302 })
303 }
304
305 async fn handle_list_resource_templates(
306 &self,
307 request: Request,
308 ) -> std::result::Result<Response, Error> {
309 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
310
311 let result = self
312 .backend
313 .list_resource_templates(params)
314 .await
315 .map_err(|e| e.into())?;
316
317 Ok(Response {
318 jsonrpc: "2.0".to_string(),
319 id: request.id,
320 result: Some(serde_json::to_value(result)?),
321 error: None,
322 })
323 }
324
325 async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
326 let params: PaginatedRequestParam = serde_json::from_value(request.params)?;
327
328 let result = self
329 .backend
330 .list_prompts(params)
331 .await
332 .map_err(|e| e.into())?;
333
334 Ok(Response {
335 jsonrpc: "2.0".to_string(),
336 id: request.id,
337 result: Some(serde_json::to_value(result)?),
338 error: None,
339 })
340 }
341
342 async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
343 let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
344
345 let result = self
346 .backend
347 .get_prompt(params)
348 .await
349 .map_err(|e| e.into())?;
350
351 Ok(Response {
352 jsonrpc: "2.0".to_string(),
353 id: request.id,
354 result: Some(serde_json::to_value(result)?),
355 error: None,
356 })
357 }
358
359 async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
360 let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
361
362 self.backend.subscribe(params).await.map_err(|e| e.into())?;
363
364 Ok(Response {
365 jsonrpc: "2.0".to_string(),
366 id: request.id,
367 result: Some(serde_json::Value::Object(Default::default())),
368 error: None,
369 })
370 }
371
372 async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
373 let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
374
375 self.backend
376 .unsubscribe(params)
377 .await
378 .map_err(|e| e.into())?;
379
380 Ok(Response {
381 jsonrpc: "2.0".to_string(),
382 id: request.id,
383 result: Some(serde_json::Value::Object(Default::default())),
384 error: None,
385 })
386 }
387
388 async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
389 let params: CompleteRequestParam = serde_json::from_value(request.params)?;
390
391 let result = self.backend.complete(params).await.map_err(|e| e.into())?;
392
393 Ok(Response {
394 jsonrpc: "2.0".to_string(),
395 id: request.id,
396 result: Some(serde_json::to_value(result)?),
397 error: None,
398 })
399 }
400
401 async fn handle_elicit(&self, request: Request) -> std::result::Result<Response, Error> {
402 let params: ElicitationRequestParam = serde_json::from_value(request.params)?;
403
404 let result = self.backend.elicit(params).await.map_err(|e| e.into())?;
405
406 Ok(Response {
407 jsonrpc: "2.0".to_string(),
408 id: request.id,
409 result: Some(serde_json::to_value(result)?),
410 error: None,
411 })
412 }
413
414 async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
415 let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
416
417 self.backend.set_level(params).await.map_err(|e| e.into())?;
418
419 Ok(Response {
420 jsonrpc: "2.0".to_string(),
421 id: request.id,
422 result: Some(serde_json::Value::Object(Default::default())),
423 error: None,
424 })
425 }
426
427 async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
428 Ok(Response {
429 jsonrpc: "2.0".to_string(),
430 id: _request.id,
431 result: Some(serde_json::Value::Object(Default::default())),
432 error: None,
433 })
434 }
435
436 async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
437 let result = self
438 .backend
439 .handle_custom_method(&request.method, request.params)
440 .await
441 .map_err(|e| e.into())?;
442
443 Ok(Response {
444 jsonrpc: "2.0".to_string(),
445 id: request.id,
446 result: Some(result),
447 error: None,
448 })
449 }
450}
451
452impl From<HandlerError> for Error {
454 fn from(err: HandlerError) -> Self {
455 match err {
456 HandlerError::Authentication(msg) => Error::unauthorized(msg),
457 HandlerError::Authorization(msg) => Error::forbidden(msg),
458 HandlerError::Backend(msg) => Error::internal_error(msg),
459 HandlerError::Protocol(e) => e,
460 }
461 }
462}
463
464#[cfg(test)]
465mod tests {
466 use super::*;
467 use crate::backend::McpBackend;
468 use crate::middleware::MiddlewareStack;
469 use async_trait::async_trait;
470 use pulseengine_mcp_auth::AuthenticationManager;
471 use pulseengine_mcp_auth::config::AuthConfig;
472 use pulseengine_mcp_logging::ErrorClassification;
473 use pulseengine_mcp_protocol::{
474 CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult, CompletionInfo,
475 Content, Error, GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult,
476 ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, ListToolsResult,
477 LoggingCapability, PaginatedRequestParam, Prompt, PromptMessage, PromptMessageContent,
478 PromptMessageRole, PromptsCapability, ProtocolVersion, ReadResourceRequestParam,
479 ReadResourceResult, Request, Resource, ResourceContents, ResourcesCapability,
480 ServerCapabilities, ServerInfo, SetLevelRequestParam, SubscribeRequestParam, Tool,
481 ToolsCapability, UnsubscribeRequestParam, error::ErrorCode,
482 };
483 use serde_json::json;
484 use std::sync::Arc;
485
486 #[derive(Clone)]
488 struct MockBackend {
489 server_info: ServerInfo,
490 tools: Vec<Tool>,
491 resources: Vec<Resource>,
492 prompts: Vec<Prompt>,
493 should_error: bool,
494 }
495
496 impl MockBackend {
497 fn new() -> Self {
498 Self {
499 server_info: ServerInfo {
500 protocol_version: ProtocolVersion::default(),
501 capabilities: ServerCapabilities {
502 tools: Some(ToolsCapability { list_changed: None }),
503 resources: Some(ResourcesCapability {
504 subscribe: Some(true),
505 list_changed: None,
506 }),
507 prompts: Some(PromptsCapability { list_changed: None }),
508 logging: Some(LoggingCapability { level: None }),
509 sampling: None,
510 elicitation: Some(ElicitationCapability {}),
511 },
512 server_info: Implementation {
513 name: "test-server".to_string(),
514 version: "1.0.0".to_string(),
515 },
516 instructions: None,
517 },
518 tools: vec![Tool {
519 name: "test_tool".to_string(),
520 description: "A test tool".to_string(),
521 input_schema: json!({
522 "type": "object",
523 "properties": {
524 "input": {"type": "string"}
525 }
526 }),
527 output_schema: None,
528 }],
529 resources: vec![Resource {
530 uri: "test://resource1".to_string(),
531 name: "Test Resource".to_string(),
532 description: Some("A test resource".to_string()),
533 mime_type: Some("text/plain".to_string()),
534 annotations: None,
535 raw: None,
536 }],
537 prompts: vec![Prompt {
538 name: "test_prompt".to_string(),
539 description: Some("A test prompt".to_string()),
540 arguments: None,
541 }],
542 should_error: false,
543 }
544 }
545
546 fn with_error() -> Self {
547 Self {
548 should_error: true,
549 ..Self::new()
550 }
551 }
552 }
553
554 #[async_trait]
555 impl McpBackend for MockBackend {
556 type Error = MockBackendError;
557 type Config = ();
558
559 async fn initialize(_config: Self::Config) -> std::result::Result<Self, Self::Error> {
560 Ok(MockBackend::new())
561 }
562
563 fn get_server_info(&self) -> ServerInfo {
564 self.server_info.clone()
565 }
566
567 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
568 if self.should_error {
569 return Err(MockBackendError::TestError(
570 "Health check failed".to_string(),
571 ));
572 }
573 Ok(())
574 }
575
576 async fn list_tools(
577 &self,
578 _params: PaginatedRequestParam,
579 ) -> std::result::Result<ListToolsResult, Self::Error> {
580 if self.should_error {
581 return Err(MockBackendError::TestError("Simulated error".to_string()));
582 }
583
584 Ok(ListToolsResult {
585 tools: self.tools.clone(),
586 next_cursor: None,
587 })
588 }
589
590 async fn call_tool(
591 &self,
592 params: CallToolRequestParam,
593 ) -> std::result::Result<CallToolResult, Self::Error> {
594 if self.should_error {
595 return Err(MockBackendError::TestError("Tool call failed".to_string()));
596 }
597
598 if params.name == "test_tool" {
599 Ok(CallToolResult {
600 content: vec![Content::Text {
601 text: "Tool executed successfully".to_string(),
602 }],
603 is_error: Some(false),
604 structured_content: None,
605 })
606 } else {
607 Err(MockBackendError::TestError("Tool not found".to_string()))
608 }
609 }
610
611 async fn list_resources(
612 &self,
613 _params: PaginatedRequestParam,
614 ) -> std::result::Result<ListResourcesResult, Self::Error> {
615 if self.should_error {
616 return Err(MockBackendError::TestError("Simulated error".to_string()));
617 }
618
619 Ok(ListResourcesResult {
620 resources: self.resources.clone(),
621 next_cursor: None,
622 })
623 }
624
625 async fn read_resource(
626 &self,
627 params: ReadResourceRequestParam,
628 ) -> std::result::Result<ReadResourceResult, Self::Error> {
629 if self.should_error {
630 return Err(MockBackendError::TestError("Simulated error".to_string()));
631 }
632
633 if params.uri == "test://resource1" {
634 Ok(ReadResourceResult {
635 contents: vec![ResourceContents {
636 uri: params.uri,
637 mime_type: Some("text/plain".to_string()),
638 text: Some("Resource content".to_string()),
639 blob: None,
640 }],
641 })
642 } else {
643 Err(MockBackendError::TestError(
644 "Resource not found".to_string(),
645 ))
646 }
647 }
648
649 async fn list_resource_templates(
650 &self,
651 _params: PaginatedRequestParam,
652 ) -> std::result::Result<ListResourceTemplatesResult, Self::Error> {
653 Ok(ListResourceTemplatesResult {
654 resource_templates: vec![],
655 next_cursor: None,
656 })
657 }
658
659 async fn list_prompts(
660 &self,
661 _params: PaginatedRequestParam,
662 ) -> std::result::Result<ListPromptsResult, Self::Error> {
663 if self.should_error {
664 return Err(MockBackendError::TestError("Simulated error".to_string()));
665 }
666
667 Ok(ListPromptsResult {
668 prompts: self.prompts.clone(),
669 next_cursor: None,
670 })
671 }
672
673 async fn get_prompt(
674 &self,
675 params: GetPromptRequestParam,
676 ) -> std::result::Result<GetPromptResult, Self::Error> {
677 if self.should_error {
678 return Err(MockBackendError::TestError("Simulated error".to_string()));
679 }
680
681 if params.name == "test_prompt" {
682 Ok(GetPromptResult {
683 description: Some("A test prompt".to_string()),
684 messages: vec![PromptMessage {
685 role: PromptMessageRole::User,
686 content: PromptMessageContent::Text {
687 text: "Test prompt message".to_string(),
688 },
689 }],
690 })
691 } else {
692 Err(MockBackendError::TestError("Prompt not found".to_string()))
693 }
694 }
695
696 async fn subscribe(
697 &self,
698 _params: SubscribeRequestParam,
699 ) -> std::result::Result<(), Self::Error> {
700 if self.should_error {
701 return Err(MockBackendError::TestError("Subscribe failed".to_string()));
702 }
703 Ok(())
704 }
705
706 async fn unsubscribe(
707 &self,
708 _params: UnsubscribeRequestParam,
709 ) -> std::result::Result<(), Self::Error> {
710 if self.should_error {
711 return Err(MockBackendError::TestError(
712 "Unsubscribe failed".to_string(),
713 ));
714 }
715 Ok(())
716 }
717
718 async fn complete(
719 &self,
720 _params: CompleteRequestParam,
721 ) -> std::result::Result<CompleteResult, Self::Error> {
722 if self.should_error {
723 return Err(MockBackendError::TestError("Complete failed".to_string()));
724 }
725
726 Ok(CompleteResult {
727 completion: vec![
728 CompletionInfo {
729 completion: "completion1".to_string(),
730 has_more: Some(false),
731 },
732 CompletionInfo {
733 completion: "completion2".to_string(),
734 has_more: Some(false),
735 },
736 ],
737 })
738 }
739
740 async fn elicit(
741 &self,
742 _params: ElicitationRequestParam,
743 ) -> std::result::Result<ElicitationResult, Self::Error> {
744 if self.should_error {
745 return Err(MockBackendError::TestError(
746 "Elicitation failed".to_string(),
747 ));
748 }
749
750 Ok(ElicitationResult::accept(serde_json::json!({
752 "name": "Test User",
753 "email": "test@example.com"
754 })))
755 }
756
757 async fn set_level(
758 &self,
759 _params: SetLevelRequestParam,
760 ) -> std::result::Result<(), Self::Error> {
761 if self.should_error {
762 return Err(MockBackendError::TestError("Set level failed".to_string()));
763 }
764 Ok(())
765 }
766
767 async fn handle_custom_method(
768 &self,
769 method: &str,
770 _params: serde_json::Value,
771 ) -> std::result::Result<serde_json::Value, Self::Error> {
772 if self.should_error {
773 return Err(MockBackendError::TestError(
774 "Custom method failed".to_string(),
775 ));
776 }
777
778 Ok(json!({
779 "method": method,
780 "result": "custom method executed"
781 }))
782 }
783 }
784
785 #[derive(Debug, thiserror::Error)]
786 enum MockBackendError {
787 #[error("Test error: {0}")]
788 TestError(String),
789 }
790
791 impl From<MockBackendError> for Error {
792 fn from(err: MockBackendError) -> Self {
793 Error::internal_error(err.to_string())
794 }
795 }
796
797 impl From<crate::backend::BackendError> for MockBackendError {
798 fn from(error: crate::backend::BackendError) -> Self {
799 MockBackendError::TestError(error.to_string())
800 }
801 }
802
803 async fn create_test_handler() -> GenericServerHandler<MockBackend> {
804 let backend = Arc::new(MockBackend::new());
805 let auth_config = AuthConfig::memory();
806 let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
807 let middleware = MiddlewareStack::new();
808
809 GenericServerHandler::new(backend, auth_manager, middleware)
810 }
811
812 async fn create_error_handler() -> GenericServerHandler<MockBackend> {
813 let backend = Arc::new(MockBackend::with_error());
814 let auth_config = AuthConfig::memory();
815 let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
816 let middleware = MiddlewareStack::new();
817
818 GenericServerHandler::new(backend, auth_manager, middleware)
819 }
820
821 #[tokio::test]
822 async fn test_handler_creation() {
823 let handler = create_test_handler().await;
824 assert!(!handler.backend.tools.is_empty());
826 }
827
828 #[tokio::test]
829 async fn test_handle_initialize() {
830 let handler = create_test_handler().await;
831 let request = Request {
832 jsonrpc: "2.0".to_string(),
833 method: "initialize".to_string(),
834 params: json!({
835 "protocolVersion": "2024-11-05",
836 "capabilities": {},
837 "clientInfo": {
838 "name": "test-client",
839 "version": "1.0.0"
840 }
841 }),
842 id: json!(1),
843 };
844
845 let response = handler.handle_request(request).await.unwrap();
846
847 assert_eq!(response.jsonrpc, "2.0");
848 assert_eq!(response.id, json!(1));
849 assert!(response.result.is_some());
850 assert!(response.error.is_none());
851
852 let result: InitializeResult = serde_json::from_value(response.result.unwrap()).unwrap();
853 assert_eq!(
854 result.protocol_version,
855 pulseengine_mcp_protocol::MCP_VERSION
856 );
857 assert_eq!(result.server_info.name, "test-server");
858 }
859
860 #[tokio::test]
861 async fn test_handle_list_tools() {
862 let handler = create_test_handler().await;
863 let request = Request {
864 jsonrpc: "2.0".to_string(),
865 method: "tools/list".to_string(),
866 params: json!({}),
867 id: json!(2),
868 };
869
870 let response = handler.handle_request(request).await.unwrap();
871
872 assert_eq!(response.jsonrpc, "2.0");
873 assert_eq!(response.id, json!(2));
874 assert!(response.result.is_some());
875 assert!(response.error.is_none());
876
877 let result: ListToolsResult = serde_json::from_value(response.result.unwrap()).unwrap();
878 assert_eq!(result.tools.len(), 1);
879 assert_eq!(result.tools[0].name, "test_tool");
880 }
881
882 #[tokio::test]
883 async fn test_handle_call_tool_success() {
884 let handler = create_test_handler().await;
885 let request = Request {
886 jsonrpc: "2.0".to_string(),
887 method: "tools/call".to_string(),
888 params: json!({
889 "name": "test_tool",
890 "arguments": {
891 "input": "test input"
892 }
893 }),
894 id: json!(3),
895 };
896
897 let response = handler.handle_request(request).await.unwrap();
898
899 assert_eq!(response.jsonrpc, "2.0");
900 assert_eq!(response.id, json!(3));
901 assert!(response.result.is_some());
902 assert!(response.error.is_none());
903
904 let result: CallToolResult = serde_json::from_value(response.result.unwrap()).unwrap();
905 assert_eq!(result.content.len(), 1);
906 assert!(!result.is_error.unwrap_or(true));
907 }
908
909 #[tokio::test]
910 async fn test_handle_call_tool_not_found() {
911 let handler = create_test_handler().await;
912 let request = Request {
913 jsonrpc: "2.0".to_string(),
914 method: "tools/call".to_string(),
915 params: json!({
916 "name": "nonexistent_tool",
917 "arguments": {}
918 }),
919 id: json!(4),
920 };
921
922 let response = handler.handle_request(request).await.unwrap();
923
924 assert_eq!(response.jsonrpc, "2.0");
925 assert_eq!(response.id, json!(4));
926 assert!(response.result.is_none());
927 assert!(response.error.is_some());
928 }
929
930 #[tokio::test]
931 async fn test_handle_list_resources() {
932 let handler = create_test_handler().await;
933 let request = Request {
934 jsonrpc: "2.0".to_string(),
935 method: "resources/list".to_string(),
936 params: json!({}),
937 id: json!(5),
938 };
939
940 let response = handler.handle_request(request).await.unwrap();
941
942 assert_eq!(response.jsonrpc, "2.0");
943 assert_eq!(response.id, json!(5));
944 assert!(response.result.is_some());
945 assert!(response.error.is_none());
946
947 let result: ListResourcesResult = serde_json::from_value(response.result.unwrap()).unwrap();
948 assert_eq!(result.resources.len(), 1);
949 assert_eq!(result.resources[0].uri, "test://resource1");
950 }
951
952 #[tokio::test]
953 async fn test_handle_read_resource() {
954 let handler = create_test_handler().await;
955 let request = Request {
956 jsonrpc: "2.0".to_string(),
957 method: "resources/read".to_string(),
958 params: json!({
959 "uri": "test://resource1"
960 }),
961 id: json!(6),
962 };
963
964 let response = handler.handle_request(request).await.unwrap();
965
966 assert_eq!(response.jsonrpc, "2.0");
967 assert_eq!(response.id, json!(6));
968 assert!(response.result.is_some());
969 assert!(response.error.is_none());
970
971 let result: ReadResourceResult = serde_json::from_value(response.result.unwrap()).unwrap();
972 assert_eq!(result.contents.len(), 1);
973 }
974
975 #[tokio::test]
976 async fn test_handle_list_prompts() {
977 let handler = create_test_handler().await;
978 let request = Request {
979 jsonrpc: "2.0".to_string(),
980 method: "prompts/list".to_string(),
981 params: json!({}),
982 id: json!(7),
983 };
984
985 let response = handler.handle_request(request).await.unwrap();
986
987 assert_eq!(response.jsonrpc, "2.0");
988 assert_eq!(response.id, json!(7));
989 assert!(response.result.is_some());
990 assert!(response.error.is_none());
991
992 let result: ListPromptsResult = serde_json::from_value(response.result.unwrap()).unwrap();
993 assert_eq!(result.prompts.len(), 1);
994 assert_eq!(result.prompts[0].name, "test_prompt");
995 }
996
997 #[tokio::test]
998 async fn test_handle_get_prompt() {
999 let handler = create_test_handler().await;
1000 let request = Request {
1001 jsonrpc: "2.0".to_string(),
1002 method: "prompts/get".to_string(),
1003 params: json!({
1004 "name": "test_prompt",
1005 "arguments": {}
1006 }),
1007 id: json!(8),
1008 };
1009
1010 let response = handler.handle_request(request).await.unwrap();
1011
1012 assert_eq!(response.jsonrpc, "2.0");
1013 assert_eq!(response.id, json!(8));
1014 assert!(response.result.is_some());
1015 assert!(response.error.is_none());
1016
1017 let result: GetPromptResult = serde_json::from_value(response.result.unwrap()).unwrap();
1018 assert_eq!(result.messages.len(), 1);
1019 }
1020
1021 #[tokio::test]
1022 async fn test_handle_subscribe() {
1023 let handler = create_test_handler().await;
1024 let request = Request {
1025 jsonrpc: "2.0".to_string(),
1026 method: "resources/subscribe".to_string(),
1027 params: json!({
1028 "uri": "test://resource1"
1029 }),
1030 id: json!(9),
1031 };
1032
1033 let response = handler.handle_request(request).await.unwrap();
1034
1035 assert_eq!(response.jsonrpc, "2.0");
1036 assert_eq!(response.id, json!(9));
1037 assert!(response.result.is_some());
1038 assert!(response.error.is_none());
1039 }
1040
1041 #[tokio::test]
1042 async fn test_handle_unsubscribe() {
1043 let handler = create_test_handler().await;
1044 let request = Request {
1045 jsonrpc: "2.0".to_string(),
1046 method: "resources/unsubscribe".to_string(),
1047 params: json!({
1048 "uri": "test://resource1"
1049 }),
1050 id: json!(10),
1051 };
1052
1053 let response = handler.handle_request(request).await.unwrap();
1054
1055 assert_eq!(response.jsonrpc, "2.0");
1056 assert_eq!(response.id, json!(10));
1057 assert!(response.result.is_some());
1058 assert!(response.error.is_none());
1059 }
1060
1061 #[tokio::test]
1062 async fn test_handle_complete() {
1063 let handler = create_test_handler().await;
1064 let request = Request {
1065 jsonrpc: "2.0".to_string(),
1066 method: "completion/complete".to_string(),
1067 params: json!({
1068 "ref_": "test_prompt",
1069 "argument": {
1070 "name": "query",
1071 "value": "test"
1072 }
1073 }),
1074 id: json!(11),
1075 };
1076
1077 let response = handler.handle_request(request).await.unwrap();
1078
1079 assert_eq!(response.jsonrpc, "2.0");
1080 assert_eq!(response.id, json!(11));
1081 assert!(response.result.is_some());
1082 assert!(response.error.is_none());
1083
1084 let result: CompleteResult = serde_json::from_value(response.result.unwrap()).unwrap();
1085 assert_eq!(result.completion.len(), 2);
1086 }
1087
1088 #[tokio::test]
1089 async fn test_handle_elicit() {
1090 let handler = create_test_handler().await;
1091 let request = Request {
1092 jsonrpc: "2.0".to_string(),
1093 method: "elicitation/create".to_string(),
1094 params: json!({
1095 "message": "Please provide your contact information",
1096 "requestedSchema": {
1097 "type": "object",
1098 "properties": {
1099 "name": {"type": "string", "description": "Your full name"},
1100 "email": {"type": "string", "format": "email"}
1101 },
1102 "required": ["name", "email"]
1103 }
1104 }),
1105 id: json!(12),
1106 };
1107
1108 let response = handler.handle_request(request).await.unwrap();
1109
1110 assert_eq!(response.jsonrpc, "2.0");
1111 assert_eq!(response.id, json!(12));
1112 assert!(response.result.is_some());
1113 assert!(response.error.is_none());
1114
1115 let result: ElicitationResult = serde_json::from_value(response.result.unwrap()).unwrap();
1116 assert!(matches!(result.response.action, ElicitationAction::Accept));
1117 assert!(result.response.data.is_some());
1118 }
1119
1120 #[tokio::test]
1121 async fn test_handle_ping() {
1122 let handler = create_test_handler().await;
1123 let request = Request {
1124 jsonrpc: "2.0".to_string(),
1125 method: "ping".to_string(),
1126 params: json!({}),
1127 id: json!(12),
1128 };
1129
1130 let response = handler.handle_request(request).await.unwrap();
1131
1132 assert_eq!(response.jsonrpc, "2.0");
1133 assert_eq!(response.id, json!(12));
1134 assert!(response.result.is_some());
1135 assert!(response.error.is_none());
1136 }
1137
1138 #[tokio::test]
1139 async fn test_handle_custom_method() {
1140 let handler = create_test_handler().await;
1141 let request = Request {
1142 jsonrpc: "2.0".to_string(),
1143 method: "custom/method".to_string(),
1144 params: json!({"test": "data"}),
1145 id: json!(13),
1146 };
1147
1148 let response = handler.handle_request(request).await.unwrap();
1149
1150 assert_eq!(response.jsonrpc, "2.0");
1151 assert_eq!(response.id, json!(13));
1152 assert!(response.result.is_some());
1153 assert!(response.error.is_none());
1154
1155 let result = response.result.unwrap();
1156 assert_eq!(result["method"], "custom/method");
1157 }
1158
1159 #[tokio::test]
1160 async fn test_backend_error_handling() {
1161 let handler = create_error_handler().await;
1162 let request = Request {
1163 jsonrpc: "2.0".to_string(),
1164 method: "tools/list".to_string(),
1165 params: json!({}),
1166 id: json!(14),
1167 };
1168
1169 let response = handler.handle_request(request).await.unwrap();
1170
1171 assert_eq!(response.jsonrpc, "2.0");
1172 assert_eq!(response.id, json!(14));
1173 assert!(response.result.is_none());
1174 assert!(response.error.is_some());
1175
1176 let error = response.error.unwrap();
1177 assert!(error.message.contains("Simulated error"));
1178 }
1179
1180 #[tokio::test]
1181 async fn test_invalid_params() {
1182 let handler = create_test_handler().await;
1183 let request = Request {
1184 jsonrpc: "2.0".to_string(),
1185 method: "tools/call".to_string(),
1186 params: json!("invalid"), id: json!(15),
1188 };
1189
1190 let response = handler.handle_request(request).await.unwrap();
1191
1192 assert_eq!(response.jsonrpc, "2.0");
1193 assert_eq!(response.id, json!(15));
1194 assert!(response.result.is_none());
1195 assert!(response.error.is_some());
1196 }
1197
1198 #[test]
1199 fn test_handler_error_classification() {
1200 let auth_error = HandlerError::Authentication("Invalid token".to_string());
1201 assert_eq!(auth_error.error_type(), "authentication");
1202 assert!(!auth_error.is_retryable());
1203 assert!(!auth_error.is_timeout());
1204 assert!(auth_error.is_auth_error());
1205 assert!(!auth_error.is_connection_error());
1206
1207 let backend_error = HandlerError::Backend("Database error".to_string());
1208 assert_eq!(backend_error.error_type(), "backend");
1209 assert!(backend_error.is_retryable());
1210 assert!(!backend_error.is_timeout());
1211 assert!(!backend_error.is_auth_error());
1212 assert!(!backend_error.is_connection_error());
1213
1214 let protocol_error =
1215 HandlerError::Protocol(Error::invalid_request("Bad request".to_string()));
1216 assert_eq!(protocol_error.error_type(), "protocol");
1217 assert!(!protocol_error.is_retryable());
1218 assert!(!protocol_error.is_timeout());
1219 assert!(!protocol_error.is_auth_error());
1220 assert!(!protocol_error.is_connection_error());
1221 }
1222
1223 #[test]
1224 fn test_handler_error_conversion() {
1225 let auth_error = HandlerError::Authentication("Auth failed".to_string());
1226 let protocol_error: Error = auth_error.into();
1227 assert_eq!(protocol_error.code, ErrorCode::Unauthorized);
1228
1229 let backend_error = HandlerError::Backend("Backend failed".to_string());
1230 let protocol_error: Error = backend_error.into();
1231 assert_eq!(protocol_error.code, ErrorCode::InternalError);
1232 }
1233
1234 #[test]
1235 fn test_handler_error_display() {
1236 let error = HandlerError::Authentication("Test auth error".to_string());
1237 assert_eq!(error.to_string(), "Authentication failed: Test auth error");
1238
1239 let error = HandlerError::Authorization("Test auth error".to_string());
1240 assert_eq!(error.to_string(), "Authorization failed: Test auth error");
1241
1242 let error = HandlerError::Backend("Test backend error".to_string());
1243 assert_eq!(error.to_string(), "Backend error: Test backend error");
1244 }
1245}