mcp_host/server/
multiplexer.rs

1//! Request Multiplexer for bidirectional JSON-RPC communication
2//!
3//! Enables server→client requests (roots/list, sampling/createMessage, etc.)
4//! by tracking pending requests and routing responses back to callers.
5//!
6//! # Architecture
7//!
8//! The multiplexer sits between the server and transport, handling:
9//! - Client→Server: Normal requests (tools/call, resources/read, etc.)
10//! - Server→Client: Requests initiated by the server (roots/list, sampling/createMessage)
11//!
12//! For server-initiated requests, the multiplexer:
13//! 1. Generates a unique request ID (UUID)
14//! 2. Stores a oneshot channel in pending requests map
15//! 3. Sends the request via transport
16//! 4. When response arrives, routes it to the waiting channel
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! // Server requesting roots from client
22//! let roots = server.request_roots().await?;
23//! for root in roots {
24//!     println!("Client root: {} ({})", root.name, root.uri);
25//! }
26//!
27//! // Server requesting sampling from client
28//! let result = server.request_sampling(messages, preferences).await?;
29//! println!("Model response: {}", result.content);
30//! ```
31
32use std::sync::Arc;
33use std::time::Duration;
34
35use tokio::sync::mpsc;
36
37use dashmap::DashMap;
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use thiserror::Error;
41use tokio::sync::oneshot;
42
43use crate::protocol::types::JsonRpcResponse;
44
45/// Errors from server→client requests
46#[derive(Debug, Error)]
47pub enum MultiplexerError {
48    /// Request timed out waiting for response
49    #[error("request timed out after {0:?}")]
50    Timeout(Duration),
51
52    /// Client returned an error response
53    #[error("client error {code}: {message}")]
54    ClientError { code: i32, message: String },
55
56    /// Transport error
57    #[error("transport error: {0}")]
58    Transport(String),
59
60    /// Response channel was closed (internal error)
61    #[error("response channel closed")]
62    ChannelClosed,
63
64    /// Failed to serialize request
65    #[error("serialization error: {0}")]
66    Serialization(#[from] serde_json::Error),
67
68    /// Client doesn't support the requested capability
69    #[error("client does not support {0}")]
70    UnsupportedCapability(String),
71}
72
73/// A pending server-initiated request awaiting response
74pub struct PendingRequest {
75    /// Request ID (UUID string)
76    pub id: String,
77
78    /// Method name for logging/debugging
79    pub method: String,
80
81    /// Channel to send the response
82    pub response_tx: oneshot::Sender<Result<Value, MultiplexerError>>,
83
84    /// When this request was created (for timeout tracking)
85    pub created_at: std::time::Instant,
86}
87
88/// Request multiplexer for handling bidirectional communication
89///
90/// Tracks pending server→client requests and routes responses
91pub struct RequestMultiplexer {
92    /// Pending requests awaiting responses: request_id → PendingRequest
93    pending: DashMap<String, PendingRequest>,
94
95    /// Default timeout for requests
96    default_timeout: Duration,
97}
98
99impl Default for RequestMultiplexer {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl RequestMultiplexer {
106    /// Create a new request multiplexer
107    pub fn new() -> Self {
108        Self {
109            pending: DashMap::new(),
110            default_timeout: Duration::from_secs(30),
111        }
112    }
113
114    /// Create with custom default timeout
115    pub fn with_timeout(timeout: Duration) -> Self {
116        Self {
117            pending: DashMap::new(),
118            default_timeout: timeout,
119        }
120    }
121
122    /// Get default timeout
123    pub fn default_timeout(&self) -> Duration {
124        self.default_timeout
125    }
126
127    /// Create a new pending request and return the receiver
128    ///
129    /// Returns (request_id, receiver) - the caller should await the receiver
130    pub fn create_pending(&self, method: impl Into<String>) -> (String, oneshot::Receiver<Result<Value, MultiplexerError>>) {
131        let id = uuid::Uuid::new_v4().to_string();
132        let method = method.into();
133        let (tx, rx) = oneshot::channel();
134
135        let pending = PendingRequest {
136            id: id.clone(),
137            method,
138            response_tx: tx,
139            created_at: std::time::Instant::now(),
140        };
141
142        self.pending.insert(id.clone(), pending);
143
144        (id, rx)
145    }
146
147    /// Route an incoming response to its pending request
148    ///
149    /// Returns true if the response was routed, false if no matching request found
150    pub fn route_response(&self, response: &JsonRpcResponse) -> bool {
151        // Extract request ID from response
152        let id = match &response.id {
153            Value::String(s) => s.clone(),
154            Value::Number(n) => n.to_string(),
155            _ => return false,
156        };
157
158        // Find and remove the pending request
159        if let Some((_, pending)) = self.pending.remove(&id) {
160            // Build result or error
161            let result = if let Some(ref error) = response.error {
162                Err(MultiplexerError::ClientError {
163                    code: error.code,
164                    message: error.message.clone(),
165                })
166            } else if let Some(ref result) = response.result {
167                Ok(result.clone())
168            } else {
169                // Empty response (no result, no error) - treat as empty object
170                Ok(Value::Object(serde_json::Map::new()))
171            };
172
173            // Send to waiting caller (ignore if receiver dropped)
174            let _ = pending.response_tx.send(result);
175
176            true
177        } else {
178            false
179        }
180    }
181
182    /// Check if a response ID matches a pending request
183    ///
184    /// Used to distinguish client responses from client requests in the message loop
185    pub fn is_pending_response(&self, id: &Value) -> bool {
186        let id_str = match id {
187            Value::String(s) => s.clone(),
188            Value::Number(n) => n.to_string(),
189            _ => return false,
190        };
191
192        self.pending.contains_key(&id_str)
193    }
194
195    /// Get number of pending requests
196    pub fn pending_count(&self) -> usize {
197        self.pending.len()
198    }
199
200    /// Cancel a pending request
201    pub fn cancel(&self, id: &str) {
202        if let Some((_, pending)) = self.pending.remove(id) {
203            let _ = pending.response_tx.send(Err(MultiplexerError::ChannelClosed));
204        }
205    }
206
207    /// Cancel all pending requests
208    pub fn cancel_all(&self) {
209        let ids: Vec<String> = self.pending.iter().map(|e| e.key().clone()).collect();
210        for id in ids {
211            self.cancel(&id);
212        }
213    }
214
215    /// Clean up timed-out requests
216    ///
217    /// Returns the number of requests that were cleaned up
218    pub fn cleanup_timed_out(&self, timeout: Duration) -> usize {
219        let now = std::time::Instant::now();
220        let mut cleaned = 0;
221
222        // Collect IDs to remove (can't modify during iteration with DashMap)
223        let timed_out: Vec<String> = self
224            .pending
225            .iter()
226            .filter(|e| now.duration_since(e.created_at) > timeout)
227            .map(|e| e.key().clone())
228            .collect();
229
230        for id in timed_out {
231            if let Some((_, pending)) = self.pending.remove(&id) {
232                let _ = pending
233                    .response_tx
234                    .send(Err(MultiplexerError::Timeout(timeout)));
235                cleaned += 1;
236            }
237        }
238
239        cleaned
240    }
241}
242
243/// Server→client request (what we send to the client)
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct JsonRpcClientRequest {
246    /// JSON-RPC version
247    pub jsonrpc: String,
248
249    /// Request ID (UUID string for server-initiated requests)
250    pub id: String,
251
252    /// Method name
253    pub method: String,
254
255    /// Request parameters
256    #[serde(skip_serializing_if = "Option::is_none")]
257    pub params: Option<Value>,
258}
259
260impl JsonRpcClientRequest {
261    /// Create a new client request
262    pub fn new(id: impl Into<String>, method: impl Into<String>, params: Option<Value>) -> Self {
263        Self {
264            jsonrpc: "2.0".to_string(),
265            id: id.into(),
266            method: method.into(),
267            params,
268        }
269    }
270}
271
272/// Root entry from roots/list response
273#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct Root {
275    /// URI of the root (e.g., "file:///workspace")
276    pub uri: String,
277
278    /// Human-readable name for the root
279    #[serde(skip_serializing_if = "Option::is_none")]
280    pub name: Option<String>,
281}
282
283/// Result of roots/list request
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct ListRootsResult {
286    /// List of workspace roots
287    pub roots: Vec<Root>,
288}
289
290/// Sampling message for createMessage request
291#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct SamplingMessage {
293    /// Role: "user" or "assistant"
294    pub role: String,
295
296    /// Message content
297    pub content: SamplingContent,
298}
299
300/// Content in a sampling message
301#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(tag = "type")]
303pub enum SamplingContent {
304    /// Text content
305    #[serde(rename = "text")]
306    Text { text: String },
307
308    /// Image content
309    #[serde(rename = "image")]
310    Image { data: String, mime_type: String },
311}
312
313/// Model preferences for sampling
314#[derive(Debug, Clone, Default, Serialize, Deserialize)]
315#[serde(rename_all = "camelCase")]
316pub struct ModelPreferences {
317    /// Preferred model hints
318    #[serde(skip_serializing_if = "Option::is_none")]
319    pub hints: Option<Vec<ModelHint>>,
320
321    /// Cost priority (0.0 = prefer cheap, 1.0 = prefer quality)
322    #[serde(skip_serializing_if = "Option::is_none")]
323    pub cost_priority: Option<f64>,
324
325    /// Speed priority (0.0 = prefer slow, 1.0 = prefer fast)
326    #[serde(skip_serializing_if = "Option::is_none")]
327    pub speed_priority: Option<f64>,
328
329    /// Intelligence priority (0.0 = prefer simple, 1.0 = prefer capable)
330    #[serde(skip_serializing_if = "Option::is_none")]
331    pub intelligence_priority: Option<f64>,
332}
333
334/// Model hint for sampling
335#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct ModelHint {
337    /// Model name pattern
338    #[serde(skip_serializing_if = "Option::is_none")]
339    pub name: Option<String>,
340}
341
342/// Parameters for sampling/createMessage request
343#[derive(Debug, Clone, Serialize, Deserialize)]
344#[serde(rename_all = "camelCase")]
345pub struct CreateMessageParams {
346    /// Messages to send
347    pub messages: Vec<SamplingMessage>,
348
349    /// Model preferences
350    #[serde(skip_serializing_if = "Option::is_none")]
351    pub model_preferences: Option<ModelPreferences>,
352
353    /// System prompt
354    #[serde(skip_serializing_if = "Option::is_none")]
355    pub system_prompt: Option<String>,
356
357    /// Include context from MCP servers
358    #[serde(skip_serializing_if = "Option::is_none")]
359    pub include_context: Option<String>,
360
361    /// Temperature
362    #[serde(skip_serializing_if = "Option::is_none")]
363    pub temperature: Option<f64>,
364
365    /// Maximum tokens to generate
366    pub max_tokens: i32,
367
368    /// Stop sequences
369    #[serde(skip_serializing_if = "Option::is_none")]
370    pub stop_sequences: Option<Vec<String>>,
371}
372
373/// Result of sampling/createMessage request
374#[derive(Debug, Clone, Serialize, Deserialize)]
375#[serde(rename_all = "camelCase")]
376pub struct CreateMessageResult {
377    /// Role of the response
378    pub role: String,
379
380    /// Response content
381    pub content: SamplingContent,
382
383    /// Model that was used
384    pub model: String,
385
386    /// Reason for stopping
387    #[serde(skip_serializing_if = "Option::is_none")]
388    pub stop_reason: Option<String>,
389}
390
391/// Client requester for making server→client requests from tools
392///
393/// This is passed to tools via ExecutionContext, allowing them to
394/// make requests to the client (roots/list, sampling/createMessage, etc.)
395///
396/// # Example
397///
398/// ```rust,ignore
399/// impl Tool for ListRootsTool {
400///     async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
401///         if let Some(requester) = ctx.client_requester() {
402///             match requester.request_roots(None).await {
403///                 Ok(roots) => {
404///                     let msg = format!("Found {} roots", roots.len());
405///                     Ok(vec![Box::new(TextContent::new(msg))])
406///                 }
407///                 Err(e) => Err(ToolError::Execution(e.to_string())),
408///             }
409///         } else {
410///             Err(ToolError::Execution("No client requester available".into()))
411///         }
412///     }
413/// }
414/// ```
415#[derive(Clone)]
416pub struct ClientRequester {
417    /// Channel to send requests to the transport
418    request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
419
420    /// Multiplexer for tracking pending requests
421    multiplexer: Arc<RequestMultiplexer>,
422
423    /// Whether the client supports roots capability
424    supports_roots: bool,
425
426    /// Whether the client supports sampling capability
427    supports_sampling: bool,
428}
429
430impl ClientRequester {
431    /// Create a new client requester
432    pub fn new(
433        request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
434        multiplexer: Arc<RequestMultiplexer>,
435        supports_roots: bool,
436        supports_sampling: bool,
437    ) -> Self {
438        Self {
439            request_tx,
440            multiplexer,
441            supports_roots,
442            supports_sampling,
443        }
444    }
445
446    /// Check if client supports roots capability
447    pub fn supports_roots(&self) -> bool {
448        self.supports_roots
449    }
450
451    /// Check if client supports sampling capability
452    pub fn supports_sampling(&self) -> bool {
453        self.supports_sampling
454    }
455
456    /// Request workspace roots from the client
457    ///
458    /// Returns an error if the client doesn't support roots capability.
459    pub async fn request_roots(
460        &self,
461        timeout: Option<Duration>,
462    ) -> Result<Vec<Root>, MultiplexerError> {
463        if !self.supports_roots {
464            return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
465        }
466
467        // Create pending request
468        let (id, rx) = self.multiplexer.create_pending("roots/list");
469
470        // Build and send the request
471        let request = JsonRpcClientRequest::new(&id, "roots/list", Some(serde_json::json!({})));
472
473        self.request_tx
474            .send(request)
475            .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
476
477        // Wait for response with timeout
478        let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
479        let result = tokio::time::timeout(timeout, rx)
480            .await
481            .map_err(|_| MultiplexerError::Timeout(timeout))?
482            .map_err(|_| MultiplexerError::ChannelClosed)??;
483
484        // Parse the result
485        let list_result: ListRootsResult = serde_json::from_value(result)?;
486        Ok(list_result.roots)
487    }
488
489    /// Request an LLM completion from the client
490    ///
491    /// Returns an error if the client doesn't support sampling capability.
492    pub async fn request_sampling(
493        &self,
494        params: CreateMessageParams,
495        timeout: Option<Duration>,
496    ) -> Result<CreateMessageResult, MultiplexerError> {
497        if !self.supports_sampling {
498            return Err(MultiplexerError::UnsupportedCapability("sampling".to_string()));
499        }
500
501        // Create pending request
502        let (id, rx) = self.multiplexer.create_pending("sampling/createMessage");
503
504        // Build and send the request
505        let params_value = serde_json::to_value(&params)?;
506        let request = JsonRpcClientRequest::new(&id, "sampling/createMessage", Some(params_value));
507
508        self.request_tx
509            .send(request)
510            .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
511
512        // Wait for response with timeout
513        let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
514        let result = tokio::time::timeout(timeout, rx)
515            .await
516            .map_err(|_| MultiplexerError::Timeout(timeout))?
517            .map_err(|_| MultiplexerError::ChannelClosed)??;
518
519        // Parse the result
520        let create_result: CreateMessageResult = serde_json::from_value(result)?;
521        Ok(create_result)
522    }
523
524    /// Request user input elicitation from the client
525    ///
526    /// Sends an `elicitation/create` request to the client and waits for the response.
527    /// The client must have the `elicitation` capability advertised.
528    ///
529    /// # Arguments
530    ///
531    /// * `message` - The message to show the user
532    /// * `schema` - The schema defining the structure of requested input
533    /// * `timeout` - Optional timeout for the request
534    ///
535    /// # Returns
536    ///
537    /// The user's response (accept/decline/cancel with optional data), or an error if:
538    /// - Client doesn't support elicitation capability
539    /// - Request times out
540    /// - Transport error occurs
541    pub async fn request_elicitation(
542        &self,
543        message: String,
544        requested_schema: Value,
545        timeout: Option<Duration>,
546    ) -> Result<crate::protocol::types::CreateElicitationResult, MultiplexerError> {
547        // Note: We'd need to track elicitation capability during initialization
548        // For now, just attempt the request
549
550        // Create pending request
551        let (id, rx) = self.multiplexer.create_pending("elicitation/create");
552
553        // Build request params
554        let params = serde_json::json!({
555            "message": message,
556            "requestedSchema": requested_schema,
557        });
558
559        let request = JsonRpcClientRequest::new(&id, "elicitation/create", Some(params));
560
561        self.request_tx
562            .send(request)
563            .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
564
565        // Wait for response with timeout
566        let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
567        let result = tokio::time::timeout(timeout, rx)
568            .await
569            .map_err(|_| MultiplexerError::Timeout(timeout))?
570            .map_err(|_| MultiplexerError::ChannelClosed)??;
571
572        // Parse the result
573        let elicitation_result: crate::protocol::types::CreateElicitationResult =
574            serde_json::from_value(result)?;
575        Ok(elicitation_result)
576    }
577}
578
579#[cfg(test)]
580mod tests {
581    use super::*;
582
583    #[test]
584    fn test_multiplexer_create_pending() {
585        let mux = RequestMultiplexer::new();
586
587        let (id1, _rx1) = mux.create_pending("roots/list");
588        let (id2, _rx2) = mux.create_pending("sampling/createMessage");
589
590        assert_ne!(id1, id2);
591        assert_eq!(mux.pending_count(), 2);
592    }
593
594    #[tokio::test]
595    async fn test_multiplexer_route_response() {
596        let mux = RequestMultiplexer::new();
597
598        let (id, rx) = mux.create_pending("test/method");
599
600        // Create a response
601        let response = JsonRpcResponse {
602            jsonrpc: "2.0".to_string(),
603            id: Value::String(id.clone()),
604            result: Some(serde_json::json!({"status": "ok"})),
605            error: None,
606        };
607
608        // Route it
609        assert!(mux.route_response(&response));
610        assert_eq!(mux.pending_count(), 0);
611
612        // Receiver should get the result
613        let result = rx.await.unwrap().unwrap();
614        assert_eq!(result["status"], "ok");
615    }
616
617    #[tokio::test]
618    async fn test_multiplexer_route_error() {
619        let mux = RequestMultiplexer::new();
620
621        let (id, rx) = mux.create_pending("test/method");
622
623        // Create an error response
624        let response = JsonRpcResponse {
625            jsonrpc: "2.0".to_string(),
626            id: Value::String(id.clone()),
627            result: None,
628            error: Some(crate::protocol::types::JsonRpcError {
629                code: -32600,
630                message: "Invalid request".to_string(),
631                data: None,
632            }),
633        };
634
635        // Route it
636        assert!(mux.route_response(&response));
637
638        // Receiver should get the error
639        let result = rx.await.unwrap();
640        assert!(matches!(result, Err(MultiplexerError::ClientError { code: -32600, .. })));
641    }
642
643    #[test]
644    fn test_multiplexer_is_pending() {
645        let mux = RequestMultiplexer::new();
646
647        let (id, _rx) = mux.create_pending("test");
648
649        assert!(mux.is_pending_response(&Value::String(id.clone())));
650        assert!(!mux.is_pending_response(&Value::String("unknown".to_string())));
651    }
652
653    #[test]
654    fn test_multiplexer_cancel() {
655        let mux = RequestMultiplexer::new();
656
657        let (id, _rx) = mux.create_pending("test");
658        assert_eq!(mux.pending_count(), 1);
659
660        mux.cancel(&id);
661        assert_eq!(mux.pending_count(), 0);
662    }
663
664    #[test]
665    fn test_client_request_serialization() {
666        let req = JsonRpcClientRequest::new(
667            "abc-123",
668            "roots/list",
669            Some(serde_json::json!({})),
670        );
671
672        let json = serde_json::to_string(&req).unwrap();
673        assert!(json.contains("\"jsonrpc\":\"2.0\""));
674        assert!(json.contains("\"id\":\"abc-123\""));
675        assert!(json.contains("\"method\":\"roots/list\""));
676    }
677
678    #[test]
679    fn test_root_deserialization() {
680        let json = r#"{"uri": "file:///workspace", "name": "My Project"}"#;
681        let root: Root = serde_json::from_str(json).unwrap();
682
683        assert_eq!(root.uri, "file:///workspace");
684        assert_eq!(root.name, Some("My Project".to_string()));
685    }
686
687    #[test]
688    fn test_sampling_content() {
689        let content = SamplingContent::Text {
690            text: "Hello, world!".to_string(),
691        };
692
693        let json = serde_json::to_string(&content).unwrap();
694        assert!(json.contains("\"type\":\"text\""));
695        assert!(json.contains("\"text\":\"Hello, world!\""));
696    }
697}