1use std::sync::Arc;
14use std::time::Duration;
15
16use arc_swap::ArcSwap;
17use dashmap::DashMap;
18use serde_json::Value;
19use tokio::sync::mpsc;
20
21use crate::content::types::Content;
22use crate::protocol::capabilities::{InitializeRequest, InitializeResult, ServerCapabilities};
23use crate::protocol::errors::{ErrorType, McpError};
24use crate::protocol::methods::McpMethod;
25use crate::protocol::types::{Implementation, JsonRpcRequest, JsonRpcResponse};
26use crate::protocol::version;
27use crate::registry::tools::ToolRegistry;
28use crate::registry::resources::ResourceManager;
29use crate::registry::prompts::PromptManager;
30use crate::server::handler::{error_response, require_initialization, success_response, RequestContext};
31use crate::server::middleware::MiddlewareChain;
32use crate::server::multiplexer::{
33 ClientRequester, CreateMessageParams, CreateMessageResult, JsonRpcClientRequest,
34 ListRootsResult, MultiplexerError, RequestMultiplexer, Root,
35};
36use crate::server::session::Session;
37use crate::server::visibility::VisibilityContext;
38use crate::transport::traits::{IncomingMessage, JsonRpcNotification, Transport};
39
40pub struct Server {
42 name: String,
44
45 version: String,
47
48 instructions: Arc<ArcSwap<Option<String>>>,
50
51 capabilities: Arc<ArcSwap<ServerCapabilities>>,
53
54 sessions: DashMap<String, Session>,
56
57 middleware: MiddlewareChain,
59
60 tool_registry: ToolRegistry,
62
63 resource_manager: ResourceManager,
65
66 prompt_manager: PromptManager,
68
69 notification_tx: mpsc::UnboundedSender<JsonRpcNotification>,
71
72 notification_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<JsonRpcNotification>>>,
74
75 multiplexer: Arc<RequestMultiplexer>,
77
78 request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
80
81 request_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<JsonRpcClientRequest>>>,
83
84 task_store: Arc<crate::managers::task::TaskStore>,
86}
87
88impl Server {
89 pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
91 let (notification_tx, notification_rx) = mpsc::unbounded_channel();
92 let (request_tx, request_rx) = mpsc::unbounded_channel();
93
94 let task_store = Arc::new(crate::managers::task::TaskStore::new(
96 std::time::Duration::from_secs(300),
97 std::time::Duration::from_secs(5),
98 ));
99
100 if tokio::runtime::Handle::try_current().is_ok() {
102 task_store.clone().spawn_cleanup_task(std::time::Duration::from_secs(60));
103 }
104
105 let default_caps = ServerCapabilities {
107 tasks: Some(crate::protocol::capabilities::TasksCapability {
108 list: Some(crate::protocol::capabilities::EmptyObject {}),
109 cancel: Some(crate::protocol::capabilities::EmptyObject {}),
110 requests: Some(crate::protocol::capabilities::TasksRequestsCapability {
111 tools: Some(crate::protocol::capabilities::TasksToolsCapability {
112 call: Some(crate::protocol::capabilities::EmptyObject {}),
113 }),
114 ..Default::default()
115 }),
116 }),
117 ..Default::default()
118 };
119
120 Self {
121 name: name.into(),
122 version: version.into(),
123 instructions: Arc::new(ArcSwap::new(Arc::new(None))),
124 capabilities: Arc::new(ArcSwap::new(Arc::new(default_caps))),
125 sessions: DashMap::new(),
126 middleware: MiddlewareChain::new(),
127 tool_registry: ToolRegistry::new(),
128 resource_manager: ResourceManager::new(),
129 prompt_manager: PromptManager::new(),
130 notification_tx,
131 notification_rx: Arc::new(tokio::sync::Mutex::new(notification_rx)),
132 multiplexer: Arc::new(RequestMultiplexer::new()),
133 request_tx,
134 request_rx: Arc::new(tokio::sync::Mutex::new(request_rx)),
135 task_store,
136 }
137 }
138
139 pub fn name(&self) -> &str {
141 &self.name
142 }
143
144 pub fn version(&self) -> &str {
146 &self.version
147 }
148
149 pub fn capabilities(&self) -> Arc<ServerCapabilities> {
151 self.capabilities.load_full()
152 }
153
154 pub fn set_capabilities(&self, capabilities: ServerCapabilities) {
156 self.capabilities.store(Arc::new(capabilities));
157 }
158
159 pub fn instructions(&self) -> Option<String> {
161 (**self.instructions.load()).clone()
162 }
163
164 pub fn set_instructions(&self, instructions: Option<String>) {
166 self.instructions.store(Arc::new(instructions));
167 }
168
169 pub fn add_middleware(&mut self, middleware: crate::server::middleware::MiddlewareFn) {
171 self.middleware.add(middleware);
172 }
173
174 pub fn tool_registry(&self) -> &ToolRegistry {
176 &self.tool_registry
177 }
178
179 pub fn resource_manager(&self) -> &ResourceManager {
181 &self.resource_manager
182 }
183
184 pub fn prompt_manager(&self) -> &PromptManager {
186 &self.prompt_manager
187 }
188
189 pub fn notification_sender(&self) -> mpsc::UnboundedSender<JsonRpcNotification> {
191 self.notification_tx.clone()
192 }
193
194 pub fn send_notification(&self, method: impl Into<String>, params: Option<Value>) -> Result<(), Box<dyn std::error::Error>> {
196 let notification = JsonRpcNotification::new(method, params);
197 self.notification_tx.send(notification)?;
198 Ok(())
199 }
200
201 pub fn get_session(&self, session_id: &str) -> Option<Session> {
203 self.sessions.get(session_id).map(|s| s.clone())
204 }
205
206 pub fn remove_session(&self, session_id: &str) -> Option<Session> {
208 self.sessions.remove(session_id).map(|(_, s)| s)
209 }
210
211 pub fn multiplexer(&self) -> Arc<RequestMultiplexer> {
213 self.multiplexer.clone()
214 }
215
216 pub fn create_client_requester(&self, session_id: &str) -> Option<ClientRequester> {
221 let session = self.get_session(session_id)?;
222 let caps = session.capabilities.as_ref()?;
223
224 Some(ClientRequester::new(
225 self.request_tx.clone(),
226 self.multiplexer.clone(),
227 caps.roots.is_some(),
228 caps.sampling.is_some(),
229 ))
230 }
231
232 pub async fn request_roots(
258 &self,
259 session_id: &str,
260 timeout: Option<Duration>,
261 ) -> Result<Vec<Root>, MultiplexerError> {
262 if let Some(session) = self.get_session(session_id) {
264 if let Some(caps) = &session.capabilities {
265 if caps.roots.is_none() {
266 return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
267 }
268 } else {
269 return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
270 }
271 } else {
272 return Err(MultiplexerError::Transport("session not found".to_string()));
273 }
274
275 let (id, rx) = self.multiplexer.create_pending("roots/list");
277
278 let request = JsonRpcClientRequest::new(&id, "roots/list", Some(serde_json::json!({})));
280
281 self.request_tx
282 .send(request)
283 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
284
285 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
287 let result = tokio::time::timeout(timeout, rx)
288 .await
289 .map_err(|_| MultiplexerError::Timeout(timeout))?
290 .map_err(|_| MultiplexerError::ChannelClosed)??;
291
292 let list_result: ListRootsResult = serde_json::from_value(result)?;
294 Ok(list_result.roots)
295 }
296
297 pub async fn request_sampling(
333 &self,
334 session_id: &str,
335 params: CreateMessageParams,
336 timeout: Option<Duration>,
337 ) -> Result<CreateMessageResult, MultiplexerError> {
338 if let Some(session) = self.get_session(session_id) {
340 if let Some(caps) = &session.capabilities {
341 if caps.sampling.is_none() {
342 return Err(MultiplexerError::UnsupportedCapability("sampling".to_string()));
343 }
344 } else {
345 return Err(MultiplexerError::UnsupportedCapability("sampling".to_string()));
346 }
347 } else {
348 return Err(MultiplexerError::Transport("session not found".to_string()));
349 }
350
351 let (id, rx) = self.multiplexer.create_pending("sampling/createMessage");
353
354 let params_value = serde_json::to_value(¶ms)?;
356 let request = JsonRpcClientRequest::new(&id, "sampling/createMessage", Some(params_value));
357
358 self.request_tx
359 .send(request)
360 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
361
362 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
364 let result = tokio::time::timeout(timeout, rx)
365 .await
366 .map_err(|_| MultiplexerError::Timeout(timeout))?
367 .map_err(|_| MultiplexerError::ChannelClosed)??;
368
369 let create_result: CreateMessageResult = serde_json::from_value(result)?;
371 Ok(create_result)
372 }
373
374 pub async fn run<T: Transport>(&self, mut transport: T) -> Result<(), Box<dyn std::error::Error>> {
383 let session_id = uuid::Uuid::new_v4().to_string();
385
386 let mut notification_rx = self.notification_rx.lock().await;
388 let mut request_rx = self.request_rx.lock().await;
389
390 loop {
391 tokio::select! {
392 Some(notification) = notification_rx.recv() => {
394 tracing::debug!(method = %notification.method, "Sending notification");
395 if let Err(e) = transport.send_notification(notification).await {
396 tracing::error!(error = %e, "Failed to send notification");
397 }
398 }
399
400 Some(request) = request_rx.recv() => {
402 tracing::debug!(method = %request.method, id = %request.id, "Sending request to client");
403 if let Err(e) = transport.send_request(request).await {
404 tracing::error!(error = %e, "Failed to send request to client");
405 }
406 }
407
408 result = transport.read_incoming() => {
410 match result {
411 Ok(IncomingMessage::Request(request)) => {
412 let is_notification = request.id.is_none();
414
415 let response = self.handle_request(&session_id, request).await;
417
418 if !is_notification
421 && let Err(e) = transport.write_message(response).await {
422 tracing::error!(error = %e, "Failed to write message");
423 break;
424 }
425 }
426 Ok(IncomingMessage::Response(response)) => {
427 if !self.multiplexer.route_response(&response) {
429 tracing::warn!(
430 id = ?response.id,
431 "Received response for unknown request ID"
432 );
433 }
434 }
435 Err(crate::transport::traits::TransportError::Closed) => {
436 tracing::info!("Transport closed, shutting down");
437 break;
438 }
439 Err(e) => {
440 tracing::error!(error = %e, "Failed to read message");
441 continue;
442 }
443 }
444 }
445 }
446 }
447
448 self.multiplexer.cancel_all();
450
451 transport.shutdown().await?;
453
454 Ok(())
455 }
456
457 pub async fn handle_request(
459 &self,
460 session_id: &str,
461 request: JsonRpcRequest,
462 ) -> JsonRpcResponse {
463 let session = self
465 .sessions
466 .entry(session_id.to_string())
467 .or_insert_with(|| Session::with_id(session_id))
468 .clone();
469
470 let ctx = RequestContext::new(session, request.clone());
472
473 let ctx = match self.middleware.process(ctx) {
475 Ok(ctx) => ctx,
476 Err(err) => return error_response(request.id, err.to_jsonrpc()),
477 };
478
479 let method = McpMethod::from(request.method.clone());
481
482 match method {
483 McpMethod::Initialize => self.handle_initialize(ctx).await,
484 McpMethod::Ping => self.handle_ping(ctx).await,
485 McpMethod::ToolsList => self.handle_tools_list(ctx).await,
486 McpMethod::ToolsCall => self.handle_tools_call(ctx).await,
487 McpMethod::ResourcesList => self.handle_resources_list(ctx).await,
488 McpMethod::ResourcesTemplatesList => self.handle_resources_templates_list(ctx).await,
489 McpMethod::ResourcesRead => self.handle_resources_read(ctx).await,
490 McpMethod::PromptsList => self.handle_prompts_list(ctx).await,
491 McpMethod::PromptsGet => self.handle_prompts_get(ctx).await,
492 McpMethod::RootsList => self.handle_roots_list(ctx).await,
493 McpMethod::SamplingCreateMessage => self.handle_sampling_create_message(ctx).await,
494 McpMethod::ElicitationCreate => self.handle_elicitation_create(ctx).await,
495 McpMethod::TasksGet => self.handle_tasks_get(ctx).await,
496 McpMethod::TasksResult => self.handle_tasks_result(ctx).await,
497 McpMethod::TasksList => self.handle_tasks_list(ctx).await,
498 McpMethod::TasksCancel => self.handle_tasks_cancel(ctx).await,
499 _ => error_response(
500 request.id,
501 McpError::method_not_found(&request.method).to_jsonrpc(),
502 ),
503 }
504 }
505
506 async fn handle_initialize(&self, ctx: RequestContext) -> JsonRpcResponse {
508 let params = ctx.params().cloned().unwrap_or(Value::Null);
510 let req: InitializeRequest = match serde_json::from_value(params) {
511 Ok(req) => req,
512 Err(_) => {
513 return error_response(
514 ctx.request.id,
515 McpError::validation("invalid_params", "Invalid initialize parameters")
516 .to_jsonrpc(),
517 )
518 }
519 };
520
521 let protocol_version = match version::negotiate_protocol_version(&req.protocol_version) {
523 Ok(version) => version,
524 Err(supported_versions) => {
525 tracing::warn!(
526 client = %req.client_info.name,
527 requested = %req.protocol_version,
528 supported = ?supported_versions,
529 "Unsupported protocol version"
530 );
531 return error_response(
532 ctx.request.id,
533 McpError::builder(ErrorType::Validation, "unsupported_protocol_version")
534 .message("Unsupported protocol version")
535 .detail("supported", serde_json::to_value(&supported_versions).unwrap())
536 .detail("requested", req.protocol_version.clone())
537 .build()
538 .to_jsonrpc()
539 );
540 }
541 };
542
543 tracing::info!(
545 client = %req.client_info.name,
546 version = %req.client_info.version,
547 protocol = %protocol_version,
548 "Client connected"
549 );
550
551 if let Some(mut session) = self.sessions.get_mut(&ctx.session.id) {
553 session.initialize(req.client_info, req.capabilities, protocol_version.clone());
554 }
555
556 let result = InitializeResult {
558 protocol_version,
559 capabilities: (**self.capabilities.load()).clone(),
560 server_info: Implementation {
561 name: self.name.clone(),
562 version: self.version.clone(),
563 },
564 instructions: self.instructions(),
565 };
566
567 success_response(
568 ctx.request.id,
569 serde_json::to_value(result).expect("Failed to serialize InitializeResult"),
570 )
571 }
572
573 async fn handle_ping(&self, ctx: RequestContext) -> JsonRpcResponse {
575 success_response(ctx.request.id, serde_json::json!({}))
576 }
577
578 async fn handle_tools_list(&self, ctx: RequestContext) -> JsonRpcResponse {
580 if let Err(err) = require_initialization(&ctx) {
581 return error_response(ctx.request.id, err.to_jsonrpc());
582 }
583
584 let visibility_ctx = VisibilityContext::new(&ctx.session);
585 let tools = self.tool_registry.list_for_session(&ctx.session, &visibility_ctx);
586 success_response(ctx.request.id, serde_json::json!({"tools": tools}))
587 }
588
589 async fn handle_tools_call(&self, ctx: RequestContext) -> JsonRpcResponse {
591 if let Err(err) = require_initialization(&ctx) {
592 return error_response(ctx.request.id, err.to_jsonrpc());
593 }
594
595 let params = ctx.params().cloned().unwrap_or(Value::Null);
597 let tool_name = match params.get("name").and_then(|v| v.as_str()) {
598 Some(name) => name,
599 None => {
600 return error_response(
601 ctx.request.id,
602 McpError::validation("invalid_params", "Missing 'name' field").to_jsonrpc(),
603 )
604 }
605 };
606
607 let tool_params = params.get("arguments").cloned().unwrap_or(Value::Null);
608
609 let task_meta: Option<crate::protocol::types::TaskMetadata> = params
611 .get("task")
612 .and_then(|t| serde_json::from_value(t.clone()).ok());
613
614 if let Some(task_metadata) = task_meta {
615 return self
617 .handle_task_augmented_tool_call(ctx, tool_name, tool_params, task_metadata)
618 .await;
619 }
620
621 let client_requester = self.create_client_requester(&ctx.session.id);
623
624 match self
625 .tool_registry
626 .call(tool_name, tool_params, &ctx.session, client_requester)
627 .await
628 {
629 Ok(content) => {
630 let content_values: Vec<Value> = content.iter().map(|c| c.to_value()).collect();
631 success_response(
632 ctx.request.id,
633 serde_json::json!({"content": content_values}),
634 )
635 }
636 Err(e) => error_response(
637 ctx.request.id,
638 McpError::internal("tool_execution_failed", e.to_string()).to_jsonrpc(),
639 ),
640 }
641 }
642
643 async fn handle_task_augmented_tool_call(
645 &self,
646 ctx: RequestContext,
647 tool_name: &str,
648 tool_params: Value,
649 task_metadata: crate::protocol::types::TaskMetadata,
650 ) -> JsonRpcResponse {
651 let (task, _result_rx) = self.task_store.create_task(
653 &ctx.session.id,
654 ctx.request.clone(),
655 task_metadata.ttl,
656 );
657
658 let task_id = task.task_id.clone();
659
660 let task_store = self.task_store.clone();
662 let tool_registry = self.tool_registry.clone();
663 let session = ctx.session.clone();
664 let client_requester = self.create_client_requester(&ctx.session.id);
665 let tool_name = tool_name.to_string();
666
667 tokio::spawn(async move {
668 match tool_registry.call(&tool_name, tool_params, &session, client_requester).await {
670 Ok(content) => {
671 let content_values: Vec<Value> =
673 content.iter().map(|c| c.to_value()).collect();
674 let result = serde_json::json!({"content": content_values});
675
676 let _ = task_store.update_status(&task_id, crate::protocol::types::TaskStatus::Completed, None).await;
677 let _ = task_store.store_result(&task_id, result).await;
678 }
679 Err(e) => {
680 let error_message = e.to_string();
682 let _ = task_store
683 .update_status(
684 &task_id,
685 crate::protocol::types::TaskStatus::Failed,
686 Some(error_message.clone()),
687 )
688 .await;
689
690 let error_result = serde_json::json!({
692 "content": [{
693 "type": "text",
694 "text": error_message
695 }],
696 "isError": true
697 });
698 let _ = task_store.store_result(&task_id, error_result).await;
699 }
700 }
701 });
702
703 success_response(
705 ctx.request.id,
706 serde_json::to_value(crate::protocol::types::CreateTaskResult { task }).unwrap(),
707 )
708 }
709
710 async fn handle_resources_list(&self, ctx: RequestContext) -> JsonRpcResponse {
712 if let Err(err) = require_initialization(&ctx) {
713 return error_response(ctx.request.id, err.to_jsonrpc());
714 }
715
716 let visibility_ctx = VisibilityContext::new(&ctx.session);
717 let resources = self.resource_manager.list_for_session(&ctx.session, &visibility_ctx);
718 success_response(ctx.request.id, serde_json::json!({"resources": resources}))
719 }
720
721 async fn handle_resources_templates_list(&self, ctx: RequestContext) -> JsonRpcResponse {
723 if let Err(err) = require_initialization(&ctx) {
724 return error_response(ctx.request.id, err.to_jsonrpc());
725 }
726
727 let visibility_ctx = VisibilityContext::new(&ctx.session);
728 let templates = self.resource_manager.list_templates_for_session(&ctx.session, &visibility_ctx);
729 success_response(ctx.request.id, serde_json::json!({"resourceTemplates": templates}))
730 }
731
732 async fn handle_resources_read(&self, ctx: RequestContext) -> JsonRpcResponse {
734 if let Err(err) = require_initialization(&ctx) {
735 return error_response(ctx.request.id, err.to_jsonrpc());
736 }
737
738 let params = ctx.params().cloned().unwrap_or(Value::Null);
740 let uri = match params.get("uri").and_then(|v| v.as_str()) {
741 Some(uri) => uri,
742 None => {
743 return error_response(
744 ctx.request.id,
745 McpError::validation("invalid_params", "Missing 'uri' field").to_jsonrpc(),
746 )
747 }
748 };
749
750 match self.resource_manager.read(uri, std::collections::HashMap::new(), &ctx.session).await {
752 Ok(contents) => {
753 let content_values: Vec<Value> = contents
755 .iter()
756 .map(|c| c.to_value())
757 .collect();
758 success_response(ctx.request.id, serde_json::json!({"contents": content_values}))
759 }
760 Err(e) => error_response(
761 ctx.request.id,
762 McpError::internal("resource_read_failed", e.to_string()).to_jsonrpc(),
763 ),
764 }
765 }
766
767 async fn handle_prompts_list(&self, ctx: RequestContext) -> JsonRpcResponse {
769 if let Err(err) = require_initialization(&ctx) {
770 return error_response(ctx.request.id, err.to_jsonrpc());
771 }
772
773 let visibility_ctx = VisibilityContext::new(&ctx.session);
774 let prompts = self.prompt_manager.list_for_session(&ctx.session, &visibility_ctx);
775 success_response(ctx.request.id, serde_json::json!({"prompts": prompts}))
776 }
777
778 async fn handle_prompts_get(&self, ctx: RequestContext) -> JsonRpcResponse {
780 if let Err(err) = require_initialization(&ctx) {
781 return error_response(ctx.request.id, err.to_jsonrpc());
782 }
783
784 let params = ctx.params().cloned().unwrap_or(Value::Null);
786 let prompt_name = match params.get("name").and_then(|v| v.as_str()) {
787 Some(name) => name,
788 None => {
789 return error_response(
790 ctx.request.id,
791 McpError::validation("invalid_params", "Missing 'name' field").to_jsonrpc(),
792 )
793 }
794 };
795
796 let prompt_params = params.get("arguments").cloned().unwrap_or(Value::Null);
797
798 match self.prompt_manager.call(prompt_name, prompt_params, &ctx.session).await {
800 Ok(result) => {
801 success_response(ctx.request.id, serde_json::to_value(result).expect("Failed to serialize prompt result"))
802 }
803 Err(e) => error_response(
804 ctx.request.id,
805 McpError::internal("prompt_get_failed", e.to_string()).to_jsonrpc(),
806 ),
807 }
808 }
809
810 async fn handle_roots_list(&self, ctx: RequestContext) -> JsonRpcResponse {
816 if let Err(err) = require_initialization(&ctx) {
817 return error_response(ctx.request.id, err.to_jsonrpc());
818 }
819
820 use crate::protocol::types::ListRootsResult;
824
825 let result = ListRootsResult { roots: vec![] };
826
827 success_response(
828 ctx.request.id,
829 serde_json::to_value(result).expect("Failed to serialize roots list"),
830 )
831 }
832
833 async fn handle_sampling_create_message(&self, ctx: RequestContext) -> JsonRpcResponse {
839 if let Err(err) = require_initialization(&ctx) {
840 return error_response(ctx.request.id, err.to_jsonrpc());
841 }
842
843 error_response(
846 ctx.request.id,
847 McpError::not_implemented(
848 "sampling/createMessage is a client capability. Use ClientRequester.create_message() for server→client requests."
849 ).to_jsonrpc(),
850 )
851 }
852
853 async fn handle_elicitation_create(&self, ctx: RequestContext) -> JsonRpcResponse {
858 if let Err(err) = require_initialization(&ctx) {
859 return error_response(ctx.request.id, err.to_jsonrpc());
860 }
861
862 error_response(
865 ctx.request.id,
866 McpError::not_implemented(
867 "elicitation/create is a client capability. Use ClientRequester.create_elicitation() for server→client requests."
868 ).to_jsonrpc(),
869 )
870 }
871
872 async fn handle_tasks_get(&self, ctx: RequestContext) -> JsonRpcResponse {
874 if let Err(err) = require_initialization(&ctx) {
875 return error_response(ctx.request.id, err.to_jsonrpc());
876 }
877
878 let params: crate::protocol::types::GetTaskParams = match ctx.params() {
879 Some(p) => match serde_json::from_value(p.clone()) {
880 Ok(params) => params,
881 Err(_) => {
882 return error_response(
883 ctx.request.id,
884 McpError::validation("invalid_params", "Missing or invalid taskId")
885 .to_jsonrpc(),
886 )
887 }
888 },
889 None => {
890 return error_response(
891 ctx.request.id,
892 McpError::validation("invalid_params", "Missing taskId parameter")
893 .to_jsonrpc(),
894 )
895 }
896 };
897
898 match self
899 .task_store
900 .get_task_for_session(¶ms.task_id, &ctx.session.id)
901 .await
902 {
903 Some(task) => success_response(ctx.request.id, serde_json::to_value(task).unwrap()),
904 None => error_response(
905 ctx.request.id,
906 McpError::validation("invalid_params", "Task not found").to_jsonrpc(),
907 ),
908 }
909 }
910
911 async fn handle_tasks_result(&self, ctx: RequestContext) -> JsonRpcResponse {
913 if let Err(err) = require_initialization(&ctx) {
914 return error_response(ctx.request.id, err.to_jsonrpc());
915 }
916
917 let params: crate::protocol::types::GetTaskParams = match ctx.params() {
918 Some(p) => match serde_json::from_value(p.clone()) {
919 Ok(params) => params,
920 Err(_) => {
921 return error_response(
922 ctx.request.id,
923 McpError::validation("invalid_params", "Missing or invalid taskId")
924 .to_jsonrpc(),
925 )
926 }
927 },
928 None => {
929 return error_response(
930 ctx.request.id,
931 McpError::validation("invalid_params", "Missing taskId parameter")
932 .to_jsonrpc(),
933 )
934 }
935 };
936
937 if self
939 .task_store
940 .get_task_for_session(¶ms.task_id, &ctx.session.id)
941 .await
942 .is_none()
943 {
944 return error_response(
945 ctx.request.id,
946 McpError::validation("invalid_params", "Task not found").to_jsonrpc(),
947 );
948 }
949
950 match self
952 .task_store
953 .wait_for_result(¶ms.task_id, std::time::Duration::from_secs(300))
954 .await
955 {
956 Ok(result) => success_response(ctx.request.id, result),
957 Err(e) => error_response(
958 ctx.request.id,
959 McpError::internal("task_error", e.to_string()).to_jsonrpc(),
960 ),
961 }
962 }
963
964 async fn handle_tasks_list(&self, ctx: RequestContext) -> JsonRpcResponse {
966 if let Err(err) = require_initialization(&ctx) {
967 return error_response(ctx.request.id, err.to_jsonrpc());
968 }
969
970 let cursor = ctx
972 .params()
973 .and_then(|p| p.get("cursor"))
974 .and_then(|c| c.as_str());
975
976 let (tasks, next_cursor) = self
977 .task_store
978 .list_tasks(&ctx.session.id, cursor, 100)
979 .await;
980
981 success_response(
982 ctx.request.id,
983 serde_json::json!({
984 "tasks": tasks,
985 "nextCursor": next_cursor,
986 }),
987 )
988 }
989
990 async fn handle_tasks_cancel(&self, ctx: RequestContext) -> JsonRpcResponse {
992 if let Err(err) = require_initialization(&ctx) {
993 return error_response(ctx.request.id, err.to_jsonrpc());
994 }
995
996 let params: crate::protocol::types::CancelTaskParams = match ctx.params() {
997 Some(p) => match serde_json::from_value(p.clone()) {
998 Ok(params) => params,
999 Err(_) => {
1000 return error_response(
1001 ctx.request.id,
1002 McpError::validation("invalid_params", "Missing or invalid taskId")
1003 .to_jsonrpc(),
1004 )
1005 }
1006 },
1007 None => {
1008 return error_response(
1009 ctx.request.id,
1010 McpError::validation("invalid_params", "Missing taskId parameter")
1011 .to_jsonrpc(),
1012 )
1013 }
1014 };
1015
1016 match self
1017 .task_store
1018 .cancel_task(¶ms.task_id, &ctx.session.id)
1019 .await
1020 {
1021 Ok(task) => success_response(ctx.request.id, serde_json::to_value(task).unwrap()),
1022 Err(e) => {
1023 let error_msg = match e {
1024 crate::managers::task::TaskError::NotFound(_) => {
1025 McpError::validation("invalid_params", "Task not found")
1026 }
1027 crate::managers::task::TaskError::AlreadyTerminal(status) => {
1028 McpError::validation(
1029 "invalid_params",
1030 format!("Cannot cancel task: already in terminal status '{:?}'", status),
1031 )
1032 }
1033 _ => McpError::internal("task_error", e.to_string()),
1034 };
1035 error_response(ctx.request.id, error_msg.to_jsonrpc())
1036 }
1037 }
1038 }
1039}
1040
1041#[cfg(test)]
1042mod tests {
1043 use super::*;
1044
1045 #[tokio::test]
1046 async fn test_server_creation() {
1047 let server = Server::new("test-server", "1.0.0");
1048 assert_eq!(server.name(), "test-server");
1049 assert_eq!(server.version(), "1.0.0");
1050 }
1051
1052 #[tokio::test]
1053 async fn test_ping() {
1054 let server = Server::new("test-server", "1.0.0");
1055
1056 let request = JsonRpcRequest {
1057 jsonrpc: "2.0".to_string(),
1058 id: Some(Value::Number(1.into())),
1059 method: "ping".to_string(),
1060 params: None,
1061 };
1062
1063 let response = server.handle_request("test-session", request).await;
1064
1065 assert!(response.result.is_some());
1066 assert!(response.error.is_none());
1067 }
1068
1069 #[tokio::test]
1070 async fn test_initialize() {
1071 let server = Server::new("test-server", "1.0.0");
1072
1073 let request = JsonRpcRequest {
1074 jsonrpc: "2.0".to_string(),
1075 id: Some(Value::Number(1.into())),
1076 method: "initialize".to_string(),
1077 params: Some(serde_json::json!({
1078 "protocolVersion": "2025-11-25",
1079 "capabilities": {},
1080 "clientInfo": {
1081 "name": "test-client",
1082 "version": "1.0.0"
1083 }
1084 })),
1085 };
1086
1087 let response = server.handle_request("test-session", request).await;
1088
1089 assert!(response.result.is_some());
1090 assert!(response.error.is_none());
1091
1092 let session = server.get_session("test-session").unwrap();
1094 assert!(session.is_initialized());
1095 assert_eq!(session.client_info.unwrap().name, "test-client");
1096 }
1097
1098 #[tokio::test]
1099 async fn test_method_not_found() {
1100 let server = Server::new("test-server", "1.0.0");
1101
1102 let request = JsonRpcRequest {
1103 jsonrpc: "2.0".to_string(),
1104 id: Some(Value::Number(1.into())),
1105 method: "unknown/method".to_string(),
1106 params: None,
1107 };
1108
1109 let response = server.handle_request("test-session", request).await;
1110
1111 assert!(response.result.is_none());
1112 assert!(response.error.is_some());
1113 assert_eq!(response.error.unwrap().code, -32601);
1114 }
1115
1116 #[tokio::test]
1117 async fn test_requires_initialization() {
1118 let server = Server::new("test-server", "1.0.0");
1119
1120 let request = JsonRpcRequest {
1121 jsonrpc: "2.0".to_string(),
1122 id: Some(Value::Number(1.into())),
1123 method: "tools/list".to_string(),
1124 params: None,
1125 };
1126
1127 let response = server.handle_request("test-session", request.clone()).await;
1129 assert!(response.error.is_some());
1130
1131 let init_request = JsonRpcRequest {
1133 jsonrpc: "2.0".to_string(),
1134 id: Some(Value::Number(2.into())),
1135 method: "initialize".to_string(),
1136 params: Some(serde_json::json!({
1137 "protocolVersion": "2025-11-25",
1138 "capabilities": {},
1139 "clientInfo": {
1140 "name": "test-client",
1141 "version": "1.0.0"
1142 }
1143 })),
1144 };
1145 server.handle_request("test-session", init_request).await;
1146
1147 let response = server.handle_request("test-session", request).await;
1149 assert!(response.result.is_some());
1150 }
1151
1152 #[tokio::test]
1153 async fn test_session_management() {
1154 let server = Server::new("test-server", "1.0.0");
1155
1156 let request = JsonRpcRequest {
1158 jsonrpc: "2.0".to_string(),
1159 id: Some(Value::Number(1.into())),
1160 method: "ping".to_string(),
1161 params: None,
1162 };
1163 server.handle_request("session-1", request).await;
1164
1165 assert!(server.get_session("session-1").is_some());
1167
1168 let removed = server.remove_session("session-1");
1170 assert!(removed.is_some());
1171
1172 assert!(server.get_session("session-1").is_none());
1174 }
1175
1176 #[tokio::test]
1177 async fn test_capabilities_update() {
1178 let server = Server::new("test-server", "1.0.0");
1179
1180 let caps = ServerCapabilities {
1181 tools: Some(crate::protocol::capabilities::ToolsCapability {
1182 list_changed: Some(true),
1183 }),
1184 ..Default::default()
1185 };
1186
1187 server.set_capabilities(caps.clone());
1188
1189 let loaded_caps = server.capabilities();
1190 assert_eq!(loaded_caps.tools, caps.tools);
1191 }
1192
1193 async fn init_test_session(server: &Server, session_id: &str) {
1199 let request = JsonRpcRequest {
1200 jsonrpc: "2.0".to_string(),
1201 id: Some(Value::Number(1.into())),
1202 method: "initialize".to_string(),
1203 params: Some(serde_json::json!({
1204 "protocolVersion": "2025-11-25",
1205 "capabilities": {
1206 "tasks": {
1207 "list": {},
1208 "cancel": {},
1209 "requests": {
1210 "tools": {
1211 "call": {}
1212 }
1213 }
1214 }
1215 },
1216 "clientInfo": {
1217 "name": "test-client",
1218 "version": "1.0.0"
1219 }
1220 })),
1221 };
1222
1223 server.handle_request(session_id, request).await;
1224 }
1225
1226 struct TestTaskTool;
1228
1229 #[async_trait::async_trait]
1230 impl crate::registry::tools::Tool for TestTaskTool {
1231 fn name(&self) -> &str {
1232 "test_task"
1233 }
1234
1235 fn description(&self) -> Option<&str> {
1236 Some("Test tool for task execution")
1237 }
1238
1239 fn input_schema(&self) -> Value {
1240 serde_json::json!({
1241 "type": "object",
1242 "properties": {
1243 "message": {"type": "string"}
1244 }
1245 })
1246 }
1247
1248 fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
1249 Some(crate::protocol::types::ToolExecution {
1250 task_support: Some(crate::protocol::types::TaskSupport::Optional),
1251 })
1252 }
1253
1254 async fn execute(
1255 &self,
1256 ctx: crate::prelude::ExecutionContext<'_>,
1257 ) -> Result<Vec<Box<dyn crate::content::types::Content>>, crate::registry::tools::ToolError>
1258 {
1259 let msg = ctx
1260 .params
1261 .get("message")
1262 .and_then(|v| v.as_str())
1263 .unwrap_or("default");
1264
1265 Ok(vec![Box::new(crate::content::types::TextContent::new(
1266 format!("Processed: {}", msg),
1267 ))])
1268 }
1269 }
1270
1271 struct SlowTestTool;
1273
1274 #[async_trait::async_trait]
1275 impl crate::registry::tools::Tool for SlowTestTool {
1276 fn name(&self) -> &str {
1277 "slow_test"
1278 }
1279
1280 fn description(&self) -> Option<&str> {
1281 Some("Slow test tool")
1282 }
1283
1284 fn input_schema(&self) -> Value {
1285 serde_json::json!({"type": "object"})
1286 }
1287
1288 fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
1289 Some(crate::protocol::types::ToolExecution {
1290 task_support: Some(crate::protocol::types::TaskSupport::Optional),
1291 })
1292 }
1293
1294 async fn execute(
1295 &self,
1296 _ctx: crate::prelude::ExecutionContext<'_>,
1297 ) -> Result<Vec<Box<dyn crate::content::types::Content>>, crate::registry::tools::ToolError>
1298 {
1299 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1300 Ok(vec![Box::new(crate::content::types::TextContent::new(
1301 "Slow operation complete",
1302 ))])
1303 }
1304 }
1305
1306 #[tokio::test]
1307 async fn test_task_augmented_tool_call() {
1308 let server = Server::new("test-server", "1.0.0");
1309 server.tool_registry().register(TestTaskTool);
1310
1311 init_test_session(&server, "test-session").await;
1312
1313 let request = JsonRpcRequest {
1315 jsonrpc: "2.0".to_string(),
1316 id: Some(Value::Number(2.into())),
1317 method: "tools/call".to_string(),
1318 params: Some(serde_json::json!({
1319 "name": "test_task",
1320 "arguments": {"message": "hello"},
1321 "task": {"ttl": 60000}
1322 })),
1323 };
1324
1325 let response = server.handle_request("test-session", request).await;
1326
1327 assert!(response.result.is_some());
1329 assert!(response.error.is_none());
1330
1331 let result = response.result.unwrap();
1332 assert!(result.get("task").is_some());
1333
1334 let task = result.get("task").unwrap();
1335 assert!(task.get("taskId").is_some());
1336 assert_eq!(task.get("status").unwrap().as_str().unwrap(), "working");
1337 assert!(task.get("createdAt").is_some());
1338 assert_eq!(task.get("ttl").unwrap().as_u64().unwrap(), 60000);
1339 }
1340
1341 #[tokio::test]
1342 async fn test_task_get_status() {
1343 let server = Server::new("test-server", "1.0.0");
1344 server.tool_registry().register(SlowTestTool);
1345
1346 init_test_session(&server, "test-session").await;
1347
1348 let create_request = JsonRpcRequest {
1350 jsonrpc: "2.0".to_string(),
1351 id: Some(Value::Number(2.into())),
1352 method: "tools/call".to_string(),
1353 params: Some(serde_json::json!({
1354 "name": "slow_test",
1355 "arguments": {},
1356 "task": {"ttl": 60000}
1357 })),
1358 };
1359
1360 let create_response = server.handle_request("test-session", create_request).await;
1361 let task_id = create_response.result.unwrap()["task"]["taskId"]
1362 .as_str()
1363 .unwrap()
1364 .to_string();
1365
1366 let get_request = JsonRpcRequest {
1368 jsonrpc: "2.0".to_string(),
1369 id: Some(Value::Number(3.into())),
1370 method: "tasks/get".to_string(),
1371 params: Some(serde_json::json!({"taskId": task_id})),
1372 };
1373
1374 let get_response = server.handle_request("test-session", get_request).await;
1375
1376 assert!(get_response.result.is_some());
1377 let result = get_response.result.unwrap();
1378 let status = result["status"].as_str().unwrap();
1379 assert!(status == "working" || status == "completed");
1380 }
1381
1382 #[tokio::test]
1383 async fn test_task_result_blocking() {
1384 let server = Server::new("test-server", "1.0.0");
1385 server.tool_registry().register(SlowTestTool);
1386
1387 init_test_session(&server, "test-session").await;
1388
1389 let create_request = JsonRpcRequest {
1391 jsonrpc: "2.0".to_string(),
1392 id: Some(Value::Number(2.into())),
1393 method: "tools/call".to_string(),
1394 params: Some(serde_json::json!({
1395 "name": "slow_test",
1396 "arguments": {},
1397 "task": {"ttl": 60000}
1398 })),
1399 };
1400
1401 let create_response = server.handle_request("test-session", create_request).await;
1402 let task_id = create_response.result.unwrap()["task"]["taskId"]
1403 .as_str()
1404 .unwrap()
1405 .to_string();
1406
1407 let result_request = JsonRpcRequest {
1409 jsonrpc: "2.0".to_string(),
1410 id: Some(Value::Number(3.into())),
1411 method: "tasks/result".to_string(),
1412 params: Some(serde_json::json!({"taskId": task_id})),
1413 };
1414
1415 let result_response = server.handle_request("test-session", result_request).await;
1416
1417 assert!(result_response.result.is_some());
1418 assert!(result_response.error.is_none());
1419
1420 let result = result_response.result.unwrap();
1422 assert!(result.get("content").is_some());
1423 }
1424
1425 #[tokio::test]
1426 async fn test_task_cancel() {
1427 let server = Server::new("test-server", "1.0.0");
1428 server.tool_registry().register(SlowTestTool);
1429
1430 init_test_session(&server, "test-session").await;
1431
1432 let create_request = JsonRpcRequest {
1434 jsonrpc: "2.0".to_string(),
1435 id: Some(Value::Number(2.into())),
1436 method: "tools/call".to_string(),
1437 params: Some(serde_json::json!({
1438 "name": "slow_test",
1439 "arguments": {},
1440 "task": {"ttl": 60000}
1441 })),
1442 };
1443
1444 let create_response = server.handle_request("test-session", create_request).await;
1445 let task_id = create_response.result.unwrap()["task"]["taskId"]
1446 .as_str()
1447 .unwrap()
1448 .to_string();
1449
1450 let cancel_request = JsonRpcRequest {
1452 jsonrpc: "2.0".to_string(),
1453 id: Some(Value::Number(3.into())),
1454 method: "tasks/cancel".to_string(),
1455 params: Some(serde_json::json!({"taskId": task_id})),
1456 };
1457
1458 let cancel_response = server.handle_request("test-session", cancel_request).await;
1459
1460 if cancel_response.result.is_some() {
1462 let result = cancel_response.result.unwrap();
1463 let status = result["status"].as_str().unwrap();
1464 assert_eq!(status, "cancelled");
1465 }
1466 }
1468
1469 #[tokio::test]
1470 async fn test_task_list() {
1471 let server = Server::new("test-server", "1.0.0");
1472 server.tool_registry().register(TestTaskTool);
1473
1474 init_test_session(&server, "test-session").await;
1475
1476 for i in 0..3 {
1478 let request = JsonRpcRequest {
1479 jsonrpc: "2.0".to_string(),
1480 id: Some(Value::Number((i + 2).into())),
1481 method: "tools/call".to_string(),
1482 params: Some(serde_json::json!({
1483 "name": "test_task",
1484 "arguments": {"message": format!("task-{}", i)},
1485 "task": {"ttl": 60000}
1486 })),
1487 };
1488 server.handle_request("test-session", request).await;
1489 }
1490
1491 let list_request = JsonRpcRequest {
1493 jsonrpc: "2.0".to_string(),
1494 id: Some(Value::Number(10.into())),
1495 method: "tasks/list".to_string(),
1496 params: None,
1497 };
1498
1499 let list_response = server.handle_request("test-session", list_request).await;
1500
1501 assert!(list_response.result.is_some());
1502 let result = list_response.result.unwrap();
1503 let tasks = result["tasks"].as_array().unwrap();
1504 assert!(tasks.len() >= 3);
1505 }
1506
1507 #[tokio::test]
1508 async fn test_task_session_isolation() {
1509 let server = Server::new("test-server", "1.0.0");
1510 server.tool_registry().register(TestTaskTool);
1511
1512 init_test_session(&server, "session-1").await;
1513 init_test_session(&server, "session-2").await;
1514
1515 let request = JsonRpcRequest {
1517 jsonrpc: "2.0".to_string(),
1518 id: Some(Value::Number(2.into())),
1519 method: "tools/call".to_string(),
1520 params: Some(serde_json::json!({
1521 "name": "test_task",
1522 "arguments": {"message": "private"},
1523 "task": {"ttl": 60000}
1524 })),
1525 };
1526
1527 let response = server.handle_request("session-1", request).await;
1528 let task_id = response.result.unwrap()["task"]["taskId"]
1529 .as_str()
1530 .unwrap()
1531 .to_string();
1532
1533 let get_request = JsonRpcRequest {
1535 jsonrpc: "2.0".to_string(),
1536 id: Some(Value::Number(3.into())),
1537 method: "tasks/get".to_string(),
1538 params: Some(serde_json::json!({"taskId": task_id})),
1539 };
1540
1541 let get_response = server.handle_request("session-2", get_request).await;
1542
1543 assert!(get_response.error.is_some());
1545 }
1546
1547 #[tokio::test]
1548 async fn test_task_not_found() {
1549 let server = Server::new("test-server", "1.0.0");
1550 init_test_session(&server, "test-session").await;
1551
1552 let request = JsonRpcRequest {
1553 jsonrpc: "2.0".to_string(),
1554 id: Some(Value::Number(2.into())),
1555 method: "tasks/get".to_string(),
1556 params: Some(serde_json::json!({"taskId": "nonexistent-task-id"})),
1557 };
1558
1559 let response = server.handle_request("test-session", request).await;
1560
1561 assert!(response.error.is_some());
1562 assert_eq!(response.error.unwrap().code, -32602);
1563 }
1564
1565 #[tokio::test]
1566 async fn test_task_double_cancel() {
1567 let server = Server::new("test-server", "1.0.0");
1568 server.tool_registry().register(SlowTestTool);
1569
1570 init_test_session(&server, "test-session").await;
1571
1572 let create_request = JsonRpcRequest {
1574 jsonrpc: "2.0".to_string(),
1575 id: Some(Value::Number(2.into())),
1576 method: "tools/call".to_string(),
1577 params: Some(serde_json::json!({
1578 "name": "slow_test",
1579 "arguments": {},
1580 "task": {"ttl": 60000}
1581 })),
1582 };
1583
1584 let create_response = server.handle_request("test-session", create_request).await;
1585 let task_id = create_response.result.unwrap()["task"]["taskId"]
1586 .as_str()
1587 .unwrap()
1588 .to_string();
1589
1590 let cancel_request = JsonRpcRequest {
1592 jsonrpc: "2.0".to_string(),
1593 id: Some(Value::Number(3.into())),
1594 method: "tasks/cancel".to_string(),
1595 params: Some(serde_json::json!({"taskId": task_id.clone()})),
1596 };
1597
1598 let _ = server.handle_request("test-session", cancel_request.clone()).await;
1599
1600 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1602
1603 let cancel_request2 = JsonRpcRequest {
1605 jsonrpc: "2.0".to_string(),
1606 id: Some(Value::Number(4.into())),
1607 method: "tasks/cancel".to_string(),
1608 params: Some(serde_json::json!({"taskId": task_id})),
1609 };
1610
1611 let cancel_response2 = server.handle_request("test-session", cancel_request2).await;
1612
1613 assert!(cancel_response2.error.is_some());
1615 }
1616}