1use crate::{backend::McpBackend, context::RequestContext, middleware::MiddlewareStack};
4use pulseengine_mcp_auth::AuthenticationManager;
5use pulseengine_mcp_logging::{get_metrics, spans};
6use pulseengine_mcp_protocol::*;
7
8use std::sync::Arc;
9use std::time::Instant;
10use thiserror::Error;
11use tracing::{debug, error, info, instrument};
12
13#[derive(Debug, Error)]
15pub enum HandlerError {
16 #[error("Authentication failed: {0}")]
17 Authentication(String),
18
19 #[error("Authorization failed: {0}")]
20 Authorization(String),
21
22 #[error("Backend error: {0}")]
23 Backend(String),
24
25 #[error("Protocol error: {0}")]
26 Protocol(#[from] Error),
27}
28
29impl pulseengine_mcp_logging::ErrorClassification for HandlerError {
31 fn error_type(&self) -> &str {
32 match self {
33 HandlerError::Authentication(_) => "authentication",
34 HandlerError::Authorization(_) => "authorization",
35 HandlerError::Backend(_) => "backend",
36 HandlerError::Protocol(_) => "protocol",
37 }
38 }
39
40 fn is_retryable(&self) -> bool {
41 match self {
42 HandlerError::Backend(_) => true, _ => false,
44 }
45 }
46
47 fn is_timeout(&self) -> bool {
48 false }
50
51 fn is_auth_error(&self) -> bool {
52 matches!(
53 self,
54 HandlerError::Authentication(_) | HandlerError::Authorization(_)
55 )
56 }
57
58 fn is_connection_error(&self) -> bool {
59 false }
61}
62
63#[derive(Clone)]
65pub struct GenericServerHandler<B: McpBackend> {
66 backend: Arc<B>,
67 #[allow(dead_code)]
68 auth_manager: Arc<AuthenticationManager>,
69 middleware: MiddlewareStack,
70}
71
72impl<B: McpBackend> GenericServerHandler<B> {
73 pub fn new(
75 backend: Arc<B>,
76 auth_manager: Arc<AuthenticationManager>,
77 middleware: MiddlewareStack,
78 ) -> Self {
79 Self {
80 backend,
81 auth_manager,
82 middleware,
83 }
84 }
85
86 #[instrument(skip(self, request), fields(mcp.method = %request.method, mcp.request_id = ?request.id))]
88 pub async fn handle_request(
89 &self,
90 request: Request,
91 ) -> std::result::Result<Response, HandlerError> {
92 let start_time = Instant::now();
93 let method = request.method.clone();
94 debug!("Handling request: {}", method);
95
96 let request_id = request.id.clone();
98
99 let context = RequestContext::new();
101
102 let metrics = get_metrics();
104
105 metrics.record_request_start(&method).await;
107
108 let request = self.middleware.process_request(request, &context).await?;
110
111 let result = {
113 let request_id_str = request_id
114 .as_ref()
115 .map(|id| id.to_string())
116 .unwrap_or_else(|| "none".to_string());
117 let span = spans::mcp_request_span(&method, &request_id_str);
118 let _guard = span.enter();
119
120 match request.method.as_str() {
121 "initialize" => self.handle_initialize(request).await,
122 "tools/list" => self.handle_list_tools(request).await,
123 "tools/call" => self.handle_call_tool(request).await,
124 "resources/list" => self.handle_list_resources(request).await,
125 "resources/read" => self.handle_read_resource(request).await,
126 "resources/templates/list" => self.handle_list_resource_templates(request).await,
127 "prompts/list" => self.handle_list_prompts(request).await,
128 "prompts/get" => self.handle_get_prompt(request).await,
129 "resources/subscribe" => self.handle_subscribe(request).await,
130 "resources/unsubscribe" => self.handle_unsubscribe(request).await,
131 "completion/complete" => self.handle_complete(request).await,
132 "elicitation/create" => self.handle_elicit(request).await,
133 "logging/setLevel" => self.handle_set_level(request).await,
134 "ping" => self.handle_ping(request).await,
135 _ => self.handle_custom_method(request).await,
136 }
137 };
138
139 let duration = start_time.elapsed();
141
142 match result {
143 Ok(response) => {
144 metrics.record_request_end(&method, duration, true).await;
146
147 let response = self.middleware.process_response(response, &context).await?;
149
150 info!(
151 method = %method,
152 duration_ms = %duration.as_millis(),
153 request_id = ?request_id,
154 "Request completed successfully"
155 );
156
157 Ok(response)
158 }
159 Err(error) => {
160 metrics.record_request_end(&method, duration, false).await;
162
163 metrics
165 .record_error(&method, &context.request_id.to_string(), &error, duration)
166 .await;
167
168 error!(
169 method = %method,
170 duration_ms = %duration.as_millis(),
171 request_id = ?request_id,
172 error = %error,
173 "Request failed"
174 );
175
176 Ok(Response {
177 jsonrpc: "2.0".to_string(),
178 id: request_id,
179 result: None,
180 error: Some(error),
181 })
182 }
183 }
184 }
185
186 #[instrument(skip(self, request), fields(mcp.method = "initialize"))]
187 async fn handle_initialize(&self, request: Request) -> std::result::Result<Response, Error> {
188 let _params: InitializeRequestParam = serde_json::from_value(request.params)?;
189
190 let server_info = self.backend.get_server_info();
191 let result = InitializeResult {
192 protocol_version: pulseengine_mcp_protocol::MCP_VERSION.to_string(),
193 capabilities: server_info.capabilities,
194 server_info: server_info.server_info.clone(),
195 instructions: server_info.instructions,
196 };
197
198 Ok(Response {
199 jsonrpc: "2.0".to_string(),
200 id: request.id,
201 result: Some(serde_json::to_value(result)?),
202 error: None,
203 })
204 }
205
206 #[instrument(skip(self, request), fields(mcp.method = "tools/list"))]
207 async fn handle_list_tools(&self, request: Request) -> std::result::Result<Response, Error> {
208 let params: PaginatedRequestParam = if request.params.is_null() {
209 PaginatedRequestParam { cursor: None }
210 } else {
211 serde_json::from_value(request.params)?
212 };
213
214 let result = self
215 .backend
216 .list_tools(params)
217 .await
218 .map_err(|e| e.into())?;
219
220 Ok(Response {
221 jsonrpc: "2.0".to_string(),
222 id: request.id,
223 result: Some(serde_json::to_value(result)?),
224 error: None,
225 })
226 }
227
228 #[instrument(skip(self, request), fields(mcp.method = "tools/call"))]
229 async fn handle_call_tool(&self, request: Request) -> std::result::Result<Response, Error> {
230 let params: CallToolRequestParam = serde_json::from_value(request.params)?;
231 let tool_name = params.name.clone();
232 let start_time = Instant::now();
233
234 let metrics = get_metrics();
236 metrics.record_request_start(&tool_name).await;
237
238 let result = {
239 let span = spans::backend_operation_span("call_tool", Some(&tool_name));
240 let _guard = span.enter();
241 match self.backend.call_tool(params).await {
242 Ok(result) => {
243 let duration = start_time.elapsed();
244 metrics.record_request_end(&tool_name, duration, true).await;
245 info!(
246 tool = %tool_name,
247 duration_ms = %duration.as_millis(),
248 "Tool call completed successfully"
249 );
250 result
251 }
252 Err(err) => {
253 let duration = start_time.elapsed();
254 metrics
255 .record_request_end(&tool_name, duration, false)
256 .await;
257 error!(
258 tool = %tool_name,
259 duration_ms = %duration.as_millis(),
260 error = %err,
261 "Tool call failed"
262 );
263 return Err(err.into());
264 }
265 }
266 };
267
268 Ok(Response {
269 jsonrpc: "2.0".to_string(),
270 id: request.id,
271 result: Some(serde_json::to_value(result)?),
272 error: None,
273 })
274 }
275
276 async fn handle_list_resources(
277 &self,
278 request: Request,
279 ) -> std::result::Result<Response, Error> {
280 let params: PaginatedRequestParam = if request.params.is_null() {
281 PaginatedRequestParam { cursor: None }
282 } else {
283 serde_json::from_value(request.params)?
284 };
285
286 let result = self
287 .backend
288 .list_resources(params)
289 .await
290 .map_err(|e| e.into())?;
291
292 Ok(Response {
293 jsonrpc: "2.0".to_string(),
294 id: request.id,
295 result: Some(serde_json::to_value(result)?),
296 error: None,
297 })
298 }
299
300 async fn handle_read_resource(&self, request: Request) -> std::result::Result<Response, Error> {
301 let params: ReadResourceRequestParam = serde_json::from_value(request.params)?;
302
303 let result = self
304 .backend
305 .read_resource(params)
306 .await
307 .map_err(|e| e.into())?;
308
309 Ok(Response {
310 jsonrpc: "2.0".to_string(),
311 id: request.id,
312 result: Some(serde_json::to_value(result)?),
313 error: None,
314 })
315 }
316
317 async fn handle_list_resource_templates(
318 &self,
319 request: Request,
320 ) -> std::result::Result<Response, Error> {
321 let params: PaginatedRequestParam = if request.params.is_null() {
322 PaginatedRequestParam { cursor: None }
323 } else {
324 serde_json::from_value(request.params)?
325 };
326
327 let result = self
328 .backend
329 .list_resource_templates(params)
330 .await
331 .map_err(|e| e.into())?;
332
333 Ok(Response {
334 jsonrpc: "2.0".to_string(),
335 id: request.id,
336 result: Some(serde_json::to_value(result)?),
337 error: None,
338 })
339 }
340
341 async fn handle_list_prompts(&self, request: Request) -> std::result::Result<Response, Error> {
342 let params: PaginatedRequestParam = if request.params.is_null() {
343 PaginatedRequestParam { cursor: None }
344 } else {
345 serde_json::from_value(request.params)?
346 };
347
348 let result = self
349 .backend
350 .list_prompts(params)
351 .await
352 .map_err(|e| e.into())?;
353
354 Ok(Response {
355 jsonrpc: "2.0".to_string(),
356 id: request.id,
357 result: Some(serde_json::to_value(result)?),
358 error: None,
359 })
360 }
361
362 async fn handle_get_prompt(&self, request: Request) -> std::result::Result<Response, Error> {
363 let params: GetPromptRequestParam = serde_json::from_value(request.params)?;
364
365 let result = self
366 .backend
367 .get_prompt(params)
368 .await
369 .map_err(|e| e.into())?;
370
371 Ok(Response {
372 jsonrpc: "2.0".to_string(),
373 id: request.id,
374 result: Some(serde_json::to_value(result)?),
375 error: None,
376 })
377 }
378
379 async fn handle_subscribe(&self, request: Request) -> std::result::Result<Response, Error> {
380 let params: SubscribeRequestParam = serde_json::from_value(request.params)?;
381
382 self.backend.subscribe(params).await.map_err(|e| e.into())?;
383
384 Ok(Response {
385 jsonrpc: "2.0".to_string(),
386 id: request.id,
387 result: Some(serde_json::Value::Object(Default::default())),
388 error: None,
389 })
390 }
391
392 async fn handle_unsubscribe(&self, request: Request) -> std::result::Result<Response, Error> {
393 let params: UnsubscribeRequestParam = serde_json::from_value(request.params)?;
394
395 self.backend
396 .unsubscribe(params)
397 .await
398 .map_err(|e| e.into())?;
399
400 Ok(Response {
401 jsonrpc: "2.0".to_string(),
402 id: request.id,
403 result: Some(serde_json::Value::Object(Default::default())),
404 error: None,
405 })
406 }
407
408 async fn handle_complete(&self, request: Request) -> std::result::Result<Response, Error> {
409 let params: CompleteRequestParam = serde_json::from_value(request.params)?;
410
411 let result = self.backend.complete(params).await.map_err(|e| e.into())?;
412
413 Ok(Response {
414 jsonrpc: "2.0".to_string(),
415 id: request.id,
416 result: Some(serde_json::to_value(result)?),
417 error: None,
418 })
419 }
420
421 async fn handle_elicit(&self, request: Request) -> std::result::Result<Response, Error> {
422 let params: ElicitationRequestParam = serde_json::from_value(request.params)?;
423
424 let result = self.backend.elicit(params).await.map_err(|e| e.into())?;
425
426 Ok(Response {
427 jsonrpc: "2.0".to_string(),
428 id: request.id,
429 result: Some(serde_json::to_value(result)?),
430 error: None,
431 })
432 }
433
434 async fn handle_set_level(&self, request: Request) -> std::result::Result<Response, Error> {
435 let params: SetLevelRequestParam = serde_json::from_value(request.params)?;
436
437 self.backend.set_level(params).await.map_err(|e| e.into())?;
438
439 Ok(Response {
440 jsonrpc: "2.0".to_string(),
441 id: request.id,
442 result: Some(serde_json::Value::Object(Default::default())),
443 error: None,
444 })
445 }
446
447 async fn handle_ping(&self, _request: Request) -> std::result::Result<Response, Error> {
448 Ok(Response {
449 jsonrpc: "2.0".to_string(),
450 id: _request.id,
451 result: Some(serde_json::Value::Object(Default::default())),
452 error: None,
453 })
454 }
455
456 async fn handle_custom_method(&self, request: Request) -> std::result::Result<Response, Error> {
457 let result = self
458 .backend
459 .handle_custom_method(&request.method, request.params)
460 .await
461 .map_err(|e| e.into())?;
462
463 Ok(Response {
464 jsonrpc: "2.0".to_string(),
465 id: request.id,
466 result: Some(result),
467 error: None,
468 })
469 }
470}
471
472impl From<HandlerError> for Error {
474 fn from(err: HandlerError) -> Self {
475 match err {
476 HandlerError::Authentication(msg) => Error::unauthorized(msg),
477 HandlerError::Authorization(msg) => Error::forbidden(msg),
478 HandlerError::Backend(msg) => Error::internal_error(msg),
479 HandlerError::Protocol(e) => e,
480 }
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use crate::backend::McpBackend;
488 use crate::middleware::MiddlewareStack;
489 use async_trait::async_trait;
490 use pulseengine_mcp_auth::AuthenticationManager;
491 use pulseengine_mcp_auth::config::AuthConfig;
492 use pulseengine_mcp_logging::ErrorClassification;
493 use pulseengine_mcp_protocol::{
494 CallToolRequestParam, CallToolResult, CompleteRequestParam, CompleteResult, CompletionInfo,
495 Content, Error, GetPromptRequestParam, GetPromptResult, Implementation, InitializeResult,
496 ListPromptsResult, ListResourceTemplatesResult, ListResourcesResult, ListToolsResult,
497 LoggingCapability, PaginatedRequestParam, Prompt, PromptMessage, PromptMessageContent,
498 PromptMessageRole, PromptsCapability, ProtocolVersion, ReadResourceRequestParam,
499 ReadResourceResult, Request, Resource, ResourceContents, ResourcesCapability,
500 ServerCapabilities, ServerInfo, SetLevelRequestParam, SubscribeRequestParam, Tool,
501 ToolsCapability, UnsubscribeRequestParam, error::ErrorCode,
502 };
503 use serde_json::json;
504 use std::sync::Arc;
505
506 #[derive(Clone)]
508 struct MockBackend {
509 server_info: ServerInfo,
510 tools: Vec<Tool>,
511 resources: Vec<Resource>,
512 prompts: Vec<Prompt>,
513 should_error: bool,
514 }
515
516 impl MockBackend {
517 fn new() -> Self {
518 Self {
519 server_info: ServerInfo {
520 protocol_version: ProtocolVersion::default(),
521 capabilities: ServerCapabilities {
522 tools: Some(ToolsCapability { list_changed: None }),
523 resources: Some(ResourcesCapability {
524 subscribe: Some(true),
525 list_changed: None,
526 }),
527 prompts: Some(PromptsCapability { list_changed: None }),
528 logging: Some(LoggingCapability { level: None }),
529 sampling: None,
530 elicitation: Some(ElicitationCapability {}),
531 },
532 server_info: Implementation {
533 name: "test-server".to_string(),
534 version: "1.0.0".to_string(),
535 },
536 instructions: None,
537 },
538 tools: vec![Tool {
539 name: "test_tool".to_string(),
540 description: "A test tool".to_string(),
541 input_schema: json!({
542 "type": "object",
543 "properties": {
544 "input": {"type": "string"}
545 }
546 }),
547 output_schema: None,
548 title: None,
549 annotations: None,
550 icons: None,
551 }],
552 resources: vec![Resource {
553 uri: "test://resource1".to_string(),
554 name: "Test Resource".to_string(),
555 description: Some("A test resource".to_string()),
556 mime_type: Some("text/plain".to_string()),
557 annotations: None,
558 raw: None,
559 title: None,
560 icons: None,
561 }],
562 prompts: vec![Prompt {
563 name: "test_prompt".to_string(),
564 description: Some("A test prompt".to_string()),
565 arguments: None,
566 title: None,
567 icons: None,
568 }],
569 should_error: false,
570 }
571 }
572
573 fn with_error() -> Self {
574 Self {
575 should_error: true,
576 ..Self::new()
577 }
578 }
579 }
580
581 #[async_trait]
582 impl McpBackend for MockBackend {
583 type Error = MockBackendError;
584 type Config = ();
585
586 async fn initialize(_config: Self::Config) -> std::result::Result<Self, Self::Error> {
587 Ok(MockBackend::new())
588 }
589
590 fn get_server_info(&self) -> ServerInfo {
591 self.server_info.clone()
592 }
593
594 async fn health_check(&self) -> std::result::Result<(), Self::Error> {
595 if self.should_error {
596 return Err(MockBackendError::TestError(
597 "Health check failed".to_string(),
598 ));
599 }
600 Ok(())
601 }
602
603 async fn list_tools(
604 &self,
605 _params: PaginatedRequestParam,
606 ) -> std::result::Result<ListToolsResult, Self::Error> {
607 if self.should_error {
608 return Err(MockBackendError::TestError("Simulated error".to_string()));
609 }
610
611 Ok(ListToolsResult {
612 tools: self.tools.clone(),
613 next_cursor: None,
614 })
615 }
616
617 async fn call_tool(
618 &self,
619 params: CallToolRequestParam,
620 ) -> std::result::Result<CallToolResult, Self::Error> {
621 if self.should_error {
622 return Err(MockBackendError::TestError("Tool call failed".to_string()));
623 }
624
625 if params.name == "test_tool" {
626 Ok(CallToolResult {
627 content: vec![Content::Text {
628 text: "Tool executed successfully".to_string(),
629 _meta: None,
630 }],
631 is_error: Some(false),
632 structured_content: None,
633 _meta: None,
634 })
635 } else {
636 Err(MockBackendError::TestError("Tool not found".to_string()))
637 }
638 }
639
640 async fn list_resources(
641 &self,
642 _params: PaginatedRequestParam,
643 ) -> std::result::Result<ListResourcesResult, Self::Error> {
644 if self.should_error {
645 return Err(MockBackendError::TestError("Simulated error".to_string()));
646 }
647
648 Ok(ListResourcesResult {
649 resources: self.resources.clone(),
650 next_cursor: None,
651 })
652 }
653
654 async fn read_resource(
655 &self,
656 params: ReadResourceRequestParam,
657 ) -> std::result::Result<ReadResourceResult, Self::Error> {
658 if self.should_error {
659 return Err(MockBackendError::TestError("Simulated error".to_string()));
660 }
661
662 if params.uri == "test://resource1" {
663 Ok(ReadResourceResult {
664 contents: vec![ResourceContents {
665 uri: params.uri,
666 mime_type: Some("text/plain".to_string()),
667 text: Some("Resource content".to_string()),
668 blob: None,
669 _meta: None,
670 }],
671 })
672 } else {
673 Err(MockBackendError::TestError(
674 "Resource not found".to_string(),
675 ))
676 }
677 }
678
679 async fn list_resource_templates(
680 &self,
681 _params: PaginatedRequestParam,
682 ) -> std::result::Result<ListResourceTemplatesResult, Self::Error> {
683 Ok(ListResourceTemplatesResult {
684 resource_templates: vec![],
685 next_cursor: None,
686 })
687 }
688
689 async fn list_prompts(
690 &self,
691 _params: PaginatedRequestParam,
692 ) -> std::result::Result<ListPromptsResult, Self::Error> {
693 if self.should_error {
694 return Err(MockBackendError::TestError("Simulated error".to_string()));
695 }
696
697 Ok(ListPromptsResult {
698 prompts: self.prompts.clone(),
699 next_cursor: None,
700 })
701 }
702
703 async fn get_prompt(
704 &self,
705 params: GetPromptRequestParam,
706 ) -> std::result::Result<GetPromptResult, Self::Error> {
707 if self.should_error {
708 return Err(MockBackendError::TestError("Simulated error".to_string()));
709 }
710
711 if params.name == "test_prompt" {
712 Ok(GetPromptResult {
713 description: Some("A test prompt".to_string()),
714 messages: vec![PromptMessage {
715 role: PromptMessageRole::User,
716 content: PromptMessageContent::Text {
717 text: "Test prompt message".to_string(),
718 },
719 }],
720 })
721 } else {
722 Err(MockBackendError::TestError("Prompt not found".to_string()))
723 }
724 }
725
726 async fn subscribe(
727 &self,
728 _params: SubscribeRequestParam,
729 ) -> std::result::Result<(), Self::Error> {
730 if self.should_error {
731 return Err(MockBackendError::TestError("Subscribe failed".to_string()));
732 }
733 Ok(())
734 }
735
736 async fn unsubscribe(
737 &self,
738 _params: UnsubscribeRequestParam,
739 ) -> std::result::Result<(), Self::Error> {
740 if self.should_error {
741 return Err(MockBackendError::TestError(
742 "Unsubscribe failed".to_string(),
743 ));
744 }
745 Ok(())
746 }
747
748 async fn complete(
749 &self,
750 _params: CompleteRequestParam,
751 ) -> std::result::Result<CompleteResult, Self::Error> {
752 if self.should_error {
753 return Err(MockBackendError::TestError("Complete failed".to_string()));
754 }
755
756 Ok(CompleteResult {
757 completion: vec![
758 CompletionInfo {
759 completion: "completion1".to_string(),
760 has_more: Some(false),
761 },
762 CompletionInfo {
763 completion: "completion2".to_string(),
764 has_more: Some(false),
765 },
766 ],
767 })
768 }
769
770 async fn elicit(
771 &self,
772 _params: ElicitationRequestParam,
773 ) -> std::result::Result<ElicitationResult, Self::Error> {
774 if self.should_error {
775 return Err(MockBackendError::TestError(
776 "Elicitation failed".to_string(),
777 ));
778 }
779
780 Ok(ElicitationResult::accept(serde_json::json!({
782 "name": "Test User",
783 "email": "test@example.com"
784 })))
785 }
786
787 async fn set_level(
788 &self,
789 _params: SetLevelRequestParam,
790 ) -> std::result::Result<(), Self::Error> {
791 if self.should_error {
792 return Err(MockBackendError::TestError("Set level failed".to_string()));
793 }
794 Ok(())
795 }
796
797 async fn handle_custom_method(
798 &self,
799 method: &str,
800 _params: serde_json::Value,
801 ) -> std::result::Result<serde_json::Value, Self::Error> {
802 if self.should_error {
803 return Err(MockBackendError::TestError(
804 "Custom method failed".to_string(),
805 ));
806 }
807
808 Ok(json!({
809 "method": method,
810 "result": "custom method executed"
811 }))
812 }
813 }
814
815 #[derive(Debug, thiserror::Error)]
816 enum MockBackendError {
817 #[error("Test error: {0}")]
818 TestError(String),
819 }
820
821 impl From<MockBackendError> for Error {
822 fn from(err: MockBackendError) -> Self {
823 Error::internal_error(err.to_string())
824 }
825 }
826
827 impl From<crate::backend::BackendError> for MockBackendError {
828 fn from(error: crate::backend::BackendError) -> Self {
829 MockBackendError::TestError(error.to_string())
830 }
831 }
832
833 async fn create_test_handler() -> GenericServerHandler<MockBackend> {
834 let backend = Arc::new(MockBackend::new());
835 let auth_config = AuthConfig::memory();
836 let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
837 let middleware = MiddlewareStack::new();
838
839 GenericServerHandler::new(backend, auth_manager, middleware)
840 }
841
842 async fn create_error_handler() -> GenericServerHandler<MockBackend> {
843 let backend = Arc::new(MockBackend::with_error());
844 let auth_config = AuthConfig::memory();
845 let auth_manager = Arc::new(AuthenticationManager::new(auth_config).await.unwrap());
846 let middleware = MiddlewareStack::new();
847
848 GenericServerHandler::new(backend, auth_manager, middleware)
849 }
850
851 #[tokio::test]
852 async fn test_handler_creation() {
853 let handler = create_test_handler().await;
854 assert!(!handler.backend.tools.is_empty());
856 }
857
858 #[tokio::test]
859 async fn test_handle_initialize() {
860 let handler = create_test_handler().await;
861 let request = Request {
862 jsonrpc: "2.0".to_string(),
863 method: "initialize".to_string(),
864 params: json!({
865 "protocolVersion": "2024-11-05",
866 "capabilities": {},
867 "clientInfo": {
868 "name": "test-client",
869 "version": "1.0.0"
870 }
871 }),
872 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(1)),
873 };
874
875 let response = handler.handle_request(request).await.unwrap();
876
877 assert_eq!(response.jsonrpc, "2.0");
878 assert_eq!(
879 response.id,
880 Some(pulseengine_mcp_protocol::NumberOrString::Number(1))
881 );
882 assert!(response.result.is_some());
883 assert!(response.error.is_none());
884
885 let result: InitializeResult = serde_json::from_value(response.result.unwrap()).unwrap();
886 assert_eq!(
887 result.protocol_version,
888 pulseengine_mcp_protocol::MCP_VERSION
889 );
890 assert_eq!(result.server_info.name, "test-server");
891 }
892
893 #[tokio::test]
894 async fn test_handle_list_tools() {
895 let handler = create_test_handler().await;
896 let request = Request {
897 jsonrpc: "2.0".to_string(),
898 method: "tools/list".to_string(),
899 params: json!({}),
900 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(2)),
901 };
902
903 let response = handler.handle_request(request).await.unwrap();
904
905 assert_eq!(response.jsonrpc, "2.0");
906 assert_eq!(
907 response.id,
908 Some(pulseengine_mcp_protocol::NumberOrString::Number(2))
909 );
910 assert!(response.result.is_some());
911 assert!(response.error.is_none());
912
913 let result: ListToolsResult = serde_json::from_value(response.result.unwrap()).unwrap();
914 assert_eq!(result.tools.len(), 1);
915 assert_eq!(result.tools[0].name, "test_tool");
916 }
917
918 #[tokio::test]
919 async fn test_handle_call_tool_success() {
920 let handler = create_test_handler().await;
921 let request = Request {
922 jsonrpc: "2.0".to_string(),
923 method: "tools/call".to_string(),
924 params: json!({
925 "name": "test_tool",
926 "arguments": {
927 "input": "test input"
928 }
929 }),
930 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(3)),
931 };
932
933 let response = handler.handle_request(request).await.unwrap();
934
935 assert_eq!(response.jsonrpc, "2.0");
936 assert_eq!(
937 response.id,
938 Some(pulseengine_mcp_protocol::NumberOrString::Number(3))
939 );
940 assert!(response.result.is_some());
941 assert!(response.error.is_none());
942
943 let result: CallToolResult = serde_json::from_value(response.result.unwrap()).unwrap();
944 assert_eq!(result.content.len(), 1);
945 assert!(!result.is_error.unwrap_or(true));
946 }
947
948 #[tokio::test]
949 async fn test_handle_call_tool_not_found() {
950 let handler = create_test_handler().await;
951 let request = Request {
952 jsonrpc: "2.0".to_string(),
953 method: "tools/call".to_string(),
954 params: json!({
955 "name": "nonexistent_tool",
956 "arguments": {}
957 }),
958 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(4)),
959 };
960
961 let response = handler.handle_request(request).await.unwrap();
962
963 assert_eq!(response.jsonrpc, "2.0");
964 assert_eq!(
965 response.id,
966 Some(pulseengine_mcp_protocol::NumberOrString::Number(4))
967 );
968 assert!(response.result.is_none());
969 assert!(response.error.is_some());
970 }
971
972 #[tokio::test]
973 async fn test_handle_list_resources() {
974 let handler = create_test_handler().await;
975 let request = Request {
976 jsonrpc: "2.0".to_string(),
977 method: "resources/list".to_string(),
978 params: json!({}),
979 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(5)),
980 };
981
982 let response = handler.handle_request(request).await.unwrap();
983
984 assert_eq!(response.jsonrpc, "2.0");
985 assert_eq!(
986 response.id,
987 Some(pulseengine_mcp_protocol::NumberOrString::Number(5))
988 );
989 assert!(response.result.is_some());
990 assert!(response.error.is_none());
991
992 let result: ListResourcesResult = serde_json::from_value(response.result.unwrap()).unwrap();
993 assert_eq!(result.resources.len(), 1);
994 assert_eq!(result.resources[0].uri, "test://resource1");
995 }
996
997 #[tokio::test]
998 async fn test_handle_read_resource() {
999 let handler = create_test_handler().await;
1000 let request = Request {
1001 jsonrpc: "2.0".to_string(),
1002 method: "resources/read".to_string(),
1003 params: json!({
1004 "uri": "test://resource1"
1005 }),
1006 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(6)),
1007 };
1008
1009 let response = handler.handle_request(request).await.unwrap();
1010
1011 assert_eq!(response.jsonrpc, "2.0");
1012 assert_eq!(
1013 response.id,
1014 Some(pulseengine_mcp_protocol::NumberOrString::Number(6))
1015 );
1016 assert!(response.result.is_some());
1017 assert!(response.error.is_none());
1018
1019 let result: ReadResourceResult = serde_json::from_value(response.result.unwrap()).unwrap();
1020 assert_eq!(result.contents.len(), 1);
1021 }
1022
1023 #[tokio::test]
1024 async fn test_handle_list_prompts() {
1025 let handler = create_test_handler().await;
1026 let request = Request {
1027 jsonrpc: "2.0".to_string(),
1028 method: "prompts/list".to_string(),
1029 params: json!({}),
1030 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(7)),
1031 };
1032
1033 let response = handler.handle_request(request).await.unwrap();
1034
1035 assert_eq!(response.jsonrpc, "2.0");
1036 assert_eq!(
1037 response.id,
1038 Some(pulseengine_mcp_protocol::NumberOrString::Number(7))
1039 );
1040 assert!(response.result.is_some());
1041 assert!(response.error.is_none());
1042
1043 let result: ListPromptsResult = serde_json::from_value(response.result.unwrap()).unwrap();
1044 assert_eq!(result.prompts.len(), 1);
1045 assert_eq!(result.prompts[0].name, "test_prompt");
1046 }
1047
1048 #[tokio::test]
1049 async fn test_handle_get_prompt() {
1050 let handler = create_test_handler().await;
1051 let request = Request {
1052 jsonrpc: "2.0".to_string(),
1053 method: "prompts/get".to_string(),
1054 params: json!({
1055 "name": "test_prompt",
1056 "arguments": {}
1057 }),
1058 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(8)),
1059 };
1060
1061 let response = handler.handle_request(request).await.unwrap();
1062
1063 assert_eq!(response.jsonrpc, "2.0");
1064 assert_eq!(
1065 response.id,
1066 Some(pulseengine_mcp_protocol::NumberOrString::Number(8))
1067 );
1068 assert!(response.result.is_some());
1069 assert!(response.error.is_none());
1070
1071 let result: GetPromptResult = serde_json::from_value(response.result.unwrap()).unwrap();
1072 assert_eq!(result.messages.len(), 1);
1073 }
1074
1075 #[tokio::test]
1076 async fn test_handle_subscribe() {
1077 let handler = create_test_handler().await;
1078 let request = Request {
1079 jsonrpc: "2.0".to_string(),
1080 method: "resources/subscribe".to_string(),
1081 params: json!({
1082 "uri": "test://resource1"
1083 }),
1084 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(9)),
1085 };
1086
1087 let response = handler.handle_request(request).await.unwrap();
1088
1089 assert_eq!(response.jsonrpc, "2.0");
1090 assert_eq!(
1091 response.id,
1092 Some(pulseengine_mcp_protocol::NumberOrString::Number(9))
1093 );
1094 assert!(response.result.is_some());
1095 assert!(response.error.is_none());
1096 }
1097
1098 #[tokio::test]
1099 async fn test_handle_unsubscribe() {
1100 let handler = create_test_handler().await;
1101 let request = Request {
1102 jsonrpc: "2.0".to_string(),
1103 method: "resources/unsubscribe".to_string(),
1104 params: json!({
1105 "uri": "test://resource1"
1106 }),
1107 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(10)),
1108 };
1109
1110 let response = handler.handle_request(request).await.unwrap();
1111
1112 assert_eq!(response.jsonrpc, "2.0");
1113 assert_eq!(
1114 response.id,
1115 Some(pulseengine_mcp_protocol::NumberOrString::Number(10))
1116 );
1117 assert!(response.result.is_some());
1118 assert!(response.error.is_none());
1119 }
1120
1121 #[tokio::test]
1122 async fn test_handle_complete() {
1123 let handler = create_test_handler().await;
1124 let request = Request {
1125 jsonrpc: "2.0".to_string(),
1126 method: "completion/complete".to_string(),
1127 params: json!({
1128 "ref_": "test_prompt",
1129 "argument": {
1130 "name": "query",
1131 "value": "test"
1132 }
1133 }),
1134 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(11)),
1135 };
1136
1137 let response = handler.handle_request(request).await.unwrap();
1138
1139 assert_eq!(response.jsonrpc, "2.0");
1140 assert_eq!(
1141 response.id,
1142 Some(pulseengine_mcp_protocol::NumberOrString::Number(11))
1143 );
1144 assert!(response.result.is_some());
1145 assert!(response.error.is_none());
1146
1147 let result: CompleteResult = serde_json::from_value(response.result.unwrap()).unwrap();
1148 assert_eq!(result.completion.len(), 2);
1149 }
1150
1151 #[tokio::test]
1152 async fn test_handle_elicit() {
1153 let handler = create_test_handler().await;
1154 let request = Request {
1155 jsonrpc: "2.0".to_string(),
1156 method: "elicitation/create".to_string(),
1157 params: json!({
1158 "message": "Please provide your contact information",
1159 "requestedSchema": {
1160 "type": "object",
1161 "properties": {
1162 "name": {"type": "string", "description": "Your full name"},
1163 "email": {"type": "string", "format": "email"}
1164 },
1165 "required": ["name", "email"]
1166 }
1167 }),
1168 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(12)),
1169 };
1170
1171 let response = handler.handle_request(request).await.unwrap();
1172
1173 assert_eq!(response.jsonrpc, "2.0");
1174 assert_eq!(
1175 response.id,
1176 Some(pulseengine_mcp_protocol::NumberOrString::Number(12))
1177 );
1178 assert!(response.result.is_some());
1179 assert!(response.error.is_none());
1180
1181 let result: ElicitationResult = serde_json::from_value(response.result.unwrap()).unwrap();
1182 assert!(matches!(result.response.action, ElicitationAction::Accept));
1183 assert!(result.response.data.is_some());
1184 }
1185
1186 #[tokio::test]
1187 async fn test_handle_ping() {
1188 let handler = create_test_handler().await;
1189 let request = Request {
1190 jsonrpc: "2.0".to_string(),
1191 method: "ping".to_string(),
1192 params: json!({}),
1193 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(13)),
1194 };
1195
1196 let response = handler.handle_request(request).await.unwrap();
1197
1198 assert_eq!(response.jsonrpc, "2.0");
1199 assert_eq!(
1200 response.id,
1201 Some(pulseengine_mcp_protocol::NumberOrString::Number(13))
1202 );
1203 assert!(response.result.is_some());
1204 assert!(response.error.is_none());
1205 }
1206
1207 #[tokio::test]
1208 async fn test_handle_custom_method() {
1209 let handler = create_test_handler().await;
1210 let request = Request {
1211 jsonrpc: "2.0".to_string(),
1212 method: "custom/method".to_string(),
1213 params: json!({"test": "data"}),
1214 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(14)),
1215 };
1216
1217 let response = handler.handle_request(request).await.unwrap();
1218
1219 assert_eq!(response.jsonrpc, "2.0");
1220 assert_eq!(
1221 response.id,
1222 Some(pulseengine_mcp_protocol::NumberOrString::Number(14))
1223 );
1224 assert!(response.result.is_some());
1225 assert!(response.error.is_none());
1226
1227 let result = response.result.unwrap();
1228 assert_eq!(result["method"], "custom/method");
1229 }
1230
1231 #[tokio::test]
1232 async fn test_backend_error_handling() {
1233 let handler = create_error_handler().await;
1234 let request = Request {
1235 jsonrpc: "2.0".to_string(),
1236 method: "tools/list".to_string(),
1237 params: json!({}),
1238 id: Some(pulseengine_mcp_protocol::NumberOrString::Number(15)),
1239 };
1240
1241 let response = handler.handle_request(request).await.unwrap();
1242
1243 assert_eq!(response.jsonrpc, "2.0");
1244 assert_eq!(
1245 response.id,
1246 Some(pulseengine_mcp_protocol::NumberOrString::Number(15))
1247 );
1248 assert!(response.result.is_none());
1249 assert!(response.error.is_some());
1250
1251 let error = response.error.unwrap();
1252 assert!(error.message.contains("Simulated error"));
1253 }
1254
1255 #[tokio::test]
1256 async fn test_invalid_params() {
1257 let handler = create_test_handler().await;
1258 let request = Request {
1259 jsonrpc: "2.0".to_string(),
1260 method: "tools/call".to_string(),
1261 params: json!("invalid"), id: Some(pulseengine_mcp_protocol::NumberOrString::Number(16)),
1263 };
1264
1265 let response = handler.handle_request(request).await.unwrap();
1266
1267 assert_eq!(response.jsonrpc, "2.0");
1268 assert_eq!(
1269 response.id,
1270 Some(pulseengine_mcp_protocol::NumberOrString::Number(16))
1271 );
1272 assert!(response.result.is_none());
1273 assert!(response.error.is_some());
1274 }
1275
1276 #[test]
1277 fn test_handler_error_classification() {
1278 let auth_error = HandlerError::Authentication("Invalid token".to_string());
1279 assert_eq!(auth_error.error_type(), "authentication");
1280 assert!(!auth_error.is_retryable());
1281 assert!(!auth_error.is_timeout());
1282 assert!(auth_error.is_auth_error());
1283 assert!(!auth_error.is_connection_error());
1284
1285 let backend_error = HandlerError::Backend("Database error".to_string());
1286 assert_eq!(backend_error.error_type(), "backend");
1287 assert!(backend_error.is_retryable());
1288 assert!(!backend_error.is_timeout());
1289 assert!(!backend_error.is_auth_error());
1290 assert!(!backend_error.is_connection_error());
1291
1292 let protocol_error =
1293 HandlerError::Protocol(Error::invalid_request("Bad request".to_string()));
1294 assert_eq!(protocol_error.error_type(), "protocol");
1295 assert!(!protocol_error.is_retryable());
1296 assert!(!protocol_error.is_timeout());
1297 assert!(!protocol_error.is_auth_error());
1298 assert!(!protocol_error.is_connection_error());
1299 }
1300
1301 #[test]
1302 fn test_handler_error_conversion() {
1303 let auth_error = HandlerError::Authentication("Auth failed".to_string());
1304 let protocol_error: Error = auth_error.into();
1305 assert_eq!(protocol_error.code, ErrorCode::Unauthorized);
1306
1307 let backend_error = HandlerError::Backend("Backend failed".to_string());
1308 let protocol_error: Error = backend_error.into();
1309 assert_eq!(protocol_error.code, ErrorCode::InternalError);
1310 }
1311
1312 #[test]
1313 fn test_handler_error_display() {
1314 let error = HandlerError::Authentication("Test auth error".to_string());
1315 assert_eq!(error.to_string(), "Authentication failed: Test auth error");
1316
1317 let error = HandlerError::Authorization("Test auth error".to_string());
1318 assert_eq!(error.to_string(), "Authorization failed: Test auth error");
1319
1320 let error = HandlerError::Backend("Test backend error".to_string());
1321 assert_eq!(error.to_string(), "Backend error: Test backend error");
1322 }
1323}