Skip to main content

fastmcp_server/
bidirectional.rs

1//! Bidirectional request handling for server-to-client communication.
2//!
3//! This module provides the infrastructure for server-initiated requests to clients,
4//! such as:
5//! - `sampling/createMessage` - Request LLM completion from the client
6//! - `elicitation/elicit` - Request user input from the client
7//! - `roots/list` - Request filesystem roots from the client
8//!
9//! # Architecture
10//!
11//! The MCP protocol is bidirectional: while clients typically send requests to servers,
12//! servers can also send requests to clients. This creates a challenge because the
13//! server's main loop is typically blocking on `recv()`.
14//!
15//! The solution is a message dispatcher pattern:
16//! 1. A background task continuously reads from the transport
17//! 2. Incoming messages are routed based on whether they're requests or responses
18//! 3. Responses are matched to pending requests via their ID
19//! 4. Requests are dispatched to handlers
20//!
21//! # Usage
22//!
23//! ```ignore
24//! use fastmcp_server::bidirectional::RequestDispatcher;
25//!
26//! let dispatcher = RequestDispatcher::new();
27//!
28//! // Send a request and await the response
29//! let response = dispatcher.send_request(
30//!     &cx,
31//!     "sampling/createMessage",
32//!     params,
33//! ).await?;
34//! ```
35
36use std::collections::HashMap;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::sync::{Arc, Mutex};
39use std::time::Duration;
40
41use asupersync::Cx;
42use fastmcp_core::{
43    ElicitationAction, ElicitationMode, ElicitationRequest, ElicitationResponse, ElicitationSender,
44    McpError, McpErrorCode, McpResult, SamplingRequest, SamplingResponse, SamplingRole,
45    SamplingSender, SamplingStopReason,
46};
47use fastmcp_protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, RequestId};
48
49// ============================================================================
50// Pending Request Tracking
51// ============================================================================
52
53/// A oneshot channel for receiving a response.
54type ResponseSender = std::sync::mpsc::Sender<Result<serde_json::Value, JsonRpcError>>;
55type ResponseReceiver = std::sync::mpsc::Receiver<Result<serde_json::Value, JsonRpcError>>;
56
57/// Tracks pending server-to-client requests.
58///
59/// When the server sends a request to the client, it registers a response sender
60/// here. When a response arrives, the dispatcher routes it to the correct sender.
61#[derive(Debug)]
62pub struct PendingRequests {
63    /// Map from request ID to response sender.
64    pending: Mutex<HashMap<RequestId, ResponseSender>>,
65    /// Counter for generating unique request IDs.
66    next_id: AtomicU64,
67}
68
69impl PendingRequests {
70    fn lock_pending(&self) -> std::sync::MutexGuard<'_, HashMap<RequestId, ResponseSender>> {
71        match self.pending.lock() {
72            Ok(guard) => guard,
73            // Prefer availability over panic if another task panicked while holding the lock.
74            Err(poisoned) => poisoned.into_inner(),
75        }
76    }
77
78    /// Creates a new pending request tracker.
79    #[must_use]
80    pub fn new() -> Self {
81        Self {
82            pending: Mutex::new(HashMap::new()),
83            // Start at a high number to avoid collision with client request IDs
84            next_id: AtomicU64::new(1_000_000),
85        }
86    }
87
88    /// Generates a new unique request ID.
89    #[allow(clippy::cast_possible_wrap)]
90    pub fn next_request_id(&self) -> RequestId {
91        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
92        RequestId::Number(id as i64)
93    }
94
95    /// Registers a pending request and returns a receiver for the response.
96    pub fn register(&self, id: RequestId) -> ResponseReceiver {
97        let (tx, rx) = std::sync::mpsc::channel();
98        let mut pending = self.lock_pending();
99        pending.insert(id, tx);
100        rx
101    }
102
103    /// Routes a response to the appropriate pending request.
104    ///
105    /// Returns `true` if the response was routed, `false` if no matching request was found.
106    pub fn route_response(&self, response: &JsonRpcResponse) -> bool {
107        let Some(ref id) = response.id else {
108            return false;
109        };
110
111        let sender = {
112            let mut pending = self.lock_pending();
113            pending.remove(id)
114        };
115
116        if let Some(sender) = sender {
117            let result = if let Some(ref error) = response.error {
118                Err(error.clone())
119            } else {
120                Ok(response.result.clone().unwrap_or(serde_json::Value::Null))
121            };
122            // Ignore send errors (receiver may have been dropped due to cancellation)
123            let _ = sender.send(result);
124            true
125        } else {
126            false
127        }
128    }
129
130    /// Removes a pending request (e.g., on timeout or cancellation).
131    pub fn remove(&self, id: &RequestId) {
132        let mut pending = self.lock_pending();
133        pending.remove(id);
134    }
135
136    /// Cancels all pending requests with a connection closed error.
137    pub fn cancel_all(&self) {
138        let mut pending = self.lock_pending();
139        for (_, sender) in pending.drain() {
140            let _ = sender.send(Err(JsonRpcError {
141                code: McpErrorCode::InternalError.into(),
142                message: "Connection closed".to_string(),
143                data: None,
144            }));
145        }
146    }
147}
148
149impl Default for PendingRequests {
150    fn default() -> Self {
151        Self::new()
152    }
153}
154
155// ============================================================================
156// Transport Request Sender
157// ============================================================================
158
159/// Callback type for sending messages through the transport.
160pub type TransportSendFn = Arc<dyn Fn(&JsonRpcMessage) -> Result<(), String> + Send + Sync>;
161
162/// Sends server-to-client requests through the transport.
163///
164/// This struct provides a way to send requests to the client and await responses.
165/// It works in conjunction with [`PendingRequests`] to track in-flight requests.
166#[derive(Clone)]
167pub struct RequestSender {
168    /// Pending request tracker.
169    pending: Arc<PendingRequests>,
170    /// Transport send callback.
171    send_fn: TransportSendFn,
172}
173
174impl RequestSender {
175    /// Creates a new request sender.
176    pub fn new(pending: Arc<PendingRequests>, send_fn: TransportSendFn) -> Self {
177        Self { pending, send_fn }
178    }
179
180    /// Sends a request to the client and waits for a response.
181    ///
182    /// # Errors
183    ///
184    /// Returns an error if:
185    /// - The transport send fails
186    /// - The request times out (based on budget)
187    /// - The client returns an error response
188    /// - The connection is closed
189    pub fn send_request<T: serde::de::DeserializeOwned>(
190        &self,
191        cx: &Cx,
192        method: &str,
193        params: serde_json::Value,
194    ) -> McpResult<T> {
195        let id = self.pending.next_request_id();
196        let receiver = self.pending.register(id.clone());
197
198        let request = JsonRpcRequest::new(method.to_string(), Some(params), id.clone());
199        let message = JsonRpcMessage::Request(request);
200
201        // Send the request through the transport
202        if let Err(e) = (self.send_fn)(&message) {
203            self.pending.remove(&id);
204            return Err(McpError::internal_error(format!(
205                "Failed to send request: {}",
206                e
207            )));
208        }
209
210        // Wait for response while observing cancellation/budget via checkpoints.
211        //
212        // This uses periodic `recv_timeout` polling so we can call `cx.checkpoint()`
213        // to notice budget exhaustion or explicit cancellation.
214        let tick = Duration::from_millis(25);
215        loop {
216            if cx.checkpoint().is_err() {
217                self.pending.remove(&id);
218                return Err(McpError::request_cancelled());
219            }
220
221            match receiver.recv_timeout(tick) {
222                Ok(Ok(value)) => {
223                    return serde_json::from_value(value).map_err(|e| {
224                        McpError::internal_error(format!("Failed to parse response: {e}"))
225                    });
226                }
227                Ok(Err(error)) => {
228                    return Err(McpError::new(McpErrorCode::from(error.code), error.message));
229                }
230                Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
231                    // Keep waiting, but allow budget/cancellation to be observed.
232                }
233                Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
234                    return Err(McpError::internal_error(
235                        "Response channel closed unexpectedly",
236                    ));
237                }
238            }
239        }
240    }
241}
242
243impl std::fmt::Debug for RequestSender {
244    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245        f.debug_struct("RequestSender")
246            .field("pending", &self.pending)
247            .finish_non_exhaustive()
248    }
249}
250
251// ============================================================================
252// Sampling Sender Implementation
253// ============================================================================
254
255/// Sends sampling requests to the client via the transport.
256#[derive(Clone)]
257pub struct TransportSamplingSender {
258    sender: RequestSender,
259}
260
261impl TransportSamplingSender {
262    /// Creates a new transport-backed sampling sender.
263    pub fn new(sender: RequestSender) -> Self {
264        Self { sender }
265    }
266}
267
268impl SamplingSender for TransportSamplingSender {
269    fn create_message(
270        &self,
271        request: SamplingRequest,
272    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<SamplingResponse>> + Send + '_>>
273    {
274        Box::pin(async move {
275            // Convert to protocol types
276            let params = fastmcp_protocol::CreateMessageParams {
277                messages: request
278                    .messages
279                    .into_iter()
280                    .map(|m| fastmcp_protocol::SamplingMessage {
281                        role: match m.role {
282                            SamplingRole::User => fastmcp_protocol::Role::User,
283                            SamplingRole::Assistant => fastmcp_protocol::Role::Assistant,
284                        },
285                        content: fastmcp_protocol::SamplingContent::Text { text: m.text },
286                    })
287                    .collect(),
288                max_tokens: request.max_tokens,
289                system_prompt: request.system_prompt,
290                temperature: request.temperature,
291                stop_sequences: request.stop_sequences,
292                model_preferences: if request.model_hints.is_empty() {
293                    None
294                } else {
295                    Some(fastmcp_protocol::ModelPreferences {
296                        hints: request
297                            .model_hints
298                            .into_iter()
299                            .map(|name| fastmcp_protocol::ModelHint { name: Some(name) })
300                            .collect(),
301                        ..Default::default()
302                    })
303                },
304                include_context: None,
305                meta: None,
306            };
307
308            let params_value = serde_json::to_value(&params)
309                .map_err(|e| McpError::internal_error(format!("Failed to serialize: {}", e)))?;
310
311            // Create a request-scoped Cx for this server-initiated request.
312            let cx = Cx::for_request();
313
314            let result: fastmcp_protocol::CreateMessageResult =
315                self.sender
316                    .send_request(&cx, "sampling/createMessage", params_value)?;
317
318            Ok(SamplingResponse {
319                text: match result.content {
320                    fastmcp_protocol::SamplingContent::Text { text } => text,
321                    fastmcp_protocol::SamplingContent::Image { data, mime_type } => {
322                        format!("[image: {} bytes, type: {}]", data.len(), mime_type)
323                    }
324                },
325                model: result.model,
326                stop_reason: match result.stop_reason {
327                    fastmcp_protocol::StopReason::EndTurn => SamplingStopReason::EndTurn,
328                    fastmcp_protocol::StopReason::StopSequence => SamplingStopReason::StopSequence,
329                    fastmcp_protocol::StopReason::MaxTokens => SamplingStopReason::MaxTokens,
330                },
331            })
332        })
333    }
334}
335
336// ============================================================================
337// Elicitation Sender Implementation
338// ============================================================================
339
340/// Sends elicitation requests to the client via the transport.
341#[derive(Clone)]
342pub struct TransportElicitationSender {
343    sender: RequestSender,
344}
345
346impl TransportElicitationSender {
347    /// Creates a new transport-backed elicitation sender.
348    pub fn new(sender: RequestSender) -> Self {
349        Self { sender }
350    }
351}
352
353impl ElicitationSender for TransportElicitationSender {
354    fn elicit(
355        &self,
356        request: ElicitationRequest,
357    ) -> std::pin::Pin<
358        Box<dyn std::future::Future<Output = McpResult<ElicitationResponse>> + Send + '_>,
359    > {
360        Box::pin(async move {
361            let params_value = match request.mode {
362                ElicitationMode::Form => {
363                    let params = fastmcp_protocol::ElicitRequestFormParams {
364                        mode: fastmcp_protocol::ElicitMode::Form,
365                        message: request.message.clone(),
366                        requested_schema: request.schema.unwrap_or(serde_json::json!({})),
367                    };
368                    serde_json::to_value(&params).map_err(|e| {
369                        McpError::internal_error(format!("Failed to serialize: {}", e))
370                    })?
371                }
372                ElicitationMode::Url => {
373                    let params = fastmcp_protocol::ElicitRequestUrlParams {
374                        mode: fastmcp_protocol::ElicitMode::Url,
375                        message: request.message.clone(),
376                        url: request.url.unwrap_or_default(),
377                        elicitation_id: request.elicitation_id.unwrap_or_default(),
378                    };
379                    serde_json::to_value(&params).map_err(|e| {
380                        McpError::internal_error(format!("Failed to serialize: {}", e))
381                    })?
382                }
383            };
384
385            // Create a request-scoped Cx for this server-initiated request.
386            let cx = Cx::for_request();
387
388            let result: fastmcp_protocol::ElicitResult =
389                self.sender
390                    .send_request(&cx, "elicitation/elicit", params_value)?;
391
392            // Convert HashMap<String, ElicitContentValue> to HashMap<String, serde_json::Value>
393            let content = result.content.map(|content_map| {
394                let mut map = std::collections::HashMap::new();
395                for (key, value) in content_map {
396                    let json_value = match value {
397                        fastmcp_protocol::ElicitContentValue::Null => serde_json::Value::Null,
398                        fastmcp_protocol::ElicitContentValue::Bool(b) => serde_json::Value::Bool(b),
399                        fastmcp_protocol::ElicitContentValue::Int(i) => {
400                            serde_json::Value::Number(i.into())
401                        }
402                        fastmcp_protocol::ElicitContentValue::Float(f) => {
403                            serde_json::Number::from_f64(f)
404                                .map(serde_json::Value::Number)
405                                .unwrap_or(serde_json::Value::Null)
406                        }
407                        fastmcp_protocol::ElicitContentValue::String(s) => {
408                            serde_json::Value::String(s)
409                        }
410                        fastmcp_protocol::ElicitContentValue::StringArray(arr) => {
411                            serde_json::Value::Array(
412                                arr.into_iter().map(serde_json::Value::String).collect(),
413                            )
414                        }
415                    };
416                    map.insert(key, json_value);
417                }
418                map
419            });
420
421            Ok(ElicitationResponse {
422                action: match result.action {
423                    fastmcp_protocol::ElicitAction::Accept => ElicitationAction::Accept,
424                    fastmcp_protocol::ElicitAction::Decline => ElicitationAction::Decline,
425                    fastmcp_protocol::ElicitAction::Cancel => ElicitationAction::Cancel,
426                },
427                content,
428            })
429        })
430    }
431}
432
433// ============================================================================
434// Roots Provider Implementation
435// ============================================================================
436
437/// Provider for filesystem roots from the client.
438#[derive(Clone)]
439pub struct TransportRootsProvider {
440    sender: RequestSender,
441}
442
443impl TransportRootsProvider {
444    /// Creates a new transport-backed roots provider.
445    pub fn new(sender: RequestSender) -> Self {
446        Self { sender }
447    }
448
449    /// Lists the filesystem roots from the client.
450    pub fn list_roots(&self) -> McpResult<Vec<fastmcp_protocol::Root>> {
451        let cx = Cx::for_request();
452        let result: fastmcp_protocol::ListRootsResult =
453            self.sender
454                .send_request(&cx, "roots/list", serde_json::json!({}))?;
455        Ok(result.roots)
456    }
457}
458
459// ============================================================================
460// Tests
461// ============================================================================
462
463#[cfg(test)]
464mod tests {
465    use super::*;
466
467    #[test]
468    fn test_pending_requests_register_and_route() {
469        let pending = PendingRequests::new();
470
471        // Register a request
472        let id = pending.next_request_id();
473        let receiver = pending.register(id.clone());
474
475        // Simulate a response
476        let response = JsonRpcResponse::success(id, serde_json::json!({"result": "ok"}));
477        assert!(pending.route_response(&response));
478
479        // Receive the response
480        let result = receiver.recv().unwrap();
481        assert!(result.is_ok());
482        assert_eq!(result.unwrap(), serde_json::json!({"result": "ok"}));
483    }
484
485    #[test]
486    fn test_pending_requests_error_response() {
487        let pending = PendingRequests::new();
488
489        let id = pending.next_request_id();
490        let receiver = pending.register(id.clone());
491
492        // Simulate an error response
493        let response = JsonRpcResponse::error(
494            Some(id),
495            JsonRpcError {
496                code: -32600,
497                message: "Invalid request".to_string(),
498                data: None,
499            },
500        );
501        assert!(pending.route_response(&response));
502
503        // Receive the error
504        let result = receiver.recv().unwrap();
505        assert!(result.is_err());
506        assert_eq!(result.unwrap_err().message, "Invalid request");
507    }
508
509    #[test]
510    fn test_pending_requests_cancel_all() {
511        let pending = PendingRequests::new();
512
513        let id1 = pending.next_request_id();
514        let id2 = pending.next_request_id();
515        let receiver1 = pending.register(id1);
516        let receiver2 = pending.register(id2);
517
518        // Cancel all
519        pending.cancel_all();
520
521        // Both should receive errors
522        let result1 = receiver1.recv().unwrap();
523        let result2 = receiver2.recv().unwrap();
524        assert!(result1.is_err());
525        assert!(result2.is_err());
526    }
527
528    #[test]
529    fn test_route_unknown_response() {
530        let pending = PendingRequests::new();
531
532        // Route a response with unknown ID
533        let response = JsonRpcResponse::success(
534            RequestId::Number(999999),
535            serde_json::json!({"result": "ok"}),
536        );
537        assert!(!pending.route_response(&response));
538    }
539
540    // ── PendingRequests additional coverage ───────────────────────────
541
542    #[test]
543    fn pending_requests_default_is_same_as_new() {
544        let pr = PendingRequests::default();
545        let id = pr.next_request_id();
546        // IDs start at 1_000_000
547        assert_eq!(id, RequestId::Number(1_000_000));
548    }
549
550    #[test]
551    fn pending_requests_ids_are_sequential() {
552        let pr = PendingRequests::new();
553        let id1 = pr.next_request_id();
554        let id2 = pr.next_request_id();
555        let id3 = pr.next_request_id();
556        assert_eq!(id1, RequestId::Number(1_000_000));
557        assert_eq!(id2, RequestId::Number(1_000_001));
558        assert_eq!(id3, RequestId::Number(1_000_002));
559    }
560
561    #[test]
562    fn pending_requests_remove_prevents_routing() {
563        let pr = PendingRequests::new();
564        let id = pr.next_request_id();
565        let _receiver = pr.register(id.clone());
566
567        // Remove the pending request
568        pr.remove(&id);
569
570        // Routing should fail now
571        let response = JsonRpcResponse::success(id, serde_json::json!(null));
572        assert!(!pr.route_response(&response));
573    }
574
575    #[test]
576    fn pending_requests_route_response_without_id_returns_false() {
577        let pr = PendingRequests::new();
578        // A response with no id
579        let response = JsonRpcResponse {
580            jsonrpc: std::borrow::Cow::Borrowed("2.0"),
581            id: None,
582            result: Some(serde_json::json!(null)),
583            error: None,
584        };
585        assert!(!pr.route_response(&response));
586    }
587
588    #[test]
589    fn pending_requests_route_response_with_null_result() {
590        let pr = PendingRequests::new();
591        let id = pr.next_request_id();
592        let receiver = pr.register(id.clone());
593
594        // Response with no result field (result is None → becomes Null)
595        let response = JsonRpcResponse {
596            jsonrpc: std::borrow::Cow::Borrowed("2.0"),
597            id: Some(id),
598            result: None,
599            error: None,
600        };
601        assert!(pr.route_response(&response));
602
603        let result = receiver.recv().unwrap().unwrap();
604        assert_eq!(result, serde_json::Value::Null);
605    }
606
607    #[test]
608    fn pending_requests_route_after_receiver_dropped_does_not_panic() {
609        let pr = PendingRequests::new();
610        let id = pr.next_request_id();
611        let receiver = pr.register(id.clone());
612
613        // Drop the receiver
614        drop(receiver);
615
616        // Routing should still succeed (sender.send returns Err but is ignored)
617        let response = JsonRpcResponse::success(id, serde_json::json!(42));
618        assert!(pr.route_response(&response));
619    }
620
621    #[test]
622    fn pending_requests_cancel_all_clears_pending() {
623        let pr = PendingRequests::new();
624        let id = pr.next_request_id();
625        let _receiver = pr.register(id.clone());
626
627        pr.cancel_all();
628
629        // No more pending requests to route to
630        let response = JsonRpcResponse::success(id, serde_json::json!(null));
631        assert!(!pr.route_response(&response));
632    }
633
634    #[test]
635    fn pending_requests_cancel_all_empty_is_noop() {
636        let pr = PendingRequests::new();
637        // Should not panic on empty
638        pr.cancel_all();
639    }
640
641    #[test]
642    fn pending_requests_debug_format() {
643        let pr = PendingRequests::new();
644        let debug = format!("{:?}", pr);
645        assert!(debug.contains("PendingRequests"));
646    }
647
648    // ── RequestSender ────────────────────────────────────────────────
649
650    #[test]
651    fn request_sender_debug_format() {
652        let pending = Arc::new(PendingRequests::new());
653        let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
654        let sender = RequestSender::new(pending, send_fn);
655        let debug = format!("{:?}", sender);
656        assert!(debug.contains("RequestSender"));
657    }
658
659    #[test]
660    fn request_sender_transport_failure_returns_error() {
661        let pending = Arc::new(PendingRequests::new());
662        let send_fn: TransportSendFn = Arc::new(|_| Err("transport down".to_string()));
663        let sender = RequestSender::new(pending, send_fn);
664
665        let cx = Cx::for_testing();
666        let result: McpResult<serde_json::Value> =
667            sender.send_request(&cx, "test/method", serde_json::json!({}));
668        let err = result.unwrap_err();
669        assert!(err.message.contains("Failed to send request"));
670        assert!(err.message.contains("transport down"));
671    }
672
673    #[test]
674    fn request_sender_transport_failure_cleans_up_pending() {
675        let pending = Arc::new(PendingRequests::new());
676        let send_fn: TransportSendFn = Arc::new(|_| Err("fail".to_string()));
677        let sender = RequestSender::new(Arc::clone(&pending), send_fn);
678
679        let cx = Cx::for_testing();
680        let _err: McpResult<serde_json::Value> =
681            sender.send_request(&cx, "test/method", serde_json::json!({}));
682
683        // The pending request should have been cleaned up
684        let id = RequestId::Number(1_000_000); // first ID
685        let response = JsonRpcResponse::success(id, serde_json::json!(null));
686        assert!(!pending.route_response(&response));
687    }
688
689    #[test]
690    fn request_sender_clone() {
691        let pending = Arc::new(PendingRequests::new());
692        let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
693        let sender = RequestSender::new(pending, send_fn);
694        let cloned = sender.clone();
695        let debug = format!("{:?}", cloned);
696        assert!(debug.contains("RequestSender"));
697    }
698
699    // ── RequestSender send_request paths ─────────────────────────────
700
701    #[test]
702    fn request_sender_success_path() {
703        let pending = Arc::new(PendingRequests::new());
704        let pending_clone = Arc::clone(&pending);
705        let send_fn: TransportSendFn = Arc::new(move |msg| {
706            if let JsonRpcMessage::Request(req) = msg {
707                let id = req.id.clone().unwrap();
708                let response = JsonRpcResponse::success(id, serde_json::json!({"answer": 42}));
709                pending_clone.route_response(&response);
710            }
711            Ok(())
712        });
713        let sender = RequestSender::new(Arc::clone(&pending), send_fn);
714        let cx = Cx::for_testing();
715        let result: McpResult<serde_json::Value> =
716            sender.send_request(&cx, "test/method", serde_json::json!({}));
717        let value = result.unwrap();
718        assert_eq!(value["answer"], 42);
719    }
720
721    #[test]
722    fn request_sender_error_response_path() {
723        let pending = Arc::new(PendingRequests::new());
724        let pending_clone = Arc::clone(&pending);
725        let send_fn: TransportSendFn = Arc::new(move |msg| {
726            if let JsonRpcMessage::Request(req) = msg {
727                let id = req.id.clone().unwrap();
728                let response = JsonRpcResponse::error(
729                    Some(id),
730                    JsonRpcError {
731                        code: -32600,
732                        message: "bad request".to_string(),
733                        data: None,
734                    },
735                );
736                pending_clone.route_response(&response);
737            }
738            Ok(())
739        });
740        let sender = RequestSender::new(Arc::clone(&pending), send_fn);
741        let cx = Cx::for_testing();
742        let result: McpResult<serde_json::Value> =
743            sender.send_request(&cx, "test/method", serde_json::json!({}));
744        let err = result.unwrap_err();
745        assert!(err.message.contains("bad request"));
746    }
747
748    #[test]
749    fn request_sender_disconnected_path() {
750        let pending = Arc::new(PendingRequests::new());
751        let pending_clone = Arc::clone(&pending);
752        let send_fn: TransportSendFn = Arc::new(move |msg| {
753            if let JsonRpcMessage::Request(req) = msg {
754                let id = req.id.clone().unwrap();
755                // Remove the pending entry so tx is dropped, causing Disconnected
756                pending_clone.remove(&id);
757            }
758            Ok(())
759        });
760        let sender = RequestSender::new(Arc::clone(&pending), send_fn);
761        let cx = Cx::for_testing();
762        let result: McpResult<serde_json::Value> =
763            sender.send_request(&cx, "test/method", serde_json::json!({}));
764        let err = result.unwrap_err();
765        assert!(err.message.contains("Response channel closed"));
766    }
767
768    #[test]
769    fn request_sender_deserialization_error() {
770        let pending = Arc::new(PendingRequests::new());
771        let pending_clone = Arc::clone(&pending);
772        let send_fn: TransportSendFn = Arc::new(move |msg| {
773            if let JsonRpcMessage::Request(req) = msg {
774                let id = req.id.clone().unwrap();
775                // Return a string value, which won't deserialize to Vec<String>
776                let response =
777                    JsonRpcResponse::success(id, serde_json::json!("not a vec of strings"));
778                pending_clone.route_response(&response);
779            }
780            Ok(())
781        });
782        let sender = RequestSender::new(Arc::clone(&pending), send_fn);
783        let cx = Cx::for_testing();
784        let result: McpResult<Vec<String>> =
785            sender.send_request(&cx, "test/method", serde_json::json!({}));
786        let err = result.unwrap_err();
787        assert!(err.message.contains("Failed to parse response"));
788    }
789
790    // ── cancel_all error details ─────────────────────────────────────
791
792    #[test]
793    fn cancel_all_sends_connection_closed_error() {
794        let pr = PendingRequests::new();
795        let id = pr.next_request_id();
796        let receiver = pr.register(id);
797        pr.cancel_all();
798        let result = receiver.recv().unwrap();
799        let err = result.unwrap_err();
800        assert_eq!(err.code, i32::from(McpErrorCode::InternalError));
801        assert!(err.message.contains("Connection closed"));
802        assert!(err.data.is_none());
803    }
804
805    // ── route_response with error containing data ────────────────────
806
807    #[test]
808    fn route_response_error_with_data() {
809        let pr = PendingRequests::new();
810        let id = pr.next_request_id();
811        let receiver = pr.register(id.clone());
812        let response = JsonRpcResponse::error(
813            Some(id),
814            JsonRpcError {
815                code: -32001,
816                message: "custom error".to_string(),
817                data: Some(serde_json::json!({"detail": "extra info"})),
818            },
819        );
820        assert!(pr.route_response(&response));
821        let result = receiver.recv().unwrap();
822        let err = result.unwrap_err();
823        assert_eq!(err.code, -32001);
824        assert!(err.message.contains("custom error"));
825        assert!(err.data.is_some());
826    }
827
828    // ── Multiple concurrent register/route ───────────────────────────
829
830    #[test]
831    fn pending_requests_multiple_register_and_route_independently() {
832        let pr = PendingRequests::new();
833        let id1 = pr.next_request_id();
834        let id2 = pr.next_request_id();
835        let id3 = pr.next_request_id();
836        let rx1 = pr.register(id1.clone());
837        let rx2 = pr.register(id2.clone());
838        let rx3 = pr.register(id3.clone());
839
840        // Route them out of order
841        let r2 = JsonRpcResponse::success(id2.clone(), serde_json::json!("second"));
842        let r3 = JsonRpcResponse::success(id3.clone(), serde_json::json!("third"));
843        let r1 = JsonRpcResponse::success(id1.clone(), serde_json::json!("first"));
844        assert!(pr.route_response(&r2));
845        assert!(pr.route_response(&r3));
846        assert!(pr.route_response(&r1));
847
848        assert_eq!(rx1.recv().unwrap().unwrap(), serde_json::json!("first"));
849        assert_eq!(rx2.recv().unwrap().unwrap(), serde_json::json!("second"));
850        assert_eq!(rx3.recv().unwrap().unwrap(), serde_json::json!("third"));
851    }
852
853    // ── Register same id overwrites ──────────────────────────────────
854
855    #[test]
856    fn pending_requests_register_same_id_overwrites() {
857        let pr = PendingRequests::new();
858        let id = pr.next_request_id();
859        let _rx1 = pr.register(id.clone());
860        let rx2 = pr.register(id.clone()); // overwrites
861
862        let response = JsonRpcResponse::success(id, serde_json::json!("response"));
863        assert!(pr.route_response(&response));
864
865        // rx2 should receive the response (it replaced rx1)
866        let result = rx2.recv().unwrap().unwrap();
867        assert_eq!(result, serde_json::json!("response"));
868    }
869
870    // ── Transport sender constructors ────────────────────────────────
871
872    #[test]
873    fn transport_sampling_sender_new_and_clone() {
874        let pending = Arc::new(PendingRequests::new());
875        let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
876        let sender = RequestSender::new(pending, send_fn);
877        let sampling = TransportSamplingSender::new(sender);
878        let _cloned = sampling.clone();
879    }
880
881    #[test]
882    fn transport_elicitation_sender_new_and_clone() {
883        let pending = Arc::new(PendingRequests::new());
884        let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
885        let sender = RequestSender::new(pending, send_fn);
886        let elicitation = TransportElicitationSender::new(sender);
887        let _cloned = elicitation.clone();
888    }
889
890    #[test]
891    fn transport_roots_provider_new_and_clone() {
892        let pending = Arc::new(PendingRequests::new());
893        let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
894        let sender = RequestSender::new(pending, send_fn);
895        let roots = TransportRootsProvider::new(sender);
896        let _cloned = roots.clone();
897    }
898
899    // ── lock_pending with poisoned mutex ─────────────────────────────
900
901    #[test]
902    fn pending_requests_lock_pending_recovers_from_poison() {
903        let pr = Arc::new(PendingRequests::new());
904        let id = pr.next_request_id();
905        let rx = pr.register(id.clone());
906
907        // Poison the mutex by panicking while holding the lock
908        let pr2 = Arc::clone(&pr);
909        let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
910            let _guard = pr2.pending.lock().unwrap();
911            panic!("intentional poison");
912        }));
913
914        // lock_pending should recover from poison (into_inner)
915        // Routing should still work
916        let response = JsonRpcResponse::success(id, serde_json::json!("recovered"));
917        assert!(pr.route_response(&response));
918        let result = rx.recv().unwrap().unwrap();
919        assert_eq!(result, serde_json::json!("recovered"));
920    }
921
922    // ── TransportSamplingSender — create_message ─────────────────────
923
924    fn make_sender_with_responder(
925        responder: impl Fn(&JsonRpcRequest) -> serde_json::Value + Send + Sync + 'static,
926    ) -> RequestSender {
927        let pending = Arc::new(PendingRequests::new());
928        let pending_clone = Arc::clone(&pending);
929        let send_fn: TransportSendFn = Arc::new(move |msg| {
930            if let JsonRpcMessage::Request(req) = msg {
931                let id = req.id.clone().unwrap();
932                let result = responder(req);
933                let response = JsonRpcResponse::success(id, result);
934                pending_clone.route_response(&response);
935            }
936            Ok(())
937        });
938        RequestSender::new(pending, send_fn)
939    }
940
941    #[test]
942    fn transport_sampling_sender_create_message_text() {
943        let sender = make_sender_with_responder(|_| {
944            serde_json::json!({
945                "content": {"type": "text", "text": "Hello world"},
946                "role": "assistant",
947                "model": "test-model",
948                "stopReason": "endTurn"
949            })
950        });
951        let sampling = TransportSamplingSender::new(sender);
952
953        let request = SamplingRequest {
954            messages: vec![fastmcp_core::SamplingRequestMessage {
955                role: SamplingRole::User,
956                text: "Hi".to_string(),
957            }],
958            max_tokens: 100,
959            system_prompt: Some("Be helpful".to_string()),
960            temperature: Some(0.7),
961            stop_sequences: vec!["STOP".to_string()],
962            model_hints: vec![],
963        };
964
965        let future = SamplingSender::create_message(&sampling, request);
966        let result = fastmcp_core::block_on(future).unwrap();
967        assert_eq!(result.text, "Hello world");
968        assert_eq!(result.model, "test-model");
969        assert!(matches!(result.stop_reason, SamplingStopReason::EndTurn));
970    }
971
972    #[test]
973    fn transport_sampling_sender_create_message_image() {
974        let sender = make_sender_with_responder(|_| {
975            serde_json::json!({
976                "content": {"type": "image", "data": "aW1hZ2VkYXRh", "mimeType": "image/png"},
977                "role": "assistant",
978                "model": "vision-model",
979                "stopReason": "maxTokens"
980            })
981        });
982        let sampling = TransportSamplingSender::new(sender);
983
984        let request = SamplingRequest {
985            messages: vec![fastmcp_core::SamplingRequestMessage {
986                role: SamplingRole::User,
987                text: "Describe image".to_string(),
988            }],
989            max_tokens: 50,
990            system_prompt: None,
991            temperature: None,
992            stop_sequences: vec![],
993            model_hints: vec![],
994        };
995
996        let future = SamplingSender::create_message(&sampling, request);
997        let result = fastmcp_core::block_on(future).unwrap();
998        // Image content is formatted as "[image: N bytes, type: ...]"
999        assert!(result.text.contains("image"));
1000        assert!(result.text.contains("image/png"));
1001        assert_eq!(result.model, "vision-model");
1002        assert!(matches!(result.stop_reason, SamplingStopReason::MaxTokens));
1003    }
1004
1005    #[test]
1006    fn transport_sampling_sender_create_message_with_model_hints() {
1007        let sender = make_sender_with_responder(|req| {
1008            // Verify model_preferences was sent
1009            let params: serde_json::Value =
1010                serde_json::from_value(req.params.clone().unwrap()).unwrap();
1011            assert!(params["modelPreferences"]["hints"].is_array());
1012            serde_json::json!({
1013                "content": {"type": "text", "text": "ok"},
1014                "role": "assistant",
1015                "model": "preferred",
1016                "stopReason": "stopSequence"
1017            })
1018        });
1019        let sampling = TransportSamplingSender::new(sender);
1020
1021        let request = SamplingRequest {
1022            messages: vec![fastmcp_core::SamplingRequestMessage {
1023                role: SamplingRole::User,
1024                text: "Hi".to_string(),
1025            }],
1026            max_tokens: 10,
1027            system_prompt: None,
1028            temperature: None,
1029            stop_sequences: vec![],
1030            model_hints: vec!["claude-3".to_string()],
1031        };
1032
1033        let future = SamplingSender::create_message(&sampling, request);
1034        let result = fastmcp_core::block_on(future).unwrap();
1035        assert!(matches!(
1036            result.stop_reason,
1037            SamplingStopReason::StopSequence
1038        ));
1039    }
1040
1041    #[test]
1042    fn transport_sampling_sender_create_message_assistant_role() {
1043        let sender = make_sender_with_responder(|req| {
1044            let params: serde_json::Value =
1045                serde_json::from_value(req.params.clone().unwrap()).unwrap();
1046            assert_eq!(params["messages"][0]["role"], "assistant");
1047            serde_json::json!({
1048                "content": {"type": "text", "text": "continued"},
1049                "role": "assistant",
1050                "model": "m",
1051                "stopReason": "endTurn"
1052            })
1053        });
1054        let sampling = TransportSamplingSender::new(sender);
1055
1056        let request = SamplingRequest {
1057            messages: vec![fastmcp_core::SamplingRequestMessage {
1058                role: SamplingRole::Assistant,
1059                text: "Previous response".to_string(),
1060            }],
1061            max_tokens: 10,
1062            system_prompt: None,
1063            temperature: None,
1064            stop_sequences: vec![],
1065            model_hints: vec![],
1066        };
1067
1068        let future = SamplingSender::create_message(&sampling, request);
1069        let result = fastmcp_core::block_on(future).unwrap();
1070        assert_eq!(result.text, "continued");
1071    }
1072
1073    // ── TransportElicitationSender — elicit ──────────────────────────
1074
1075    #[test]
1076    fn transport_elicitation_sender_form_accept_with_content() {
1077        let sender = make_sender_with_responder(|req| {
1078            let params: serde_json::Value =
1079                serde_json::from_value(req.params.clone().unwrap()).unwrap();
1080            assert_eq!(params["mode"], "form");
1081            serde_json::json!({
1082                "action": "accept",
1083                "content": {
1084                    "name": "Alice",
1085                    "age": 30,
1086                    "active": true,
1087                    "score": 9.5,
1088                    "tags": ["a", "b"],
1089                    "empty": null
1090                }
1091            })
1092        });
1093        let elicitation = TransportElicitationSender::new(sender);
1094
1095        let request = ElicitationRequest {
1096            message: "Fill the form".to_string(),
1097            mode: ElicitationMode::Form,
1098            schema: Some(serde_json::json!({"type": "object"})),
1099            url: None,
1100            elicitation_id: None,
1101        };
1102
1103        let future = ElicitationSender::elicit(&elicitation, request);
1104        let result = fastmcp_core::block_on(future).unwrap();
1105        assert!(matches!(result.action, ElicitationAction::Accept));
1106        let content = result.content.unwrap();
1107        assert_eq!(content["name"], serde_json::json!("Alice"));
1108        assert_eq!(content["age"], serde_json::json!(30));
1109        assert_eq!(content["active"], serde_json::json!(true));
1110        assert_eq!(content["score"], serde_json::json!(9.5));
1111        assert_eq!(content["tags"], serde_json::json!(["a", "b"]));
1112        assert_eq!(content["empty"], serde_json::Value::Null);
1113    }
1114
1115    #[test]
1116    fn transport_elicitation_sender_form_decline() {
1117        let sender = make_sender_with_responder(|_| {
1118            serde_json::json!({
1119                "action": "decline"
1120            })
1121        });
1122        let elicitation = TransportElicitationSender::new(sender);
1123
1124        let request = ElicitationRequest {
1125            message: "Confirm?".to_string(),
1126            mode: ElicitationMode::Form,
1127            schema: None,
1128            url: None,
1129            elicitation_id: None,
1130        };
1131
1132        let future = ElicitationSender::elicit(&elicitation, request);
1133        let result = fastmcp_core::block_on(future).unwrap();
1134        assert!(matches!(result.action, ElicitationAction::Decline));
1135        assert!(result.content.is_none());
1136    }
1137
1138    #[test]
1139    fn transport_elicitation_sender_url_mode() {
1140        let sender = make_sender_with_responder(|req| {
1141            let params: serde_json::Value =
1142                serde_json::from_value(req.params.clone().unwrap()).unwrap();
1143            assert_eq!(params["mode"], "url");
1144            assert_eq!(params["url"], "https://example.com/auth");
1145            serde_json::json!({
1146                "action": "cancel"
1147            })
1148        });
1149        let elicitation = TransportElicitationSender::new(sender);
1150
1151        let request = ElicitationRequest {
1152            message: "Please authenticate".to_string(),
1153            mode: ElicitationMode::Url,
1154            schema: None,
1155            url: Some("https://example.com/auth".to_string()),
1156            elicitation_id: Some("eid-123".to_string()),
1157        };
1158
1159        let future = ElicitationSender::elicit(&elicitation, request);
1160        let result = fastmcp_core::block_on(future).unwrap();
1161        assert!(matches!(result.action, ElicitationAction::Cancel));
1162    }
1163
1164    // ── TransportRootsProvider — list_roots ──────────────────────────
1165
1166    #[test]
1167    fn transport_roots_provider_list_roots() {
1168        let sender = make_sender_with_responder(|_| {
1169            serde_json::json!({
1170                "roots": [
1171                    {"uri": "file:///home/user/project", "name": "Project"},
1172                    {"uri": "file:///tmp"}
1173                ]
1174            })
1175        });
1176        let roots = TransportRootsProvider::new(sender);
1177        let result = roots.list_roots().unwrap();
1178        assert_eq!(result.len(), 2);
1179        assert_eq!(result[0].uri, "file:///home/user/project");
1180        assert_eq!(result[0].name, Some("Project".to_string()));
1181        assert_eq!(result[1].uri, "file:///tmp");
1182        assert!(result[1].name.is_none());
1183    }
1184
1185    #[test]
1186    fn transport_roots_provider_empty_roots() {
1187        let sender = make_sender_with_responder(|_| serde_json::json!({ "roots": [] }));
1188        let roots = TransportRootsProvider::new(sender);
1189        let result = roots.list_roots().unwrap();
1190        assert!(result.is_empty());
1191    }
1192
1193    // ── RequestSender ID cleanup after success ───────────────────────
1194
1195    // ── RequestSender — cancelled cx path ──────────────────────────
1196
1197    #[test]
1198    fn request_sender_cancelled_cx_returns_cancelled_error() {
1199        let pending = Arc::new(PendingRequests::new());
1200        // Transport succeeds but never sends a response
1201        let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
1202        let sender = RequestSender::new(Arc::clone(&pending), send_fn);
1203
1204        let cx = Cx::for_testing();
1205        cx.set_cancel_requested(true);
1206
1207        let result: McpResult<serde_json::Value> =
1208            sender.send_request(&cx, "test/cancel", serde_json::json!({}));
1209        let err = result.unwrap_err();
1210        assert_eq!(err.code, McpErrorCode::RequestCancelled);
1211    }
1212
1213    // ── ElicitationSender url mode with None url/elicitation_id ──
1214
1215    #[test]
1216    fn transport_elicitation_sender_url_mode_defaults() {
1217        let sender = make_sender_with_responder(|req| {
1218            let params: serde_json::Value =
1219                serde_json::from_value(req.params.clone().unwrap()).unwrap();
1220            assert_eq!(params["mode"], "url");
1221            // url and elicitation_id default to empty strings
1222            assert_eq!(params["url"], "");
1223            assert_eq!(params["elicitationId"], "");
1224            serde_json::json!({ "action": "accept" })
1225        });
1226        let elicitation = TransportElicitationSender::new(sender);
1227
1228        let request = ElicitationRequest {
1229            message: "Auth".to_string(),
1230            mode: ElicitationMode::Url,
1231            schema: None,
1232            url: None,
1233            elicitation_id: None,
1234        };
1235
1236        let future = ElicitationSender::elicit(&elicitation, request);
1237        let result = fastmcp_core::block_on(future).unwrap();
1238        assert!(matches!(result.action, ElicitationAction::Accept));
1239    }
1240
1241    // ── TransportRootsProvider — transport failure ───────────────
1242
1243    #[test]
1244    fn transport_roots_provider_transport_failure() {
1245        let pending = Arc::new(PendingRequests::new());
1246        let send_fn: TransportSendFn = Arc::new(|_| Err("network error".to_string()));
1247        let sender = RequestSender::new(pending, send_fn);
1248        let roots = TransportRootsProvider::new(sender);
1249
1250        let result = roots.list_roots();
1251        assert!(result.is_err());
1252        assert!(
1253            result
1254                .unwrap_err()
1255                .message
1256                .contains("Failed to send request")
1257        );
1258    }
1259
1260    // ── SamplingSender — transport failure ───────────────────────
1261
1262    #[test]
1263    fn transport_sampling_sender_transport_failure() {
1264        let pending = Arc::new(PendingRequests::new());
1265        let send_fn: TransportSendFn = Arc::new(|_| Err("connection reset".to_string()));
1266        let sender = RequestSender::new(pending, send_fn);
1267        let sampling = TransportSamplingSender::new(sender);
1268
1269        let request = SamplingRequest {
1270            messages: vec![fastmcp_core::SamplingRequestMessage {
1271                role: SamplingRole::User,
1272                text: "Hi".to_string(),
1273            }],
1274            max_tokens: 10,
1275            system_prompt: None,
1276            temperature: None,
1277            stop_sequences: vec![],
1278            model_hints: vec![],
1279        };
1280
1281        let future = SamplingSender::create_message(&sampling, request);
1282        let result = fastmcp_core::block_on(future);
1283        assert!(result.is_err());
1284        assert!(
1285            result
1286                .unwrap_err()
1287                .message
1288                .contains("Failed to send request")
1289        );
1290    }
1291
1292    // ── SamplingSender — multiple messages ───────────────────────
1293
1294    #[test]
1295    fn transport_sampling_sender_multiple_messages() {
1296        let sender = make_sender_with_responder(|req| {
1297            let params: serde_json::Value =
1298                serde_json::from_value(req.params.clone().unwrap()).unwrap();
1299            let messages = params["messages"].as_array().unwrap();
1300            assert_eq!(messages.len(), 3);
1301            assert_eq!(messages[0]["role"], "user");
1302            assert_eq!(messages[1]["role"], "assistant");
1303            assert_eq!(messages[2]["role"], "user");
1304            serde_json::json!({
1305                "content": {"type": "text", "text": "done"},
1306                "role": "assistant",
1307                "model": "m",
1308                "stopReason": "endTurn"
1309            })
1310        });
1311        let sampling = TransportSamplingSender::new(sender);
1312
1313        let request = SamplingRequest {
1314            messages: vec![
1315                fastmcp_core::SamplingRequestMessage {
1316                    role: SamplingRole::User,
1317                    text: "Hello".to_string(),
1318                },
1319                fastmcp_core::SamplingRequestMessage {
1320                    role: SamplingRole::Assistant,
1321                    text: "Hi".to_string(),
1322                },
1323                fastmcp_core::SamplingRequestMessage {
1324                    role: SamplingRole::User,
1325                    text: "Follow up".to_string(),
1326                },
1327            ],
1328            max_tokens: 100,
1329            system_prompt: None,
1330            temperature: None,
1331            stop_sequences: vec![],
1332            model_hints: vec![],
1333        };
1334
1335        let future = SamplingSender::create_message(&sampling, request);
1336        let result = fastmcp_core::block_on(future).unwrap();
1337        assert_eq!(result.text, "done");
1338    }
1339
1340    // ── RequestSender — ID cleanup after success ────────────────
1341
1342    #[test]
1343    fn request_sender_id_cleaned_from_pending_after_success() {
1344        let pending = Arc::new(PendingRequests::new());
1345        let pending_clone = Arc::clone(&pending);
1346        let send_fn: TransportSendFn = Arc::new(move |msg| {
1347            if let JsonRpcMessage::Request(req) = msg {
1348                let id = req.id.clone().unwrap();
1349                let response = JsonRpcResponse::success(id, serde_json::json!(null));
1350                pending_clone.route_response(&response);
1351            }
1352            Ok(())
1353        });
1354        let sender = RequestSender::new(Arc::clone(&pending), send_fn);
1355        let cx = Cx::for_testing();
1356        let _: serde_json::Value = sender
1357            .send_request(&cx, "test/method", serde_json::json!({}))
1358            .unwrap();
1359
1360        // The pending request should have been consumed by route_response
1361        let first_id = RequestId::Number(1_000_000);
1362        let response = JsonRpcResponse::success(first_id, serde_json::json!(null));
1363        assert!(!pending.route_response(&response));
1364    }
1365}