1use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_logging::{get_metrics, spans};
6use pulseengine_mcp_protocol::*;
7
8use std::sync::Arc;
9use std::time::Instant;
10use thiserror::Error;
11use tracing::{debug, error, info, instrument};
12
13#[derive(Debug, Error)]
15pub enum HandlerError {
16 #[error("Authentication failed: {0}")]
17 Authentication(String),
18
19 #[error("Authorization failed: {0}")]
20 Authorization(String),
21
22 #[error("Backend error: {0}")]
23 Backend(String),
24
25 #[error("Protocol error: {0}")]
26 Protocol(#[from] Error),
27}
28
29impl pulseengine_mcp_logging::ErrorClassification for HandlerError {
31 fn error_type(&self) -> &str {
32 match self {
33 HandlerError::Authentication(_) => "authentication",
34 HandlerError::Authorization(_) => "authorization",
35 HandlerError::Backend(_) => "backend",
36 HandlerError::Protocol(_) => "protocol",
37 }
38 }
39
40 fn is_retryable(&self) -> bool {
41 match self {
42 HandlerError::Backend(_) => true, _ => false,
44 }
45 }
46
47 fn is_timeout(&self) -> bool {
48 false }
50
51 fn is_auth_error(&self) -> bool {
52 matches!(
53 self,
54 HandlerError::Authentication(_) | HandlerError::Authorization(_)
55 )
56 }
57
58 fn is_connection_error(&self) -> bool {
59 false }
61}
62
63#[derive(Clone)]
65pub struct GenericServerHandler<B: McpBackend> {
66 backend: Arc<B>,
67 #[allow(dead_code)]
68 auth_manager: Arc<AuthenticationManager>,
69 middleware: MiddlewareStack,
70}
71
72impl<B: McpBackend> GenericServerHandler<B> {
73 pub fn new(
75 backend: Arc<B>,
76 auth_manager: Arc<AuthenticationManager>,
77 middleware: MiddlewareStack,
78 ) -> Self {
79 Self {
80 backend,
81 auth_manager,
82 middleware,
83 }
84 }
85
86 #[instrument(skip(self, request), fields(mcp.method = %request.method, mcp.request_id = %request.id))]
88 pub async fn handle_request(
89 &self,
90 request: Request,
91 ) -> std::result::Result<Response, HandlerError> {
92 let start_time = Instant::now();
93 let method = request.method.clone();
94 debug!("Handling request: {}", method);
95
96 let request_id = request.id.clone();
98
99 let context = RequestContext::new();
101
102 let metrics = get_metrics();
104
105 metrics.record_request_start(&method).await;
107
108 let request = self.middleware.process_request(request, &context).await?;
110
111 let result = {
113 let span = spans::mcp_request_span(&method, &request_id.to_string());
114 let _guard = span.enter();
115
116 match request.method.as_str() {
117 "initialize" => self.handle_initialize(request).await,
118 "tools/list" => self.handle_list_tools(request).await,
119 "tools/call" => self.handle_call_tool(request).await,
120 "resources/list" => self.handle_list_resources(request).await,
121 "resources/read" => self.handle_read_resource(request).await,
122 "resources/templates/list" => self.handle_list_resource_templates(request).await,
123 "prompts/list" => self.handle_list_prompts(request).await,
124 "prompts/get" => self.handle_get_prompt(request).await,
125 "resources/subscribe" => self.handle_subscribe(request).await,
126 "resources/unsubscribe" => self.handle_unsubscribe(request).await,
127 "completion/complete" => self.handle_complete(request).await,
128 "elicitation/create" => self.handle_elicit(request).await,
129 "logging/setLevel" => self.handle_set_level(request).await,
130 "ping" => self.handle_ping(request).await,
131 _ => self.handle_custom_method(request).await,
132 }
133 };
134
135 let duration = start_time.elapsed();
137
138 match result {
139 Ok(response) => {
140 metrics.record_request_end(&method, duration, true).await;
142
143 let response = self.middleware.process_response(response, &context).await?;
145
146 info!(
147 method = %method,
148 duration_ms = %duration.as_millis(),
149 request_id = ?request_id,
150 "Request completed successfully"
151 );
152
153 Ok(response)
154 }
155 Err(error) => {
156 metrics.record_request_end(&method, duration, false).await;
158
159 metrics
161 .record_error(&method, &context.request_id.to_string(), &error, duration)
162 .await;
163
164 error!(
165 method = %method,
166 duration_ms = %duration.as_millis(),
167 request_id = ?request_id,
168 error = %error,
169 "Request failed"
170 );
171
172 Ok(Response {
173 jsonrpc: "2.0".to_string(),
174 id: request_id,
175 result: None,
176 error: Some(error),
177 })
178 }
179 }
180 }
181
182 #[instrument(skip(self, request), fields(mcp.method = "initialize"))]
183 async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
184 let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
185
186 let server_info = self.backend.get_server_info();
187 let result = InitializeResult {
188 protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
189 capabilities: server_info.capabilities,
190 server_info: server_info.server_info.clone(),
191 instructions: server_info.instructions,
192 };
193
194 Ok(Response {
195 jsonrpc: "2.0".to_string(),
196 id: request.id,
197 result: Some(serde_json::to_value(result)?),
198 error: None,
199 })
200 }
201
202 #[instrument(skip(self, request), fields(mcp.method = "tools/list"))]
203 async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
204 let params: PaginatedRequestParam = if request.params.is_null() {
205 PaginatedRequestParam { cursor: None }
206 } else {
207 serde_json::from_value(request.params)?
208 };
209
210 let result = self
211 .backend
212 .list_tools(params)
213 .await
214 .map_err(|e| e.into())?;
215
216 Ok(Response {
217 jsonrpc: "2.0".to_string(),
218 id: request.id,
219 result: Some(serde_json::to_value(result)?),
220 error: None,
221 })
222 }
223
224 #[instrument(skip(self, request), fields(mcp.method = "tools/call"))]
225 async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
226 let params: CallToolRequestParam = serde_json::from_value(request.params)?;
227 let tool_name = params.name.clone();
228 let start_time = Instant::now();
229
230 let metrics = get_metrics();
232 metrics.record_request_start(&tool_name).await;
233
234 let result = {
235 let span = spans::backend_operation_span("call_tool", Some(&tool_name));
236 let _guard = span.enter();
237 match self.backend.call_tool(params).await {
238 Ok(result) => {
239 let duration = start_time.elapsed();
240 metrics.record_request_end(&tool_name, duration, true).await;
241 info!(
242 tool = %tool_name,
243 duration_ms = %duration.as_millis(),
244 "Tool call completed successfully"
245 );
246 result
247 }
248 Err(err) => {
249 let duration = start_time.elapsed();
250 metrics
251 .record_request_end(&tool_name, duration, false)
252 .await;
253 error!(
254 tool = %tool_name,
255 duration_ms = %duration.as_millis(),
256 error = %err,
257 "Tool call failed"
258 );
259 return Err(err.into());
260 }
261 }
262 };
263
264 Ok(Response {
265 jsonrpc: "2.0".to_string(),
266 id: request.id,
267 result: Some(serde_json::to_value(result)?),
268 error: None,
269 })
270 }
271
272 async fn handle_list_resources(
273 &self,
274 request: Request,
275 ) -> std::result::Result<Response, Error> {
276 let params: PaginatedRequestParam = if request.params.is_null() {
277 PaginatedRequestParam { cursor: None }
278 } else {
279 serde_json::from_value(request.params)?
280 };
281
282 let result = self
283 .backend
284 .list_resources(params)
285 .await
286 .map_err(|e| e.into())?;
287
288 Ok(Response {
289 jsonrpc: "2.0".to_string(),
290 id: request.id,
291 result: Some(serde_json::to_value(result)?),
292 error: None,
293 })
294 }
295
296 async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
297 let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
298
299 let result = self
300 .backend
301 .read_resource(params)
302 .await
303 .map_err(|e| e.into())?;
304
305 Ok(Response {
306 jsonrpc: "2.0".to_string(),
307 id: request.id,
308 result: Some(serde_json::to_value(result)?),
309 error: None,
310 })
311 }
312
313 async fn handle_list_resource_templates(
314 &self,
315 request: Request,
316 ) -> std::result::Result<Response, Error> {
317 let params: PaginatedRequestParam = if request.params.is_null() {
318 PaginatedRequestParam { cursor: None }
319 } else {
320 serde_json::from_value(request.params)?
321 };
322
323 let result = self
324 .backend
325 .list_resource_templates(params)
326 .await
327 .map_err(|e| e.into())?;
328
329 Ok(Response {
330 jsonrpc: "2.0".to_string(),
331 id: request.id,
332 result: Some(serde_json::to_value(result)?),
333 error: None,
334 })
335 }
336
337 async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
338 let params: PaginatedRequestParam = if request.params.is_null() {
339 PaginatedRequestParam { cursor: None }
340 } else {
341 serde_json::from_value(request.params)?
342 };
343
344 let result = self
345 .backend
346 .list_prompts(params)
347 .await
348 .map_err(|e| e.into())?;
349
350 Ok(Response {
351 jsonrpc: "2.0".to_string(),
352 id: request.id,
353 result: Some(serde_json::to_value(result)?),
354 error: None,
355 })
356 }
357
358 async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
359 let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
360
361 let result = self
362 .backend
363 .get_prompt(params)
364 .await
365 .map_err(|e| e.into())?;
366
367 Ok(Response {
368 jsonrpc: "2.0".to_string(),
369 id: request.id,
370 result: Some(serde_json::to_value(result)?),
371 error: None,
372 })
373 }
374
375 async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
376 let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
377
378 self.backend.subscribe(params).await.map_err(|e| e.into())?;
379
380 Ok(Response {
381 jsonrpc: "2.0".to_string(),
382 id: request.id,
383 result: Some(serde_json::Value::Object(Default::default())),
384 error: None,
385 })
386 }
387
388 async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
389 let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
390
391 self.backend
392 .unsubscribe(params)
393 .await
394 .map_err(|e| e.into())?;
395
396 Ok(Response {
397 jsonrpc: "2.0".to_string(),
398 id: request.id,
399 result: Some(serde_json::Value::Object(Default::default())),
400 error: None,
401 })
402 }
403
404 async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
405 let params: CompleteRequestParam = serde_json::from_value(request.params)?;
406
407 let result = self.backend.complete(params).await.map_err(|e| e.into())?;
408
409 Ok(Response {
410 jsonrpc: "2.0".to_string(),
411 id: request.id,
412 result: Some(serde_json::to_value(result)?),
413 error: None,
414 })
415 }
416
417 async fn handle_elicit(&self, request: Request) -> std::result::Result<Response, Error> {
418 let params: ElicitationRequestParam = serde_json::from_value(request.params)?;
419
420 let result = self.backend.elicit(params).await.map_err(|e| e.into())?;
421
422 Ok(Response {
423 jsonrpc: "2.0".to_string(),
424 id: request.id,
425 result: Some(serde_json::to_value(result)?),
426 error: None,
427 })
428 }
429
430 async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
431 let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
432
433 self.backend.set_level(params).await.map_err(|e| e.into())?;
434
435 Ok(Response {
436 jsonrpc: "2.0".to_string(),
437 id: request.id,
438 result: Some(serde_json::Value::Object(Default::default())),
439 error: None,
440 })
441 }
442
443 async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
444 Ok(Response {
445 jsonrpc: "2.0".to_string(),
446 id: _request.id,
447 result: Some(serde_json::Value::Object(Default::default())),
448 error: None,
449 })
450 }
451
452 async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
453 let result = self
454 .backend
455 .handle_custom_method(&request.method, request.params)
456 .await
457 .map_err(|e| e.into())?;
458
459 Ok(Response {
460 jsonrpc: "2.0".to_string(),
461 id: request.id,
462 result: Some(result),
463 error: None,
464 })
465 }
466}
467
468impl From<HandlerError> for Error {
470 fn from(err: HandlerError) -> Self {
471 match err {
472 HandlerError::Authentication(msg) => Error::unauthorized(msg),
473 HandlerError::Authorization(msg) => Error::forbidden(msg),
474 HandlerError::Backend(msg) => Error::internal_error(msg),
475 HandlerError::Protocol(e) => e,
476 }
477 }
478}
479
480#[cfg(test)]
481mod tests {
482 use super::*;
483 use crate::backend::McpBackend;
484 use crate::middleware::MiddlewareStack;
485 use async_trait::async_trait;
486 use pulseengine_mcp_auth::AuthenticationManager;
487 use pulseengine_mcp_auth::config::AuthConfig;
488 use pulseengine_mcp_logging::ErrorClassification;
489 use pulseengine_mcp_protocol::{
490 CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult, CompletionInfo,
491 Content, Error, GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult,
492 ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, ListToolsResult,
493 LoggingCapability, PaginatedRequestParam, Prompt, PromptMessage, PromptMessageContent,
494 PromptMessageRole, PromptsCapability, ProtocolVersion, ReadResourceRequestParam,
495 ReadResourceResult, Request, Resource, ResourceContents, ResourcesCapability,
496 ServerCapabilities, ServerInfo, SetLevelRequestParam, SubscribeRequestParam, Tool,
497 ToolsCapability, UnsubscribeRequestParam, error::ErrorCode,
498 };
499 use serde_json::json;
500 use std::sync::Arc;
501
502 #[derive(Clone)]
504 struct MockBackend {
505 server_info: ServerInfo,
506 tools: Vec<Tool>,
507 resources: Vec<Resource>,
508 prompts: Vec<Prompt>,
509 should_error: bool,
510 }
511
512 impl MockBackend {
513 fn new() -> Self {
514 Self {
515 server_info: ServerInfo {
516 protocol_version: ProtocolVersion::default(),
517 capabilities: ServerCapabilities {
518 tools: Some(ToolsCapability { list_changed: None }),
519 resources: Some(ResourcesCapability {
520 subscribe: Some(true),
521 list_changed: None,
522 }),
523 prompts: Some(PromptsCapability { list_changed: None }),
524 logging: Some(LoggingCapability { level: None }),
525 sampling: None,
526 elicitation: Some(ElicitationCapability {}),
527 },
528 server_info: Implementation {
529 name: "test-server".to_string(),
530 version: "1.0.0".to_string(),
531 },
532 instructions: None,
533 },
534 tools: vec![Tool {
535 name: "test_tool".to_string(),
536 description: "A test tool".to_string(),
537 input_schema: json!({
538 "type": "object",
539 "properties": {
540 "input": {"type": "string"}
541 }
542 }),
543 output_schema: None,
544 }],
545 resources: vec![Resource {
546 uri: "test://resource1".to_string(),
547 name: "Test Resource".to_string(),
548 description: Some("A test resource".to_string()),
549 mime_type: Some("text/plain".to_string()),
550 annotations: None,
551 raw: None,
552 }],
553 prompts: vec![Prompt {
554 name: "test_prompt".to_string(),
555 description: Some("A test prompt".to_string()),
556 arguments: None,
557 }],
558 should_error: false,
559 }
560 }
561
562 fn with_error() -> Self {
563 Self {
564 should_error: true,
565 ..Self::new()
566 }
567 }
568 }
569
570 #[async_trait]
571 impl McpBackend for MockBackend {
572 type Error = MockBackendError;
573 type Config = ();
574
575 async fn initialize(_config: Self::Config) -> std::result::Result<Self, Self::Error> {
576 Ok(MockBackend::new())
577 }
578
579 fn get_server_info(&self) -> ServerInfo {
580 self.server_info.clone()
581 }
582
583 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
584 if self.should_error {
585 return Err(MockBackendError::TestError(
586 "Health check failed".to_string(),
587 ));
588 }
589 Ok(())
590 }
591
592 async fn list_tools(
593 &self,
594 _params: PaginatedRequestParam,
595 ) -> std::result::Result<ListToolsResult, Self::Error> {
596 if self.should_error {
597 return Err(MockBackendError::TestError("Simulated error".to_string()));
598 }
599
600 Ok(ListToolsResult {
601 tools: self.tools.clone(),
602 next_cursor: None,
603 })
604 }
605
606 async fn call_tool(
607 &self,
608 params: CallToolRequestParam,
609 ) -> std::result::Result<CallToolResult, Self::Error> {
610 if self.should_error {
611 return Err(MockBackendError::TestError("Tool call failed".to_string()));
612 }
613
614 if params.name == "test_tool" {
615 Ok(CallToolResult {
616 content: vec![Content::Text {
617 text: "Tool executed successfully".to_string(),
618 }],
619 is_error: Some(false),
620 structured_content: None,
621 })
622 } else {
623 Err(MockBackendError::TestError("Tool not found".to_string()))
624 }
625 }
626
627 async fn list_resources(
628 &self,
629 _params: PaginatedRequestParam,
630 ) -> std::result::Result<ListResourcesResult, Self::Error> {
631 if self.should_error {
632 return Err(MockBackendError::TestError("Simulated error".to_string()));
633 }
634
635 Ok(ListResourcesResult {
636 resources: self.resources.clone(),
637 next_cursor: None,
638 })
639 }
640
641 async fn read_resource(
642 &self,
643 params: ReadResourceRequestParam,
644 ) -> std::result::Result<ReadResourceResult, Self::Error> {
645 if self.should_error {
646 return Err(MockBackendError::TestError("Simulated error".to_string()));
647 }
648
649 if params.uri == "test://resource1" {
650 Ok(ReadResourceResult {
651 contents: vec![ResourceContents {
652 uri: params.uri,
653 mime_type: Some("text/plain".to_string()),
654 text: Some("Resource content".to_string()),
655 blob: None,
656 }],
657 })
658 } else {
659 Err(MockBackendError::TestError(
660 "Resource not found".to_string(),
661 ))
662 }
663 }
664
665 async fn list_resource_templates(
666 &self,
667 _params: PaginatedRequestParam,
668 ) -> std::result::Result<ListResourceTemplatesResult, Self::Error> {
669 Ok(ListResourceTemplatesResult {
670 resource_templates: vec![],
671 next_cursor: None,
672 })
673 }
674
675 async fn list_prompts(
676 &self,
677 _params: PaginatedRequestParam,
678 ) -> std::result::Result<ListPromptsResult, Self::Error> {
679 if self.should_error {
680 return Err(MockBackendError::TestError("Simulated error".to_string()));
681 }
682
683 Ok(ListPromptsResult {
684 prompts: self.prompts.clone(),
685 next_cursor: None,
686 })
687 }
688
689 async fn get_prompt(
690 &self,
691 params: GetPromptRequestParam,
692 ) -> std::result::Result<GetPromptResult, Self::Error> {
693 if self.should_error {
694 return Err(MockBackendError::TestError("Simulated error".to_string()));
695 }
696
697 if params.name == "test_prompt" {
698 Ok(GetPromptResult {
699 description: Some("A test prompt".to_string()),
700 messages: vec![PromptMessage {
701 role: PromptMessageRole::User,
702 content: PromptMessageContent::Text {
703 text: "Test prompt message".to_string(),
704 },
705 }],
706 })
707 } else {
708 Err(MockBackendError::TestError("Prompt not found".to_string()))
709 }
710 }
711
712 async fn subscribe(
713 &self,
714 _params: SubscribeRequestParam,
715 ) -> std::result::Result<(), Self::Error> {
716 if self.should_error {
717 return Err(MockBackendError::TestError("Subscribe failed".to_string()));
718 }
719 Ok(())
720 }
721
722 async fn unsubscribe(
723 &self,
724 _params: UnsubscribeRequestParam,
725 ) -> std::result::Result<(), Self::Error> {
726 if self.should_error {
727 return Err(MockBackendError::TestError(
728 "Unsubscribe failed".to_string(),
729 ));
730 }
731 Ok(())
732 }
733
734 async fn complete(
735 &self,
736 _params: CompleteRequestParam,
737 ) -> std::result::Result<CompleteResult, Self::Error> {
738 if self.should_error {
739 return Err(MockBackendError::TestError("Complete failed".to_string()));
740 }
741
742 Ok(CompleteResult {
743 completion: vec![
744 CompletionInfo {
745 completion: "completion1".to_string(),
746 has_more: Some(false),
747 },
748 CompletionInfo {
749 completion: "completion2".to_string(),
750 has_more: Some(false),
751 },
752 ],
753 })
754 }
755
756 async fn elicit(
757 &self,
758 _params: ElicitationRequestParam,
759 ) -> std::result::Result<ElicitationResult, Self::Error> {
760 if self.should_error {
761 return Err(MockBackendError::TestError(
762 "Elicitation failed".to_string(),
763 ));
764 }
765
766 Ok(ElicitationResult::accept(serde_json::json!({
768 "name": "Test User",
769 "email": "test@example.com"
770 })))
771 }
772
773 async fn set_level(
774 &self,
775 _params: SetLevelRequestParam,
776 ) -> std::result::Result<(), Self::Error> {
777 if self.should_error {
778 return Err(MockBackendError::TestError("Set level failed".to_string()));
779 }
780 Ok(())
781 }
782
783 async fn handle_custom_method(
784 &self,
785 method: &str,
786 _params: serde_json::Value,
787 ) -> std::result::Result<serde_json::Value, Self::Error> {
788 if self.should_error {
789 return Err(MockBackendError::TestError(
790 "Custom method failed".to_string(),
791 ));
792 }
793
794 Ok(json!({
795 "method": method,
796 "result": "custom method executed"
797 }))
798 }
799 }
800
801 #[derive(Debug, thiserror::Error)]
802 enum MockBackendError {
803 #[error("Test error: {0}")]
804 TestError(String),
805 }
806
807 impl From<MockBackendError> for Error {
808 fn from(err: MockBackendError) -> Self {
809 Error::internal_error(err.to_string())
810 }
811 }
812
813 impl From<crate::backend::BackendError> for MockBackendError {
814 fn from(error: crate::backend::BackendError) -> Self {
815 MockBackendError::TestError(error.to_string())
816 }
817 }
818
819 async fn create_test_handler() -> GenericServerHandler<MockBackend> {
820 let backend = Arc::new(MockBackend::new());
821 let auth_config = AuthConfig::memory();
822 let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
823 let middleware = MiddlewareStack::new();
824
825 GenericServerHandler::new(backend, auth_manager, middleware)
826 }
827
828 async fn create_error_handler() -> GenericServerHandler<MockBackend> {
829 let backend = Arc::new(MockBackend::with_error());
830 let auth_config = AuthConfig::memory();
831 let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
832 let middleware = MiddlewareStack::new();
833
834 GenericServerHandler::new(backend, auth_manager, middleware)
835 }
836
837 #[tokio::test]
838 async fn test_handler_creation() {
839 let handler = create_test_handler().await;
840 assert!(!handler.backend.tools.is_empty());
842 }
843
844 #[tokio::test]
845 async fn test_handle_initialize() {
846 let handler = create_test_handler().await;
847 let request = Request {
848 jsonrpc: "2.0".to_string(),
849 method: "initialize".to_string(),
850 params: json!({
851 "protocolVersion": "2024-11-05",
852 "capabilities": {},
853 "clientInfo": {
854 "name": "test-client",
855 "version": "1.0.0"
856 }
857 }),
858 id: json!(1),
859 };
860
861 let response = handler.handle_request(request).await.unwrap();
862
863 assert_eq!(response.jsonrpc, "2.0");
864 assert_eq!(response.id, json!(1));
865 assert!(response.result.is_some());
866 assert!(response.error.is_none());
867
868 let result: InitializeResult = serde_json::from_value(response.result.unwrap()).unwrap();
869 assert_eq!(
870 result.protocol_version,
871 pulseengine_mcp_protocol::MCP_VERSION
872 );
873 assert_eq!(result.server_info.name, "test-server");
874 }
875
876 #[tokio::test]
877 async fn test_handle_list_tools() {
878 let handler = create_test_handler().await;
879 let request = Request {
880 jsonrpc: "2.0".to_string(),
881 method: "tools/list".to_string(),
882 params: json!({}),
883 id: json!(2),
884 };
885
886 let response = handler.handle_request(request).await.unwrap();
887
888 assert_eq!(response.jsonrpc, "2.0");
889 assert_eq!(response.id, json!(2));
890 assert!(response.result.is_some());
891 assert!(response.error.is_none());
892
893 let result: ListToolsResult = serde_json::from_value(response.result.unwrap()).unwrap();
894 assert_eq!(result.tools.len(), 1);
895 assert_eq!(result.tools[0].name, "test_tool");
896 }
897
898 #[tokio::test]
899 async fn test_handle_call_tool_success() {
900 let handler = create_test_handler().await;
901 let request = Request {
902 jsonrpc: "2.0".to_string(),
903 method: "tools/call".to_string(),
904 params: json!({
905 "name": "test_tool",
906 "arguments": {
907 "input": "test input"
908 }
909 }),
910 id: json!(3),
911 };
912
913 let response = handler.handle_request(request).await.unwrap();
914
915 assert_eq!(response.jsonrpc, "2.0");
916 assert_eq!(response.id, json!(3));
917 assert!(response.result.is_some());
918 assert!(response.error.is_none());
919
920 let result: CallToolResult = serde_json::from_value(response.result.unwrap()).unwrap();
921 assert_eq!(result.content.len(), 1);
922 assert!(!result.is_error.unwrap_or(true));
923 }
924
925 #[tokio::test]
926 async fn test_handle_call_tool_not_found() {
927 let handler = create_test_handler().await;
928 let request = Request {
929 jsonrpc: "2.0".to_string(),
930 method: "tools/call".to_string(),
931 params: json!({
932 "name": "nonexistent_tool",
933 "arguments": {}
934 }),
935 id: json!(4),
936 };
937
938 let response = handler.handle_request(request).await.unwrap();
939
940 assert_eq!(response.jsonrpc, "2.0");
941 assert_eq!(response.id, json!(4));
942 assert!(response.result.is_none());
943 assert!(response.error.is_some());
944 }
945
946 #[tokio::test]
947 async fn test_handle_list_resources() {
948 let handler = create_test_handler().await;
949 let request = Request {
950 jsonrpc: "2.0".to_string(),
951 method: "resources/list".to_string(),
952 params: json!({}),
953 id: json!(5),
954 };
955
956 let response = handler.handle_request(request).await.unwrap();
957
958 assert_eq!(response.jsonrpc, "2.0");
959 assert_eq!(response.id, json!(5));
960 assert!(response.result.is_some());
961 assert!(response.error.is_none());
962
963 let result: ListResourcesResult = serde_json::from_value(response.result.unwrap()).unwrap();
964 assert_eq!(result.resources.len(), 1);
965 assert_eq!(result.resources[0].uri, "test://resource1");
966 }
967
968 #[tokio::test]
969 async fn test_handle_read_resource() {
970 let handler = create_test_handler().await;
971 let request = Request {
972 jsonrpc: "2.0".to_string(),
973 method: "resources/read".to_string(),
974 params: json!({
975 "uri": "test://resource1"
976 }),
977 id: json!(6),
978 };
979
980 let response = handler.handle_request(request).await.unwrap();
981
982 assert_eq!(response.jsonrpc, "2.0");
983 assert_eq!(response.id, json!(6));
984 assert!(response.result.is_some());
985 assert!(response.error.is_none());
986
987 let result: ReadResourceResult = serde_json::from_value(response.result.unwrap()).unwrap();
988 assert_eq!(result.contents.len(), 1);
989 }
990
991 #[tokio::test]
992 async fn test_handle_list_prompts() {
993 let handler = create_test_handler().await;
994 let request = Request {
995 jsonrpc: "2.0".to_string(),
996 method: "prompts/list".to_string(),
997 params: json!({}),
998 id: json!(7),
999 };
1000
1001 let response = handler.handle_request(request).await.unwrap();
1002
1003 assert_eq!(response.jsonrpc, "2.0");
1004 assert_eq!(response.id, json!(7));
1005 assert!(response.result.is_some());
1006 assert!(response.error.is_none());
1007
1008 let result: ListPromptsResult = serde_json::from_value(response.result.unwrap()).unwrap();
1009 assert_eq!(result.prompts.len(), 1);
1010 assert_eq!(result.prompts[0].name, "test_prompt");
1011 }
1012
1013 #[tokio::test]
1014 async fn test_handle_get_prompt() {
1015 let handler = create_test_handler().await;
1016 let request = Request {
1017 jsonrpc: "2.0".to_string(),
1018 method: "prompts/get".to_string(),
1019 params: json!({
1020 "name": "test_prompt",
1021 "arguments": {}
1022 }),
1023 id: json!(8),
1024 };
1025
1026 let response = handler.handle_request(request).await.unwrap();
1027
1028 assert_eq!(response.jsonrpc, "2.0");
1029 assert_eq!(response.id, json!(8));
1030 assert!(response.result.is_some());
1031 assert!(response.error.is_none());
1032
1033 let result: GetPromptResult = serde_json::from_value(response.result.unwrap()).unwrap();
1034 assert_eq!(result.messages.len(), 1);
1035 }
1036
1037 #[tokio::test]
1038 async fn test_handle_subscribe() {
1039 let handler = create_test_handler().await;
1040 let request = Request {
1041 jsonrpc: "2.0".to_string(),
1042 method: "resources/subscribe".to_string(),
1043 params: json!({
1044 "uri": "test://resource1"
1045 }),
1046 id: json!(9),
1047 };
1048
1049 let response = handler.handle_request(request).await.unwrap();
1050
1051 assert_eq!(response.jsonrpc, "2.0");
1052 assert_eq!(response.id, json!(9));
1053 assert!(response.result.is_some());
1054 assert!(response.error.is_none());
1055 }
1056
1057 #[tokio::test]
1058 async fn test_handle_unsubscribe() {
1059 let handler = create_test_handler().await;
1060 let request = Request {
1061 jsonrpc: "2.0".to_string(),
1062 method: "resources/unsubscribe".to_string(),
1063 params: json!({
1064 "uri": "test://resource1"
1065 }),
1066 id: json!(10),
1067 };
1068
1069 let response = handler.handle_request(request).await.unwrap();
1070
1071 assert_eq!(response.jsonrpc, "2.0");
1072 assert_eq!(response.id, json!(10));
1073 assert!(response.result.is_some());
1074 assert!(response.error.is_none());
1075 }
1076
1077 #[tokio::test]
1078 async fn test_handle_complete() {
1079 let handler = create_test_handler().await;
1080 let request = Request {
1081 jsonrpc: "2.0".to_string(),
1082 method: "completion/complete".to_string(),
1083 params: json!({
1084 "ref_": "test_prompt",
1085 "argument": {
1086 "name": "query",
1087 "value": "test"
1088 }
1089 }),
1090 id: json!(11),
1091 };
1092
1093 let response = handler.handle_request(request).await.unwrap();
1094
1095 assert_eq!(response.jsonrpc, "2.0");
1096 assert_eq!(response.id, json!(11));
1097 assert!(response.result.is_some());
1098 assert!(response.error.is_none());
1099
1100 let result: CompleteResult = serde_json::from_value(response.result.unwrap()).unwrap();
1101 assert_eq!(result.completion.len(), 2);
1102 }
1103
1104 #[tokio::test]
1105 async fn test_handle_elicit() {
1106 let handler = create_test_handler().await;
1107 let request = Request {
1108 jsonrpc: "2.0".to_string(),
1109 method: "elicitation/create".to_string(),
1110 params: json!({
1111 "message": "Please provide your contact information",
1112 "requestedSchema": {
1113 "type": "object",
1114 "properties": {
1115 "name": {"type": "string", "description": "Your full name"},
1116 "email": {"type": "string", "format": "email"}
1117 },
1118 "required": ["name", "email"]
1119 }
1120 }),
1121 id: json!(12),
1122 };
1123
1124 let response = handler.handle_request(request).await.unwrap();
1125
1126 assert_eq!(response.jsonrpc, "2.0");
1127 assert_eq!(response.id, json!(12));
1128 assert!(response.result.is_some());
1129 assert!(response.error.is_none());
1130
1131 let result: ElicitationResult = serde_json::from_value(response.result.unwrap()).unwrap();
1132 assert!(matches!(result.response.action, ElicitationAction::Accept));
1133 assert!(result.response.data.is_some());
1134 }
1135
1136 #[tokio::test]
1137 async fn test_handle_ping() {
1138 let handler = create_test_handler().await;
1139 let request = Request {
1140 jsonrpc: "2.0".to_string(),
1141 method: "ping".to_string(),
1142 params: json!({}),
1143 id: json!(12),
1144 };
1145
1146 let response = handler.handle_request(request).await.unwrap();
1147
1148 assert_eq!(response.jsonrpc, "2.0");
1149 assert_eq!(response.id, json!(12));
1150 assert!(response.result.is_some());
1151 assert!(response.error.is_none());
1152 }
1153
1154 #[tokio::test]
1155 async fn test_handle_custom_method() {
1156 let handler = create_test_handler().await;
1157 let request = Request {
1158 jsonrpc: "2.0".to_string(),
1159 method: "custom/method".to_string(),
1160 params: json!({"test": "data"}),
1161 id: json!(13),
1162 };
1163
1164 let response = handler.handle_request(request).await.unwrap();
1165
1166 assert_eq!(response.jsonrpc, "2.0");
1167 assert_eq!(response.id, json!(13));
1168 assert!(response.result.is_some());
1169 assert!(response.error.is_none());
1170
1171 let result = response.result.unwrap();
1172 assert_eq!(result["method"], "custom/method");
1173 }
1174
1175 #[tokio::test]
1176 async fn test_backend_error_handling() {
1177 let handler = create_error_handler().await;
1178 let request = Request {
1179 jsonrpc: "2.0".to_string(),
1180 method: "tools/list".to_string(),
1181 params: json!({}),
1182 id: json!(14),
1183 };
1184
1185 let response = handler.handle_request(request).await.unwrap();
1186
1187 assert_eq!(response.jsonrpc, "2.0");
1188 assert_eq!(response.id, json!(14));
1189 assert!(response.result.is_none());
1190 assert!(response.error.is_some());
1191
1192 let error = response.error.unwrap();
1193 assert!(error.message.contains("Simulated error"));
1194 }
1195
1196 #[tokio::test]
1197 async fn test_invalid_params() {
1198 let handler = create_test_handler().await;
1199 let request = Request {
1200 jsonrpc: "2.0".to_string(),
1201 method: "tools/call".to_string(),
1202 params: json!("invalid"), id: json!(15),
1204 };
1205
1206 let response = handler.handle_request(request).await.unwrap();
1207
1208 assert_eq!(response.jsonrpc, "2.0");
1209 assert_eq!(response.id, json!(15));
1210 assert!(response.result.is_none());
1211 assert!(response.error.is_some());
1212 }
1213
1214 #[test]
1215 fn test_handler_error_classification() {
1216 let auth_error = HandlerError::Authentication("Invalid token".to_string());
1217 assert_eq!(auth_error.error_type(), "authentication");
1218 assert!(!auth_error.is_retryable());
1219 assert!(!auth_error.is_timeout());
1220 assert!(auth_error.is_auth_error());
1221 assert!(!auth_error.is_connection_error());
1222
1223 let backend_error = HandlerError::Backend("Database error".to_string());
1224 assert_eq!(backend_error.error_type(), "backend");
1225 assert!(backend_error.is_retryable());
1226 assert!(!backend_error.is_timeout());
1227 assert!(!backend_error.is_auth_error());
1228 assert!(!backend_error.is_connection_error());
1229
1230 let protocol_error =
1231 HandlerError::Protocol(Error::invalid_request("Bad request".to_string()));
1232 assert_eq!(protocol_error.error_type(), "protocol");
1233 assert!(!protocol_error.is_retryable());
1234 assert!(!protocol_error.is_timeout());
1235 assert!(!protocol_error.is_auth_error());
1236 assert!(!protocol_error.is_connection_error());
1237 }
1238
1239 #[test]
1240 fn test_handler_error_conversion() {
1241 let auth_error = HandlerError::Authentication("Auth failed".to_string());
1242 let protocol_error: Error = auth_error.into();
1243 assert_eq!(protocol_error.code, ErrorCode::Unauthorized);
1244
1245 let backend_error = HandlerError::Backend("Backend failed".to_string());
1246 let protocol_error: Error = backend_error.into();
1247 assert_eq!(protocol_error.code, ErrorCode::InternalError);
1248 }
1249
1250 #[test]
1251 fn test_handler_error_display() {
1252 let error = HandlerError::Authentication("Test auth error".to_string());
1253 assert_eq!(error.to_string(), "Authentication failed: Test auth error");
1254
1255 let error = HandlerError::Authorization("Test auth error".to_string());
1256 assert_eq!(error.to_string(), "Authorization failed: Test auth error");
1257
1258 let error = HandlerError::Backend("Test backend error".to_string());
1259 assert_eq!(error.to_string(), "Backend error: Test backend error");
1260 }
1261}