1use std::sync::Arc;
14use std::time::Duration;
15
16use arc_swap::ArcSwap;
17use dashmap::DashMap;
18use serde_json::Value;
19use tokio::sync::mpsc;
20
21use crate::content::types::Content;
22use crate::protocol::capabilities::{InitializeRequest, InitializeResult, ServerCapabilities};
23use crate::protocol::errors::{ErrorType, McpError};
24use crate::protocol::methods::McpMethod;
25use crate::protocol::types::{Implementation, JsonRpcRequest, JsonRpcResponse};
26use crate::protocol::version;
27use crate::registry::prompts::PromptManager;
28use crate::registry::resources::ResourceManager;
29use crate::registry::tools::ToolRegistry;
30use crate::server::handler::{
31 RequestContext, error_response, require_initialization, success_response,
32};
33use crate::server::middleware::MiddlewareChain;
34use crate::server::multiplexer::{
35 ClientRequester, CreateMessageParams, CreateMessageResult, JsonRpcClientRequest,
36 ListRootsResult, MultiplexerError, RequestMultiplexer, Root,
37};
38use crate::server::session::Session;
39use crate::server::visibility::VisibilityContext;
40use crate::transport::traits::{IncomingMessage, JsonRpcNotification, Transport};
41
42pub struct Server {
44 name: String,
46
47 version: String,
49
50 instructions: Arc<ArcSwap<Option<String>>>,
52
53 capabilities: Arc<ArcSwap<ServerCapabilities>>,
55
56 sessions: DashMap<String, Session>,
58
59 middleware: MiddlewareChain,
61
62 tool_registry: ToolRegistry,
64
65 resource_manager: ResourceManager,
67
68 prompt_manager: PromptManager,
70
71 notification_tx: mpsc::UnboundedSender<JsonRpcNotification>,
73
74 notification_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<JsonRpcNotification>>>,
76
77 logger: crate::logging::McpLogger,
79
80 multiplexer: Arc<RequestMultiplexer>,
82
83 request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
85
86 request_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<JsonRpcClientRequest>>>,
88
89 task_store: Arc<crate::managers::task::TaskStore>,
91}
92
93impl Server {
94 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
96 let (notification_tx, notification_rx) = mpsc::unbounded_channel();
97 let (request_tx, request_rx) = mpsc::unbounded_channel();
98
99 let task_store = Arc::new(crate::managers::task::TaskStore::new(
101 std::time::Duration::from_secs(300),
102 std::time::Duration::from_secs(5),
103 ));
104
105 if tokio::runtime::Handle::try_current().is_ok() {
107 task_store
108 .clone()
109 .spawn_cleanup_task(std::time::Duration::from_secs(60));
110 }
111
112 let default_caps = ServerCapabilities {
114 tasks: Some(crate::protocol::capabilities::TasksCapability {
115 list: Some(crate::protocol::capabilities::EmptyObject {}),
116 cancel: Some(crate::protocol::capabilities::EmptyObject {}),
117 requests: Some(crate::protocol::capabilities::TasksRequestsCapability {
118 tools: Some(crate::protocol::capabilities::TasksToolsCapability {
119 call: Some(crate::protocol::capabilities::EmptyObject {}),
120 }),
121 ..Default::default()
122 }),
123 }),
124 ..Default::default()
125 };
126
127 let logger = crate::logging::McpLogger::new(notification_tx.clone(), "mcp-server");
129
130 Self {
131 name: name.into(),
132 version: version.into(),
133 instructions: Arc::new(ArcSwap::new(Arc::new(None))),
134 capabilities: Arc::new(ArcSwap::new(Arc::new(default_caps))),
135 sessions: DashMap::new(),
136 middleware: MiddlewareChain::new(),
137 tool_registry: ToolRegistry::new(),
138 resource_manager: ResourceManager::new(),
139 prompt_manager: PromptManager::new(),
140 notification_tx,
141 notification_rx: Arc::new(tokio::sync::Mutex::new(notification_rx)),
142 logger,
143 multiplexer: Arc::new(RequestMultiplexer::new()),
144 request_tx,
145 request_rx: Arc::new(tokio::sync::Mutex::new(request_rx)),
146 task_store,
147 }
148 }
149
150 pub fn name(&self) -> &str {
152 &self.name
153 }
154
155 pub fn version(&self) -> &str {
157 &self.version
158 }
159
160 pub fn capabilities(&self) -> Arc<ServerCapabilities> {
162 self.capabilities.load_full()
163 }
164
165 pub fn set_capabilities(&self, capabilities: ServerCapabilities) {
167 self.capabilities.store(Arc::new(capabilities));
168 }
169
170 pub fn instructions(&self) -> Option<String> {
172 (**self.instructions.load()).clone()
173 }
174
175 pub fn set_instructions(&self, instructions: Option<String>) {
177 self.instructions.store(Arc::new(instructions));
178 }
179
180 pub fn add_middleware(&mut self, middleware: crate::server::middleware::MiddlewareFn) {
182 self.middleware.add(middleware);
183 }
184
185 pub fn tool_registry(&self) -> &ToolRegistry {
187 &self.tool_registry
188 }
189
190 pub fn resource_manager(&self) -> &ResourceManager {
192 &self.resource_manager
193 }
194
195 pub fn prompt_manager(&self) -> &PromptManager {
197 &self.prompt_manager
198 }
199
200 pub fn logger(&self) -> &crate::logging::McpLogger {
202 &self.logger
203 }
204
205 pub fn notification_sender(&self) -> mpsc::UnboundedSender<JsonRpcNotification> {
207 self.notification_tx.clone()
208 }
209
210 pub fn send_notification(
212 &self,
213 method: impl Into<String>,
214 params: Option<Value>,
215 ) -> Result<(), Box<dyn std::error::Error>> {
216 let notification = JsonRpcNotification::new(method, params);
217 self.notification_tx.send(notification)?;
218 Ok(())
219 }
220
221 pub fn get_session(&self, session_id: &str) -> Option<Session> {
223 self.sessions.get(session_id).map(|s| s.clone())
224 }
225
226 pub fn remove_session(&self, session_id: &str) -> Option<Session> {
228 self.sessions.remove(session_id).map(|(_, s)| s)
229 }
230
231 pub fn multiplexer(&self) -> Arc<RequestMultiplexer> {
233 self.multiplexer.clone()
234 }
235
236 pub fn create_client_requester(&self, session_id: &str) -> Option<ClientRequester> {
241 let session = self.get_session(session_id)?;
242 let caps = session.capabilities.as_ref()?;
243
244 Some(ClientRequester::new(
245 self.request_tx.clone(),
246 self.multiplexer.clone(),
247 caps.roots.is_some(),
248 caps.sampling.is_some(),
249 ))
250 }
251
252 pub async fn request_roots(
278 &self,
279 session_id: &str,
280 timeout: Option<Duration>,
281 ) -> Result<Vec<Root>, MultiplexerError> {
282 if let Some(session) = self.get_session(session_id) {
284 if let Some(caps) = &session.capabilities {
285 if caps.roots.is_none() {
286 return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
287 }
288 } else {
289 return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
290 }
291 } else {
292 return Err(MultiplexerError::Transport("session not found".to_string()));
293 }
294
295 let (id, rx) = self.multiplexer.create_pending("roots/list");
297
298 let request = JsonRpcClientRequest::new(&id, "roots/list", Some(serde_json::json!({})));
300
301 self.request_tx
302 .send(request)
303 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
304
305 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
307 let result = tokio::time::timeout(timeout, rx)
308 .await
309 .map_err(|_| MultiplexerError::Timeout(timeout))?
310 .map_err(|_| MultiplexerError::ChannelClosed)??;
311
312 let list_result: ListRootsResult = serde_json::from_value(result)?;
314 Ok(list_result.roots)
315 }
316
317 pub async fn request_sampling(
353 &self,
354 session_id: &str,
355 params: CreateMessageParams,
356 timeout: Option<Duration>,
357 ) -> Result<CreateMessageResult, MultiplexerError> {
358 if let Some(session) = self.get_session(session_id) {
360 if let Some(caps) = &session.capabilities {
361 if caps.sampling.is_none() {
362 return Err(MultiplexerError::UnsupportedCapability(
363 "sampling".to_string(),
364 ));
365 }
366 } else {
367 return Err(MultiplexerError::UnsupportedCapability(
368 "sampling".to_string(),
369 ));
370 }
371 } else {
372 return Err(MultiplexerError::Transport("session not found".to_string()));
373 }
374
375 let (id, rx) = self.multiplexer.create_pending("sampling/createMessage");
377
378 let params_value = serde_json::to_value(¶ms)?;
380 let request = JsonRpcClientRequest::new(&id, "sampling/createMessage", Some(params_value));
381
382 self.request_tx
383 .send(request)
384 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
385
386 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
388 let result = tokio::time::timeout(timeout, rx)
389 .await
390 .map_err(|_| MultiplexerError::Timeout(timeout))?
391 .map_err(|_| MultiplexerError::ChannelClosed)??;
392
393 let create_result: CreateMessageResult = serde_json::from_value(result)?;
395 Ok(create_result)
396 }
397
398 pub async fn run<T: Transport>(
407 &self,
408 mut transport: T,
409 ) -> Result<(), Box<dyn std::error::Error>> {
410 let session_id = uuid::Uuid::new_v4().to_string();
412
413 let mut notification_rx = self.notification_rx.lock().await;
415 let mut request_rx = self.request_rx.lock().await;
416
417 loop {
418 tokio::select! {
419 Some(notification) = notification_rx.recv() => {
421 tracing::debug!(method = %notification.method, "Sending notification");
422 if let Err(e) = transport.send_notification(notification).await {
423 tracing::error!(error = %e, "Failed to send notification");
424 }
425 }
426
427 Some(request) = request_rx.recv() => {
429 tracing::debug!(method = %request.method, id = %request.id, "Sending request to client");
430 if let Err(e) = transport.send_request(request).await {
431 tracing::error!(error = %e, "Failed to send request to client");
432 }
433 }
434
435 result = transport.read_incoming() => {
437 match result {
438 Ok(IncomingMessage::Request(request)) => {
439 let is_notification = request.id.is_none();
441
442 let response = self.handle_request(&session_id, request).await;
444
445 if !is_notification
448 && let Err(e) = transport.write_message(response).await {
449 tracing::error!(error = %e, "Failed to write message");
450 break;
451 }
452 }
453 Ok(IncomingMessage::Response(response)) => {
454 if !self.multiplexer.route_response(&response) {
456 tracing::warn!(
457 id = ?response.id,
458 "Received response for unknown request ID"
459 );
460 }
461 }
462 Err(crate::transport::traits::TransportError::Closed) => {
463 tracing::info!("Transport closed, shutting down");
464 break;
465 }
466 Err(e) => {
467 tracing::error!(error = %e, "Failed to read message");
468 continue;
469 }
470 }
471 }
472 }
473 }
474
475 self.multiplexer.cancel_all();
477
478 transport.shutdown().await?;
480
481 Ok(())
482 }
483
484 pub async fn handle_request(
486 &self,
487 session_id: &str,
488 request: JsonRpcRequest,
489 ) -> JsonRpcResponse {
490 let session = self
492 .sessions
493 .entry(session_id.to_string())
494 .or_insert_with(|| {
495 let mut session = Session::with_id(session_id);
496 session.set_notification_channel(self.notification_tx.clone());
497 session
498 })
499 .clone();
500
501 let ctx = RequestContext::new(session, request.clone());
503
504 let ctx = match self.middleware.process(ctx) {
506 Ok(ctx) => ctx,
507 Err(err) => return error_response(request.id, err.to_jsonrpc()),
508 };
509
510 let method = McpMethod::from(request.method.clone());
512
513 match method {
514 McpMethod::Initialize => self.handle_initialize(ctx).await,
515 McpMethod::Ping => self.handle_ping(ctx).await,
516 McpMethod::LoggingSetLevel => self.handle_logging_set_level(ctx).await,
517 McpMethod::ToolsList => self.handle_tools_list(ctx).await,
518 McpMethod::ToolsCall => self.handle_tools_call(ctx).await,
519 McpMethod::ResourcesList => self.handle_resources_list(ctx).await,
520 McpMethod::ResourcesTemplatesList => self.handle_resources_templates_list(ctx).await,
521 McpMethod::ResourcesRead => self.handle_resources_read(ctx).await,
522 McpMethod::PromptsList => self.handle_prompts_list(ctx).await,
523 McpMethod::PromptsGet => self.handle_prompts_get(ctx).await,
524 McpMethod::RootsList => self.handle_roots_list(ctx).await,
525 McpMethod::SamplingCreateMessage => self.handle_sampling_create_message(ctx).await,
526 McpMethod::ElicitationCreate => self.handle_elicitation_create(ctx).await,
527 McpMethod::TasksGet => self.handle_tasks_get(ctx).await,
528 McpMethod::TasksResult => self.handle_tasks_result(ctx).await,
529 McpMethod::TasksList => self.handle_tasks_list(ctx).await,
530 McpMethod::TasksCancel => self.handle_tasks_cancel(ctx).await,
531 _ => error_response(
532 request.id,
533 McpError::method_not_found(&request.method).to_jsonrpc(),
534 ),
535 }
536 }
537
538 async fn handle_initialize(&self, ctx: RequestContext) -> JsonRpcResponse {
540 let params = ctx.params().cloned().unwrap_or(Value::Null);
542 let req: InitializeRequest = match serde_json::from_value(params) {
543 Ok(req) => req,
544 Err(_) => {
545 return error_response(
546 ctx.request.id,
547 McpError::validation("invalid_params", "Invalid initialize parameters")
548 .to_jsonrpc(),
549 );
550 }
551 };
552
553 let protocol_version = match version::negotiate_protocol_version(&req.protocol_version) {
555 Ok(version) => version,
556 Err(supported_versions) => {
557 tracing::warn!(
558 client = %req.client_info.name,
559 requested = %req.protocol_version,
560 supported = ?supported_versions,
561 "Unsupported protocol version"
562 );
563 return error_response(
564 ctx.request.id,
565 McpError::builder(ErrorType::Validation, "unsupported_protocol_version")
566 .message("Unsupported protocol version")
567 .detail(
568 "supported",
569 serde_json::to_value(&supported_versions).unwrap(),
570 )
571 .detail("requested", req.protocol_version.clone())
572 .build()
573 .to_jsonrpc(),
574 );
575 }
576 };
577
578 tracing::info!(
580 client = %req.client_info.name,
581 version = %req.client_info.version,
582 protocol = %protocol_version,
583 "Client connected"
584 );
585
586 if let Some(mut session) = self.sessions.get_mut(&ctx.session.id) {
588 session.initialize(req.client_info, req.capabilities, protocol_version.clone());
589 }
590
591 let result = InitializeResult {
593 protocol_version,
594 capabilities: (**self.capabilities.load()).clone(),
595 server_info: Implementation {
596 name: self.name.clone(),
597 version: self.version.clone(),
598 },
599 instructions: self.instructions(),
600 };
601
602 success_response(
603 ctx.request.id,
604 serde_json::to_value(result).expect("Failed to serialize InitializeResult"),
605 )
606 }
607
608 async fn handle_ping(&self, ctx: RequestContext) -> JsonRpcResponse {
610 success_response(ctx.request.id, serde_json::json!({}))
611 }
612
613 async fn handle_logging_set_level(&self, ctx: RequestContext) -> JsonRpcResponse {
615 use crate::logging::LogLevel;
616 use crate::protocol::types::SetLevelRequest;
617
618 let params = ctx.params().cloned().unwrap_or(Value::Null);
620 let req: SetLevelRequest = match serde_json::from_value(params) {
621 Ok(req) => req,
622 Err(_) => {
623 return error_response(
624 ctx.request.id,
625 McpError::validation("invalid_params", "Invalid setLevel parameters")
626 .to_jsonrpc(),
627 );
628 }
629 };
630
631 let level = match req.level.parse::<LogLevel>() {
633 Ok(level) => level,
634 Err(_) => {
635 return error_response(
636 ctx.request.id,
637 McpError::validation(
638 "invalid_level",
639 format!(
640 "Invalid log level '{}'. Valid levels: debug, info, notice, warning, error, critical, alert, emergency",
641 req.level
642 ),
643 )
644 .to_jsonrpc(),
645 )
646 }
647 };
648
649 self.logger.set_min_level(level);
651
652 tracing::debug!(level = %req.level, "Log level updated");
653
654 success_response(ctx.request.id, serde_json::json!({}))
655 }
656
657 async fn handle_tools_list(&self, ctx: RequestContext) -> JsonRpcResponse {
659 if let Err(err) = require_initialization(&ctx) {
660 return error_response(ctx.request.id, err.to_jsonrpc());
661 }
662
663 let visibility_ctx = VisibilityContext::new(&ctx.session);
664 let tools = self
665 .tool_registry
666 .list_for_session(&ctx.session, &visibility_ctx);
667 success_response(ctx.request.id, serde_json::json!({"tools": tools}))
668 }
669
670 async fn handle_tools_call(&self, ctx: RequestContext) -> JsonRpcResponse {
672 if let Err(err) = require_initialization(&ctx) {
673 return error_response(ctx.request.id, err.to_jsonrpc());
674 }
675
676 let params = ctx.params().cloned().unwrap_or(Value::Null);
678 let tool_name = match params.get("name").and_then(|v| v.as_str()) {
679 Some(name) => name,
680 None => {
681 return error_response(
682 ctx.request.id,
683 McpError::validation("invalid_params", "Missing 'name' field").to_jsonrpc(),
684 );
685 }
686 };
687
688 let tool_params = params.get("arguments").cloned().unwrap_or(Value::Null);
689
690 let task_meta: Option<crate::protocol::types::TaskMetadata> = params
692 .get("task")
693 .and_then(|t| serde_json::from_value(t.clone()).ok());
694
695 if let Some(task_metadata) = task_meta {
696 return self
698 .handle_task_augmented_tool_call(ctx, tool_name, tool_params, task_metadata)
699 .await;
700 }
701
702 let client_requester = self.create_client_requester(&ctx.session.id);
704
705 match self
706 .tool_registry
707 .call(
708 tool_name,
709 tool_params,
710 &ctx.session,
711 &self.logger,
712 client_requester,
713 )
714 .await
715 {
716 Ok(content) => {
717 let content_values: Vec<Value> = content.iter().map(|c| c.to_value()).collect();
718 success_response(
719 ctx.request.id,
720 serde_json::json!({"content": content_values}),
721 )
722 }
723 Err(e) => error_response(
724 ctx.request.id,
725 McpError::internal("tool_execution_failed", e.to_string()).to_jsonrpc(),
726 ),
727 }
728 }
729
730 async fn handle_task_augmented_tool_call(
732 &self,
733 ctx: RequestContext,
734 tool_name: &str,
735 tool_params: Value,
736 task_metadata: crate::protocol::types::TaskMetadata,
737 ) -> JsonRpcResponse {
738 let (task, _result_rx) =
740 self.task_store
741 .create_task(&ctx.session.id, ctx.request.clone(), task_metadata.ttl);
742
743 let task_id = task.task_id.clone();
744
745 let task_store = self.task_store.clone();
747 let tool_registry = self.tool_registry.clone();
748 let logger = self.logger.clone();
749 let session = ctx.session.clone();
750 let client_requester = self.create_client_requester(&ctx.session.id);
751 let tool_name = tool_name.to_string();
752
753 tokio::spawn(async move {
754 match tool_registry
756 .call(&tool_name, tool_params, &session, &logger, client_requester)
757 .await
758 {
759 Ok(content) => {
760 let content_values: Vec<Value> = content.iter().map(|c| c.to_value()).collect();
762 let result = serde_json::json!({"content": content_values});
763
764 let _ = task_store
765 .update_status(
766 &task_id,
767 crate::protocol::types::TaskStatus::Completed,
768 None,
769 )
770 .await;
771 let _ = task_store.store_result(&task_id, result).await;
772 }
773 Err(e) => {
774 let error_message = e.to_string();
776 let _ = task_store
777 .update_status(
778 &task_id,
779 crate::protocol::types::TaskStatus::Failed,
780 Some(error_message.clone()),
781 )
782 .await;
783
784 let error_result = serde_json::json!({
786 "content": [{
787 "type": "text",
788 "text": error_message
789 }],
790 "isError": true
791 });
792 let _ = task_store.store_result(&task_id, error_result).await;
793 }
794 }
795 });
796
797 success_response(
799 ctx.request.id,
800 serde_json::to_value(crate::protocol::types::CreateTaskResult { task }).unwrap(),
801 )
802 }
803
804 async fn handle_resources_list(&self, ctx: RequestContext) -> JsonRpcResponse {
806 if let Err(err) = require_initialization(&ctx) {
807 return error_response(ctx.request.id, err.to_jsonrpc());
808 }
809
810 let visibility_ctx = VisibilityContext::new(&ctx.session);
811 let resources = self
812 .resource_manager
813 .list_for_session(&ctx.session, &visibility_ctx);
814 success_response(ctx.request.id, serde_json::json!({"resources": resources}))
815 }
816
817 async fn handle_resources_templates_list(&self, ctx: RequestContext) -> JsonRpcResponse {
819 if let Err(err) = require_initialization(&ctx) {
820 return error_response(ctx.request.id, err.to_jsonrpc());
821 }
822
823 let visibility_ctx = VisibilityContext::new(&ctx.session);
824 let templates = self
825 .resource_manager
826 .list_templates_for_session(&ctx.session, &visibility_ctx);
827 success_response(
828 ctx.request.id,
829 serde_json::json!({"resourceTemplates": templates}),
830 )
831 }
832
833 async fn handle_resources_read(&self, ctx: RequestContext) -> JsonRpcResponse {
835 if let Err(err) = require_initialization(&ctx) {
836 return error_response(ctx.request.id, err.to_jsonrpc());
837 }
838
839 let params = ctx.params().cloned().unwrap_or(Value::Null);
841 let uri = match params.get("uri").and_then(|v| v.as_str()) {
842 Some(uri) => uri,
843 None => {
844 return error_response(
845 ctx.request.id,
846 McpError::validation("invalid_params", "Missing 'uri' field").to_jsonrpc(),
847 );
848 }
849 };
850
851 match self
853 .resource_manager
854 .read(
855 uri,
856 std::collections::HashMap::new(),
857 &ctx.session,
858 &self.logger,
859 )
860 .await
861 {
862 Ok(contents) => {
863 let content_values: Vec<Value> = contents.iter().map(|c| c.to_value()).collect();
865 success_response(
866 ctx.request.id,
867 serde_json::json!({"contents": content_values}),
868 )
869 }
870 Err(e) => error_response(
871 ctx.request.id,
872 McpError::internal("resource_read_failed", e.to_string()).to_jsonrpc(),
873 ),
874 }
875 }
876
877 async fn handle_prompts_list(&self, ctx: RequestContext) -> JsonRpcResponse {
879 if let Err(err) = require_initialization(&ctx) {
880 return error_response(ctx.request.id, err.to_jsonrpc());
881 }
882
883 let visibility_ctx = VisibilityContext::new(&ctx.session);
884 let prompts = self
885 .prompt_manager
886 .list_for_session(&ctx.session, &visibility_ctx);
887 success_response(ctx.request.id, serde_json::json!({"prompts": prompts}))
888 }
889
890 async fn handle_prompts_get(&self, ctx: RequestContext) -> JsonRpcResponse {
892 if let Err(err) = require_initialization(&ctx) {
893 return error_response(ctx.request.id, err.to_jsonrpc());
894 }
895
896 let params = ctx.params().cloned().unwrap_or(Value::Null);
898 let prompt_name = match params.get("name").and_then(|v| v.as_str()) {
899 Some(name) => name,
900 None => {
901 return error_response(
902 ctx.request.id,
903 McpError::validation("invalid_params", "Missing 'name' field").to_jsonrpc(),
904 );
905 }
906 };
907
908 let prompt_params = params.get("arguments").cloned().unwrap_or(Value::Null);
909
910 match self
912 .prompt_manager
913 .call(prompt_name, prompt_params, &ctx.session, &self.logger)
914 .await
915 {
916 Ok(result) => success_response(
917 ctx.request.id,
918 serde_json::to_value(result).expect("Failed to serialize prompt result"),
919 ),
920 Err(e) => error_response(
921 ctx.request.id,
922 McpError::internal("prompt_get_failed", e.to_string()).to_jsonrpc(),
923 ),
924 }
925 }
926
927 async fn handle_roots_list(&self, ctx: RequestContext) -> JsonRpcResponse {
933 if let Err(err) = require_initialization(&ctx) {
934 return error_response(ctx.request.id, err.to_jsonrpc());
935 }
936
937 use crate::protocol::types::ListRootsResult;
941
942 let result = ListRootsResult { roots: vec![] };
943
944 success_response(
945 ctx.request.id,
946 serde_json::to_value(result).expect("Failed to serialize roots list"),
947 )
948 }
949
950 async fn handle_sampling_create_message(&self, ctx: RequestContext) -> JsonRpcResponse {
956 if let Err(err) = require_initialization(&ctx) {
957 return error_response(ctx.request.id, err.to_jsonrpc());
958 }
959
960 error_response(
963 ctx.request.id,
964 McpError::not_implemented(
965 "sampling/createMessage is a client capability. Use ClientRequester.create_message() for server→client requests."
966 ).to_jsonrpc(),
967 )
968 }
969
970 async fn handle_elicitation_create(&self, ctx: RequestContext) -> JsonRpcResponse {
975 if let Err(err) = require_initialization(&ctx) {
976 return error_response(ctx.request.id, err.to_jsonrpc());
977 }
978
979 error_response(
982 ctx.request.id,
983 McpError::not_implemented(
984 "elicitation/create is a client capability. Use ClientRequester.create_elicitation() for server→client requests."
985 ).to_jsonrpc(),
986 )
987 }
988
989 async fn handle_tasks_get(&self, ctx: RequestContext) -> JsonRpcResponse {
991 if let Err(err) = require_initialization(&ctx) {
992 return error_response(ctx.request.id, err.to_jsonrpc());
993 }
994
995 let params: crate::protocol::types::GetTaskParams = match ctx.params() {
996 Some(p) => match serde_json::from_value(p.clone()) {
997 Ok(params) => params,
998 Err(_) => {
999 return error_response(
1000 ctx.request.id,
1001 McpError::validation("invalid_params", "Missing or invalid taskId")
1002 .to_jsonrpc(),
1003 );
1004 }
1005 },
1006 None => {
1007 return error_response(
1008 ctx.request.id,
1009 McpError::validation("invalid_params", "Missing taskId parameter").to_jsonrpc(),
1010 );
1011 }
1012 };
1013
1014 match self
1015 .task_store
1016 .get_task_for_session(¶ms.task_id, &ctx.session.id)
1017 .await
1018 {
1019 Some(task) => success_response(ctx.request.id, serde_json::to_value(task).unwrap()),
1020 None => error_response(
1021 ctx.request.id,
1022 McpError::validation("invalid_params", "Task not found").to_jsonrpc(),
1023 ),
1024 }
1025 }
1026
1027 async fn handle_tasks_result(&self, ctx: RequestContext) -> JsonRpcResponse {
1029 if let Err(err) = require_initialization(&ctx) {
1030 return error_response(ctx.request.id, err.to_jsonrpc());
1031 }
1032
1033 let params: crate::protocol::types::GetTaskParams = match ctx.params() {
1034 Some(p) => match serde_json::from_value(p.clone()) {
1035 Ok(params) => params,
1036 Err(_) => {
1037 return error_response(
1038 ctx.request.id,
1039 McpError::validation("invalid_params", "Missing or invalid taskId")
1040 .to_jsonrpc(),
1041 );
1042 }
1043 },
1044 None => {
1045 return error_response(
1046 ctx.request.id,
1047 McpError::validation("invalid_params", "Missing taskId parameter").to_jsonrpc(),
1048 );
1049 }
1050 };
1051
1052 if self
1054 .task_store
1055 .get_task_for_session(¶ms.task_id, &ctx.session.id)
1056 .await
1057 .is_none()
1058 {
1059 return error_response(
1060 ctx.request.id,
1061 McpError::validation("invalid_params", "Task not found").to_jsonrpc(),
1062 );
1063 }
1064
1065 match self
1067 .task_store
1068 .wait_for_result(¶ms.task_id, std::time::Duration::from_secs(300))
1069 .await
1070 {
1071 Ok(result) => success_response(ctx.request.id, result),
1072 Err(e) => error_response(
1073 ctx.request.id,
1074 McpError::internal("task_error", e.to_string()).to_jsonrpc(),
1075 ),
1076 }
1077 }
1078
1079 async fn handle_tasks_list(&self, ctx: RequestContext) -> JsonRpcResponse {
1081 if let Err(err) = require_initialization(&ctx) {
1082 return error_response(ctx.request.id, err.to_jsonrpc());
1083 }
1084
1085 let cursor = ctx
1087 .params()
1088 .and_then(|p| p.get("cursor"))
1089 .and_then(|c| c.as_str());
1090
1091 let (tasks, next_cursor) = self
1092 .task_store
1093 .list_tasks(&ctx.session.id, cursor, 100)
1094 .await;
1095
1096 success_response(
1097 ctx.request.id,
1098 serde_json::json!({
1099 "tasks": tasks,
1100 "nextCursor": next_cursor,
1101 }),
1102 )
1103 }
1104
1105 async fn handle_tasks_cancel(&self, ctx: RequestContext) -> JsonRpcResponse {
1107 if let Err(err) = require_initialization(&ctx) {
1108 return error_response(ctx.request.id, err.to_jsonrpc());
1109 }
1110
1111 let params: crate::protocol::types::CancelTaskParams = match ctx.params() {
1112 Some(p) => match serde_json::from_value(p.clone()) {
1113 Ok(params) => params,
1114 Err(_) => {
1115 return error_response(
1116 ctx.request.id,
1117 McpError::validation("invalid_params", "Missing or invalid taskId")
1118 .to_jsonrpc(),
1119 );
1120 }
1121 },
1122 None => {
1123 return error_response(
1124 ctx.request.id,
1125 McpError::validation("invalid_params", "Missing taskId parameter").to_jsonrpc(),
1126 );
1127 }
1128 };
1129
1130 match self
1131 .task_store
1132 .cancel_task(¶ms.task_id, &ctx.session.id)
1133 .await
1134 {
1135 Ok(task) => success_response(ctx.request.id, serde_json::to_value(task).unwrap()),
1136 Err(e) => {
1137 let error_msg = match e {
1138 crate::managers::task::TaskError::NotFound(_) => {
1139 McpError::validation("invalid_params", "Task not found")
1140 }
1141 crate::managers::task::TaskError::AlreadyTerminal(status) => {
1142 McpError::validation(
1143 "invalid_params",
1144 format!(
1145 "Cannot cancel task: already in terminal status '{:?}'",
1146 status
1147 ),
1148 )
1149 }
1150 _ => McpError::internal("task_error", e.to_string()),
1151 };
1152 error_response(ctx.request.id, error_msg.to_jsonrpc())
1153 }
1154 }
1155 }
1156}
1157
1158#[cfg(test)]
1159mod tests {
1160 use super::*;
1161
1162 #[tokio::test]
1163 async fn test_server_creation() {
1164 let server = Server::new("test-server", "1.0.0");
1165 assert_eq!(server.name(), "test-server");
1166 assert_eq!(server.version(), "1.0.0");
1167 }
1168
1169 #[tokio::test]
1170 async fn test_ping() {
1171 let server = Server::new("test-server", "1.0.0");
1172
1173 let request = JsonRpcRequest {
1174 jsonrpc: "2.0".to_string(),
1175 id: Some(Value::Number(1.into())),
1176 method: "ping".to_string(),
1177 params: None,
1178 };
1179
1180 let response = server.handle_request("test-session", request).await;
1181
1182 assert!(response.result.is_some());
1183 assert!(response.error.is_none());
1184 }
1185
1186 #[tokio::test]
1187 async fn test_initialize() {
1188 let server = Server::new("test-server", "1.0.0");
1189
1190 let request = JsonRpcRequest {
1191 jsonrpc: "2.0".to_string(),
1192 id: Some(Value::Number(1.into())),
1193 method: "initialize".to_string(),
1194 params: Some(serde_json::json!({
1195 "protocolVersion": "2025-11-25",
1196 "capabilities": {},
1197 "clientInfo": {
1198 "name": "test-client",
1199 "version": "1.0.0"
1200 }
1201 })),
1202 };
1203
1204 let response = server.handle_request("test-session", request).await;
1205
1206 assert!(response.result.is_some());
1207 assert!(response.error.is_none());
1208
1209 let session = server.get_session("test-session").unwrap();
1211 assert!(session.is_initialized());
1212 assert_eq!(session.client_info.unwrap().name, "test-client");
1213 }
1214
1215 #[tokio::test]
1216 async fn test_method_not_found() {
1217 let server = Server::new("test-server", "1.0.0");
1218
1219 let request = JsonRpcRequest {
1220 jsonrpc: "2.0".to_string(),
1221 id: Some(Value::Number(1.into())),
1222 method: "unknown/method".to_string(),
1223 params: None,
1224 };
1225
1226 let response = server.handle_request("test-session", request).await;
1227
1228 assert!(response.result.is_none());
1229 assert!(response.error.is_some());
1230 assert_eq!(response.error.unwrap().code, -32601);
1231 }
1232
1233 #[tokio::test]
1234 async fn test_requires_initialization() {
1235 let server = Server::new("test-server", "1.0.0");
1236
1237 let request = JsonRpcRequest {
1238 jsonrpc: "2.0".to_string(),
1239 id: Some(Value::Number(1.into())),
1240 method: "tools/list".to_string(),
1241 params: None,
1242 };
1243
1244 let response = server.handle_request("test-session", request.clone()).await;
1246 assert!(response.error.is_some());
1247
1248 let init_request = JsonRpcRequest {
1250 jsonrpc: "2.0".to_string(),
1251 id: Some(Value::Number(2.into())),
1252 method: "initialize".to_string(),
1253 params: Some(serde_json::json!({
1254 "protocolVersion": "2025-11-25",
1255 "capabilities": {},
1256 "clientInfo": {
1257 "name": "test-client",
1258 "version": "1.0.0"
1259 }
1260 })),
1261 };
1262 server.handle_request("test-session", init_request).await;
1263
1264 let response = server.handle_request("test-session", request).await;
1266 assert!(response.result.is_some());
1267 }
1268
1269 #[tokio::test]
1270 async fn test_session_management() {
1271 let server = Server::new("test-server", "1.0.0");
1272
1273 let request = JsonRpcRequest {
1275 jsonrpc: "2.0".to_string(),
1276 id: Some(Value::Number(1.into())),
1277 method: "ping".to_string(),
1278 params: None,
1279 };
1280 server.handle_request("session-1", request).await;
1281
1282 assert!(server.get_session("session-1").is_some());
1284
1285 let removed = server.remove_session("session-1");
1287 assert!(removed.is_some());
1288
1289 assert!(server.get_session("session-1").is_none());
1291 }
1292
1293 #[tokio::test]
1294 async fn test_capabilities_update() {
1295 let server = Server::new("test-server", "1.0.0");
1296
1297 let caps = ServerCapabilities {
1298 tools: Some(crate::protocol::capabilities::ToolsCapability {
1299 list_changed: Some(true),
1300 }),
1301 ..Default::default()
1302 };
1303
1304 server.set_capabilities(caps.clone());
1305
1306 let loaded_caps = server.capabilities();
1307 assert_eq!(loaded_caps.tools, caps.tools);
1308 }
1309
1310 async fn init_test_session(server: &Server, session_id: &str) {
1316 let request = JsonRpcRequest {
1317 jsonrpc: "2.0".to_string(),
1318 id: Some(Value::Number(1.into())),
1319 method: "initialize".to_string(),
1320 params: Some(serde_json::json!({
1321 "protocolVersion": "2025-11-25",
1322 "capabilities": {
1323 "tasks": {
1324 "list": {},
1325 "cancel": {},
1326 "requests": {
1327 "tools": {
1328 "call": {}
1329 }
1330 }
1331 }
1332 },
1333 "clientInfo": {
1334 "name": "test-client",
1335 "version": "1.0.0"
1336 }
1337 })),
1338 };
1339
1340 server.handle_request(session_id, request).await;
1341 }
1342
1343 struct TestTaskTool;
1345
1346 #[async_trait::async_trait]
1347 impl crate::registry::tools::Tool for TestTaskTool {
1348 fn name(&self) -> &str {
1349 "test_task"
1350 }
1351
1352 fn description(&self) -> Option<&str> {
1353 Some("Test tool for task execution")
1354 }
1355
1356 fn input_schema(&self) -> Value {
1357 serde_json::json!({
1358 "type": "object",
1359 "properties": {
1360 "message": {"type": "string"}
1361 }
1362 })
1363 }
1364
1365 fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
1366 Some(crate::protocol::types::ToolExecution {
1367 task_support: Some(crate::protocol::types::TaskSupport::Optional),
1368 })
1369 }
1370
1371 async fn execute(
1372 &self,
1373 ctx: crate::prelude::ExecutionContext<'_>,
1374 ) -> Result<Vec<Box<dyn crate::content::types::Content>>, crate::registry::tools::ToolError>
1375 {
1376 let msg = ctx
1377 .params
1378 .get("message")
1379 .and_then(|v| v.as_str())
1380 .unwrap_or("default");
1381
1382 Ok(vec![Box::new(crate::content::types::TextContent::new(
1383 format!("Processed: {}", msg),
1384 ))])
1385 }
1386 }
1387
1388 struct SlowTestTool;
1390
1391 #[async_trait::async_trait]
1392 impl crate::registry::tools::Tool for SlowTestTool {
1393 fn name(&self) -> &str {
1394 "slow_test"
1395 }
1396
1397 fn description(&self) -> Option<&str> {
1398 Some("Slow test tool")
1399 }
1400
1401 fn input_schema(&self) -> Value {
1402 serde_json::json!({"type": "object"})
1403 }
1404
1405 fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
1406 Some(crate::protocol::types::ToolExecution {
1407 task_support: Some(crate::protocol::types::TaskSupport::Optional),
1408 })
1409 }
1410
1411 async fn execute(
1412 &self,
1413 _ctx: crate::prelude::ExecutionContext<'_>,
1414 ) -> Result<Vec<Box<dyn crate::content::types::Content>>, crate::registry::tools::ToolError>
1415 {
1416 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1417 Ok(vec![Box::new(crate::content::types::TextContent::new(
1418 "Slow operation complete",
1419 ))])
1420 }
1421 }
1422
1423 #[tokio::test]
1424 async fn test_task_augmented_tool_call() {
1425 let server = Server::new("test-server", "1.0.0");
1426 server.tool_registry().register(TestTaskTool);
1427
1428 init_test_session(&server, "test-session").await;
1429
1430 let request = JsonRpcRequest {
1432 jsonrpc: "2.0".to_string(),
1433 id: Some(Value::Number(2.into())),
1434 method: "tools/call".to_string(),
1435 params: Some(serde_json::json!({
1436 "name": "test_task",
1437 "arguments": {"message": "hello"},
1438 "task": {"ttl": 60000}
1439 })),
1440 };
1441
1442 let response = server.handle_request("test-session", request).await;
1443
1444 assert!(response.result.is_some());
1446 assert!(response.error.is_none());
1447
1448 let result = response.result.unwrap();
1449 assert!(result.get("task").is_some());
1450
1451 let task = result.get("task").unwrap();
1452 assert!(task.get("taskId").is_some());
1453 assert_eq!(task.get("status").unwrap().as_str().unwrap(), "working");
1454 assert!(task.get("createdAt").is_some());
1455 assert_eq!(task.get("ttl").unwrap().as_u64().unwrap(), 60000);
1456 }
1457
1458 #[tokio::test]
1459 async fn test_task_get_status() {
1460 let server = Server::new("test-server", "1.0.0");
1461 server.tool_registry().register(SlowTestTool);
1462
1463 init_test_session(&server, "test-session").await;
1464
1465 let create_request = JsonRpcRequest {
1467 jsonrpc: "2.0".to_string(),
1468 id: Some(Value::Number(2.into())),
1469 method: "tools/call".to_string(),
1470 params: Some(serde_json::json!({
1471 "name": "slow_test",
1472 "arguments": {},
1473 "task": {"ttl": 60000}
1474 })),
1475 };
1476
1477 let create_response = server.handle_request("test-session", create_request).await;
1478 let task_id = create_response.result.unwrap()["task"]["taskId"]
1479 .as_str()
1480 .unwrap()
1481 .to_string();
1482
1483 let get_request = JsonRpcRequest {
1485 jsonrpc: "2.0".to_string(),
1486 id: Some(Value::Number(3.into())),
1487 method: "tasks/get".to_string(),
1488 params: Some(serde_json::json!({"taskId": task_id})),
1489 };
1490
1491 let get_response = server.handle_request("test-session", get_request).await;
1492
1493 assert!(get_response.result.is_some());
1494 let result = get_response.result.unwrap();
1495 let status = result["status"].as_str().unwrap();
1496 assert!(status == "working" || status == "completed");
1497 }
1498
1499 #[tokio::test]
1500 async fn test_task_result_blocking() {
1501 let server = Server::new("test-server", "1.0.0");
1502 server.tool_registry().register(SlowTestTool);
1503
1504 init_test_session(&server, "test-session").await;
1505
1506 let create_request = JsonRpcRequest {
1508 jsonrpc: "2.0".to_string(),
1509 id: Some(Value::Number(2.into())),
1510 method: "tools/call".to_string(),
1511 params: Some(serde_json::json!({
1512 "name": "slow_test",
1513 "arguments": {},
1514 "task": {"ttl": 60000}
1515 })),
1516 };
1517
1518 let create_response = server.handle_request("test-session", create_request).await;
1519 let task_id = create_response.result.unwrap()["task"]["taskId"]
1520 .as_str()
1521 .unwrap()
1522 .to_string();
1523
1524 let result_request = JsonRpcRequest {
1526 jsonrpc: "2.0".to_string(),
1527 id: Some(Value::Number(3.into())),
1528 method: "tasks/result".to_string(),
1529 params: Some(serde_json::json!({"taskId": task_id})),
1530 };
1531
1532 let result_response = server.handle_request("test-session", result_request).await;
1533
1534 assert!(result_response.result.is_some());
1535 assert!(result_response.error.is_none());
1536
1537 let result = result_response.result.unwrap();
1539 assert!(result.get("content").is_some());
1540 }
1541
1542 #[tokio::test]
1543 async fn test_task_cancel() {
1544 let server = Server::new("test-server", "1.0.0");
1545 server.tool_registry().register(SlowTestTool);
1546
1547 init_test_session(&server, "test-session").await;
1548
1549 let create_request = JsonRpcRequest {
1551 jsonrpc: "2.0".to_string(),
1552 id: Some(Value::Number(2.into())),
1553 method: "tools/call".to_string(),
1554 params: Some(serde_json::json!({
1555 "name": "slow_test",
1556 "arguments": {},
1557 "task": {"ttl": 60000}
1558 })),
1559 };
1560
1561 let create_response = server.handle_request("test-session", create_request).await;
1562 let task_id = create_response.result.unwrap()["task"]["taskId"]
1563 .as_str()
1564 .unwrap()
1565 .to_string();
1566
1567 let cancel_request = JsonRpcRequest {
1569 jsonrpc: "2.0".to_string(),
1570 id: Some(Value::Number(3.into())),
1571 method: "tasks/cancel".to_string(),
1572 params: Some(serde_json::json!({"taskId": task_id})),
1573 };
1574
1575 let cancel_response = server.handle_request("test-session", cancel_request).await;
1576
1577 if cancel_response.result.is_some() {
1579 let result = cancel_response.result.unwrap();
1580 let status = result["status"].as_str().unwrap();
1581 assert_eq!(status, "cancelled");
1582 }
1583 }
1585
1586 #[tokio::test]
1587 async fn test_task_list() {
1588 let server = Server::new("test-server", "1.0.0");
1589 server.tool_registry().register(TestTaskTool);
1590
1591 init_test_session(&server, "test-session").await;
1592
1593 for i in 0..3 {
1595 let request = JsonRpcRequest {
1596 jsonrpc: "2.0".to_string(),
1597 id: Some(Value::Number((i + 2).into())),
1598 method: "tools/call".to_string(),
1599 params: Some(serde_json::json!({
1600 "name": "test_task",
1601 "arguments": {"message": format!("task-{}", i)},
1602 "task": {"ttl": 60000}
1603 })),
1604 };
1605 server.handle_request("test-session", request).await;
1606 }
1607
1608 let list_request = JsonRpcRequest {
1610 jsonrpc: "2.0".to_string(),
1611 id: Some(Value::Number(10.into())),
1612 method: "tasks/list".to_string(),
1613 params: None,
1614 };
1615
1616 let list_response = server.handle_request("test-session", list_request).await;
1617
1618 assert!(list_response.result.is_some());
1619 let result = list_response.result.unwrap();
1620 let tasks = result["tasks"].as_array().unwrap();
1621 assert!(tasks.len() >= 3);
1622 }
1623
1624 #[tokio::test]
1625 async fn test_task_session_isolation() {
1626 let server = Server::new("test-server", "1.0.0");
1627 server.tool_registry().register(TestTaskTool);
1628
1629 init_test_session(&server, "session-1").await;
1630 init_test_session(&server, "session-2").await;
1631
1632 let request = JsonRpcRequest {
1634 jsonrpc: "2.0".to_string(),
1635 id: Some(Value::Number(2.into())),
1636 method: "tools/call".to_string(),
1637 params: Some(serde_json::json!({
1638 "name": "test_task",
1639 "arguments": {"message": "private"},
1640 "task": {"ttl": 60000}
1641 })),
1642 };
1643
1644 let response = server.handle_request("session-1", request).await;
1645 let task_id = response.result.unwrap()["task"]["taskId"]
1646 .as_str()
1647 .unwrap()
1648 .to_string();
1649
1650 let get_request = JsonRpcRequest {
1652 jsonrpc: "2.0".to_string(),
1653 id: Some(Value::Number(3.into())),
1654 method: "tasks/get".to_string(),
1655 params: Some(serde_json::json!({"taskId": task_id})),
1656 };
1657
1658 let get_response = server.handle_request("session-2", get_request).await;
1659
1660 assert!(get_response.error.is_some());
1662 }
1663
1664 #[tokio::test]
1665 async fn test_task_not_found() {
1666 let server = Server::new("test-server", "1.0.0");
1667 init_test_session(&server, "test-session").await;
1668
1669 let request = JsonRpcRequest {
1670 jsonrpc: "2.0".to_string(),
1671 id: Some(Value::Number(2.into())),
1672 method: "tasks/get".to_string(),
1673 params: Some(serde_json::json!({"taskId": "nonexistent-task-id"})),
1674 };
1675
1676 let response = server.handle_request("test-session", request).await;
1677
1678 assert!(response.error.is_some());
1679 assert_eq!(response.error.unwrap().code, -32602);
1680 }
1681
1682 #[tokio::test]
1683 async fn test_task_double_cancel() {
1684 let server = Server::new("test-server", "1.0.0");
1685 server.tool_registry().register(SlowTestTool);
1686
1687 init_test_session(&server, "test-session").await;
1688
1689 let create_request = JsonRpcRequest {
1691 jsonrpc: "2.0".to_string(),
1692 id: Some(Value::Number(2.into())),
1693 method: "tools/call".to_string(),
1694 params: Some(serde_json::json!({
1695 "name": "slow_test",
1696 "arguments": {},
1697 "task": {"ttl": 60000}
1698 })),
1699 };
1700
1701 let create_response = server.handle_request("test-session", create_request).await;
1702 let task_id = create_response.result.unwrap()["task"]["taskId"]
1703 .as_str()
1704 .unwrap()
1705 .to_string();
1706
1707 let cancel_request = JsonRpcRequest {
1709 jsonrpc: "2.0".to_string(),
1710 id: Some(Value::Number(3.into())),
1711 method: "tasks/cancel".to_string(),
1712 params: Some(serde_json::json!({"taskId": task_id.clone()})),
1713 };
1714
1715 let _ = server
1716 .handle_request("test-session", cancel_request.clone())
1717 .await;
1718
1719 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1721
1722 let cancel_request2 = JsonRpcRequest {
1724 jsonrpc: "2.0".to_string(),
1725 id: Some(Value::Number(4.into())),
1726 method: "tasks/cancel".to_string(),
1727 params: Some(serde_json::json!({"taskId": task_id})),
1728 };
1729
1730 let cancel_response2 = server.handle_request("test-session", cancel_request2).await;
1731
1732 assert!(cancel_response2.error.is_some());
1734 }
1735}