mcp_host/server/
core.rs

1//! Server core implementation
2//!
3//! Main MCP server with method handlers and session management.
4//!
5//! # Bidirectional Communication
6//!
7//! The server supports bidirectional communication:
8//! - Client→Server: tools/call, resources/read, prompts/get, etc.
9//! - Server→Client: roots/list, sampling/createMessage, etc.
10//!
11//! Server-initiated requests are handled via the [`RequestMultiplexer`].
12
13use std::sync::Arc;
14use std::time::Duration;
15
16use arc_swap::ArcSwap;
17use dashmap::DashMap;
18use serde_json::Value;
19use tokio::sync::mpsc;
20
21use crate::content::types::Content;
22use crate::protocol::capabilities::{InitializeRequest, InitializeResult, ServerCapabilities};
23use crate::protocol::errors::{ErrorType, McpError};
24use crate::protocol::methods::McpMethod;
25use crate::protocol::types::{Implementation, JsonRpcRequest, JsonRpcResponse};
26use crate::protocol::version;
27use crate::registry::prompts::PromptManager;
28use crate::registry::resources::ResourceManager;
29use crate::registry::tools::ToolRegistry;
30use crate::server::handler::{
31    RequestContext, error_response, require_initialization, success_response,
32};
33use crate::server::middleware::MiddlewareChain;
34use crate::server::multiplexer::{
35    ClientRequester, CreateMessageParams, CreateMessageResult, JsonRpcClientRequest,
36    ListRootsResult, MultiplexerError, RequestMultiplexer, Root,
37};
38use crate::server::session::Session;
39use crate::server::visibility::VisibilityContext;
40use crate::transport::traits::{IncomingMessage, JsonRpcNotification, Transport};
41
42/// Main MCP server
43pub struct Server {
44    /// Server name
45    name: String,
46
47    /// Server version
48    version: String,
49
50    /// Server instructions for LLMs (atomically swappable)
51    instructions: Arc<ArcSwap<Option<String>>>,
52
53    /// Server capabilities (atomically swappable for dynamic updates)
54    capabilities: Arc<ArcSwap<ServerCapabilities>>,
55
56    /// Active sessions by session ID
57    sessions: DashMap<String, Session>,
58
59    /// Middleware chain for request processing
60    middleware: MiddlewareChain,
61
62    /// Tool registry
63    tool_registry: ToolRegistry,
64
65    /// Resource manager
66    resource_manager: ResourceManager,
67
68    /// Prompt manager
69    prompt_manager: PromptManager,
70
71    /// Notification sender (for background tasks to send notifications)
72    notification_tx: mpsc::UnboundedSender<JsonRpcNotification>,
73
74    /// Notification receiver (internal, used by run loop)
75    notification_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<JsonRpcNotification>>>,
76
77    /// Global logger for server notifications
78    logger: crate::logging::McpLogger,
79
80    /// Request multiplexer for server→client requests
81    multiplexer: Arc<RequestMultiplexer>,
82
83    /// Channel to send server→client requests to the transport
84    request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
85
86    /// Receiver for server→client requests (used by run loop)
87    request_rx: Arc<tokio::sync::Mutex<mpsc::UnboundedReceiver<JsonRpcClientRequest>>>,
88
89    /// Task store for async task execution
90    task_store: Arc<crate::managers::task::TaskStore>,
91}
92
93impl Server {
94    /// Create new server
95    pub fn new(name: impl Into<String>, version: impl Into<String>) -> Self {
96        let (notification_tx, notification_rx) = mpsc::unbounded_channel();
97        let (request_tx, request_rx) = mpsc::unbounded_channel();
98
99        // Create task store with 5-minute TTL and 5-second poll interval
100        let task_store = Arc::new(crate::managers::task::TaskStore::new(
101            std::time::Duration::from_secs(300),
102            std::time::Duration::from_secs(5),
103        ));
104
105        // Spawn cleanup task if in tokio runtime context (runs every minute)
106        if tokio::runtime::Handle::try_current().is_ok() {
107            task_store
108                .clone()
109                .spawn_cleanup_task(std::time::Duration::from_secs(60));
110        }
111
112        // Default capabilities with tasks support
113        let default_caps = ServerCapabilities {
114            tasks: Some(crate::protocol::capabilities::TasksCapability {
115                list: Some(crate::protocol::capabilities::EmptyObject {}),
116                cancel: Some(crate::protocol::capabilities::EmptyObject {}),
117                requests: Some(crate::protocol::capabilities::TasksRequestsCapability {
118                    tools: Some(crate::protocol::capabilities::TasksToolsCapability {
119                        call: Some(crate::protocol::capabilities::EmptyObject {}),
120                    }),
121                    ..Default::default()
122                }),
123            }),
124            ..Default::default()
125        };
126
127        // Create logger with default config
128        let logger = crate::logging::McpLogger::new(notification_tx.clone(), "mcp-server");
129
130        Self {
131            name: name.into(),
132            version: version.into(),
133            instructions: Arc::new(ArcSwap::new(Arc::new(None))),
134            capabilities: Arc::new(ArcSwap::new(Arc::new(default_caps))),
135            sessions: DashMap::new(),
136            middleware: MiddlewareChain::new(),
137            tool_registry: ToolRegistry::new(),
138            resource_manager: ResourceManager::new(),
139            prompt_manager: PromptManager::new(),
140            notification_tx,
141            notification_rx: Arc::new(tokio::sync::Mutex::new(notification_rx)),
142            logger,
143            multiplexer: Arc::new(RequestMultiplexer::new()),
144            request_tx,
145            request_rx: Arc::new(tokio::sync::Mutex::new(request_rx)),
146            task_store,
147        }
148    }
149
150    /// Get server name
151    pub fn name(&self) -> &str {
152        &self.name
153    }
154
155    /// Get server version
156    pub fn version(&self) -> &str {
157        &self.version
158    }
159
160    /// Get current capabilities
161    pub fn capabilities(&self) -> Arc<ServerCapabilities> {
162        self.capabilities.load_full()
163    }
164
165    /// Update server capabilities
166    pub fn set_capabilities(&self, capabilities: ServerCapabilities) {
167        self.capabilities.store(Arc::new(capabilities));
168    }
169
170    /// Get current instructions
171    pub fn instructions(&self) -> Option<String> {
172        (**self.instructions.load()).clone()
173    }
174
175    /// Set server instructions for LLMs
176    pub fn set_instructions(&self, instructions: Option<String>) {
177        self.instructions.store(Arc::new(instructions));
178    }
179
180    /// Add middleware to the chain
181    pub fn add_middleware(&mut self, middleware: crate::server::middleware::MiddlewareFn) {
182        self.middleware.add(middleware);
183    }
184
185    /// Get tool registry
186    pub fn tool_registry(&self) -> &ToolRegistry {
187        &self.tool_registry
188    }
189
190    /// Get resource manager
191    pub fn resource_manager(&self) -> &ResourceManager {
192        &self.resource_manager
193    }
194
195    /// Get prompt manager
196    pub fn prompt_manager(&self) -> &PromptManager {
197        &self.prompt_manager
198    }
199
200    /// Get global logger
201    pub fn logger(&self) -> &crate::logging::McpLogger {
202        &self.logger
203    }
204
205    /// Get notification sender (for background tasks)
206    pub fn notification_sender(&self) -> mpsc::UnboundedSender<JsonRpcNotification> {
207        self.notification_tx.clone()
208    }
209
210    /// Send a notification to the client
211    pub fn send_notification(
212        &self,
213        method: impl Into<String>,
214        params: Option<Value>,
215    ) -> Result<(), Box<dyn std::error::Error>> {
216        let notification = JsonRpcNotification::new(method, params);
217        self.notification_tx.send(notification)?;
218        Ok(())
219    }
220
221    /// Get session by ID
222    pub fn get_session(&self, session_id: &str) -> Option<Session> {
223        self.sessions.get(session_id).map(|s| s.clone())
224    }
225
226    /// Remove session
227    pub fn remove_session(&self, session_id: &str) -> Option<Session> {
228        self.sessions.remove(session_id).map(|(_, s)| s)
229    }
230
231    /// Get the request multiplexer (for advanced use cases)
232    pub fn multiplexer(&self) -> Arc<RequestMultiplexer> {
233        self.multiplexer.clone()
234    }
235
236    /// Create a client requester for the given session
237    ///
238    /// The client requester allows tools to make server→client requests
239    /// like roots/list and sampling/createMessage.
240    pub fn create_client_requester(&self, session_id: &str) -> Option<ClientRequester> {
241        let session = self.get_session(session_id)?;
242        let caps = session.capabilities.as_ref()?;
243
244        Some(ClientRequester::new(
245            self.request_tx.clone(),
246            self.multiplexer.clone(),
247            caps.roots.is_some(),
248            caps.sampling.is_some(),
249        ))
250    }
251
252    /// Request workspace roots from the client
253    ///
254    /// Sends a `roots/list` request to the client and waits for the response.
255    /// The client must have the `roots` capability advertised.
256    ///
257    /// # Arguments
258    ///
259    /// * `session_id` - The session to check for roots capability
260    /// * `timeout` - Optional timeout (defaults to 30 seconds)
261    ///
262    /// # Returns
263    ///
264    /// List of workspace roots, or an error if:
265    /// - Client doesn't support roots capability
266    /// - Request times out
267    /// - Client returns an error
268    ///
269    /// # Example
270    ///
271    /// ```rust,ignore
272    /// let roots = server.request_roots("session-123", None).await?;
273    /// for root in roots {
274    ///     println!("Root: {} ({})", root.name.unwrap_or_default(), root.uri);
275    /// }
276    /// ```
277    pub async fn request_roots(
278        &self,
279        session_id: &str,
280        timeout: Option<Duration>,
281    ) -> Result<Vec<Root>, MultiplexerError> {
282        // Check if client supports roots
283        if let Some(session) = self.get_session(session_id) {
284            if let Some(caps) = &session.capabilities {
285                if caps.roots.is_none() {
286                    return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
287                }
288            } else {
289                return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
290            }
291        } else {
292            return Err(MultiplexerError::Transport("session not found".to_string()));
293        }
294
295        // Create pending request
296        let (id, rx) = self.multiplexer.create_pending("roots/list");
297
298        // Build and send the request
299        let request = JsonRpcClientRequest::new(&id, "roots/list", Some(serde_json::json!({})));
300
301        self.request_tx
302            .send(request)
303            .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
304
305        // Wait for response with timeout
306        let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
307        let result = tokio::time::timeout(timeout, rx)
308            .await
309            .map_err(|_| MultiplexerError::Timeout(timeout))?
310            .map_err(|_| MultiplexerError::ChannelClosed)??;
311
312        // Parse the result
313        let list_result: ListRootsResult = serde_json::from_value(result)?;
314        Ok(list_result.roots)
315    }
316
317    /// Request an LLM completion from the client
318    ///
319    /// Sends a `sampling/createMessage` request to the client.
320    /// The client must have the `sampling` capability advertised.
321    ///
322    /// # Arguments
323    ///
324    /// * `session_id` - The session to check for sampling capability
325    /// * `params` - The sampling parameters
326    /// * `timeout` - Optional timeout (defaults to 30 seconds)
327    ///
328    /// # Returns
329    ///
330    /// The completion result, or an error if:
331    /// - Client doesn't support sampling capability
332    /// - Request times out
333    /// - Client returns an error
334    ///
335    /// # Example
336    ///
337    /// ```rust,ignore
338    /// use mcp_host::server::multiplexer::{CreateMessageParams, SamplingContent, SamplingMessage};
339    ///
340    /// let params = CreateMessageParams {
341    ///     messages: vec![SamplingMessage {
342    ///         role: "user".to_string(),
343    ///         content: SamplingContent::Text { text: "Hello!".to_string() },
344    ///     }],
345    ///     max_tokens: 1000,
346    ///     ..Default::default()
347    /// };
348    ///
349    /// let result = server.request_sampling("session-123", params, None).await?;
350    /// println!("Response: {:?}", result.content);
351    /// ```
352    pub async fn request_sampling(
353        &self,
354        session_id: &str,
355        params: CreateMessageParams,
356        timeout: Option<Duration>,
357    ) -> Result<CreateMessageResult, MultiplexerError> {
358        // Check if client supports sampling
359        if let Some(session) = self.get_session(session_id) {
360            if let Some(caps) = &session.capabilities {
361                if caps.sampling.is_none() {
362                    return Err(MultiplexerError::UnsupportedCapability(
363                        "sampling".to_string(),
364                    ));
365                }
366            } else {
367                return Err(MultiplexerError::UnsupportedCapability(
368                    "sampling".to_string(),
369                ));
370            }
371        } else {
372            return Err(MultiplexerError::Transport("session not found".to_string()));
373        }
374
375        // Create pending request
376        let (id, rx) = self.multiplexer.create_pending("sampling/createMessage");
377
378        // Build and send the request
379        let params_value = serde_json::to_value(&params)?;
380        let request = JsonRpcClientRequest::new(&id, "sampling/createMessage", Some(params_value));
381
382        self.request_tx
383            .send(request)
384            .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
385
386        // Wait for response with timeout
387        let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
388        let result = tokio::time::timeout(timeout, rx)
389            .await
390            .map_err(|_| MultiplexerError::Timeout(timeout))?
391            .map_err(|_| MultiplexerError::ChannelClosed)??;
392
393        // Parse the result
394        let create_result: CreateMessageResult = serde_json::from_value(result)?;
395        Ok(create_result)
396    }
397
398    /// Run server with the given transport
399    ///
400    /// This is the main event loop that reads requests from the transport,
401    /// processes them, and writes responses back.
402    ///
403    /// Supports bidirectional communication:
404    /// - Incoming client requests are handled and responses sent
405    /// - Server-initiated requests are sent via channels and responses routed back
406    pub async fn run<T: Transport>(
407        &self,
408        mut transport: T,
409    ) -> Result<(), Box<dyn std::error::Error>> {
410        // Generate a unique session ID per MCP spec (globally unique, cryptographically secure)
411        let session_id = uuid::Uuid::new_v4().to_string();
412
413        // Get receivers
414        let mut notification_rx = self.notification_rx.lock().await;
415        let mut request_rx = self.request_rx.lock().await;
416
417        loop {
418            tokio::select! {
419                // Handle outgoing notifications
420                Some(notification) = notification_rx.recv() => {
421                    tracing::debug!(method = %notification.method, "Sending notification");
422                    if let Err(e) = transport.send_notification(notification).await {
423                        tracing::error!(error = %e, "Failed to send notification");
424                    }
425                }
426
427                // Handle outgoing server→client requests
428                Some(request) = request_rx.recv() => {
429                    tracing::debug!(method = %request.method, id = %request.id, "Sending request to client");
430                    if let Err(e) = transport.send_request(request).await {
431                        tracing::error!(error = %e, "Failed to send request to client");
432                    }
433                }
434
435                // Handle incoming messages (requests or responses)
436                result = transport.read_incoming() => {
437                    match result {
438                        Ok(IncomingMessage::Request(request)) => {
439                            // Check if this is a notification (no ID) or a request (has ID)
440                            let is_notification = request.id.is_none();
441
442                            // Handle client request
443                            let response = self.handle_request(&session_id, request).await;
444
445                            // Only write response if this was NOT a notification
446                            // Notifications must NOT receive responses per JSON-RPC spec
447                            if !is_notification
448                                && let Err(e) = transport.write_message(response).await {
449                                    tracing::error!(error = %e, "Failed to write message");
450                                    break;
451                                }
452                        }
453                        Ok(IncomingMessage::Response(response)) => {
454                            // Route response to pending server-initiated request
455                            if !self.multiplexer.route_response(&response) {
456                                tracing::warn!(
457                                    id = ?response.id,
458                                    "Received response for unknown request ID"
459                                );
460                            }
461                        }
462                        Err(crate::transport::traits::TransportError::Closed) => {
463                            tracing::info!("Transport closed, shutting down");
464                            break;
465                        }
466                        Err(e) => {
467                            tracing::error!(error = %e, "Failed to read message");
468                            continue;
469                        }
470                    }
471                }
472            }
473        }
474
475        // Cancel any pending requests
476        self.multiplexer.cancel_all();
477
478        // Shutdown transport gracefully
479        transport.shutdown().await?;
480
481        Ok(())
482    }
483
484    /// Handle incoming JSON-RPC request
485    pub async fn handle_request(
486        &self,
487        session_id: &str,
488        request: JsonRpcRequest,
489    ) -> JsonRpcResponse {
490        // Get or create session
491        let session = self
492            .sessions
493            .entry(session_id.to_string())
494            .or_insert_with(|| {
495                let mut session = Session::with_id(session_id);
496                session.set_notification_channel(self.notification_tx.clone());
497                session
498            })
499            .clone();
500
501        // Create request context
502        let ctx = RequestContext::new(session, request.clone());
503
504        // Process through middleware chain
505        let ctx = match self.middleware.process(ctx) {
506            Ok(ctx) => ctx,
507            Err(err) => return error_response(request.id, err.to_jsonrpc()),
508        };
509
510        // Route to appropriate handler
511        let method = McpMethod::from(request.method.clone());
512
513        match method {
514            McpMethod::Initialize => self.handle_initialize(ctx).await,
515            McpMethod::Ping => self.handle_ping(ctx).await,
516            McpMethod::LoggingSetLevel => self.handle_logging_set_level(ctx).await,
517            McpMethod::ToolsList => self.handle_tools_list(ctx).await,
518            McpMethod::ToolsCall => self.handle_tools_call(ctx).await,
519            McpMethod::ResourcesList => self.handle_resources_list(ctx).await,
520            McpMethod::ResourcesTemplatesList => self.handle_resources_templates_list(ctx).await,
521            McpMethod::ResourcesRead => self.handle_resources_read(ctx).await,
522            McpMethod::PromptsList => self.handle_prompts_list(ctx).await,
523            McpMethod::PromptsGet => self.handle_prompts_get(ctx).await,
524            McpMethod::RootsList => self.handle_roots_list(ctx).await,
525            McpMethod::SamplingCreateMessage => self.handle_sampling_create_message(ctx).await,
526            McpMethod::ElicitationCreate => self.handle_elicitation_create(ctx).await,
527            McpMethod::TasksGet => self.handle_tasks_get(ctx).await,
528            McpMethod::TasksResult => self.handle_tasks_result(ctx).await,
529            McpMethod::TasksList => self.handle_tasks_list(ctx).await,
530            McpMethod::TasksCancel => self.handle_tasks_cancel(ctx).await,
531            _ => error_response(
532                request.id,
533                McpError::method_not_found(&request.method).to_jsonrpc(),
534            ),
535        }
536    }
537
538    /// Handle initialize request
539    async fn handle_initialize(&self, ctx: RequestContext) -> JsonRpcResponse {
540        // Parse initialize params
541        let params = ctx.params().cloned().unwrap_or(Value::Null);
542        let req: InitializeRequest = match serde_json::from_value(params) {
543            Ok(req) => req,
544            Err(_) => {
545                return error_response(
546                    ctx.request.id,
547                    McpError::validation("invalid_params", "Invalid initialize parameters")
548                        .to_jsonrpc(),
549                );
550            }
551        };
552
553        // Negotiate protocol version
554        let protocol_version = match version::negotiate_protocol_version(&req.protocol_version) {
555            Ok(version) => version,
556            Err(supported_versions) => {
557                tracing::warn!(
558                    client = %req.client_info.name,
559                    requested = %req.protocol_version,
560                    supported = ?supported_versions,
561                    "Unsupported protocol version"
562                );
563                return error_response(
564                    ctx.request.id,
565                    McpError::builder(ErrorType::Validation, "unsupported_protocol_version")
566                        .message("Unsupported protocol version")
567                        .detail(
568                            "supported",
569                            serde_json::to_value(&supported_versions).unwrap(),
570                        )
571                        .detail("requested", req.protocol_version.clone())
572                        .build()
573                        .to_jsonrpc(),
574                );
575            }
576        };
577
578        // Log client connection
579        tracing::info!(
580            client = %req.client_info.name,
581            version = %req.client_info.version,
582            protocol = %protocol_version,
583            "Client connected"
584        );
585
586        // Initialize session
587        if let Some(mut session) = self.sessions.get_mut(&ctx.session.id) {
588            session.initialize(req.client_info, req.capabilities, protocol_version.clone());
589        }
590
591        // Build response
592        let result = InitializeResult {
593            protocol_version,
594            capabilities: (**self.capabilities.load()).clone(),
595            server_info: Implementation {
596                name: self.name.clone(),
597                version: self.version.clone(),
598            },
599            instructions: self.instructions(),
600        };
601
602        success_response(
603            ctx.request.id,
604            serde_json::to_value(result).expect("Failed to serialize InitializeResult"),
605        )
606    }
607
608    /// Handle ping request
609    async fn handle_ping(&self, ctx: RequestContext) -> JsonRpcResponse {
610        success_response(ctx.request.id, serde_json::json!({}))
611    }
612
613    /// Handle logging/setLevel request
614    async fn handle_logging_set_level(&self, ctx: RequestContext) -> JsonRpcResponse {
615        use crate::logging::LogLevel;
616        use crate::protocol::types::SetLevelRequest;
617
618        // Parse request params
619        let params = ctx.params().cloned().unwrap_or(Value::Null);
620        let req: SetLevelRequest = match serde_json::from_value(params) {
621            Ok(req) => req,
622            Err(_) => {
623                return error_response(
624                    ctx.request.id,
625                    McpError::validation("invalid_params", "Invalid setLevel parameters")
626                        .to_jsonrpc(),
627                );
628            }
629        };
630
631        // Parse log level
632        let level = match req.level.parse::<LogLevel>() {
633            Ok(level) => level,
634            Err(_) => {
635                return error_response(
636                    ctx.request.id,
637                    McpError::validation(
638                        "invalid_level",
639                        format!(
640                            "Invalid log level '{}'. Valid levels: debug, info, notice, warning, error, critical, alert, emergency",
641                            req.level
642                        ),
643                    )
644                    .to_jsonrpc(),
645                )
646            }
647        };
648
649        // Update logger min level
650        self.logger.set_min_level(level);
651
652        tracing::debug!(level = %req.level, "Log level updated");
653
654        success_response(ctx.request.id, serde_json::json!({}))
655    }
656
657    /// Handle tools/list request
658    async fn handle_tools_list(&self, ctx: RequestContext) -> JsonRpcResponse {
659        if let Err(err) = require_initialization(&ctx) {
660            return error_response(ctx.request.id, err.to_jsonrpc());
661        }
662
663        let visibility_ctx = VisibilityContext::new(&ctx.session);
664        let tools = self
665            .tool_registry
666            .list_for_session(&ctx.session, &visibility_ctx);
667        success_response(ctx.request.id, serde_json::json!({"tools": tools}))
668    }
669
670    /// Handle tools/call request
671    async fn handle_tools_call(&self, ctx: RequestContext) -> JsonRpcResponse {
672        if let Err(err) = require_initialization(&ctx) {
673            return error_response(ctx.request.id, err.to_jsonrpc());
674        }
675
676        // Parse tool call params
677        let params = ctx.params().cloned().unwrap_or(Value::Null);
678        let tool_name = match params.get("name").and_then(|v| v.as_str()) {
679            Some(name) => name,
680            None => {
681                return error_response(
682                    ctx.request.id,
683                    McpError::validation("invalid_params", "Missing 'name' field").to_jsonrpc(),
684                );
685            }
686        };
687
688        let tool_params = params.get("arguments").cloned().unwrap_or(Value::Null);
689
690        // Check for task-augmented execution
691        let task_meta: Option<crate::protocol::types::TaskMetadata> = params
692            .get("task")
693            .and_then(|t| serde_json::from_value(t.clone()).ok());
694
695        if let Some(task_metadata) = task_meta {
696            // Task-augmented execution
697            return self
698                .handle_task_augmented_tool_call(ctx, tool_name, tool_params, task_metadata)
699                .await;
700        }
701
702        // Regular synchronous execution
703        let client_requester = self.create_client_requester(&ctx.session.id);
704
705        match self
706            .tool_registry
707            .call(
708                tool_name,
709                tool_params,
710                &ctx.session,
711                &self.logger,
712                client_requester,
713            )
714            .await
715        {
716            Ok(content) => {
717                let content_values: Vec<Value> = content.iter().map(|c| c.to_value()).collect();
718                success_response(
719                    ctx.request.id,
720                    serde_json::json!({"content": content_values}),
721                )
722            }
723            Err(e) => error_response(
724                ctx.request.id,
725                McpError::internal("tool_execution_failed", e.to_string()).to_jsonrpc(),
726            ),
727        }
728    }
729
730    /// Handle task-augmented tool call
731    async fn handle_task_augmented_tool_call(
732        &self,
733        ctx: RequestContext,
734        tool_name: &str,
735        tool_params: Value,
736        task_metadata: crate::protocol::types::TaskMetadata,
737    ) -> JsonRpcResponse {
738        // Create task
739        let (task, _result_rx) =
740            self.task_store
741                .create_task(&ctx.session.id, ctx.request.clone(), task_metadata.ttl);
742
743        let task_id = task.task_id.clone();
744
745        // Spawn background execution
746        let task_store = self.task_store.clone();
747        let tool_registry = self.tool_registry.clone();
748        let logger = self.logger.clone();
749        let session = ctx.session.clone();
750        let client_requester = self.create_client_requester(&ctx.session.id);
751        let tool_name = tool_name.to_string();
752
753        tokio::spawn(async move {
754            // Execute tool
755            match tool_registry
756                .call(&tool_name, tool_params, &session, &logger, client_requester)
757                .await
758            {
759                Ok(content) => {
760                    // Success - store result
761                    let content_values: Vec<Value> = content.iter().map(|c| c.to_value()).collect();
762                    let result = serde_json::json!({"content": content_values});
763
764                    let _ = task_store
765                        .update_status(
766                            &task_id,
767                            crate::protocol::types::TaskStatus::Completed,
768                            None,
769                        )
770                        .await;
771                    let _ = task_store.store_result(&task_id, result).await;
772                }
773                Err(e) => {
774                    // Failure - store error
775                    let error_message = e.to_string();
776                    let _ = task_store
777                        .update_status(
778                            &task_id,
779                            crate::protocol::types::TaskStatus::Failed,
780                            Some(error_message.clone()),
781                        )
782                        .await;
783
784                    // Store error as result (for tasks/result)
785                    let error_result = serde_json::json!({
786                        "content": [{
787                            "type": "text",
788                            "text": error_message
789                        }],
790                        "isError": true
791                    });
792                    let _ = task_store.store_result(&task_id, error_result).await;
793                }
794            }
795        });
796
797        // Return immediately with task info
798        success_response(
799            ctx.request.id,
800            serde_json::to_value(crate::protocol::types::CreateTaskResult { task }).unwrap(),
801        )
802    }
803
804    /// Handle resources/list request
805    async fn handle_resources_list(&self, ctx: RequestContext) -> JsonRpcResponse {
806        if let Err(err) = require_initialization(&ctx) {
807            return error_response(ctx.request.id, err.to_jsonrpc());
808        }
809
810        let visibility_ctx = VisibilityContext::new(&ctx.session);
811        let resources = self
812            .resource_manager
813            .list_for_session(&ctx.session, &visibility_ctx);
814        success_response(ctx.request.id, serde_json::json!({"resources": resources}))
815    }
816
817    /// Handle resources/templates/list request
818    async fn handle_resources_templates_list(&self, ctx: RequestContext) -> JsonRpcResponse {
819        if let Err(err) = require_initialization(&ctx) {
820            return error_response(ctx.request.id, err.to_jsonrpc());
821        }
822
823        let visibility_ctx = VisibilityContext::new(&ctx.session);
824        let templates = self
825            .resource_manager
826            .list_templates_for_session(&ctx.session, &visibility_ctx);
827        success_response(
828            ctx.request.id,
829            serde_json::json!({"resourceTemplates": templates}),
830        )
831    }
832
833    /// Handle resources/read request
834    async fn handle_resources_read(&self, ctx: RequestContext) -> JsonRpcResponse {
835        if let Err(err) = require_initialization(&ctx) {
836            return error_response(ctx.request.id, err.to_jsonrpc());
837        }
838
839        // Parse resource read params
840        let params = ctx.params().cloned().unwrap_or(Value::Null);
841        let uri = match params.get("uri").and_then(|v| v.as_str()) {
842            Some(uri) => uri,
843            None => {
844                return error_response(
845                    ctx.request.id,
846                    McpError::validation("invalid_params", "Missing 'uri' field").to_jsonrpc(),
847                );
848            }
849        };
850
851        // Read resource - ResourceContent already includes uri/mimeType
852        match self
853            .resource_manager
854            .read(
855                uri,
856                std::collections::HashMap::new(),
857                &ctx.session,
858                &self.logger,
859            )
860            .await
861        {
862            Ok(contents) => {
863                // ResourceContent.to_value() already produces MCP-compliant JSON with uri/mimeType
864                let content_values: Vec<Value> = contents.iter().map(|c| c.to_value()).collect();
865                success_response(
866                    ctx.request.id,
867                    serde_json::json!({"contents": content_values}),
868                )
869            }
870            Err(e) => error_response(
871                ctx.request.id,
872                McpError::internal("resource_read_failed", e.to_string()).to_jsonrpc(),
873            ),
874        }
875    }
876
877    /// Handle prompts/list request
878    async fn handle_prompts_list(&self, ctx: RequestContext) -> JsonRpcResponse {
879        if let Err(err) = require_initialization(&ctx) {
880            return error_response(ctx.request.id, err.to_jsonrpc());
881        }
882
883        let visibility_ctx = VisibilityContext::new(&ctx.session);
884        let prompts = self
885            .prompt_manager
886            .list_for_session(&ctx.session, &visibility_ctx);
887        success_response(ctx.request.id, serde_json::json!({"prompts": prompts}))
888    }
889
890    /// Handle prompts/get request
891    async fn handle_prompts_get(&self, ctx: RequestContext) -> JsonRpcResponse {
892        if let Err(err) = require_initialization(&ctx) {
893            return error_response(ctx.request.id, err.to_jsonrpc());
894        }
895
896        // Parse prompt get params
897        let params = ctx.params().cloned().unwrap_or(Value::Null);
898        let prompt_name = match params.get("name").and_then(|v| v.as_str()) {
899            Some(name) => name,
900            None => {
901                return error_response(
902                    ctx.request.id,
903                    McpError::validation("invalid_params", "Missing 'name' field").to_jsonrpc(),
904                );
905            }
906        };
907
908        let prompt_params = params.get("arguments").cloned().unwrap_or(Value::Null);
909
910        // Call prompt
911        match self
912            .prompt_manager
913            .call(prompt_name, prompt_params, &ctx.session, &self.logger)
914            .await
915        {
916            Ok(result) => success_response(
917                ctx.request.id,
918                serde_json::to_value(result).expect("Failed to serialize prompt result"),
919            ),
920            Err(e) => error_response(
921                ctx.request.id,
922                McpError::internal("prompt_get_failed", e.to_string()).to_jsonrpc(),
923            ),
924        }
925    }
926
927    /// Handle roots/list request
928    ///
929    /// NOTE: This is a CLIENT capability - the server requests roots FROM the client.
930    /// This handler would only be used if mcphost-rs is acting as a client.
931    /// For server→client root requests, use `request_roots()` instead.
932    async fn handle_roots_list(&self, ctx: RequestContext) -> JsonRpcResponse {
933        if let Err(err) = require_initialization(&ctx) {
934            return error_response(ctx.request.id, err.to_jsonrpc());
935        }
936
937        // If mcphost-rs is acting as a server (normal case), we don't have roots to provide
938        // The client provides roots TO us, not the other way around
939        // Return empty list for now
940        use crate::protocol::types::ListRootsResult;
941
942        let result = ListRootsResult { roots: vec![] };
943
944        success_response(
945            ctx.request.id,
946            serde_json::to_value(result).expect("Failed to serialize roots list"),
947        )
948    }
949
950    /// Handle sampling/createMessage request
951    ///
952    /// NOTE: This is a CLIENT capability - the server requests LLM completions FROM the client.
953    /// This is for when the MCP server needs the client to generate LLM responses.
954    /// This is NOT for the server to generate responses itself.
955    async fn handle_sampling_create_message(&self, ctx: RequestContext) -> JsonRpcResponse {
956        if let Err(err) = require_initialization(&ctx) {
957            return error_response(ctx.request.id, err.to_jsonrpc());
958        }
959
960        // Servers don't generate LLM responses, clients do
961        // This would only be used if mcphost-rs acts as a client
962        error_response(
963            ctx.request.id,
964            McpError::not_implemented(
965                "sampling/createMessage is a client capability. Use ClientRequester.create_message() for server→client requests."
966            ).to_jsonrpc(),
967        )
968    }
969
970    /// Handle elicitation/create request
971    ///
972    /// NOTE: This is a CLIENT capability - the server requests user input FROM the client.
973    /// This is for when the MCP server needs the client to prompt the user for structured input.
974    async fn handle_elicitation_create(&self, ctx: RequestContext) -> JsonRpcResponse {
975        if let Err(err) = require_initialization(&ctx) {
976            return error_response(ctx.request.id, err.to_jsonrpc());
977        }
978
979        // Servers don't prompt users directly, clients do
980        // This would only be used if mcphost-rs acts as a client
981        error_response(
982            ctx.request.id,
983            McpError::not_implemented(
984                "elicitation/create is a client capability. Use ClientRequester.create_elicitation() for server→client requests."
985            ).to_jsonrpc(),
986        )
987    }
988
989    /// Handle tasks/get request
990    async fn handle_tasks_get(&self, ctx: RequestContext) -> JsonRpcResponse {
991        if let Err(err) = require_initialization(&ctx) {
992            return error_response(ctx.request.id, err.to_jsonrpc());
993        }
994
995        let params: crate::protocol::types::GetTaskParams = match ctx.params() {
996            Some(p) => match serde_json::from_value(p.clone()) {
997                Ok(params) => params,
998                Err(_) => {
999                    return error_response(
1000                        ctx.request.id,
1001                        McpError::validation("invalid_params", "Missing or invalid taskId")
1002                            .to_jsonrpc(),
1003                    );
1004                }
1005            },
1006            None => {
1007                return error_response(
1008                    ctx.request.id,
1009                    McpError::validation("invalid_params", "Missing taskId parameter").to_jsonrpc(),
1010                );
1011            }
1012        };
1013
1014        match self
1015            .task_store
1016            .get_task_for_session(&params.task_id, &ctx.session.id)
1017            .await
1018        {
1019            Some(task) => success_response(ctx.request.id, serde_json::to_value(task).unwrap()),
1020            None => error_response(
1021                ctx.request.id,
1022                McpError::validation("invalid_params", "Task not found").to_jsonrpc(),
1023            ),
1024        }
1025    }
1026
1027    /// Handle tasks/result request (blocks until terminal state)
1028    async fn handle_tasks_result(&self, ctx: RequestContext) -> JsonRpcResponse {
1029        if let Err(err) = require_initialization(&ctx) {
1030            return error_response(ctx.request.id, err.to_jsonrpc());
1031        }
1032
1033        let params: crate::protocol::types::GetTaskParams = match ctx.params() {
1034            Some(p) => match serde_json::from_value(p.clone()) {
1035                Ok(params) => params,
1036                Err(_) => {
1037                    return error_response(
1038                        ctx.request.id,
1039                        McpError::validation("invalid_params", "Missing or invalid taskId")
1040                            .to_jsonrpc(),
1041                    );
1042                }
1043            },
1044            None => {
1045                return error_response(
1046                    ctx.request.id,
1047                    McpError::validation("invalid_params", "Missing taskId parameter").to_jsonrpc(),
1048                );
1049            }
1050        };
1051
1052        // Verify session owns task
1053        if self
1054            .task_store
1055            .get_task_for_session(&params.task_id, &ctx.session.id)
1056            .await
1057            .is_none()
1058        {
1059            return error_response(
1060                ctx.request.id,
1061                McpError::validation("invalid_params", "Task not found").to_jsonrpc(),
1062            );
1063        }
1064
1065        // Wait for result (5-minute timeout)
1066        match self
1067            .task_store
1068            .wait_for_result(&params.task_id, std::time::Duration::from_secs(300))
1069            .await
1070        {
1071            Ok(result) => success_response(ctx.request.id, result),
1072            Err(e) => error_response(
1073                ctx.request.id,
1074                McpError::internal("task_error", e.to_string()).to_jsonrpc(),
1075            ),
1076        }
1077    }
1078
1079    /// Handle tasks/list request
1080    async fn handle_tasks_list(&self, ctx: RequestContext) -> JsonRpcResponse {
1081        if let Err(err) = require_initialization(&ctx) {
1082            return error_response(ctx.request.id, err.to_jsonrpc());
1083        }
1084
1085        // Parse cursor if present
1086        let cursor = ctx
1087            .params()
1088            .and_then(|p| p.get("cursor"))
1089            .and_then(|c| c.as_str());
1090
1091        let (tasks, next_cursor) = self
1092            .task_store
1093            .list_tasks(&ctx.session.id, cursor, 100)
1094            .await;
1095
1096        success_response(
1097            ctx.request.id,
1098            serde_json::json!({
1099                "tasks": tasks,
1100                "nextCursor": next_cursor,
1101            }),
1102        )
1103    }
1104
1105    /// Handle tasks/cancel request
1106    async fn handle_tasks_cancel(&self, ctx: RequestContext) -> JsonRpcResponse {
1107        if let Err(err) = require_initialization(&ctx) {
1108            return error_response(ctx.request.id, err.to_jsonrpc());
1109        }
1110
1111        let params: crate::protocol::types::CancelTaskParams = match ctx.params() {
1112            Some(p) => match serde_json::from_value(p.clone()) {
1113                Ok(params) => params,
1114                Err(_) => {
1115                    return error_response(
1116                        ctx.request.id,
1117                        McpError::validation("invalid_params", "Missing or invalid taskId")
1118                            .to_jsonrpc(),
1119                    );
1120                }
1121            },
1122            None => {
1123                return error_response(
1124                    ctx.request.id,
1125                    McpError::validation("invalid_params", "Missing taskId parameter").to_jsonrpc(),
1126                );
1127            }
1128        };
1129
1130        match self
1131            .task_store
1132            .cancel_task(&params.task_id, &ctx.session.id)
1133            .await
1134        {
1135            Ok(task) => success_response(ctx.request.id, serde_json::to_value(task).unwrap()),
1136            Err(e) => {
1137                let error_msg = match e {
1138                    crate::managers::task::TaskError::NotFound(_) => {
1139                        McpError::validation("invalid_params", "Task not found")
1140                    }
1141                    crate::managers::task::TaskError::AlreadyTerminal(status) => {
1142                        McpError::validation(
1143                            "invalid_params",
1144                            format!(
1145                                "Cannot cancel task: already in terminal status '{:?}'",
1146                                status
1147                            ),
1148                        )
1149                    }
1150                    _ => McpError::internal("task_error", e.to_string()),
1151                };
1152                error_response(ctx.request.id, error_msg.to_jsonrpc())
1153            }
1154        }
1155    }
1156}
1157
1158#[cfg(test)]
1159mod tests {
1160    use super::*;
1161
1162    #[tokio::test]
1163    async fn test_server_creation() {
1164        let server = Server::new("test-server", "1.0.0");
1165        assert_eq!(server.name(), "test-server");
1166        assert_eq!(server.version(), "1.0.0");
1167    }
1168
1169    #[tokio::test]
1170    async fn test_ping() {
1171        let server = Server::new("test-server", "1.0.0");
1172
1173        let request = JsonRpcRequest {
1174            jsonrpc: "2.0".to_string(),
1175            id: Some(Value::Number(1.into())),
1176            method: "ping".to_string(),
1177            params: None,
1178        };
1179
1180        let response = server.handle_request("test-session", request).await;
1181
1182        assert!(response.result.is_some());
1183        assert!(response.error.is_none());
1184    }
1185
1186    #[tokio::test]
1187    async fn test_initialize() {
1188        let server = Server::new("test-server", "1.0.0");
1189
1190        let request = JsonRpcRequest {
1191            jsonrpc: "2.0".to_string(),
1192            id: Some(Value::Number(1.into())),
1193            method: "initialize".to_string(),
1194            params: Some(serde_json::json!({
1195                "protocolVersion": "2025-11-25",
1196                "capabilities": {},
1197                "clientInfo": {
1198                    "name": "test-client",
1199                    "version": "1.0.0"
1200                }
1201            })),
1202        };
1203
1204        let response = server.handle_request("test-session", request).await;
1205
1206        assert!(response.result.is_some());
1207        assert!(response.error.is_none());
1208
1209        // Check session was initialized
1210        let session = server.get_session("test-session").unwrap();
1211        assert!(session.is_initialized());
1212        assert_eq!(session.client_info.unwrap().name, "test-client");
1213    }
1214
1215    #[tokio::test]
1216    async fn test_method_not_found() {
1217        let server = Server::new("test-server", "1.0.0");
1218
1219        let request = JsonRpcRequest {
1220            jsonrpc: "2.0".to_string(),
1221            id: Some(Value::Number(1.into())),
1222            method: "unknown/method".to_string(),
1223            params: None,
1224        };
1225
1226        let response = server.handle_request("test-session", request).await;
1227
1228        assert!(response.result.is_none());
1229        assert!(response.error.is_some());
1230        assert_eq!(response.error.unwrap().code, -32601);
1231    }
1232
1233    #[tokio::test]
1234    async fn test_requires_initialization() {
1235        let server = Server::new("test-server", "1.0.0");
1236
1237        let request = JsonRpcRequest {
1238            jsonrpc: "2.0".to_string(),
1239            id: Some(Value::Number(1.into())),
1240            method: "tools/list".to_string(),
1241            params: None,
1242        };
1243
1244        // Should fail without initialization
1245        let response = server.handle_request("test-session", request.clone()).await;
1246        assert!(response.error.is_some());
1247
1248        // Initialize session
1249        let init_request = JsonRpcRequest {
1250            jsonrpc: "2.0".to_string(),
1251            id: Some(Value::Number(2.into())),
1252            method: "initialize".to_string(),
1253            params: Some(serde_json::json!({
1254                "protocolVersion": "2025-11-25",
1255                "capabilities": {},
1256                "clientInfo": {
1257                    "name": "test-client",
1258                    "version": "1.0.0"
1259                }
1260            })),
1261        };
1262        server.handle_request("test-session", init_request).await;
1263
1264        // Should succeed after initialization
1265        let response = server.handle_request("test-session", request).await;
1266        assert!(response.result.is_some());
1267    }
1268
1269    #[tokio::test]
1270    async fn test_session_management() {
1271        let server = Server::new("test-server", "1.0.0");
1272
1273        // Create session
1274        let request = JsonRpcRequest {
1275            jsonrpc: "2.0".to_string(),
1276            id: Some(Value::Number(1.into())),
1277            method: "ping".to_string(),
1278            params: None,
1279        };
1280        server.handle_request("session-1", request).await;
1281
1282        // Session should exist
1283        assert!(server.get_session("session-1").is_some());
1284
1285        // Remove session
1286        let removed = server.remove_session("session-1");
1287        assert!(removed.is_some());
1288
1289        // Session should no longer exist
1290        assert!(server.get_session("session-1").is_none());
1291    }
1292
1293    #[tokio::test]
1294    async fn test_capabilities_update() {
1295        let server = Server::new("test-server", "1.0.0");
1296
1297        let caps = ServerCapabilities {
1298            tools: Some(crate::protocol::capabilities::ToolsCapability {
1299                list_changed: Some(true),
1300            }),
1301            ..Default::default()
1302        };
1303
1304        server.set_capabilities(caps.clone());
1305
1306        let loaded_caps = server.capabilities();
1307        assert_eq!(loaded_caps.tools, caps.tools);
1308    }
1309
1310    // ========================================================================
1311    // Task Integration Tests
1312    // ========================================================================
1313
1314    /// Helper: Initialize a test session
1315    async fn init_test_session(server: &Server, session_id: &str) {
1316        let request = JsonRpcRequest {
1317            jsonrpc: "2.0".to_string(),
1318            id: Some(Value::Number(1.into())),
1319            method: "initialize".to_string(),
1320            params: Some(serde_json::json!({
1321                "protocolVersion": "2025-11-25",
1322                "capabilities": {
1323                    "tasks": {
1324                        "list": {},
1325                        "cancel": {},
1326                        "requests": {
1327                            "tools": {
1328                                "call": {}
1329                            }
1330                        }
1331                    }
1332                },
1333                "clientInfo": {
1334                    "name": "test-client",
1335                    "version": "1.0.0"
1336                }
1337            })),
1338        };
1339
1340        server.handle_request(session_id, request).await;
1341    }
1342
1343    /// Test tool that completes immediately
1344    struct TestTaskTool;
1345
1346    #[async_trait::async_trait]
1347    impl crate::registry::tools::Tool for TestTaskTool {
1348        fn name(&self) -> &str {
1349            "test_task"
1350        }
1351
1352        fn description(&self) -> Option<&str> {
1353            Some("Test tool for task execution")
1354        }
1355
1356        fn input_schema(&self) -> Value {
1357            serde_json::json!({
1358                "type": "object",
1359                "properties": {
1360                    "message": {"type": "string"}
1361                }
1362            })
1363        }
1364
1365        fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
1366            Some(crate::protocol::types::ToolExecution {
1367                task_support: Some(crate::protocol::types::TaskSupport::Optional),
1368            })
1369        }
1370
1371        async fn execute(
1372            &self,
1373            ctx: crate::prelude::ExecutionContext<'_>,
1374        ) -> Result<Vec<Box<dyn crate::content::types::Content>>, crate::registry::tools::ToolError>
1375        {
1376            let msg = ctx
1377                .params
1378                .get("message")
1379                .and_then(|v| v.as_str())
1380                .unwrap_or("default");
1381
1382            Ok(vec![Box::new(crate::content::types::TextContent::new(
1383                format!("Processed: {}", msg),
1384            ))])
1385        }
1386    }
1387
1388    /// Test tool that takes time to complete
1389    struct SlowTestTool;
1390
1391    #[async_trait::async_trait]
1392    impl crate::registry::tools::Tool for SlowTestTool {
1393        fn name(&self) -> &str {
1394            "slow_test"
1395        }
1396
1397        fn description(&self) -> Option<&str> {
1398            Some("Slow test tool")
1399        }
1400
1401        fn input_schema(&self) -> Value {
1402            serde_json::json!({"type": "object"})
1403        }
1404
1405        fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
1406            Some(crate::protocol::types::ToolExecution {
1407                task_support: Some(crate::protocol::types::TaskSupport::Optional),
1408            })
1409        }
1410
1411        async fn execute(
1412            &self,
1413            _ctx: crate::prelude::ExecutionContext<'_>,
1414        ) -> Result<Vec<Box<dyn crate::content::types::Content>>, crate::registry::tools::ToolError>
1415        {
1416            tokio::time::sleep(std::time::Duration::from_millis(500)).await;
1417            Ok(vec![Box::new(crate::content::types::TextContent::new(
1418                "Slow operation complete",
1419            ))])
1420        }
1421    }
1422
1423    #[tokio::test]
1424    async fn test_task_augmented_tool_call() {
1425        let server = Server::new("test-server", "1.0.0");
1426        server.tool_registry().register(TestTaskTool);
1427
1428        init_test_session(&server, "test-session").await;
1429
1430        // Call tool with task metadata
1431        let request = JsonRpcRequest {
1432            jsonrpc: "2.0".to_string(),
1433            id: Some(Value::Number(2.into())),
1434            method: "tools/call".to_string(),
1435            params: Some(serde_json::json!({
1436                "name": "test_task",
1437                "arguments": {"message": "hello"},
1438                "task": {"ttl": 60000}
1439            })),
1440        };
1441
1442        let response = server.handle_request("test-session", request).await;
1443
1444        // Should get CreateTaskResult with task metadata
1445        assert!(response.result.is_some());
1446        assert!(response.error.is_none());
1447
1448        let result = response.result.unwrap();
1449        assert!(result.get("task").is_some());
1450
1451        let task = result.get("task").unwrap();
1452        assert!(task.get("taskId").is_some());
1453        assert_eq!(task.get("status").unwrap().as_str().unwrap(), "working");
1454        assert!(task.get("createdAt").is_some());
1455        assert_eq!(task.get("ttl").unwrap().as_u64().unwrap(), 60000);
1456    }
1457
1458    #[tokio::test]
1459    async fn test_task_get_status() {
1460        let server = Server::new("test-server", "1.0.0");
1461        server.tool_registry().register(SlowTestTool);
1462
1463        init_test_session(&server, "test-session").await;
1464
1465        // Create task
1466        let create_request = JsonRpcRequest {
1467            jsonrpc: "2.0".to_string(),
1468            id: Some(Value::Number(2.into())),
1469            method: "tools/call".to_string(),
1470            params: Some(serde_json::json!({
1471                "name": "slow_test",
1472                "arguments": {},
1473                "task": {"ttl": 60000}
1474            })),
1475        };
1476
1477        let create_response = server.handle_request("test-session", create_request).await;
1478        let task_id = create_response.result.unwrap()["task"]["taskId"]
1479            .as_str()
1480            .unwrap()
1481            .to_string();
1482
1483        // Get task status immediately (should be working)
1484        let get_request = JsonRpcRequest {
1485            jsonrpc: "2.0".to_string(),
1486            id: Some(Value::Number(3.into())),
1487            method: "tasks/get".to_string(),
1488            params: Some(serde_json::json!({"taskId": task_id})),
1489        };
1490
1491        let get_response = server.handle_request("test-session", get_request).await;
1492
1493        assert!(get_response.result.is_some());
1494        let result = get_response.result.unwrap();
1495        let status = result["status"].as_str().unwrap();
1496        assert!(status == "working" || status == "completed");
1497    }
1498
1499    #[tokio::test]
1500    async fn test_task_result_blocking() {
1501        let server = Server::new("test-server", "1.0.0");
1502        server.tool_registry().register(SlowTestTool);
1503
1504        init_test_session(&server, "test-session").await;
1505
1506        // Create task
1507        let create_request = JsonRpcRequest {
1508            jsonrpc: "2.0".to_string(),
1509            id: Some(Value::Number(2.into())),
1510            method: "tools/call".to_string(),
1511            params: Some(serde_json::json!({
1512                "name": "slow_test",
1513                "arguments": {},
1514                "task": {"ttl": 60000}
1515            })),
1516        };
1517
1518        let create_response = server.handle_request("test-session", create_request).await;
1519        let task_id = create_response.result.unwrap()["task"]["taskId"]
1520            .as_str()
1521            .unwrap()
1522            .to_string();
1523
1524        // Get result (should block until complete)
1525        let result_request = JsonRpcRequest {
1526            jsonrpc: "2.0".to_string(),
1527            id: Some(Value::Number(3.into())),
1528            method: "tasks/result".to_string(),
1529            params: Some(serde_json::json!({"taskId": task_id})),
1530        };
1531
1532        let result_response = server.handle_request("test-session", result_request).await;
1533
1534        assert!(result_response.result.is_some());
1535        assert!(result_response.error.is_none());
1536
1537        // Should have tool result
1538        let result = result_response.result.unwrap();
1539        assert!(result.get("content").is_some());
1540    }
1541
1542    #[tokio::test]
1543    async fn test_task_cancel() {
1544        let server = Server::new("test-server", "1.0.0");
1545        server.tool_registry().register(SlowTestTool);
1546
1547        init_test_session(&server, "test-session").await;
1548
1549        // Create task
1550        let create_request = JsonRpcRequest {
1551            jsonrpc: "2.0".to_string(),
1552            id: Some(Value::Number(2.into())),
1553            method: "tools/call".to_string(),
1554            params: Some(serde_json::json!({
1555                "name": "slow_test",
1556                "arguments": {},
1557                "task": {"ttl": 60000}
1558            })),
1559        };
1560
1561        let create_response = server.handle_request("test-session", create_request).await;
1562        let task_id = create_response.result.unwrap()["task"]["taskId"]
1563            .as_str()
1564            .unwrap()
1565            .to_string();
1566
1567        // Cancel immediately
1568        let cancel_request = JsonRpcRequest {
1569            jsonrpc: "2.0".to_string(),
1570            id: Some(Value::Number(3.into())),
1571            method: "tasks/cancel".to_string(),
1572            params: Some(serde_json::json!({"taskId": task_id})),
1573        };
1574
1575        let cancel_response = server.handle_request("test-session", cancel_request).await;
1576
1577        // Should succeed if still working
1578        if cancel_response.result.is_some() {
1579            let result = cancel_response.result.unwrap();
1580            let status = result["status"].as_str().unwrap();
1581            assert_eq!(status, "cancelled");
1582        }
1583        // Or fail if already completed (timing-dependent)
1584    }
1585
1586    #[tokio::test]
1587    async fn test_task_list() {
1588        let server = Server::new("test-server", "1.0.0");
1589        server.tool_registry().register(TestTaskTool);
1590
1591        init_test_session(&server, "test-session").await;
1592
1593        // Create multiple tasks
1594        for i in 0..3 {
1595            let request = JsonRpcRequest {
1596                jsonrpc: "2.0".to_string(),
1597                id: Some(Value::Number((i + 2).into())),
1598                method: "tools/call".to_string(),
1599                params: Some(serde_json::json!({
1600                    "name": "test_task",
1601                    "arguments": {"message": format!("task-{}", i)},
1602                    "task": {"ttl": 60000}
1603                })),
1604            };
1605            server.handle_request("test-session", request).await;
1606        }
1607
1608        // List tasks
1609        let list_request = JsonRpcRequest {
1610            jsonrpc: "2.0".to_string(),
1611            id: Some(Value::Number(10.into())),
1612            method: "tasks/list".to_string(),
1613            params: None,
1614        };
1615
1616        let list_response = server.handle_request("test-session", list_request).await;
1617
1618        assert!(list_response.result.is_some());
1619        let result = list_response.result.unwrap();
1620        let tasks = result["tasks"].as_array().unwrap();
1621        assert!(tasks.len() >= 3);
1622    }
1623
1624    #[tokio::test]
1625    async fn test_task_session_isolation() {
1626        let server = Server::new("test-server", "1.0.0");
1627        server.tool_registry().register(TestTaskTool);
1628
1629        init_test_session(&server, "session-1").await;
1630        init_test_session(&server, "session-2").await;
1631
1632        // Create task in session-1
1633        let request = JsonRpcRequest {
1634            jsonrpc: "2.0".to_string(),
1635            id: Some(Value::Number(2.into())),
1636            method: "tools/call".to_string(),
1637            params: Some(serde_json::json!({
1638                "name": "test_task",
1639                "arguments": {"message": "private"},
1640                "task": {"ttl": 60000}
1641            })),
1642        };
1643
1644        let response = server.handle_request("session-1", request).await;
1645        let task_id = response.result.unwrap()["task"]["taskId"]
1646            .as_str()
1647            .unwrap()
1648            .to_string();
1649
1650        // Try to access from session-2 (should fail)
1651        let get_request = JsonRpcRequest {
1652            jsonrpc: "2.0".to_string(),
1653            id: Some(Value::Number(3.into())),
1654            method: "tasks/get".to_string(),
1655            params: Some(serde_json::json!({"taskId": task_id})),
1656        };
1657
1658        let get_response = server.handle_request("session-2", get_request).await;
1659
1660        // Should return error (task not found for this session)
1661        assert!(get_response.error.is_some());
1662    }
1663
1664    #[tokio::test]
1665    async fn test_task_not_found() {
1666        let server = Server::new("test-server", "1.0.0");
1667        init_test_session(&server, "test-session").await;
1668
1669        let request = JsonRpcRequest {
1670            jsonrpc: "2.0".to_string(),
1671            id: Some(Value::Number(2.into())),
1672            method: "tasks/get".to_string(),
1673            params: Some(serde_json::json!({"taskId": "nonexistent-task-id"})),
1674        };
1675
1676        let response = server.handle_request("test-session", request).await;
1677
1678        assert!(response.error.is_some());
1679        assert_eq!(response.error.unwrap().code, -32602);
1680    }
1681
1682    #[tokio::test]
1683    async fn test_task_double_cancel() {
1684        let server = Server::new("test-server", "1.0.0");
1685        server.tool_registry().register(SlowTestTool);
1686
1687        init_test_session(&server, "test-session").await;
1688
1689        // Create task
1690        let create_request = JsonRpcRequest {
1691            jsonrpc: "2.0".to_string(),
1692            id: Some(Value::Number(2.into())),
1693            method: "tools/call".to_string(),
1694            params: Some(serde_json::json!({
1695                "name": "slow_test",
1696                "arguments": {},
1697                "task": {"ttl": 60000}
1698            })),
1699        };
1700
1701        let create_response = server.handle_request("test-session", create_request).await;
1702        let task_id = create_response.result.unwrap()["task"]["taskId"]
1703            .as_str()
1704            .unwrap()
1705            .to_string();
1706
1707        // Cancel first time
1708        let cancel_request = JsonRpcRequest {
1709            jsonrpc: "2.0".to_string(),
1710            id: Some(Value::Number(3.into())),
1711            method: "tasks/cancel".to_string(),
1712            params: Some(serde_json::json!({"taskId": task_id.clone()})),
1713        };
1714
1715        let _ = server
1716            .handle_request("test-session", cancel_request.clone())
1717            .await;
1718
1719        // Wait a bit to ensure cancellation completes
1720        tokio::time::sleep(std::time::Duration::from_millis(100)).await;
1721
1722        // Try to cancel again (should fail - already terminal)
1723        let cancel_request2 = JsonRpcRequest {
1724            jsonrpc: "2.0".to_string(),
1725            id: Some(Value::Number(4.into())),
1726            method: "tasks/cancel".to_string(),
1727            params: Some(serde_json::json!({"taskId": task_id})),
1728        };
1729
1730        let cancel_response2 = server.handle_request("test-session", cancel_request2).await;
1731
1732        // Second cancel should fail (already in terminal state)
1733        assert!(cancel_response2.error.is_some());
1734    }
1735}