mcp_host/server/
multiplexer.rs

1//! Request Multiplexer for bidirectional JSON-RPC communication
2//!
3//! Enables server→client requests (roots/list, sampling/createMessage, etc.)
4//! by tracking pending requests and routing responses back to callers.
5//!
6//! # Architecture
7//!
8//! The multiplexer sits between the server and transport, handling:
9//! - Client→Server: Normal requests (tools/call, resources/read, etc.)
10//! - Server→Client: Requests initiated by the server (roots/list, sampling/createMessage)
11//!
12//! For server-initiated requests, the multiplexer:
13//! 1. Generates a unique request ID (UUID)
14//! 2. Stores a oneshot channel in pending requests map
15//! 3. Sends the request via transport
16//! 4. When response arrives, routes it to the waiting channel
17//!
18//! # Example
19//!
20//! ```rust,ignore
21//! // Server requesting roots from client
22//! let roots = server.request_roots().await?;
23//! for root in roots {
24//!     println!("Client root: {} ({})", root.name, root.uri);
25//! }
26//!
27//! // Server requesting sampling from client
28//! let result = server.request_sampling(messages, preferences).await?;
29//! println!("Model response: {}", result.content);
30//! ```
31
32use std::sync::Arc;
33use std::time::Duration;
34
35use tokio::sync::mpsc;
36
37use dashmap::DashMap;
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use thiserror::Error;
41use tokio::sync::oneshot;
42
43use crate::protocol::types::JsonRpcResponse;
44
45/// Errors from server→client requests
46#[derive(Debug, Error)]
47pub enum MultiplexerError {
48    /// Request timed out waiting for response
49    #[error("request timed out after {0:?}")]
50    Timeout(Duration),
51
52    /// Client returned an error response
53    #[error("client error {code}: {message}")]
54    ClientError { code: i32, message: String },
55
56    /// Transport error
57    #[error("transport error: {0}")]
58    Transport(String),
59
60    /// Response channel was closed (internal error)
61    #[error("response channel closed")]
62    ChannelClosed,
63
64    /// Failed to serialize request
65    #[error("serialization error: {0}")]
66    Serialization(#[from] serde_json::Error),
67
68    /// Client doesn't support the requested capability
69    #[error("client does not support {0}")]
70    UnsupportedCapability(String),
71}
72
73/// A pending server-initiated request awaiting response
74pub struct PendingRequest {
75    /// Request ID (UUID string)
76    pub id: String,
77
78    /// Method name for logging/debugging
79    pub method: String,
80
81    /// Channel to send the response
82    pub response_tx: oneshot::Sender<Result<Value, MultiplexerError>>,
83
84    /// When this request was created (for timeout tracking)
85    pub created_at: std::time::Instant,
86}
87
88/// Request multiplexer for handling bidirectional communication
89///
90/// Tracks pending server→client requests and routes responses
91pub struct RequestMultiplexer {
92    /// Pending requests awaiting responses: request_id → PendingRequest
93    pending: DashMap<String, PendingRequest>,
94
95    /// Default timeout for requests
96    default_timeout: Duration,
97}
98
99impl Default for RequestMultiplexer {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl RequestMultiplexer {
106    /// Create a new request multiplexer
107    pub fn new() -> Self {
108        Self {
109            pending: DashMap::new(),
110            default_timeout: Duration::from_secs(30),
111        }
112    }
113
114    /// Create with custom default timeout
115    pub fn with_timeout(timeout: Duration) -> Self {
116        Self {
117            pending: DashMap::new(),
118            default_timeout: timeout,
119        }
120    }
121
122    /// Get default timeout
123    pub fn default_timeout(&self) -> Duration {
124        self.default_timeout
125    }
126
127    /// Create a new pending request and return the receiver
128    ///
129    /// Returns (request_id, receiver) - the caller should await the receiver
130    pub fn create_pending(
131        &self,
132        method: impl Into<String>,
133    ) -> (String, oneshot::Receiver<Result<Value, MultiplexerError>>) {
134        let id = uuid::Uuid::new_v4().to_string();
135        let method = method.into();
136        let (tx, rx) = oneshot::channel();
137
138        let pending = PendingRequest {
139            id: id.clone(),
140            method,
141            response_tx: tx,
142            created_at: std::time::Instant::now(),
143        };
144
145        self.pending.insert(id.clone(), pending);
146
147        (id, rx)
148    }
149
150    /// Route an incoming response to its pending request
151    ///
152    /// Returns true if the response was routed, false if no matching request found
153    pub fn route_response(&self, response: &JsonRpcResponse) -> bool {
154        // Extract request ID from response
155        let id = match &response.id {
156            Value::String(s) => s.clone(),
157            Value::Number(n) => n.to_string(),
158            _ => return false,
159        };
160
161        // Find and remove the pending request
162        if let Some((_, pending)) = self.pending.remove(&id) {
163            // Build result or error
164            let result = if let Some(ref error) = response.error {
165                Err(MultiplexerError::ClientError {
166                    code: error.code,
167                    message: error.message.clone(),
168                })
169            } else if let Some(ref result) = response.result {
170                Ok(result.clone())
171            } else {
172                // Empty response (no result, no error) - treat as empty object
173                Ok(Value::Object(serde_json::Map::new()))
174            };
175
176            // Send to waiting caller (ignore if receiver dropped)
177            let _ = pending.response_tx.send(result);
178
179            true
180        } else {
181            false
182        }
183    }
184
185    /// Check if a response ID matches a pending request
186    ///
187    /// Used to distinguish client responses from client requests in the message loop
188    pub fn is_pending_response(&self, id: &Value) -> bool {
189        let id_str = match id {
190            Value::String(s) => s.clone(),
191            Value::Number(n) => n.to_string(),
192            _ => return false,
193        };
194
195        self.pending.contains_key(&id_str)
196    }
197
198    /// Get number of pending requests
199    pub fn pending_count(&self) -> usize {
200        self.pending.len()
201    }
202
203    /// Cancel a pending request
204    pub fn cancel(&self, id: &str) {
205        if let Some((_, pending)) = self.pending.remove(id) {
206            let _ = pending
207                .response_tx
208                .send(Err(MultiplexerError::ChannelClosed));
209        }
210    }
211
212    /// Cancel all pending requests
213    pub fn cancel_all(&self) {
214        let ids: Vec<String> = self.pending.iter().map(|e| e.key().clone()).collect();
215        for id in ids {
216            self.cancel(&id);
217        }
218    }
219
220    /// Clean up timed-out requests
221    ///
222    /// Returns the number of requests that were cleaned up
223    pub fn cleanup_timed_out(&self, timeout: Duration) -> usize {
224        let now = std::time::Instant::now();
225        let mut cleaned = 0;
226
227        // Collect IDs to remove (can't modify during iteration with DashMap)
228        let timed_out: Vec<String> = self
229            .pending
230            .iter()
231            .filter(|e| now.duration_since(e.created_at) > timeout)
232            .map(|e| e.key().clone())
233            .collect();
234
235        for id in timed_out {
236            if let Some((_, pending)) = self.pending.remove(&id) {
237                let _ = pending
238                    .response_tx
239                    .send(Err(MultiplexerError::Timeout(timeout)));
240                cleaned += 1;
241            }
242        }
243
244        cleaned
245    }
246}
247
248/// Server→client request (what we send to the client)
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct JsonRpcClientRequest {
251    /// JSON-RPC version
252    pub jsonrpc: String,
253
254    /// Request ID (UUID string for server-initiated requests)
255    pub id: String,
256
257    /// Method name
258    pub method: String,
259
260    /// Request parameters
261    #[serde(skip_serializing_if = "Option::is_none")]
262    pub params: Option<Value>,
263}
264
265impl JsonRpcClientRequest {
266    /// Create a new client request
267    pub fn new(id: impl Into<String>, method: impl Into<String>, params: Option<Value>) -> Self {
268        Self {
269            jsonrpc: "2.0".to_string(),
270            id: id.into(),
271            method: method.into(),
272            params,
273        }
274    }
275}
276
277/// Root entry from roots/list response
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct Root {
280    /// URI of the root (e.g., "file:///workspace")
281    pub uri: String,
282
283    /// Human-readable name for the root
284    #[serde(skip_serializing_if = "Option::is_none")]
285    pub name: Option<String>,
286}
287
288/// Result of roots/list request
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct ListRootsResult {
291    /// List of workspace roots
292    pub roots: Vec<Root>,
293}
294
295/// Sampling message for createMessage request
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct SamplingMessage {
298    /// Role: "user" or "assistant"
299    pub role: String,
300
301    /// Message content
302    pub content: SamplingContent,
303}
304
305/// Content in a sampling message
306#[derive(Debug, Clone, Serialize, Deserialize)]
307#[serde(tag = "type")]
308pub enum SamplingContent {
309    /// Text content
310    #[serde(rename = "text")]
311    Text { text: String },
312
313    /// Image content
314    #[serde(rename = "image")]
315    Image { data: String, mime_type: String },
316}
317
318/// Model preferences for sampling
319#[derive(Debug, Clone, Default, Serialize, Deserialize)]
320#[serde(rename_all = "camelCase")]
321pub struct ModelPreferences {
322    /// Preferred model hints
323    #[serde(skip_serializing_if = "Option::is_none")]
324    pub hints: Option<Vec<ModelHint>>,
325
326    /// Cost priority (0.0 = prefer cheap, 1.0 = prefer quality)
327    #[serde(skip_serializing_if = "Option::is_none")]
328    pub cost_priority: Option<f64>,
329
330    /// Speed priority (0.0 = prefer slow, 1.0 = prefer fast)
331    #[serde(skip_serializing_if = "Option::is_none")]
332    pub speed_priority: Option<f64>,
333
334    /// Intelligence priority (0.0 = prefer simple, 1.0 = prefer capable)
335    #[serde(skip_serializing_if = "Option::is_none")]
336    pub intelligence_priority: Option<f64>,
337}
338
339/// Model hint for sampling
340#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct ModelHint {
342    /// Model name pattern
343    #[serde(skip_serializing_if = "Option::is_none")]
344    pub name: Option<String>,
345}
346
347/// Parameters for sampling/createMessage request
348#[derive(Debug, Clone, Serialize, Deserialize)]
349#[serde(rename_all = "camelCase")]
350pub struct CreateMessageParams {
351    /// Messages to send
352    pub messages: Vec<SamplingMessage>,
353
354    /// Model preferences
355    #[serde(skip_serializing_if = "Option::is_none")]
356    pub model_preferences: Option<ModelPreferences>,
357
358    /// System prompt
359    #[serde(skip_serializing_if = "Option::is_none")]
360    pub system_prompt: Option<String>,
361
362    /// Include context from MCP servers
363    #[serde(skip_serializing_if = "Option::is_none")]
364    pub include_context: Option<String>,
365
366    /// Temperature
367    #[serde(skip_serializing_if = "Option::is_none")]
368    pub temperature: Option<f64>,
369
370    /// Maximum tokens to generate
371    pub max_tokens: i32,
372
373    /// Stop sequences
374    #[serde(skip_serializing_if = "Option::is_none")]
375    pub stop_sequences: Option<Vec<String>>,
376}
377
378/// Result of sampling/createMessage request
379#[derive(Debug, Clone, Serialize, Deserialize)]
380#[serde(rename_all = "camelCase")]
381pub struct CreateMessageResult {
382    /// Role of the response
383    pub role: String,
384
385    /// Response content
386    pub content: SamplingContent,
387
388    /// Model that was used
389    pub model: String,
390
391    /// Reason for stopping
392    #[serde(skip_serializing_if = "Option::is_none")]
393    pub stop_reason: Option<String>,
394}
395
396/// Client requester for making server→client requests from tools
397///
398/// This is passed to tools via ExecutionContext, allowing them to
399/// make requests to the client (roots/list, sampling/createMessage, etc.)
400///
401/// # Example
402///
403/// ```rust,ignore
404/// impl Tool for ListRootsTool {
405///     async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
406///         if let Some(requester) = ctx.client_requester() {
407///             match requester.request_roots(None).await {
408///                 Ok(roots) => {
409///                     let msg = format!("Found {} roots", roots.len());
410///                     Ok(vec![Box::new(TextContent::new(msg))])
411///                 }
412///                 Err(e) => Err(ToolError::Execution(e.to_string())),
413///             }
414///         } else {
415///             Err(ToolError::Execution("No client requester available".into()))
416///         }
417///     }
418/// }
419/// ```
420#[derive(Clone)]
421pub struct ClientRequester {
422    /// Channel to send requests to the transport
423    request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
424
425    /// Multiplexer for tracking pending requests
426    multiplexer: Arc<RequestMultiplexer>,
427
428    /// Whether the client supports roots capability
429    supports_roots: bool,
430
431    /// Whether the client supports sampling capability
432    supports_sampling: bool,
433}
434
435impl ClientRequester {
436    /// Create a new client requester
437    pub fn new(
438        request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
439        multiplexer: Arc<RequestMultiplexer>,
440        supports_roots: bool,
441        supports_sampling: bool,
442    ) -> Self {
443        Self {
444            request_tx,
445            multiplexer,
446            supports_roots,
447            supports_sampling,
448        }
449    }
450
451    /// Check if client supports roots capability
452    pub fn supports_roots(&self) -> bool {
453        self.supports_roots
454    }
455
456    /// Check if client supports sampling capability
457    pub fn supports_sampling(&self) -> bool {
458        self.supports_sampling
459    }
460
461    /// Request workspace roots from the client
462    ///
463    /// Returns an error if the client doesn't support roots capability.
464    pub async fn request_roots(
465        &self,
466        timeout: Option<Duration>,
467    ) -> Result<Vec<Root>, MultiplexerError> {
468        if !self.supports_roots {
469            return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
470        }
471
472        // Create pending request
473        let (id, rx) = self.multiplexer.create_pending("roots/list");
474
475        // Build and send the request
476        let request = JsonRpcClientRequest::new(&id, "roots/list", Some(serde_json::json!({})));
477
478        self.request_tx
479            .send(request)
480            .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
481
482        // Wait for response with timeout
483        let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
484        let result = tokio::time::timeout(timeout, rx)
485            .await
486            .map_err(|_| MultiplexerError::Timeout(timeout))?
487            .map_err(|_| MultiplexerError::ChannelClosed)??;
488
489        // Parse the result
490        let list_result: ListRootsResult = serde_json::from_value(result)?;
491        Ok(list_result.roots)
492    }
493
494    /// Request an LLM completion from the client
495    ///
496    /// Returns an error if the client doesn't support sampling capability.
497    pub async fn request_sampling(
498        &self,
499        params: CreateMessageParams,
500        timeout: Option<Duration>,
501    ) -> Result<CreateMessageResult, MultiplexerError> {
502        if !self.supports_sampling {
503            return Err(MultiplexerError::UnsupportedCapability(
504                "sampling".to_string(),
505            ));
506        }
507
508        // Create pending request
509        let (id, rx) = self.multiplexer.create_pending("sampling/createMessage");
510
511        // Build and send the request
512        let params_value = serde_json::to_value(&params)?;
513        let request = JsonRpcClientRequest::new(&id, "sampling/createMessage", Some(params_value));
514
515        self.request_tx
516            .send(request)
517            .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
518
519        // Wait for response with timeout
520        let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
521        let result = tokio::time::timeout(timeout, rx)
522            .await
523            .map_err(|_| MultiplexerError::Timeout(timeout))?
524            .map_err(|_| MultiplexerError::ChannelClosed)??;
525
526        // Parse the result
527        let create_result: CreateMessageResult = serde_json::from_value(result)?;
528        Ok(create_result)
529    }
530
531    /// Request user input elicitation from the client
532    ///
533    /// Sends an `elicitation/create` request to the client and waits for the response.
534    /// The client must have the `elicitation` capability advertised.
535    ///
536    /// # Arguments
537    ///
538    /// * `message` - The message to show the user
539    /// * `schema` - The schema defining the structure of requested input
540    /// * `timeout` - Optional timeout for the request
541    ///
542    /// # Returns
543    ///
544    /// The user's response (accept/decline/cancel with optional data), or an error if:
545    /// - Client doesn't support elicitation capability
546    /// - Request times out
547    /// - Transport error occurs
548    pub async fn request_elicitation(
549        &self,
550        message: String,
551        requested_schema: Value,
552        timeout: Option<Duration>,
553    ) -> Result<crate::protocol::types::CreateElicitationResult, MultiplexerError> {
554        // Note: We'd need to track elicitation capability during initialization
555        // For now, just attempt the request
556
557        // Create pending request
558        let (id, rx) = self.multiplexer.create_pending("elicitation/create");
559
560        // Build request params
561        let params = serde_json::json!({
562            "message": message,
563            "requestedSchema": requested_schema,
564        });
565
566        let request = JsonRpcClientRequest::new(&id, "elicitation/create", Some(params));
567
568        self.request_tx
569            .send(request)
570            .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
571
572        // Wait for response with timeout
573        let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
574        let result = tokio::time::timeout(timeout, rx)
575            .await
576            .map_err(|_| MultiplexerError::Timeout(timeout))?
577            .map_err(|_| MultiplexerError::ChannelClosed)??;
578
579        // Parse the result
580        let elicitation_result: crate::protocol::types::CreateElicitationResult =
581            serde_json::from_value(result)?;
582        Ok(elicitation_result)
583    }
584}
585
586#[cfg(test)]
587mod tests {
588    use super::*;
589
590    #[test]
591    fn test_multiplexer_create_pending() {
592        let mux = RequestMultiplexer::new();
593
594        let (id1, _rx1) = mux.create_pending("roots/list");
595        let (id2, _rx2) = mux.create_pending("sampling/createMessage");
596
597        assert_ne!(id1, id2);
598        assert_eq!(mux.pending_count(), 2);
599    }
600
601    #[tokio::test]
602    async fn test_multiplexer_route_response() {
603        let mux = RequestMultiplexer::new();
604
605        let (id, rx) = mux.create_pending("test/method");
606
607        // Create a response
608        let response = JsonRpcResponse {
609            jsonrpc: "2.0".to_string(),
610            id: Value::String(id.clone()),
611            result: Some(serde_json::json!({"status": "ok"})),
612            error: None,
613        };
614
615        // Route it
616        assert!(mux.route_response(&response));
617        assert_eq!(mux.pending_count(), 0);
618
619        // Receiver should get the result
620        let result = rx.await.unwrap().unwrap();
621        assert_eq!(result["status"], "ok");
622    }
623
624    #[tokio::test]
625    async fn test_multiplexer_route_error() {
626        let mux = RequestMultiplexer::new();
627
628        let (id, rx) = mux.create_pending("test/method");
629
630        // Create an error response
631        let response = JsonRpcResponse {
632            jsonrpc: "2.0".to_string(),
633            id: Value::String(id.clone()),
634            result: None,
635            error: Some(crate::protocol::types::JsonRpcError {
636                code: -32600,
637                message: "Invalid request".to_string(),
638                data: None,
639            }),
640        };
641
642        // Route it
643        assert!(mux.route_response(&response));
644
645        // Receiver should get the error
646        let result = rx.await.unwrap();
647        assert!(matches!(
648            result,
649            Err(MultiplexerError::ClientError { code: -32600, .. })
650        ));
651    }
652
653    #[test]
654    fn test_multiplexer_is_pending() {
655        let mux = RequestMultiplexer::new();
656
657        let (id, _rx) = mux.create_pending("test");
658
659        assert!(mux.is_pending_response(&Value::String(id.clone())));
660        assert!(!mux.is_pending_response(&Value::String("unknown".to_string())));
661    }
662
663    #[test]
664    fn test_multiplexer_cancel() {
665        let mux = RequestMultiplexer::new();
666
667        let (id, _rx) = mux.create_pending("test");
668        assert_eq!(mux.pending_count(), 1);
669
670        mux.cancel(&id);
671        assert_eq!(mux.pending_count(), 0);
672    }
673
674    #[test]
675    fn test_client_request_serialization() {
676        let req = JsonRpcClientRequest::new("abc-123", "roots/list", Some(serde_json::json!({})));
677
678        let json = serde_json::to_string(&req).unwrap();
679        assert!(json.contains("\"jsonrpc\":\"2.0\""));
680        assert!(json.contains("\"id\":\"abc-123\""));
681        assert!(json.contains("\"method\":\"roots/list\""));
682    }
683
684    #[test]
685    fn test_root_deserialization() {
686        let json = r#"{"uri": "file:///workspace", "name": "My Project"}"#;
687        let root: Root = serde_json::from_str(json).unwrap();
688
689        assert_eq!(root.uri, "file:///workspace");
690        assert_eq!(root.name, Some("My Project".to_string()));
691    }
692
693    #[test]
694    fn test_sampling_content() {
695        let content = SamplingContent::Text {
696            text: "Hello, world!".to_string(),
697        };
698
699        let json = serde_json::to_string(&content).unwrap();
700        assert!(json.contains("\"type\":\"text\""));
701        assert!(json.contains("\"text\":\"Hello, world!\""));
702    }
703}