1use std::sync::Arc;
33use std::time::Duration;
34
35use tokio::sync::mpsc;
36
37use dashmap::DashMap;
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use thiserror::Error;
41use tokio::sync::oneshot;
42
43use crate::protocol::types::JsonRpcResponse;
44
45#[derive(Debug, Error)]
47pub enum MultiplexerError {
48 #[error("request timed out after {0:?}")]
50 Timeout(Duration),
51
52 #[error("client error {code}: {message}")]
54 ClientError { code: i32, message: String },
55
56 #[error("transport error: {0}")]
58 Transport(String),
59
60 #[error("response channel closed")]
62 ChannelClosed,
63
64 #[error("serialization error: {0}")]
66 Serialization(#[from] serde_json::Error),
67
68 #[error("client does not support {0}")]
70 UnsupportedCapability(String),
71}
72
73pub struct PendingRequest {
75 pub id: String,
77
78 pub method: String,
80
81 pub response_tx: oneshot::Sender<Result<Value, MultiplexerError>>,
83
84 pub created_at: std::time::Instant,
86}
87
88pub struct RequestMultiplexer {
92 pending: DashMap<String, PendingRequest>,
94
95 default_timeout: Duration,
97}
98
99impl Default for RequestMultiplexer {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl RequestMultiplexer {
106 pub fn new() -> Self {
108 Self {
109 pending: DashMap::new(),
110 default_timeout: Duration::from_secs(30),
111 }
112 }
113
114 pub fn with_timeout(timeout: Duration) -> Self {
116 Self {
117 pending: DashMap::new(),
118 default_timeout: timeout,
119 }
120 }
121
122 pub fn default_timeout(&self) -> Duration {
124 self.default_timeout
125 }
126
127 pub fn create_pending(
131 &self,
132 method: impl Into<String>,
133 ) -> (String, oneshot::Receiver<Result<Value, MultiplexerError>>) {
134 let id = uuid::Uuid::new_v4().to_string();
135 let method = method.into();
136 let (tx, rx) = oneshot::channel();
137
138 let pending = PendingRequest {
139 id: id.clone(),
140 method,
141 response_tx: tx,
142 created_at: std::time::Instant::now(),
143 };
144
145 self.pending.insert(id.clone(), pending);
146
147 (id, rx)
148 }
149
150 pub fn route_response(&self, response: &JsonRpcResponse) -> bool {
154 let id = match &response.id {
156 Value::String(s) => s.clone(),
157 Value::Number(n) => n.to_string(),
158 _ => return false,
159 };
160
161 if let Some((_, pending)) = self.pending.remove(&id) {
163 let result = if let Some(ref error) = response.error {
165 Err(MultiplexerError::ClientError {
166 code: error.code,
167 message: error.message.clone(),
168 })
169 } else if let Some(ref result) = response.result {
170 Ok(result.clone())
171 } else {
172 Ok(Value::Object(serde_json::Map::new()))
174 };
175
176 let _ = pending.response_tx.send(result);
178
179 true
180 } else {
181 false
182 }
183 }
184
185 pub fn is_pending_response(&self, id: &Value) -> bool {
189 let id_str = match id {
190 Value::String(s) => s.clone(),
191 Value::Number(n) => n.to_string(),
192 _ => return false,
193 };
194
195 self.pending.contains_key(&id_str)
196 }
197
198 pub fn pending_count(&self) -> usize {
200 self.pending.len()
201 }
202
203 pub fn cancel(&self, id: &str) {
205 if let Some((_, pending)) = self.pending.remove(id) {
206 let _ = pending
207 .response_tx
208 .send(Err(MultiplexerError::ChannelClosed));
209 }
210 }
211
212 pub fn cancel_all(&self) {
214 let ids: Vec<String> = self.pending.iter().map(|e| e.key().clone()).collect();
215 for id in ids {
216 self.cancel(&id);
217 }
218 }
219
220 pub fn cleanup_timed_out(&self, timeout: Duration) -> usize {
224 let now = std::time::Instant::now();
225 let mut cleaned = 0;
226
227 let timed_out: Vec<String> = self
229 .pending
230 .iter()
231 .filter(|e| now.duration_since(e.created_at) > timeout)
232 .map(|e| e.key().clone())
233 .collect();
234
235 for id in timed_out {
236 if let Some((_, pending)) = self.pending.remove(&id) {
237 let _ = pending
238 .response_tx
239 .send(Err(MultiplexerError::Timeout(timeout)));
240 cleaned += 1;
241 }
242 }
243
244 cleaned
245 }
246}
247
248#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct JsonRpcClientRequest {
251 pub jsonrpc: String,
253
254 pub id: String,
256
257 pub method: String,
259
260 #[serde(skip_serializing_if = "Option::is_none")]
262 pub params: Option<Value>,
263}
264
265impl JsonRpcClientRequest {
266 pub fn new(id: impl Into<String>, method: impl Into<String>, params: Option<Value>) -> Self {
268 Self {
269 jsonrpc: "2.0".to_string(),
270 id: id.into(),
271 method: method.into(),
272 params,
273 }
274 }
275}
276
277#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct Root {
280 pub uri: String,
282
283 #[serde(skip_serializing_if = "Option::is_none")]
285 pub name: Option<String>,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct ListRootsResult {
291 pub roots: Vec<Root>,
293}
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct SamplingMessage {
298 pub role: String,
300
301 pub content: SamplingContent,
303}
304
305#[derive(Debug, Clone, Serialize, Deserialize)]
307#[serde(tag = "type")]
308pub enum SamplingContent {
309 #[serde(rename = "text")]
311 Text { text: String },
312
313 #[serde(rename = "image")]
315 Image { data: String, mime_type: String },
316}
317
318#[derive(Debug, Clone, Default, Serialize, Deserialize)]
320#[serde(rename_all = "camelCase")]
321pub struct ModelPreferences {
322 #[serde(skip_serializing_if = "Option::is_none")]
324 pub hints: Option<Vec<ModelHint>>,
325
326 #[serde(skip_serializing_if = "Option::is_none")]
328 pub cost_priority: Option<f64>,
329
330 #[serde(skip_serializing_if = "Option::is_none")]
332 pub speed_priority: Option<f64>,
333
334 #[serde(skip_serializing_if = "Option::is_none")]
336 pub intelligence_priority: Option<f64>,
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
341pub struct ModelHint {
342 #[serde(skip_serializing_if = "Option::is_none")]
344 pub name: Option<String>,
345}
346
347#[derive(Debug, Clone, Serialize, Deserialize)]
349#[serde(rename_all = "camelCase")]
350pub struct CreateMessageParams {
351 pub messages: Vec<SamplingMessage>,
353
354 #[serde(skip_serializing_if = "Option::is_none")]
356 pub model_preferences: Option<ModelPreferences>,
357
358 #[serde(skip_serializing_if = "Option::is_none")]
360 pub system_prompt: Option<String>,
361
362 #[serde(skip_serializing_if = "Option::is_none")]
364 pub include_context: Option<String>,
365
366 #[serde(skip_serializing_if = "Option::is_none")]
368 pub temperature: Option<f64>,
369
370 pub max_tokens: i32,
372
373 #[serde(skip_serializing_if = "Option::is_none")]
375 pub stop_sequences: Option<Vec<String>>,
376}
377
378#[derive(Debug, Clone, Serialize, Deserialize)]
380#[serde(rename_all = "camelCase")]
381pub struct CreateMessageResult {
382 pub role: String,
384
385 pub content: SamplingContent,
387
388 pub model: String,
390
391 #[serde(skip_serializing_if = "Option::is_none")]
393 pub stop_reason: Option<String>,
394}
395
396#[derive(Clone)]
421pub struct ClientRequester {
422 request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
424
425 multiplexer: Arc<RequestMultiplexer>,
427
428 supports_roots: bool,
430
431 supports_sampling: bool,
433}
434
435impl ClientRequester {
436 pub fn new(
438 request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
439 multiplexer: Arc<RequestMultiplexer>,
440 supports_roots: bool,
441 supports_sampling: bool,
442 ) -> Self {
443 Self {
444 request_tx,
445 multiplexer,
446 supports_roots,
447 supports_sampling,
448 }
449 }
450
451 pub fn supports_roots(&self) -> bool {
453 self.supports_roots
454 }
455
456 pub fn supports_sampling(&self) -> bool {
458 self.supports_sampling
459 }
460
461 pub async fn request_roots(
465 &self,
466 timeout: Option<Duration>,
467 ) -> Result<Vec<Root>, MultiplexerError> {
468 if !self.supports_roots {
469 return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
470 }
471
472 let (id, rx) = self.multiplexer.create_pending("roots/list");
474
475 let request = JsonRpcClientRequest::new(&id, "roots/list", Some(serde_json::json!({})));
477
478 self.request_tx
479 .send(request)
480 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
481
482 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
484 let result = tokio::time::timeout(timeout, rx)
485 .await
486 .map_err(|_| MultiplexerError::Timeout(timeout))?
487 .map_err(|_| MultiplexerError::ChannelClosed)??;
488
489 let list_result: ListRootsResult = serde_json::from_value(result)?;
491 Ok(list_result.roots)
492 }
493
494 pub async fn request_sampling(
498 &self,
499 params: CreateMessageParams,
500 timeout: Option<Duration>,
501 ) -> Result<CreateMessageResult, MultiplexerError> {
502 if !self.supports_sampling {
503 return Err(MultiplexerError::UnsupportedCapability(
504 "sampling".to_string(),
505 ));
506 }
507
508 let (id, rx) = self.multiplexer.create_pending("sampling/createMessage");
510
511 let params_value = serde_json::to_value(¶ms)?;
513 let request = JsonRpcClientRequest::new(&id, "sampling/createMessage", Some(params_value));
514
515 self.request_tx
516 .send(request)
517 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
518
519 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
521 let result = tokio::time::timeout(timeout, rx)
522 .await
523 .map_err(|_| MultiplexerError::Timeout(timeout))?
524 .map_err(|_| MultiplexerError::ChannelClosed)??;
525
526 let create_result: CreateMessageResult = serde_json::from_value(result)?;
528 Ok(create_result)
529 }
530
531 pub async fn request_elicitation(
549 &self,
550 message: String,
551 requested_schema: Value,
552 timeout: Option<Duration>,
553 ) -> Result<crate::protocol::types::CreateElicitationResult, MultiplexerError> {
554 let (id, rx) = self.multiplexer.create_pending("elicitation/create");
559
560 let params = serde_json::json!({
562 "message": message,
563 "requestedSchema": requested_schema,
564 });
565
566 let request = JsonRpcClientRequest::new(&id, "elicitation/create", Some(params));
567
568 self.request_tx
569 .send(request)
570 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
571
572 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
574 let result = tokio::time::timeout(timeout, rx)
575 .await
576 .map_err(|_| MultiplexerError::Timeout(timeout))?
577 .map_err(|_| MultiplexerError::ChannelClosed)??;
578
579 let elicitation_result: crate::protocol::types::CreateElicitationResult =
581 serde_json::from_value(result)?;
582 Ok(elicitation_result)
583 }
584}
585
586#[cfg(test)]
587mod tests {
588 use super::*;
589
590 #[test]
591 fn test_multiplexer_create_pending() {
592 let mux = RequestMultiplexer::new();
593
594 let (id1, _rx1) = mux.create_pending("roots/list");
595 let (id2, _rx2) = mux.create_pending("sampling/createMessage");
596
597 assert_ne!(id1, id2);
598 assert_eq!(mux.pending_count(), 2);
599 }
600
601 #[tokio::test]
602 async fn test_multiplexer_route_response() {
603 let mux = RequestMultiplexer::new();
604
605 let (id, rx) = mux.create_pending("test/method");
606
607 let response = JsonRpcResponse {
609 jsonrpc: "2.0".to_string(),
610 id: Value::String(id.clone()),
611 result: Some(serde_json::json!({"status": "ok"})),
612 error: None,
613 };
614
615 assert!(mux.route_response(&response));
617 assert_eq!(mux.pending_count(), 0);
618
619 let result = rx.await.unwrap().unwrap();
621 assert_eq!(result["status"], "ok");
622 }
623
624 #[tokio::test]
625 async fn test_multiplexer_route_error() {
626 let mux = RequestMultiplexer::new();
627
628 let (id, rx) = mux.create_pending("test/method");
629
630 let response = JsonRpcResponse {
632 jsonrpc: "2.0".to_string(),
633 id: Value::String(id.clone()),
634 result: None,
635 error: Some(crate::protocol::types::JsonRpcError {
636 code: -32600,
637 message: "Invalid request".to_string(),
638 data: None,
639 }),
640 };
641
642 assert!(mux.route_response(&response));
644
645 let result = rx.await.unwrap();
647 assert!(matches!(
648 result,
649 Err(MultiplexerError::ClientError { code: -32600, .. })
650 ));
651 }
652
653 #[test]
654 fn test_multiplexer_is_pending() {
655 let mux = RequestMultiplexer::new();
656
657 let (id, _rx) = mux.create_pending("test");
658
659 assert!(mux.is_pending_response(&Value::String(id.clone())));
660 assert!(!mux.is_pending_response(&Value::String("unknown".to_string())));
661 }
662
663 #[test]
664 fn test_multiplexer_cancel() {
665 let mux = RequestMultiplexer::new();
666
667 let (id, _rx) = mux.create_pending("test");
668 assert_eq!(mux.pending_count(), 1);
669
670 mux.cancel(&id);
671 assert_eq!(mux.pending_count(), 0);
672 }
673
674 #[test]
675 fn test_client_request_serialization() {
676 let req = JsonRpcClientRequest::new("abc-123", "roots/list", Some(serde_json::json!({})));
677
678 let json = serde_json::to_string(&req).unwrap();
679 assert!(json.contains("\"jsonrpc\":\"2.0\""));
680 assert!(json.contains("\"id\":\"abc-123\""));
681 assert!(json.contains("\"method\":\"roots/list\""));
682 }
683
684 #[test]
685 fn test_root_deserialization() {
686 let json = r#"{"uri": "file:///workspace", "name": "My Project"}"#;
687 let root: Root = serde_json::from_str(json).unwrap();
688
689 assert_eq!(root.uri, "file:///workspace");
690 assert_eq!(root.name, Some("My Project".to_string()));
691 }
692
693 #[test]
694 fn test_sampling_content() {
695 let content = SamplingContent::Text {
696 text: "Hello, world!".to_string(),
697 };
698
699 let json = serde_json::to_string(&content).unwrap();
700 assert!(json.contains("\"type\":\"text\""));
701 assert!(json.contains("\"text\":\"Hello, world!\""));
702 }
703}