Skip to main content

fastmcp_server/
bidirectional.rs

1//! Bidirectional request handling for server-to-client communication.
2//!
3//! This module provides the infrastructure for server-initiated requests to clients,
4//! such as:
5//! - `sampling/createMessage` - Request LLM completion from the client
6//! - `elicitation/elicit` - Request user input from the client
7//! - `roots/list` - Request filesystem roots from the client
8//!
9//! # Architecture
10//!
11//! The MCP protocol is bidirectional: while clients typically send requests to servers,
12//! servers can also send requests to clients. This creates a challenge because the
13//! server's main loop is typically blocking on `recv()`.
14//!
15//! The solution is a message dispatcher pattern:
16//! 1. A background task continuously reads from the transport
17//! 2. Incoming messages are routed based on whether they're requests or responses
18//! 3. Responses are matched to pending requests via their ID
19//! 4. Requests are dispatched to handlers
20//!
21//! # Usage
22//!
23//! ```ignore
24//! use fastmcp_server::bidirectional::RequestDispatcher;
25//!
26//! let dispatcher = RequestDispatcher::new();
27//!
28//! // Send a request and await the response
29//! let response = dispatcher.send_request(
30//!     &cx,
31//!     "sampling/createMessage",
32//!     params,
33//! ).await?;
34//! ```
35
36use std::collections::HashMap;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::sync::{Arc, Mutex};
39
40use asupersync::Cx;
41use fastmcp_core::{
42    ElicitationAction, ElicitationMode, ElicitationRequest, ElicitationResponse, ElicitationSender,
43    McpError, McpErrorCode, McpResult, SamplingRequest, SamplingResponse, SamplingRole,
44    SamplingSender, SamplingStopReason,
45};
46use fastmcp_protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, RequestId};
47
48// ============================================================================
49// Pending Request Tracking
50// ============================================================================
51
52/// A oneshot channel for receiving a response.
53type ResponseSender = std::sync::mpsc::Sender<Result<serde_json::Value, JsonRpcError>>;
54type ResponseReceiver = std::sync::mpsc::Receiver<Result<serde_json::Value, JsonRpcError>>;
55
56/// Tracks pending server-to-client requests.
57///
58/// When the server sends a request to the client, it registers a response sender
59/// here. When a response arrives, the dispatcher routes it to the correct sender.
60#[derive(Debug)]
61pub struct PendingRequests {
62    /// Map from request ID to response sender.
63    pending: Mutex<HashMap<RequestId, ResponseSender>>,
64    /// Counter for generating unique request IDs.
65    next_id: AtomicU64,
66}
67
68impl PendingRequests {
69    /// Creates a new pending request tracker.
70    #[must_use]
71    pub fn new() -> Self {
72        Self {
73            pending: Mutex::new(HashMap::new()),
74            // Start at a high number to avoid collision with client request IDs
75            next_id: AtomicU64::new(1_000_000),
76        }
77    }
78
79    /// Generates a new unique request ID.
80    #[allow(clippy::cast_possible_wrap)]
81    pub fn next_request_id(&self) -> RequestId {
82        let id = self.next_id.fetch_add(1, Ordering::SeqCst);
83        RequestId::Number(id as i64)
84    }
85
86    /// Registers a pending request and returns a receiver for the response.
87    pub fn register(&self, id: RequestId) -> ResponseReceiver {
88        let (tx, rx) = std::sync::mpsc::channel();
89        let mut pending = self.pending.lock().unwrap();
90        pending.insert(id, tx);
91        rx
92    }
93
94    /// Routes a response to the appropriate pending request.
95    ///
96    /// Returns `true` if the response was routed, `false` if no matching request was found.
97    pub fn route_response(&self, response: &JsonRpcResponse) -> bool {
98        let Some(ref id) = response.id else {
99            return false;
100        };
101
102        let sender = {
103            let mut pending = self.pending.lock().unwrap();
104            pending.remove(id)
105        };
106
107        if let Some(sender) = sender {
108            let result = if let Some(ref error) = response.error {
109                Err(error.clone())
110            } else {
111                Ok(response.result.clone().unwrap_or(serde_json::Value::Null))
112            };
113            // Ignore send errors (receiver may have been dropped due to cancellation)
114            let _ = sender.send(result);
115            true
116        } else {
117            false
118        }
119    }
120
121    /// Removes a pending request (e.g., on timeout or cancellation).
122    pub fn remove(&self, id: &RequestId) {
123        let mut pending = self.pending.lock().unwrap();
124        pending.remove(id);
125    }
126
127    /// Cancels all pending requests with a connection closed error.
128    pub fn cancel_all(&self) {
129        let mut pending = self.pending.lock().unwrap();
130        for (_, sender) in pending.drain() {
131            let _ = sender.send(Err(JsonRpcError {
132                code: McpErrorCode::InternalError.into(),
133                message: "Connection closed".to_string(),
134                data: None,
135            }));
136        }
137    }
138}
139
140impl Default for PendingRequests {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146// ============================================================================
147// Transport Request Sender
148// ============================================================================
149
150/// Callback type for sending messages through the transport.
151pub type TransportSendFn = Arc<dyn Fn(&JsonRpcMessage) -> Result<(), String> + Send + Sync>;
152
153/// Sends server-to-client requests through the transport.
154///
155/// This struct provides a way to send requests to the client and await responses.
156/// It works in conjunction with [`PendingRequests`] to track in-flight requests.
157#[derive(Clone)]
158pub struct RequestSender {
159    /// Pending request tracker.
160    pending: Arc<PendingRequests>,
161    /// Transport send callback.
162    send_fn: TransportSendFn,
163}
164
165impl RequestSender {
166    /// Creates a new request sender.
167    pub fn new(pending: Arc<PendingRequests>, send_fn: TransportSendFn) -> Self {
168        Self { pending, send_fn }
169    }
170
171    /// Sends a request to the client and waits for a response.
172    ///
173    /// # Errors
174    ///
175    /// Returns an error if:
176    /// - The transport send fails
177    /// - The request times out (based on budget)
178    /// - The client returns an error response
179    /// - The connection is closed
180    pub fn send_request<T: serde::de::DeserializeOwned>(
181        &self,
182        _cx: &Cx,
183        method: &str,
184        params: serde_json::Value,
185    ) -> McpResult<T> {
186        let id = self.pending.next_request_id();
187        let receiver = self.pending.register(id.clone());
188
189        let request = JsonRpcRequest::new(method.to_string(), Some(params), id.clone());
190        let message = JsonRpcMessage::Request(request);
191
192        // Send the request through the transport
193        if let Err(e) = (self.send_fn)(&message) {
194            self.pending.remove(&id);
195            return Err(McpError::internal_error(format!(
196                "Failed to send request: {}",
197                e
198            )));
199        }
200
201        // Wait for response
202        // TODO: Add timeout based on budget
203        match receiver.recv() {
204            Ok(Ok(value)) => serde_json::from_value(value)
205                .map_err(|e| McpError::internal_error(format!("Failed to parse response: {}", e))),
206            Ok(Err(error)) => Err(McpError::new(McpErrorCode::from(error.code), error.message)),
207            Err(_) => Err(McpError::internal_error(
208                "Response channel closed unexpectedly",
209            )),
210        }
211    }
212}
213
214impl std::fmt::Debug for RequestSender {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        f.debug_struct("RequestSender")
217            .field("pending", &self.pending)
218            .finish_non_exhaustive()
219    }
220}
221
222// ============================================================================
223// Sampling Sender Implementation
224// ============================================================================
225
226/// Sends sampling requests to the client via the transport.
227#[derive(Clone)]
228pub struct TransportSamplingSender {
229    sender: RequestSender,
230}
231
232impl TransportSamplingSender {
233    /// Creates a new transport-backed sampling sender.
234    pub fn new(sender: RequestSender) -> Self {
235        Self { sender }
236    }
237}
238
239impl SamplingSender for TransportSamplingSender {
240    fn create_message(
241        &self,
242        request: SamplingRequest,
243    ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<SamplingResponse>> + Send + '_>>
244    {
245        Box::pin(async move {
246            // Convert to protocol types
247            let params = fastmcp_protocol::CreateMessageParams {
248                messages: request
249                    .messages
250                    .into_iter()
251                    .map(|m| fastmcp_protocol::SamplingMessage {
252                        role: match m.role {
253                            SamplingRole::User => fastmcp_protocol::Role::User,
254                            SamplingRole::Assistant => fastmcp_protocol::Role::Assistant,
255                        },
256                        content: fastmcp_protocol::SamplingContent::Text { text: m.text },
257                    })
258                    .collect(),
259                max_tokens: request.max_tokens,
260                system_prompt: request.system_prompt,
261                temperature: request.temperature,
262                stop_sequences: request.stop_sequences,
263                model_preferences: if request.model_hints.is_empty() {
264                    None
265                } else {
266                    Some(fastmcp_protocol::ModelPreferences {
267                        hints: request
268                            .model_hints
269                            .into_iter()
270                            .map(|name| fastmcp_protocol::ModelHint { name: Some(name) })
271                            .collect(),
272                        ..Default::default()
273                    })
274                },
275                include_context: None,
276                meta: None,
277            };
278
279            let params_value = serde_json::to_value(&params)
280                .map_err(|e| McpError::internal_error(format!("Failed to serialize: {}", e)))?;
281
282            // Create a temporary Cx for the request
283            let cx = Cx::for_testing();
284
285            let result: fastmcp_protocol::CreateMessageResult =
286                self.sender
287                    .send_request(&cx, "sampling/createMessage", params_value)?;
288
289            Ok(SamplingResponse {
290                text: match result.content {
291                    fastmcp_protocol::SamplingContent::Text { text } => text,
292                    fastmcp_protocol::SamplingContent::Image { data, mime_type } => {
293                        format!("[image: {} bytes, type: {}]", data.len(), mime_type)
294                    }
295                },
296                model: result.model,
297                stop_reason: match result.stop_reason {
298                    fastmcp_protocol::StopReason::EndTurn => SamplingStopReason::EndTurn,
299                    fastmcp_protocol::StopReason::StopSequence => SamplingStopReason::StopSequence,
300                    fastmcp_protocol::StopReason::MaxTokens => SamplingStopReason::MaxTokens,
301                },
302            })
303        })
304    }
305}
306
307// ============================================================================
308// Elicitation Sender Implementation
309// ============================================================================
310
311/// Sends elicitation requests to the client via the transport.
312#[derive(Clone)]
313pub struct TransportElicitationSender {
314    sender: RequestSender,
315}
316
317impl TransportElicitationSender {
318    /// Creates a new transport-backed elicitation sender.
319    pub fn new(sender: RequestSender) -> Self {
320        Self { sender }
321    }
322}
323
324impl ElicitationSender for TransportElicitationSender {
325    fn elicit(
326        &self,
327        request: ElicitationRequest,
328    ) -> std::pin::Pin<
329        Box<dyn std::future::Future<Output = McpResult<ElicitationResponse>> + Send + '_>,
330    > {
331        Box::pin(async move {
332            let params_value = match request.mode {
333                ElicitationMode::Form => {
334                    let params = fastmcp_protocol::ElicitRequestFormParams {
335                        mode: fastmcp_protocol::ElicitMode::Form,
336                        message: request.message.clone(),
337                        requested_schema: request.schema.unwrap_or(serde_json::json!({})),
338                    };
339                    serde_json::to_value(&params).map_err(|e| {
340                        McpError::internal_error(format!("Failed to serialize: {}", e))
341                    })?
342                }
343                ElicitationMode::Url => {
344                    let params = fastmcp_protocol::ElicitRequestUrlParams {
345                        mode: fastmcp_protocol::ElicitMode::Url,
346                        message: request.message.clone(),
347                        url: request.url.unwrap_or_default(),
348                        elicitation_id: request.elicitation_id.unwrap_or_default(),
349                    };
350                    serde_json::to_value(&params).map_err(|e| {
351                        McpError::internal_error(format!("Failed to serialize: {}", e))
352                    })?
353                }
354            };
355
356            // Create a temporary Cx for the request
357            let cx = Cx::for_testing();
358
359            let result: fastmcp_protocol::ElicitResult =
360                self.sender
361                    .send_request(&cx, "elicitation/elicit", params_value)?;
362
363            // Convert HashMap<String, ElicitContentValue> to HashMap<String, serde_json::Value>
364            let content = result.content.map(|content_map| {
365                let mut map = std::collections::HashMap::new();
366                for (key, value) in content_map {
367                    let json_value = match value {
368                        fastmcp_protocol::ElicitContentValue::Null => serde_json::Value::Null,
369                        fastmcp_protocol::ElicitContentValue::Bool(b) => serde_json::Value::Bool(b),
370                        fastmcp_protocol::ElicitContentValue::Int(i) => {
371                            serde_json::Value::Number(i.into())
372                        }
373                        fastmcp_protocol::ElicitContentValue::Float(f) => {
374                            serde_json::Number::from_f64(f)
375                                .map(serde_json::Value::Number)
376                                .unwrap_or(serde_json::Value::Null)
377                        }
378                        fastmcp_protocol::ElicitContentValue::String(s) => {
379                            serde_json::Value::String(s)
380                        }
381                        fastmcp_protocol::ElicitContentValue::StringArray(arr) => {
382                            serde_json::Value::Array(
383                                arr.into_iter().map(serde_json::Value::String).collect(),
384                            )
385                        }
386                    };
387                    map.insert(key, json_value);
388                }
389                map
390            });
391
392            Ok(ElicitationResponse {
393                action: match result.action {
394                    fastmcp_protocol::ElicitAction::Accept => ElicitationAction::Accept,
395                    fastmcp_protocol::ElicitAction::Decline => ElicitationAction::Decline,
396                    fastmcp_protocol::ElicitAction::Cancel => ElicitationAction::Cancel,
397                },
398                content,
399            })
400        })
401    }
402}
403
404// ============================================================================
405// Roots Provider Implementation
406// ============================================================================
407
408/// Provider for filesystem roots from the client.
409#[derive(Clone)]
410pub struct TransportRootsProvider {
411    sender: RequestSender,
412}
413
414impl TransportRootsProvider {
415    /// Creates a new transport-backed roots provider.
416    pub fn new(sender: RequestSender) -> Self {
417        Self { sender }
418    }
419
420    /// Lists the filesystem roots from the client.
421    pub fn list_roots(&self) -> McpResult<Vec<fastmcp_protocol::Root>> {
422        let cx = Cx::for_testing();
423        let result: fastmcp_protocol::ListRootsResult =
424            self.sender
425                .send_request(&cx, "roots/list", serde_json::json!({}))?;
426        Ok(result.roots)
427    }
428}
429
430// ============================================================================
431// Tests
432// ============================================================================
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn test_pending_requests_register_and_route() {
440        let pending = PendingRequests::new();
441
442        // Register a request
443        let id = pending.next_request_id();
444        let receiver = pending.register(id.clone());
445
446        // Simulate a response
447        let response = JsonRpcResponse::success(id, serde_json::json!({"result": "ok"}));
448        assert!(pending.route_response(&response));
449
450        // Receive the response
451        let result = receiver.recv().unwrap();
452        assert!(result.is_ok());
453        assert_eq!(result.unwrap(), serde_json::json!({"result": "ok"}));
454    }
455
456    #[test]
457    fn test_pending_requests_error_response() {
458        let pending = PendingRequests::new();
459
460        let id = pending.next_request_id();
461        let receiver = pending.register(id.clone());
462
463        // Simulate an error response
464        let response = JsonRpcResponse::error(
465            Some(id),
466            JsonRpcError {
467                code: -32600,
468                message: "Invalid request".to_string(),
469                data: None,
470            },
471        );
472        assert!(pending.route_response(&response));
473
474        // Receive the error
475        let result = receiver.recv().unwrap();
476        assert!(result.is_err());
477        assert_eq!(result.unwrap_err().message, "Invalid request");
478    }
479
480    #[test]
481    fn test_pending_requests_cancel_all() {
482        let pending = PendingRequests::new();
483
484        let id1 = pending.next_request_id();
485        let id2 = pending.next_request_id();
486        let receiver1 = pending.register(id1);
487        let receiver2 = pending.register(id2);
488
489        // Cancel all
490        pending.cancel_all();
491
492        // Both should receive errors
493        let result1 = receiver1.recv().unwrap();
494        let result2 = receiver2.recv().unwrap();
495        assert!(result1.is_err());
496        assert!(result2.is_err());
497    }
498
499    #[test]
500    fn test_route_unknown_response() {
501        let pending = PendingRequests::new();
502
503        // Route a response with unknown ID
504        let response = JsonRpcResponse::success(
505            RequestId::Number(999999),
506            serde_json::json!({"result": "ok"}),
507        );
508        assert!(!pending.route_response(&response));
509    }
510}