1use std::sync::Arc;
33use std::time::Duration;
34
35use tokio::sync::mpsc;
36
37use dashmap::DashMap;
38use serde::{Deserialize, Serialize};
39use serde_json::Value;
40use thiserror::Error;
41use tokio::sync::oneshot;
42
43use crate::protocol::types::JsonRpcResponse;
44
45#[derive(Debug, Error)]
47pub enum MultiplexerError {
48 #[error("request timed out after {0:?}")]
50 Timeout(Duration),
51
52 #[error("client error {code}: {message}")]
54 ClientError { code: i32, message: String },
55
56 #[error("transport error: {0}")]
58 Transport(String),
59
60 #[error("response channel closed")]
62 ChannelClosed,
63
64 #[error("serialization error: {0}")]
66 Serialization(#[from] serde_json::Error),
67
68 #[error("client does not support {0}")]
70 UnsupportedCapability(String),
71}
72
73pub struct PendingRequest {
75 pub id: String,
77
78 pub method: String,
80
81 pub response_tx: oneshot::Sender<Result<Value, MultiplexerError>>,
83
84 pub created_at: std::time::Instant,
86}
87
88pub struct RequestMultiplexer {
92 pending: DashMap<String, PendingRequest>,
94
95 default_timeout: Duration,
97}
98
99impl Default for RequestMultiplexer {
100 fn default() -> Self {
101 Self::new()
102 }
103}
104
105impl RequestMultiplexer {
106 pub fn new() -> Self {
108 Self {
109 pending: DashMap::new(),
110 default_timeout: Duration::from_secs(30),
111 }
112 }
113
114 pub fn with_timeout(timeout: Duration) -> Self {
116 Self {
117 pending: DashMap::new(),
118 default_timeout: timeout,
119 }
120 }
121
122 pub fn default_timeout(&self) -> Duration {
124 self.default_timeout
125 }
126
127 pub fn create_pending(&self, method: impl Into<String>) -> (String, oneshot::Receiver<Result<Value, MultiplexerError>>) {
131 let id = uuid::Uuid::new_v4().to_string();
132 let method = method.into();
133 let (tx, rx) = oneshot::channel();
134
135 let pending = PendingRequest {
136 id: id.clone(),
137 method,
138 response_tx: tx,
139 created_at: std::time::Instant::now(),
140 };
141
142 self.pending.insert(id.clone(), pending);
143
144 (id, rx)
145 }
146
147 pub fn route_response(&self, response: &JsonRpcResponse) -> bool {
151 let id = match &response.id {
153 Value::String(s) => s.clone(),
154 Value::Number(n) => n.to_string(),
155 _ => return false,
156 };
157
158 if let Some((_, pending)) = self.pending.remove(&id) {
160 let result = if let Some(ref error) = response.error {
162 Err(MultiplexerError::ClientError {
163 code: error.code,
164 message: error.message.clone(),
165 })
166 } else if let Some(ref result) = response.result {
167 Ok(result.clone())
168 } else {
169 Ok(Value::Object(serde_json::Map::new()))
171 };
172
173 let _ = pending.response_tx.send(result);
175
176 true
177 } else {
178 false
179 }
180 }
181
182 pub fn is_pending_response(&self, id: &Value) -> bool {
186 let id_str = match id {
187 Value::String(s) => s.clone(),
188 Value::Number(n) => n.to_string(),
189 _ => return false,
190 };
191
192 self.pending.contains_key(&id_str)
193 }
194
195 pub fn pending_count(&self) -> usize {
197 self.pending.len()
198 }
199
200 pub fn cancel(&self, id: &str) {
202 if let Some((_, pending)) = self.pending.remove(id) {
203 let _ = pending.response_tx.send(Err(MultiplexerError::ChannelClosed));
204 }
205 }
206
207 pub fn cancel_all(&self) {
209 let ids: Vec<String> = self.pending.iter().map(|e| e.key().clone()).collect();
210 for id in ids {
211 self.cancel(&id);
212 }
213 }
214
215 pub fn cleanup_timed_out(&self, timeout: Duration) -> usize {
219 let now = std::time::Instant::now();
220 let mut cleaned = 0;
221
222 let timed_out: Vec<String> = self
224 .pending
225 .iter()
226 .filter(|e| now.duration_since(e.created_at) > timeout)
227 .map(|e| e.key().clone())
228 .collect();
229
230 for id in timed_out {
231 if let Some((_, pending)) = self.pending.remove(&id) {
232 let _ = pending
233 .response_tx
234 .send(Err(MultiplexerError::Timeout(timeout)));
235 cleaned += 1;
236 }
237 }
238
239 cleaned
240 }
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct JsonRpcClientRequest {
246 pub jsonrpc: String,
248
249 pub id: String,
251
252 pub method: String,
254
255 #[serde(skip_serializing_if = "Option::is_none")]
257 pub params: Option<Value>,
258}
259
260impl JsonRpcClientRequest {
261 pub fn new(id: impl Into<String>, method: impl Into<String>, params: Option<Value>) -> Self {
263 Self {
264 jsonrpc: "2.0".to_string(),
265 id: id.into(),
266 method: method.into(),
267 params,
268 }
269 }
270}
271
272#[derive(Debug, Clone, Serialize, Deserialize)]
274pub struct Root {
275 pub uri: String,
277
278 #[serde(skip_serializing_if = "Option::is_none")]
280 pub name: Option<String>,
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct ListRootsResult {
286 pub roots: Vec<Root>,
288}
289
290#[derive(Debug, Clone, Serialize, Deserialize)]
292pub struct SamplingMessage {
293 pub role: String,
295
296 pub content: SamplingContent,
298}
299
300#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(tag = "type")]
303pub enum SamplingContent {
304 #[serde(rename = "text")]
306 Text { text: String },
307
308 #[serde(rename = "image")]
310 Image { data: String, mime_type: String },
311}
312
313#[derive(Debug, Clone, Default, Serialize, Deserialize)]
315#[serde(rename_all = "camelCase")]
316pub struct ModelPreferences {
317 #[serde(skip_serializing_if = "Option::is_none")]
319 pub hints: Option<Vec<ModelHint>>,
320
321 #[serde(skip_serializing_if = "Option::is_none")]
323 pub cost_priority: Option<f64>,
324
325 #[serde(skip_serializing_if = "Option::is_none")]
327 pub speed_priority: Option<f64>,
328
329 #[serde(skip_serializing_if = "Option::is_none")]
331 pub intelligence_priority: Option<f64>,
332}
333
334#[derive(Debug, Clone, Serialize, Deserialize)]
336pub struct ModelHint {
337 #[serde(skip_serializing_if = "Option::is_none")]
339 pub name: Option<String>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
344#[serde(rename_all = "camelCase")]
345pub struct CreateMessageParams {
346 pub messages: Vec<SamplingMessage>,
348
349 #[serde(skip_serializing_if = "Option::is_none")]
351 pub model_preferences: Option<ModelPreferences>,
352
353 #[serde(skip_serializing_if = "Option::is_none")]
355 pub system_prompt: Option<String>,
356
357 #[serde(skip_serializing_if = "Option::is_none")]
359 pub include_context: Option<String>,
360
361 #[serde(skip_serializing_if = "Option::is_none")]
363 pub temperature: Option<f64>,
364
365 pub max_tokens: i32,
367
368 #[serde(skip_serializing_if = "Option::is_none")]
370 pub stop_sequences: Option<Vec<String>>,
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize)]
375#[serde(rename_all = "camelCase")]
376pub struct CreateMessageResult {
377 pub role: String,
379
380 pub content: SamplingContent,
382
383 pub model: String,
385
386 #[serde(skip_serializing_if = "Option::is_none")]
388 pub stop_reason: Option<String>,
389}
390
391#[derive(Clone)]
416pub struct ClientRequester {
417 request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
419
420 multiplexer: Arc<RequestMultiplexer>,
422
423 supports_roots: bool,
425
426 supports_sampling: bool,
428}
429
430impl ClientRequester {
431 pub fn new(
433 request_tx: mpsc::UnboundedSender<JsonRpcClientRequest>,
434 multiplexer: Arc<RequestMultiplexer>,
435 supports_roots: bool,
436 supports_sampling: bool,
437 ) -> Self {
438 Self {
439 request_tx,
440 multiplexer,
441 supports_roots,
442 supports_sampling,
443 }
444 }
445
446 pub fn supports_roots(&self) -> bool {
448 self.supports_roots
449 }
450
451 pub fn supports_sampling(&self) -> bool {
453 self.supports_sampling
454 }
455
456 pub async fn request_roots(
460 &self,
461 timeout: Option<Duration>,
462 ) -> Result<Vec<Root>, MultiplexerError> {
463 if !self.supports_roots {
464 return Err(MultiplexerError::UnsupportedCapability("roots".to_string()));
465 }
466
467 let (id, rx) = self.multiplexer.create_pending("roots/list");
469
470 let request = JsonRpcClientRequest::new(&id, "roots/list", Some(serde_json::json!({})));
472
473 self.request_tx
474 .send(request)
475 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
476
477 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
479 let result = tokio::time::timeout(timeout, rx)
480 .await
481 .map_err(|_| MultiplexerError::Timeout(timeout))?
482 .map_err(|_| MultiplexerError::ChannelClosed)??;
483
484 let list_result: ListRootsResult = serde_json::from_value(result)?;
486 Ok(list_result.roots)
487 }
488
489 pub async fn request_sampling(
493 &self,
494 params: CreateMessageParams,
495 timeout: Option<Duration>,
496 ) -> Result<CreateMessageResult, MultiplexerError> {
497 if !self.supports_sampling {
498 return Err(MultiplexerError::UnsupportedCapability("sampling".to_string()));
499 }
500
501 let (id, rx) = self.multiplexer.create_pending("sampling/createMessage");
503
504 let params_value = serde_json::to_value(¶ms)?;
506 let request = JsonRpcClientRequest::new(&id, "sampling/createMessage", Some(params_value));
507
508 self.request_tx
509 .send(request)
510 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
511
512 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
514 let result = tokio::time::timeout(timeout, rx)
515 .await
516 .map_err(|_| MultiplexerError::Timeout(timeout))?
517 .map_err(|_| MultiplexerError::ChannelClosed)??;
518
519 let create_result: CreateMessageResult = serde_json::from_value(result)?;
521 Ok(create_result)
522 }
523
524 pub async fn request_elicitation(
542 &self,
543 message: String,
544 requested_schema: Value,
545 timeout: Option<Duration>,
546 ) -> Result<crate::protocol::types::CreateElicitationResult, MultiplexerError> {
547 let (id, rx) = self.multiplexer.create_pending("elicitation/create");
552
553 let params = serde_json::json!({
555 "message": message,
556 "requestedSchema": requested_schema,
557 });
558
559 let request = JsonRpcClientRequest::new(&id, "elicitation/create", Some(params));
560
561 self.request_tx
562 .send(request)
563 .map_err(|e| MultiplexerError::Transport(e.to_string()))?;
564
565 let timeout = timeout.unwrap_or(self.multiplexer.default_timeout());
567 let result = tokio::time::timeout(timeout, rx)
568 .await
569 .map_err(|_| MultiplexerError::Timeout(timeout))?
570 .map_err(|_| MultiplexerError::ChannelClosed)??;
571
572 let elicitation_result: crate::protocol::types::CreateElicitationResult =
574 serde_json::from_value(result)?;
575 Ok(elicitation_result)
576 }
577}
578
579#[cfg(test)]
580mod tests {
581 use super::*;
582
583 #[test]
584 fn test_multiplexer_create_pending() {
585 let mux = RequestMultiplexer::new();
586
587 let (id1, _rx1) = mux.create_pending("roots/list");
588 let (id2, _rx2) = mux.create_pending("sampling/createMessage");
589
590 assert_ne!(id1, id2);
591 assert_eq!(mux.pending_count(), 2);
592 }
593
594 #[tokio::test]
595 async fn test_multiplexer_route_response() {
596 let mux = RequestMultiplexer::new();
597
598 let (id, rx) = mux.create_pending("test/method");
599
600 let response = JsonRpcResponse {
602 jsonrpc: "2.0".to_string(),
603 id: Value::String(id.clone()),
604 result: Some(serde_json::json!({"status": "ok"})),
605 error: None,
606 };
607
608 assert!(mux.route_response(&response));
610 assert_eq!(mux.pending_count(), 0);
611
612 let result = rx.await.unwrap().unwrap();
614 assert_eq!(result["status"], "ok");
615 }
616
617 #[tokio::test]
618 async fn test_multiplexer_route_error() {
619 let mux = RequestMultiplexer::new();
620
621 let (id, rx) = mux.create_pending("test/method");
622
623 let response = JsonRpcResponse {
625 jsonrpc: "2.0".to_string(),
626 id: Value::String(id.clone()),
627 result: None,
628 error: Some(crate::protocol::types::JsonRpcError {
629 code: -32600,
630 message: "Invalid request".to_string(),
631 data: None,
632 }),
633 };
634
635 assert!(mux.route_response(&response));
637
638 let result = rx.await.unwrap();
640 assert!(matches!(result, Err(MultiplexerError::ClientError { code: -32600, .. })));
641 }
642
643 #[test]
644 fn test_multiplexer_is_pending() {
645 let mux = RequestMultiplexer::new();
646
647 let (id, _rx) = mux.create_pending("test");
648
649 assert!(mux.is_pending_response(&Value::String(id.clone())));
650 assert!(!mux.is_pending_response(&Value::String("unknown".to_string())));
651 }
652
653 #[test]
654 fn test_multiplexer_cancel() {
655 let mux = RequestMultiplexer::new();
656
657 let (id, _rx) = mux.create_pending("test");
658 assert_eq!(mux.pending_count(), 1);
659
660 mux.cancel(&id);
661 assert_eq!(mux.pending_count(), 0);
662 }
663
664 #[test]
665 fn test_client_request_serialization() {
666 let req = JsonRpcClientRequest::new(
667 "abc-123",
668 "roots/list",
669 Some(serde_json::json!({})),
670 );
671
672 let json = serde_json::to_string(&req).unwrap();
673 assert!(json.contains("\"jsonrpc\":\"2.0\""));
674 assert!(json.contains("\"id\":\"abc-123\""));
675 assert!(json.contains("\"method\":\"roots/list\""));
676 }
677
678 #[test]
679 fn test_root_deserialization() {
680 let json = r#"{"uri": "file:///workspace", "name": "My Project"}"#;
681 let root: Root = serde_json::from_str(json).unwrap();
682
683 assert_eq!(root.uri, "file:///workspace");
684 assert_eq!(root.name, Some("My Project".to_string()));
685 }
686
687 #[test]
688 fn test_sampling_content() {
689 let content = SamplingContent::Text {
690 text: "Hello, world!".to_string(),
691 };
692
693 let json = serde_json::to_string(&content).unwrap();
694 assert!(json.contains("\"type\":\"text\""));
695 assert!(json.contains("\"text\":\"Hello, world!\""));
696 }
697}