1use std::collections::HashMap;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::sync::{Arc, Mutex};
39use std::time::Duration;
40
41use asupersync::Cx;
42use fastmcp_core::{
43 ElicitationAction, ElicitationMode, ElicitationRequest, ElicitationResponse, ElicitationSender,
44 McpError, McpErrorCode, McpResult, SamplingRequest, SamplingResponse, SamplingRole,
45 SamplingSender, SamplingStopReason,
46};
47use fastmcp_protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, RequestId};
48
49type ResponseSender = std::sync::mpsc::Sender<Result<serde_json::Value, JsonRpcError>>;
55type ResponseReceiver = std::sync::mpsc::Receiver<Result<serde_json::Value, JsonRpcError>>;
56
57#[derive(Debug)]
62pub struct PendingRequests {
63 pending: Mutex<HashMap<RequestId, ResponseSender>>,
65 next_id: AtomicU64,
67}
68
69impl PendingRequests {
70 fn lock_pending(&self) -> std::sync::MutexGuard<'_, HashMap<RequestId, ResponseSender>> {
71 match self.pending.lock() {
72 Ok(guard) => guard,
73 Err(poisoned) => poisoned.into_inner(),
75 }
76 }
77
78 #[must_use]
80 pub fn new() -> Self {
81 Self {
82 pending: Mutex::new(HashMap::new()),
83 next_id: AtomicU64::new(1_000_000),
85 }
86 }
87
88 #[allow(clippy::cast_possible_wrap)]
90 pub fn next_request_id(&self) -> RequestId {
91 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
92 RequestId::Number(id as i64)
93 }
94
95 pub fn register(&self, id: RequestId) -> ResponseReceiver {
97 let (tx, rx) = std::sync::mpsc::channel();
98 let mut pending = self.lock_pending();
99 pending.insert(id, tx);
100 rx
101 }
102
103 pub fn route_response(&self, response: &JsonRpcResponse) -> bool {
107 let Some(ref id) = response.id else {
108 return false;
109 };
110
111 let sender = {
112 let mut pending = self.lock_pending();
113 pending.remove(id)
114 };
115
116 if let Some(sender) = sender {
117 let result = if let Some(ref error) = response.error {
118 Err(error.clone())
119 } else {
120 Ok(response.result.clone().unwrap_or(serde_json::Value::Null))
121 };
122 let _ = sender.send(result);
124 true
125 } else {
126 false
127 }
128 }
129
130 pub fn remove(&self, id: &RequestId) {
132 let mut pending = self.lock_pending();
133 pending.remove(id);
134 }
135
136 pub fn cancel_all(&self) {
138 let mut pending = self.lock_pending();
139 for (_, sender) in pending.drain() {
140 let _ = sender.send(Err(JsonRpcError {
141 code: McpErrorCode::InternalError.into(),
142 message: "Connection closed".to_string(),
143 data: None,
144 }));
145 }
146 }
147}
148
149impl Default for PendingRequests {
150 fn default() -> Self {
151 Self::new()
152 }
153}
154
155pub type TransportSendFn = Arc<dyn Fn(&JsonRpcMessage) -> Result<(), String> + Send + Sync>;
161
162#[derive(Clone)]
167pub struct RequestSender {
168 pending: Arc<PendingRequests>,
170 send_fn: TransportSendFn,
172}
173
174impl RequestSender {
175 pub fn new(pending: Arc<PendingRequests>, send_fn: TransportSendFn) -> Self {
177 Self { pending, send_fn }
178 }
179
180 pub fn send_request<T: serde::de::DeserializeOwned>(
190 &self,
191 cx: &Cx,
192 method: &str,
193 params: serde_json::Value,
194 ) -> McpResult<T> {
195 let id = self.pending.next_request_id();
196 let receiver = self.pending.register(id.clone());
197
198 let request = JsonRpcRequest::new(method.to_string(), Some(params), id.clone());
199 let message = JsonRpcMessage::Request(request);
200
201 if let Err(e) = (self.send_fn)(&message) {
203 self.pending.remove(&id);
204 return Err(McpError::internal_error(format!(
205 "Failed to send request: {}",
206 e
207 )));
208 }
209
210 let tick = Duration::from_millis(25);
215 loop {
216 if cx.checkpoint().is_err() {
217 self.pending.remove(&id);
218 return Err(McpError::request_cancelled());
219 }
220
221 match receiver.recv_timeout(tick) {
222 Ok(Ok(value)) => {
223 return serde_json::from_value(value).map_err(|e| {
224 McpError::internal_error(format!("Failed to parse response: {e}"))
225 });
226 }
227 Ok(Err(error)) => {
228 return Err(McpError::new(McpErrorCode::from(error.code), error.message));
229 }
230 Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
231 }
233 Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
234 return Err(McpError::internal_error(
235 "Response channel closed unexpectedly",
236 ));
237 }
238 }
239 }
240 }
241}
242
243impl std::fmt::Debug for RequestSender {
244 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
245 f.debug_struct("RequestSender")
246 .field("pending", &self.pending)
247 .finish_non_exhaustive()
248 }
249}
250
251#[derive(Clone)]
257pub struct TransportSamplingSender {
258 sender: RequestSender,
259}
260
261impl TransportSamplingSender {
262 pub fn new(sender: RequestSender) -> Self {
264 Self { sender }
265 }
266}
267
268impl SamplingSender for TransportSamplingSender {
269 fn create_message(
270 &self,
271 request: SamplingRequest,
272 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<SamplingResponse>> + Send + '_>>
273 {
274 Box::pin(async move {
275 let params = fastmcp_protocol::CreateMessageParams {
277 messages: request
278 .messages
279 .into_iter()
280 .map(|m| fastmcp_protocol::SamplingMessage {
281 role: match m.role {
282 SamplingRole::User => fastmcp_protocol::Role::User,
283 SamplingRole::Assistant => fastmcp_protocol::Role::Assistant,
284 },
285 content: fastmcp_protocol::SamplingContent::Text { text: m.text },
286 })
287 .collect(),
288 max_tokens: request.max_tokens,
289 system_prompt: request.system_prompt,
290 temperature: request.temperature,
291 stop_sequences: request.stop_sequences,
292 model_preferences: if request.model_hints.is_empty() {
293 None
294 } else {
295 Some(fastmcp_protocol::ModelPreferences {
296 hints: request
297 .model_hints
298 .into_iter()
299 .map(|name| fastmcp_protocol::ModelHint { name: Some(name) })
300 .collect(),
301 ..Default::default()
302 })
303 },
304 include_context: None,
305 meta: None,
306 };
307
308 let params_value = serde_json::to_value(¶ms)
309 .map_err(|e| McpError::internal_error(format!("Failed to serialize: {}", e)))?;
310
311 let cx = Cx::for_request();
313
314 let result: fastmcp_protocol::CreateMessageResult =
315 self.sender
316 .send_request(&cx, "sampling/createMessage", params_value)?;
317
318 Ok(SamplingResponse {
319 text: match result.content {
320 fastmcp_protocol::SamplingContent::Text { text } => text,
321 fastmcp_protocol::SamplingContent::Image { data, mime_type } => {
322 format!("[image: {} bytes, type: {}]", data.len(), mime_type)
323 }
324 },
325 model: result.model,
326 stop_reason: match result.stop_reason {
327 fastmcp_protocol::StopReason::EndTurn => SamplingStopReason::EndTurn,
328 fastmcp_protocol::StopReason::StopSequence => SamplingStopReason::StopSequence,
329 fastmcp_protocol::StopReason::MaxTokens => SamplingStopReason::MaxTokens,
330 },
331 })
332 })
333 }
334}
335
336#[derive(Clone)]
342pub struct TransportElicitationSender {
343 sender: RequestSender,
344}
345
346impl TransportElicitationSender {
347 pub fn new(sender: RequestSender) -> Self {
349 Self { sender }
350 }
351}
352
353impl ElicitationSender for TransportElicitationSender {
354 fn elicit(
355 &self,
356 request: ElicitationRequest,
357 ) -> std::pin::Pin<
358 Box<dyn std::future::Future<Output = McpResult<ElicitationResponse>> + Send + '_>,
359 > {
360 Box::pin(async move {
361 let params_value = match request.mode {
362 ElicitationMode::Form => {
363 let params = fastmcp_protocol::ElicitRequestFormParams {
364 mode: fastmcp_protocol::ElicitMode::Form,
365 message: request.message.clone(),
366 requested_schema: request.schema.unwrap_or(serde_json::json!({})),
367 };
368 serde_json::to_value(¶ms).map_err(|e| {
369 McpError::internal_error(format!("Failed to serialize: {}", e))
370 })?
371 }
372 ElicitationMode::Url => {
373 let params = fastmcp_protocol::ElicitRequestUrlParams {
374 mode: fastmcp_protocol::ElicitMode::Url,
375 message: request.message.clone(),
376 url: request.url.unwrap_or_default(),
377 elicitation_id: request.elicitation_id.unwrap_or_default(),
378 };
379 serde_json::to_value(¶ms).map_err(|e| {
380 McpError::internal_error(format!("Failed to serialize: {}", e))
381 })?
382 }
383 };
384
385 let cx = Cx::for_request();
387
388 let result: fastmcp_protocol::ElicitResult =
389 self.sender
390 .send_request(&cx, "elicitation/elicit", params_value)?;
391
392 let content = result.content.map(|content_map| {
394 let mut map = std::collections::HashMap::new();
395 for (key, value) in content_map {
396 let json_value = match value {
397 fastmcp_protocol::ElicitContentValue::Null => serde_json::Value::Null,
398 fastmcp_protocol::ElicitContentValue::Bool(b) => serde_json::Value::Bool(b),
399 fastmcp_protocol::ElicitContentValue::Int(i) => {
400 serde_json::Value::Number(i.into())
401 }
402 fastmcp_protocol::ElicitContentValue::Float(f) => {
403 serde_json::Number::from_f64(f)
404 .map(serde_json::Value::Number)
405 .unwrap_or(serde_json::Value::Null)
406 }
407 fastmcp_protocol::ElicitContentValue::String(s) => {
408 serde_json::Value::String(s)
409 }
410 fastmcp_protocol::ElicitContentValue::StringArray(arr) => {
411 serde_json::Value::Array(
412 arr.into_iter().map(serde_json::Value::String).collect(),
413 )
414 }
415 };
416 map.insert(key, json_value);
417 }
418 map
419 });
420
421 Ok(ElicitationResponse {
422 action: match result.action {
423 fastmcp_protocol::ElicitAction::Accept => ElicitationAction::Accept,
424 fastmcp_protocol::ElicitAction::Decline => ElicitationAction::Decline,
425 fastmcp_protocol::ElicitAction::Cancel => ElicitationAction::Cancel,
426 },
427 content,
428 })
429 })
430 }
431}
432
433#[derive(Clone)]
439pub struct TransportRootsProvider {
440 sender: RequestSender,
441}
442
443impl TransportRootsProvider {
444 pub fn new(sender: RequestSender) -> Self {
446 Self { sender }
447 }
448
449 pub fn list_roots(&self) -> McpResult<Vec<fastmcp_protocol::Root>> {
451 let cx = Cx::for_request();
452 let result: fastmcp_protocol::ListRootsResult =
453 self.sender
454 .send_request(&cx, "roots/list", serde_json::json!({}))?;
455 Ok(result.roots)
456 }
457}
458
459#[cfg(test)]
464mod tests {
465 use super::*;
466
467 #[test]
468 fn test_pending_requests_register_and_route() {
469 let pending = PendingRequests::new();
470
471 let id = pending.next_request_id();
473 let receiver = pending.register(id.clone());
474
475 let response = JsonRpcResponse::success(id, serde_json::json!({"result": "ok"}));
477 assert!(pending.route_response(&response));
478
479 let result = receiver.recv().unwrap();
481 assert!(result.is_ok());
482 assert_eq!(result.unwrap(), serde_json::json!({"result": "ok"}));
483 }
484
485 #[test]
486 fn test_pending_requests_error_response() {
487 let pending = PendingRequests::new();
488
489 let id = pending.next_request_id();
490 let receiver = pending.register(id.clone());
491
492 let response = JsonRpcResponse::error(
494 Some(id),
495 JsonRpcError {
496 code: -32600,
497 message: "Invalid request".to_string(),
498 data: None,
499 },
500 );
501 assert!(pending.route_response(&response));
502
503 let result = receiver.recv().unwrap();
505 assert!(result.is_err());
506 assert_eq!(result.unwrap_err().message, "Invalid request");
507 }
508
509 #[test]
510 fn test_pending_requests_cancel_all() {
511 let pending = PendingRequests::new();
512
513 let id1 = pending.next_request_id();
514 let id2 = pending.next_request_id();
515 let receiver1 = pending.register(id1);
516 let receiver2 = pending.register(id2);
517
518 pending.cancel_all();
520
521 let result1 = receiver1.recv().unwrap();
523 let result2 = receiver2.recv().unwrap();
524 assert!(result1.is_err());
525 assert!(result2.is_err());
526 }
527
528 #[test]
529 fn test_route_unknown_response() {
530 let pending = PendingRequests::new();
531
532 let response = JsonRpcResponse::success(
534 RequestId::Number(999999),
535 serde_json::json!({"result": "ok"}),
536 );
537 assert!(!pending.route_response(&response));
538 }
539
540 #[test]
543 fn pending_requests_default_is_same_as_new() {
544 let pr = PendingRequests::default();
545 let id = pr.next_request_id();
546 assert_eq!(id, RequestId::Number(1_000_000));
548 }
549
550 #[test]
551 fn pending_requests_ids_are_sequential() {
552 let pr = PendingRequests::new();
553 let id1 = pr.next_request_id();
554 let id2 = pr.next_request_id();
555 let id3 = pr.next_request_id();
556 assert_eq!(id1, RequestId::Number(1_000_000));
557 assert_eq!(id2, RequestId::Number(1_000_001));
558 assert_eq!(id3, RequestId::Number(1_000_002));
559 }
560
561 #[test]
562 fn pending_requests_remove_prevents_routing() {
563 let pr = PendingRequests::new();
564 let id = pr.next_request_id();
565 let _receiver = pr.register(id.clone());
566
567 pr.remove(&id);
569
570 let response = JsonRpcResponse::success(id, serde_json::json!(null));
572 assert!(!pr.route_response(&response));
573 }
574
575 #[test]
576 fn pending_requests_route_response_without_id_returns_false() {
577 let pr = PendingRequests::new();
578 let response = JsonRpcResponse {
580 jsonrpc: std::borrow::Cow::Borrowed("2.0"),
581 id: None,
582 result: Some(serde_json::json!(null)),
583 error: None,
584 };
585 assert!(!pr.route_response(&response));
586 }
587
588 #[test]
589 fn pending_requests_route_response_with_null_result() {
590 let pr = PendingRequests::new();
591 let id = pr.next_request_id();
592 let receiver = pr.register(id.clone());
593
594 let response = JsonRpcResponse {
596 jsonrpc: std::borrow::Cow::Borrowed("2.0"),
597 id: Some(id),
598 result: None,
599 error: None,
600 };
601 assert!(pr.route_response(&response));
602
603 let result = receiver.recv().unwrap().unwrap();
604 assert_eq!(result, serde_json::Value::Null);
605 }
606
607 #[test]
608 fn pending_requests_route_after_receiver_dropped_does_not_panic() {
609 let pr = PendingRequests::new();
610 let id = pr.next_request_id();
611 let receiver = pr.register(id.clone());
612
613 drop(receiver);
615
616 let response = JsonRpcResponse::success(id, serde_json::json!(42));
618 assert!(pr.route_response(&response));
619 }
620
621 #[test]
622 fn pending_requests_cancel_all_clears_pending() {
623 let pr = PendingRequests::new();
624 let id = pr.next_request_id();
625 let _receiver = pr.register(id.clone());
626
627 pr.cancel_all();
628
629 let response = JsonRpcResponse::success(id, serde_json::json!(null));
631 assert!(!pr.route_response(&response));
632 }
633
634 #[test]
635 fn pending_requests_cancel_all_empty_is_noop() {
636 let pr = PendingRequests::new();
637 pr.cancel_all();
639 }
640
641 #[test]
642 fn pending_requests_debug_format() {
643 let pr = PendingRequests::new();
644 let debug = format!("{:?}", pr);
645 assert!(debug.contains("PendingRequests"));
646 }
647
648 #[test]
651 fn request_sender_debug_format() {
652 let pending = Arc::new(PendingRequests::new());
653 let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
654 let sender = RequestSender::new(pending, send_fn);
655 let debug = format!("{:?}", sender);
656 assert!(debug.contains("RequestSender"));
657 }
658
659 #[test]
660 fn request_sender_transport_failure_returns_error() {
661 let pending = Arc::new(PendingRequests::new());
662 let send_fn: TransportSendFn = Arc::new(|_| Err("transport down".to_string()));
663 let sender = RequestSender::new(pending, send_fn);
664
665 let cx = Cx::for_testing();
666 let result: McpResult<serde_json::Value> =
667 sender.send_request(&cx, "test/method", serde_json::json!({}));
668 let err = result.unwrap_err();
669 assert!(err.message.contains("Failed to send request"));
670 assert!(err.message.contains("transport down"));
671 }
672
673 #[test]
674 fn request_sender_transport_failure_cleans_up_pending() {
675 let pending = Arc::new(PendingRequests::new());
676 let send_fn: TransportSendFn = Arc::new(|_| Err("fail".to_string()));
677 let sender = RequestSender::new(Arc::clone(&pending), send_fn);
678
679 let cx = Cx::for_testing();
680 let _err: McpResult<serde_json::Value> =
681 sender.send_request(&cx, "test/method", serde_json::json!({}));
682
683 let id = RequestId::Number(1_000_000); let response = JsonRpcResponse::success(id, serde_json::json!(null));
686 assert!(!pending.route_response(&response));
687 }
688
689 #[test]
690 fn request_sender_clone() {
691 let pending = Arc::new(PendingRequests::new());
692 let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
693 let sender = RequestSender::new(pending, send_fn);
694 let cloned = sender.clone();
695 let debug = format!("{:?}", cloned);
696 assert!(debug.contains("RequestSender"));
697 }
698
699 #[test]
702 fn request_sender_success_path() {
703 let pending = Arc::new(PendingRequests::new());
704 let pending_clone = Arc::clone(&pending);
705 let send_fn: TransportSendFn = Arc::new(move |msg| {
706 if let JsonRpcMessage::Request(req) = msg {
707 let id = req.id.clone().unwrap();
708 let response = JsonRpcResponse::success(id, serde_json::json!({"answer": 42}));
709 pending_clone.route_response(&response);
710 }
711 Ok(())
712 });
713 let sender = RequestSender::new(Arc::clone(&pending), send_fn);
714 let cx = Cx::for_testing();
715 let result: McpResult<serde_json::Value> =
716 sender.send_request(&cx, "test/method", serde_json::json!({}));
717 let value = result.unwrap();
718 assert_eq!(value["answer"], 42);
719 }
720
721 #[test]
722 fn request_sender_error_response_path() {
723 let pending = Arc::new(PendingRequests::new());
724 let pending_clone = Arc::clone(&pending);
725 let send_fn: TransportSendFn = Arc::new(move |msg| {
726 if let JsonRpcMessage::Request(req) = msg {
727 let id = req.id.clone().unwrap();
728 let response = JsonRpcResponse::error(
729 Some(id),
730 JsonRpcError {
731 code: -32600,
732 message: "bad request".to_string(),
733 data: None,
734 },
735 );
736 pending_clone.route_response(&response);
737 }
738 Ok(())
739 });
740 let sender = RequestSender::new(Arc::clone(&pending), send_fn);
741 let cx = Cx::for_testing();
742 let result: McpResult<serde_json::Value> =
743 sender.send_request(&cx, "test/method", serde_json::json!({}));
744 let err = result.unwrap_err();
745 assert!(err.message.contains("bad request"));
746 }
747
748 #[test]
749 fn request_sender_disconnected_path() {
750 let pending = Arc::new(PendingRequests::new());
751 let pending_clone = Arc::clone(&pending);
752 let send_fn: TransportSendFn = Arc::new(move |msg| {
753 if let JsonRpcMessage::Request(req) = msg {
754 let id = req.id.clone().unwrap();
755 pending_clone.remove(&id);
757 }
758 Ok(())
759 });
760 let sender = RequestSender::new(Arc::clone(&pending), send_fn);
761 let cx = Cx::for_testing();
762 let result: McpResult<serde_json::Value> =
763 sender.send_request(&cx, "test/method", serde_json::json!({}));
764 let err = result.unwrap_err();
765 assert!(err.message.contains("Response channel closed"));
766 }
767
768 #[test]
769 fn request_sender_deserialization_error() {
770 let pending = Arc::new(PendingRequests::new());
771 let pending_clone = Arc::clone(&pending);
772 let send_fn: TransportSendFn = Arc::new(move |msg| {
773 if let JsonRpcMessage::Request(req) = msg {
774 let id = req.id.clone().unwrap();
775 let response =
777 JsonRpcResponse::success(id, serde_json::json!("not a vec of strings"));
778 pending_clone.route_response(&response);
779 }
780 Ok(())
781 });
782 let sender = RequestSender::new(Arc::clone(&pending), send_fn);
783 let cx = Cx::for_testing();
784 let result: McpResult<Vec<String>> =
785 sender.send_request(&cx, "test/method", serde_json::json!({}));
786 let err = result.unwrap_err();
787 assert!(err.message.contains("Failed to parse response"));
788 }
789
790 #[test]
793 fn cancel_all_sends_connection_closed_error() {
794 let pr = PendingRequests::new();
795 let id = pr.next_request_id();
796 let receiver = pr.register(id);
797 pr.cancel_all();
798 let result = receiver.recv().unwrap();
799 let err = result.unwrap_err();
800 assert_eq!(err.code, i32::from(McpErrorCode::InternalError));
801 assert!(err.message.contains("Connection closed"));
802 assert!(err.data.is_none());
803 }
804
805 #[test]
808 fn route_response_error_with_data() {
809 let pr = PendingRequests::new();
810 let id = pr.next_request_id();
811 let receiver = pr.register(id.clone());
812 let response = JsonRpcResponse::error(
813 Some(id),
814 JsonRpcError {
815 code: -32001,
816 message: "custom error".to_string(),
817 data: Some(serde_json::json!({"detail": "extra info"})),
818 },
819 );
820 assert!(pr.route_response(&response));
821 let result = receiver.recv().unwrap();
822 let err = result.unwrap_err();
823 assert_eq!(err.code, -32001);
824 assert!(err.message.contains("custom error"));
825 assert!(err.data.is_some());
826 }
827
828 #[test]
831 fn pending_requests_multiple_register_and_route_independently() {
832 let pr = PendingRequests::new();
833 let id1 = pr.next_request_id();
834 let id2 = pr.next_request_id();
835 let id3 = pr.next_request_id();
836 let rx1 = pr.register(id1.clone());
837 let rx2 = pr.register(id2.clone());
838 let rx3 = pr.register(id3.clone());
839
840 let r2 = JsonRpcResponse::success(id2.clone(), serde_json::json!("second"));
842 let r3 = JsonRpcResponse::success(id3.clone(), serde_json::json!("third"));
843 let r1 = JsonRpcResponse::success(id1.clone(), serde_json::json!("first"));
844 assert!(pr.route_response(&r2));
845 assert!(pr.route_response(&r3));
846 assert!(pr.route_response(&r1));
847
848 assert_eq!(rx1.recv().unwrap().unwrap(), serde_json::json!("first"));
849 assert_eq!(rx2.recv().unwrap().unwrap(), serde_json::json!("second"));
850 assert_eq!(rx3.recv().unwrap().unwrap(), serde_json::json!("third"));
851 }
852
853 #[test]
856 fn pending_requests_register_same_id_overwrites() {
857 let pr = PendingRequests::new();
858 let id = pr.next_request_id();
859 let _rx1 = pr.register(id.clone());
860 let rx2 = pr.register(id.clone()); let response = JsonRpcResponse::success(id, serde_json::json!("response"));
863 assert!(pr.route_response(&response));
864
865 let result = rx2.recv().unwrap().unwrap();
867 assert_eq!(result, serde_json::json!("response"));
868 }
869
870 #[test]
873 fn transport_sampling_sender_new_and_clone() {
874 let pending = Arc::new(PendingRequests::new());
875 let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
876 let sender = RequestSender::new(pending, send_fn);
877 let sampling = TransportSamplingSender::new(sender);
878 let _cloned = sampling.clone();
879 }
880
881 #[test]
882 fn transport_elicitation_sender_new_and_clone() {
883 let pending = Arc::new(PendingRequests::new());
884 let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
885 let sender = RequestSender::new(pending, send_fn);
886 let elicitation = TransportElicitationSender::new(sender);
887 let _cloned = elicitation.clone();
888 }
889
890 #[test]
891 fn transport_roots_provider_new_and_clone() {
892 let pending = Arc::new(PendingRequests::new());
893 let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
894 let sender = RequestSender::new(pending, send_fn);
895 let roots = TransportRootsProvider::new(sender);
896 let _cloned = roots.clone();
897 }
898
899 #[test]
902 fn pending_requests_lock_pending_recovers_from_poison() {
903 let pr = Arc::new(PendingRequests::new());
904 let id = pr.next_request_id();
905 let rx = pr.register(id.clone());
906
907 let pr2 = Arc::clone(&pr);
909 let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
910 let _guard = pr2.pending.lock().unwrap();
911 panic!("intentional poison");
912 }));
913
914 let response = JsonRpcResponse::success(id, serde_json::json!("recovered"));
917 assert!(pr.route_response(&response));
918 let result = rx.recv().unwrap().unwrap();
919 assert_eq!(result, serde_json::json!("recovered"));
920 }
921
922 fn make_sender_with_responder(
925 responder: impl Fn(&JsonRpcRequest) -> serde_json::Value + Send + Sync + 'static,
926 ) -> RequestSender {
927 let pending = Arc::new(PendingRequests::new());
928 let pending_clone = Arc::clone(&pending);
929 let send_fn: TransportSendFn = Arc::new(move |msg| {
930 if let JsonRpcMessage::Request(req) = msg {
931 let id = req.id.clone().unwrap();
932 let result = responder(req);
933 let response = JsonRpcResponse::success(id, result);
934 pending_clone.route_response(&response);
935 }
936 Ok(())
937 });
938 RequestSender::new(pending, send_fn)
939 }
940
941 #[test]
942 fn transport_sampling_sender_create_message_text() {
943 let sender = make_sender_with_responder(|_| {
944 serde_json::json!({
945 "content": {"type": "text", "text": "Hello world"},
946 "role": "assistant",
947 "model": "test-model",
948 "stopReason": "endTurn"
949 })
950 });
951 let sampling = TransportSamplingSender::new(sender);
952
953 let request = SamplingRequest {
954 messages: vec![fastmcp_core::SamplingRequestMessage {
955 role: SamplingRole::User,
956 text: "Hi".to_string(),
957 }],
958 max_tokens: 100,
959 system_prompt: Some("Be helpful".to_string()),
960 temperature: Some(0.7),
961 stop_sequences: vec!["STOP".to_string()],
962 model_hints: vec![],
963 };
964
965 let future = SamplingSender::create_message(&sampling, request);
966 let result = fastmcp_core::block_on(future).unwrap();
967 assert_eq!(result.text, "Hello world");
968 assert_eq!(result.model, "test-model");
969 assert!(matches!(result.stop_reason, SamplingStopReason::EndTurn));
970 }
971
972 #[test]
973 fn transport_sampling_sender_create_message_image() {
974 let sender = make_sender_with_responder(|_| {
975 serde_json::json!({
976 "content": {"type": "image", "data": "aW1hZ2VkYXRh", "mimeType": "image/png"},
977 "role": "assistant",
978 "model": "vision-model",
979 "stopReason": "maxTokens"
980 })
981 });
982 let sampling = TransportSamplingSender::new(sender);
983
984 let request = SamplingRequest {
985 messages: vec![fastmcp_core::SamplingRequestMessage {
986 role: SamplingRole::User,
987 text: "Describe image".to_string(),
988 }],
989 max_tokens: 50,
990 system_prompt: None,
991 temperature: None,
992 stop_sequences: vec![],
993 model_hints: vec![],
994 };
995
996 let future = SamplingSender::create_message(&sampling, request);
997 let result = fastmcp_core::block_on(future).unwrap();
998 assert!(result.text.contains("image"));
1000 assert!(result.text.contains("image/png"));
1001 assert_eq!(result.model, "vision-model");
1002 assert!(matches!(result.stop_reason, SamplingStopReason::MaxTokens));
1003 }
1004
1005 #[test]
1006 fn transport_sampling_sender_create_message_with_model_hints() {
1007 let sender = make_sender_with_responder(|req| {
1008 let params: serde_json::Value =
1010 serde_json::from_value(req.params.clone().unwrap()).unwrap();
1011 assert!(params["modelPreferences"]["hints"].is_array());
1012 serde_json::json!({
1013 "content": {"type": "text", "text": "ok"},
1014 "role": "assistant",
1015 "model": "preferred",
1016 "stopReason": "stopSequence"
1017 })
1018 });
1019 let sampling = TransportSamplingSender::new(sender);
1020
1021 let request = SamplingRequest {
1022 messages: vec![fastmcp_core::SamplingRequestMessage {
1023 role: SamplingRole::User,
1024 text: "Hi".to_string(),
1025 }],
1026 max_tokens: 10,
1027 system_prompt: None,
1028 temperature: None,
1029 stop_sequences: vec![],
1030 model_hints: vec!["claude-3".to_string()],
1031 };
1032
1033 let future = SamplingSender::create_message(&sampling, request);
1034 let result = fastmcp_core::block_on(future).unwrap();
1035 assert!(matches!(
1036 result.stop_reason,
1037 SamplingStopReason::StopSequence
1038 ));
1039 }
1040
1041 #[test]
1042 fn transport_sampling_sender_create_message_assistant_role() {
1043 let sender = make_sender_with_responder(|req| {
1044 let params: serde_json::Value =
1045 serde_json::from_value(req.params.clone().unwrap()).unwrap();
1046 assert_eq!(params["messages"][0]["role"], "assistant");
1047 serde_json::json!({
1048 "content": {"type": "text", "text": "continued"},
1049 "role": "assistant",
1050 "model": "m",
1051 "stopReason": "endTurn"
1052 })
1053 });
1054 let sampling = TransportSamplingSender::new(sender);
1055
1056 let request = SamplingRequest {
1057 messages: vec![fastmcp_core::SamplingRequestMessage {
1058 role: SamplingRole::Assistant,
1059 text: "Previous response".to_string(),
1060 }],
1061 max_tokens: 10,
1062 system_prompt: None,
1063 temperature: None,
1064 stop_sequences: vec![],
1065 model_hints: vec![],
1066 };
1067
1068 let future = SamplingSender::create_message(&sampling, request);
1069 let result = fastmcp_core::block_on(future).unwrap();
1070 assert_eq!(result.text, "continued");
1071 }
1072
1073 #[test]
1076 fn transport_elicitation_sender_form_accept_with_content() {
1077 let sender = make_sender_with_responder(|req| {
1078 let params: serde_json::Value =
1079 serde_json::from_value(req.params.clone().unwrap()).unwrap();
1080 assert_eq!(params["mode"], "form");
1081 serde_json::json!({
1082 "action": "accept",
1083 "content": {
1084 "name": "Alice",
1085 "age": 30,
1086 "active": true,
1087 "score": 9.5,
1088 "tags": ["a", "b"],
1089 "empty": null
1090 }
1091 })
1092 });
1093 let elicitation = TransportElicitationSender::new(sender);
1094
1095 let request = ElicitationRequest {
1096 message: "Fill the form".to_string(),
1097 mode: ElicitationMode::Form,
1098 schema: Some(serde_json::json!({"type": "object"})),
1099 url: None,
1100 elicitation_id: None,
1101 };
1102
1103 let future = ElicitationSender::elicit(&elicitation, request);
1104 let result = fastmcp_core::block_on(future).unwrap();
1105 assert!(matches!(result.action, ElicitationAction::Accept));
1106 let content = result.content.unwrap();
1107 assert_eq!(content["name"], serde_json::json!("Alice"));
1108 assert_eq!(content["age"], serde_json::json!(30));
1109 assert_eq!(content["active"], serde_json::json!(true));
1110 assert_eq!(content["score"], serde_json::json!(9.5));
1111 assert_eq!(content["tags"], serde_json::json!(["a", "b"]));
1112 assert_eq!(content["empty"], serde_json::Value::Null);
1113 }
1114
1115 #[test]
1116 fn transport_elicitation_sender_form_decline() {
1117 let sender = make_sender_with_responder(|_| {
1118 serde_json::json!({
1119 "action": "decline"
1120 })
1121 });
1122 let elicitation = TransportElicitationSender::new(sender);
1123
1124 let request = ElicitationRequest {
1125 message: "Confirm?".to_string(),
1126 mode: ElicitationMode::Form,
1127 schema: None,
1128 url: None,
1129 elicitation_id: None,
1130 };
1131
1132 let future = ElicitationSender::elicit(&elicitation, request);
1133 let result = fastmcp_core::block_on(future).unwrap();
1134 assert!(matches!(result.action, ElicitationAction::Decline));
1135 assert!(result.content.is_none());
1136 }
1137
1138 #[test]
1139 fn transport_elicitation_sender_url_mode() {
1140 let sender = make_sender_with_responder(|req| {
1141 let params: serde_json::Value =
1142 serde_json::from_value(req.params.clone().unwrap()).unwrap();
1143 assert_eq!(params["mode"], "url");
1144 assert_eq!(params["url"], "https://example.com/auth");
1145 serde_json::json!({
1146 "action": "cancel"
1147 })
1148 });
1149 let elicitation = TransportElicitationSender::new(sender);
1150
1151 let request = ElicitationRequest {
1152 message: "Please authenticate".to_string(),
1153 mode: ElicitationMode::Url,
1154 schema: None,
1155 url: Some("https://example.com/auth".to_string()),
1156 elicitation_id: Some("eid-123".to_string()),
1157 };
1158
1159 let future = ElicitationSender::elicit(&elicitation, request);
1160 let result = fastmcp_core::block_on(future).unwrap();
1161 assert!(matches!(result.action, ElicitationAction::Cancel));
1162 }
1163
1164 #[test]
1167 fn transport_roots_provider_list_roots() {
1168 let sender = make_sender_with_responder(|_| {
1169 serde_json::json!({
1170 "roots": [
1171 {"uri": "file:///home/user/project", "name": "Project"},
1172 {"uri": "file:///tmp"}
1173 ]
1174 })
1175 });
1176 let roots = TransportRootsProvider::new(sender);
1177 let result = roots.list_roots().unwrap();
1178 assert_eq!(result.len(), 2);
1179 assert_eq!(result[0].uri, "file:///home/user/project");
1180 assert_eq!(result[0].name, Some("Project".to_string()));
1181 assert_eq!(result[1].uri, "file:///tmp");
1182 assert!(result[1].name.is_none());
1183 }
1184
1185 #[test]
1186 fn transport_roots_provider_empty_roots() {
1187 let sender = make_sender_with_responder(|_| serde_json::json!({ "roots": [] }));
1188 let roots = TransportRootsProvider::new(sender);
1189 let result = roots.list_roots().unwrap();
1190 assert!(result.is_empty());
1191 }
1192
1193 #[test]
1198 fn request_sender_cancelled_cx_returns_cancelled_error() {
1199 let pending = Arc::new(PendingRequests::new());
1200 let send_fn: TransportSendFn = Arc::new(|_| Ok(()));
1202 let sender = RequestSender::new(Arc::clone(&pending), send_fn);
1203
1204 let cx = Cx::for_testing();
1205 cx.set_cancel_requested(true);
1206
1207 let result: McpResult<serde_json::Value> =
1208 sender.send_request(&cx, "test/cancel", serde_json::json!({}));
1209 let err = result.unwrap_err();
1210 assert_eq!(err.code, McpErrorCode::RequestCancelled);
1211 }
1212
1213 #[test]
1216 fn transport_elicitation_sender_url_mode_defaults() {
1217 let sender = make_sender_with_responder(|req| {
1218 let params: serde_json::Value =
1219 serde_json::from_value(req.params.clone().unwrap()).unwrap();
1220 assert_eq!(params["mode"], "url");
1221 assert_eq!(params["url"], "");
1223 assert_eq!(params["elicitationId"], "");
1224 serde_json::json!({ "action": "accept" })
1225 });
1226 let elicitation = TransportElicitationSender::new(sender);
1227
1228 let request = ElicitationRequest {
1229 message: "Auth".to_string(),
1230 mode: ElicitationMode::Url,
1231 schema: None,
1232 url: None,
1233 elicitation_id: None,
1234 };
1235
1236 let future = ElicitationSender::elicit(&elicitation, request);
1237 let result = fastmcp_core::block_on(future).unwrap();
1238 assert!(matches!(result.action, ElicitationAction::Accept));
1239 }
1240
1241 #[test]
1244 fn transport_roots_provider_transport_failure() {
1245 let pending = Arc::new(PendingRequests::new());
1246 let send_fn: TransportSendFn = Arc::new(|_| Err("network error".to_string()));
1247 let sender = RequestSender::new(pending, send_fn);
1248 let roots = TransportRootsProvider::new(sender);
1249
1250 let result = roots.list_roots();
1251 assert!(result.is_err());
1252 assert!(
1253 result
1254 .unwrap_err()
1255 .message
1256 .contains("Failed to send request")
1257 );
1258 }
1259
1260 #[test]
1263 fn transport_sampling_sender_transport_failure() {
1264 let pending = Arc::new(PendingRequests::new());
1265 let send_fn: TransportSendFn = Arc::new(|_| Err("connection reset".to_string()));
1266 let sender = RequestSender::new(pending, send_fn);
1267 let sampling = TransportSamplingSender::new(sender);
1268
1269 let request = SamplingRequest {
1270 messages: vec![fastmcp_core::SamplingRequestMessage {
1271 role: SamplingRole::User,
1272 text: "Hi".to_string(),
1273 }],
1274 max_tokens: 10,
1275 system_prompt: None,
1276 temperature: None,
1277 stop_sequences: vec![],
1278 model_hints: vec![],
1279 };
1280
1281 let future = SamplingSender::create_message(&sampling, request);
1282 let result = fastmcp_core::block_on(future);
1283 assert!(result.is_err());
1284 assert!(
1285 result
1286 .unwrap_err()
1287 .message
1288 .contains("Failed to send request")
1289 );
1290 }
1291
1292 #[test]
1295 fn transport_sampling_sender_multiple_messages() {
1296 let sender = make_sender_with_responder(|req| {
1297 let params: serde_json::Value =
1298 serde_json::from_value(req.params.clone().unwrap()).unwrap();
1299 let messages = params["messages"].as_array().unwrap();
1300 assert_eq!(messages.len(), 3);
1301 assert_eq!(messages[0]["role"], "user");
1302 assert_eq!(messages[1]["role"], "assistant");
1303 assert_eq!(messages[2]["role"], "user");
1304 serde_json::json!({
1305 "content": {"type": "text", "text": "done"},
1306 "role": "assistant",
1307 "model": "m",
1308 "stopReason": "endTurn"
1309 })
1310 });
1311 let sampling = TransportSamplingSender::new(sender);
1312
1313 let request = SamplingRequest {
1314 messages: vec![
1315 fastmcp_core::SamplingRequestMessage {
1316 role: SamplingRole::User,
1317 text: "Hello".to_string(),
1318 },
1319 fastmcp_core::SamplingRequestMessage {
1320 role: SamplingRole::Assistant,
1321 text: "Hi".to_string(),
1322 },
1323 fastmcp_core::SamplingRequestMessage {
1324 role: SamplingRole::User,
1325 text: "Follow up".to_string(),
1326 },
1327 ],
1328 max_tokens: 100,
1329 system_prompt: None,
1330 temperature: None,
1331 stop_sequences: vec![],
1332 model_hints: vec![],
1333 };
1334
1335 let future = SamplingSender::create_message(&sampling, request);
1336 let result = fastmcp_core::block_on(future).unwrap();
1337 assert_eq!(result.text, "done");
1338 }
1339
1340 #[test]
1343 fn request_sender_id_cleaned_from_pending_after_success() {
1344 let pending = Arc::new(PendingRequests::new());
1345 let pending_clone = Arc::clone(&pending);
1346 let send_fn: TransportSendFn = Arc::new(move |msg| {
1347 if let JsonRpcMessage::Request(req) = msg {
1348 let id = req.id.clone().unwrap();
1349 let response = JsonRpcResponse::success(id, serde_json::json!(null));
1350 pending_clone.route_response(&response);
1351 }
1352 Ok(())
1353 });
1354 let sender = RequestSender::new(Arc::clone(&pending), send_fn);
1355 let cx = Cx::for_testing();
1356 let _: serde_json::Value = sender
1357 .send_request(&cx, "test/method", serde_json::json!({}))
1358 .unwrap();
1359
1360 let first_id = RequestId::Number(1_000_000);
1362 let response = JsonRpcResponse::success(first_id, serde_json::json!(null));
1363 assert!(!pending.route_response(&response));
1364 }
1365}