1use std::collections::HashMap;
37use std::sync::atomic::{AtomicU64, Ordering};
38use std::sync::{Arc, Mutex};
39
40use asupersync::Cx;
41use fastmcp_core::{
42 ElicitationAction, ElicitationMode, ElicitationRequest, ElicitationResponse, ElicitationSender,
43 McpError, McpErrorCode, McpResult, SamplingRequest, SamplingResponse, SamplingRole,
44 SamplingSender, SamplingStopReason,
45};
46use fastmcp_protocol::{JsonRpcError, JsonRpcMessage, JsonRpcRequest, JsonRpcResponse, RequestId};
47
48type ResponseSender = std::sync::mpsc::Sender<Result<serde_json::Value, JsonRpcError>>;
54type ResponseReceiver = std::sync::mpsc::Receiver<Result<serde_json::Value, JsonRpcError>>;
55
56#[derive(Debug)]
61pub struct PendingRequests {
62 pending: Mutex<HashMap<RequestId, ResponseSender>>,
64 next_id: AtomicU64,
66}
67
68impl PendingRequests {
69 #[must_use]
71 pub fn new() -> Self {
72 Self {
73 pending: Mutex::new(HashMap::new()),
74 next_id: AtomicU64::new(1_000_000),
76 }
77 }
78
79 #[allow(clippy::cast_possible_wrap)]
81 pub fn next_request_id(&self) -> RequestId {
82 let id = self.next_id.fetch_add(1, Ordering::SeqCst);
83 RequestId::Number(id as i64)
84 }
85
86 pub fn register(&self, id: RequestId) -> ResponseReceiver {
88 let (tx, rx) = std::sync::mpsc::channel();
89 let mut pending = self.pending.lock().unwrap();
90 pending.insert(id, tx);
91 rx
92 }
93
94 pub fn route_response(&self, response: &JsonRpcResponse) -> bool {
98 let Some(ref id) = response.id else {
99 return false;
100 };
101
102 let sender = {
103 let mut pending = self.pending.lock().unwrap();
104 pending.remove(id)
105 };
106
107 if let Some(sender) = sender {
108 let result = if let Some(ref error) = response.error {
109 Err(error.clone())
110 } else {
111 Ok(response.result.clone().unwrap_or(serde_json::Value::Null))
112 };
113 let _ = sender.send(result);
115 true
116 } else {
117 false
118 }
119 }
120
121 pub fn remove(&self, id: &RequestId) {
123 let mut pending = self.pending.lock().unwrap();
124 pending.remove(id);
125 }
126
127 pub fn cancel_all(&self) {
129 let mut pending = self.pending.lock().unwrap();
130 for (_, sender) in pending.drain() {
131 let _ = sender.send(Err(JsonRpcError {
132 code: McpErrorCode::InternalError.into(),
133 message: "Connection closed".to_string(),
134 data: None,
135 }));
136 }
137 }
138}
139
140impl Default for PendingRequests {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146pub type TransportSendFn = Arc<dyn Fn(&JsonRpcMessage) -> Result<(), String> + Send + Sync>;
152
153#[derive(Clone)]
158pub struct RequestSender {
159 pending: Arc<PendingRequests>,
161 send_fn: TransportSendFn,
163}
164
165impl RequestSender {
166 pub fn new(pending: Arc<PendingRequests>, send_fn: TransportSendFn) -> Self {
168 Self { pending, send_fn }
169 }
170
171 pub fn send_request<T: serde::de::DeserializeOwned>(
181 &self,
182 _cx: &Cx,
183 method: &str,
184 params: serde_json::Value,
185 ) -> McpResult<T> {
186 let id = self.pending.next_request_id();
187 let receiver = self.pending.register(id.clone());
188
189 let request = JsonRpcRequest::new(method.to_string(), Some(params), id.clone());
190 let message = JsonRpcMessage::Request(request);
191
192 if let Err(e) = (self.send_fn)(&message) {
194 self.pending.remove(&id);
195 return Err(McpError::internal_error(format!(
196 "Failed to send request: {}",
197 e
198 )));
199 }
200
201 match receiver.recv() {
204 Ok(Ok(value)) => serde_json::from_value(value)
205 .map_err(|e| McpError::internal_error(format!("Failed to parse response: {}", e))),
206 Ok(Err(error)) => Err(McpError::new(McpErrorCode::from(error.code), error.message)),
207 Err(_) => Err(McpError::internal_error(
208 "Response channel closed unexpectedly",
209 )),
210 }
211 }
212}
213
214impl std::fmt::Debug for RequestSender {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 f.debug_struct("RequestSender")
217 .field("pending", &self.pending)
218 .finish_non_exhaustive()
219 }
220}
221
222#[derive(Clone)]
228pub struct TransportSamplingSender {
229 sender: RequestSender,
230}
231
232impl TransportSamplingSender {
233 pub fn new(sender: RequestSender) -> Self {
235 Self { sender }
236 }
237}
238
239impl SamplingSender for TransportSamplingSender {
240 fn create_message(
241 &self,
242 request: SamplingRequest,
243 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = McpResult<SamplingResponse>> + Send + '_>>
244 {
245 Box::pin(async move {
246 let params = fastmcp_protocol::CreateMessageParams {
248 messages: request
249 .messages
250 .into_iter()
251 .map(|m| fastmcp_protocol::SamplingMessage {
252 role: match m.role {
253 SamplingRole::User => fastmcp_protocol::Role::User,
254 SamplingRole::Assistant => fastmcp_protocol::Role::Assistant,
255 },
256 content: fastmcp_protocol::SamplingContent::Text { text: m.text },
257 })
258 .collect(),
259 max_tokens: request.max_tokens,
260 system_prompt: request.system_prompt,
261 temperature: request.temperature,
262 stop_sequences: request.stop_sequences,
263 model_preferences: if request.model_hints.is_empty() {
264 None
265 } else {
266 Some(fastmcp_protocol::ModelPreferences {
267 hints: request
268 .model_hints
269 .into_iter()
270 .map(|name| fastmcp_protocol::ModelHint { name: Some(name) })
271 .collect(),
272 ..Default::default()
273 })
274 },
275 include_context: None,
276 meta: None,
277 };
278
279 let params_value = serde_json::to_value(¶ms)
280 .map_err(|e| McpError::internal_error(format!("Failed to serialize: {}", e)))?;
281
282 let cx = Cx::for_testing();
284
285 let result: fastmcp_protocol::CreateMessageResult =
286 self.sender
287 .send_request(&cx, "sampling/createMessage", params_value)?;
288
289 Ok(SamplingResponse {
290 text: match result.content {
291 fastmcp_protocol::SamplingContent::Text { text } => text,
292 fastmcp_protocol::SamplingContent::Image { data, mime_type } => {
293 format!("[image: {} bytes, type: {}]", data.len(), mime_type)
294 }
295 },
296 model: result.model,
297 stop_reason: match result.stop_reason {
298 fastmcp_protocol::StopReason::EndTurn => SamplingStopReason::EndTurn,
299 fastmcp_protocol::StopReason::StopSequence => SamplingStopReason::StopSequence,
300 fastmcp_protocol::StopReason::MaxTokens => SamplingStopReason::MaxTokens,
301 },
302 })
303 })
304 }
305}
306
307#[derive(Clone)]
313pub struct TransportElicitationSender {
314 sender: RequestSender,
315}
316
317impl TransportElicitationSender {
318 pub fn new(sender: RequestSender) -> Self {
320 Self { sender }
321 }
322}
323
324impl ElicitationSender for TransportElicitationSender {
325 fn elicit(
326 &self,
327 request: ElicitationRequest,
328 ) -> std::pin::Pin<
329 Box<dyn std::future::Future<Output = McpResult<ElicitationResponse>> + Send + '_>,
330 > {
331 Box::pin(async move {
332 let params_value = match request.mode {
333 ElicitationMode::Form => {
334 let params = fastmcp_protocol::ElicitRequestFormParams {
335 mode: fastmcp_protocol::ElicitMode::Form,
336 message: request.message.clone(),
337 requested_schema: request.schema.unwrap_or(serde_json::json!({})),
338 };
339 serde_json::to_value(¶ms).map_err(|e| {
340 McpError::internal_error(format!("Failed to serialize: {}", e))
341 })?
342 }
343 ElicitationMode::Url => {
344 let params = fastmcp_protocol::ElicitRequestUrlParams {
345 mode: fastmcp_protocol::ElicitMode::Url,
346 message: request.message.clone(),
347 url: request.url.unwrap_or_default(),
348 elicitation_id: request.elicitation_id.unwrap_or_default(),
349 };
350 serde_json::to_value(¶ms).map_err(|e| {
351 McpError::internal_error(format!("Failed to serialize: {}", e))
352 })?
353 }
354 };
355
356 let cx = Cx::for_testing();
358
359 let result: fastmcp_protocol::ElicitResult =
360 self.sender
361 .send_request(&cx, "elicitation/elicit", params_value)?;
362
363 let content = result.content.map(|content_map| {
365 let mut map = std::collections::HashMap::new();
366 for (key, value) in content_map {
367 let json_value = match value {
368 fastmcp_protocol::ElicitContentValue::Null => serde_json::Value::Null,
369 fastmcp_protocol::ElicitContentValue::Bool(b) => serde_json::Value::Bool(b),
370 fastmcp_protocol::ElicitContentValue::Int(i) => {
371 serde_json::Value::Number(i.into())
372 }
373 fastmcp_protocol::ElicitContentValue::Float(f) => {
374 serde_json::Number::from_f64(f)
375 .map(serde_json::Value::Number)
376 .unwrap_or(serde_json::Value::Null)
377 }
378 fastmcp_protocol::ElicitContentValue::String(s) => {
379 serde_json::Value::String(s)
380 }
381 fastmcp_protocol::ElicitContentValue::StringArray(arr) => {
382 serde_json::Value::Array(
383 arr.into_iter().map(serde_json::Value::String).collect(),
384 )
385 }
386 };
387 map.insert(key, json_value);
388 }
389 map
390 });
391
392 Ok(ElicitationResponse {
393 action: match result.action {
394 fastmcp_protocol::ElicitAction::Accept => ElicitationAction::Accept,
395 fastmcp_protocol::ElicitAction::Decline => ElicitationAction::Decline,
396 fastmcp_protocol::ElicitAction::Cancel => ElicitationAction::Cancel,
397 },
398 content,
399 })
400 })
401 }
402}
403
404#[derive(Clone)]
410pub struct TransportRootsProvider {
411 sender: RequestSender,
412}
413
414impl TransportRootsProvider {
415 pub fn new(sender: RequestSender) -> Self {
417 Self { sender }
418 }
419
420 pub fn list_roots(&self) -> McpResult<Vec<fastmcp_protocol::Root>> {
422 let cx = Cx::for_testing();
423 let result: fastmcp_protocol::ListRootsResult =
424 self.sender
425 .send_request(&cx, "roots/list", serde_json::json!({}))?;
426 Ok(result.roots)
427 }
428}
429
430#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_pending_requests_register_and_route() {
440 let pending = PendingRequests::new();
441
442 let id = pending.next_request_id();
444 let receiver = pending.register(id.clone());
445
446 let response = JsonRpcResponse::success(id, serde_json::json!({"result": "ok"}));
448 assert!(pending.route_response(&response));
449
450 let result = receiver.recv().unwrap();
452 assert!(result.is_ok());
453 assert_eq!(result.unwrap(), serde_json::json!({"result": "ok"}));
454 }
455
456 #[test]
457 fn test_pending_requests_error_response() {
458 let pending = PendingRequests::new();
459
460 let id = pending.next_request_id();
461 let receiver = pending.register(id.clone());
462
463 let response = JsonRpcResponse::error(
465 Some(id),
466 JsonRpcError {
467 code: -32600,
468 message: "Invalid request".to_string(),
469 data: None,
470 },
471 );
472 assert!(pending.route_response(&response));
473
474 let result = receiver.recv().unwrap();
476 assert!(result.is_err());
477 assert_eq!(result.unwrap_err().message, "Invalid request");
478 }
479
480 #[test]
481 fn test_pending_requests_cancel_all() {
482 let pending = PendingRequests::new();
483
484 let id1 = pending.next_request_id();
485 let id2 = pending.next_request_id();
486 let receiver1 = pending.register(id1);
487 let receiver2 = pending.register(id2);
488
489 pending.cancel_all();
491
492 let result1 = receiver1.recv().unwrap();
494 let result2 = receiver2.recv().unwrap();
495 assert!(result1.is_err());
496 assert!(result2.is_err());
497 }
498
499 #[test]
500 fn test_route_unknown_response() {
501 let pending = PendingRequests::new();
502
503 let response = JsonRpcResponse::success(
505 RequestId::Number(999999),
506 serde_json::json!({"result": "ok"}),
507 );
508 assert!(!pending.route_response(&response));
509 }
510}