1use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender};
58use std::time::Duration;
59
60use asupersync::Cx;
61use fastmcp_protocol::JsonRpcMessage;
62
63use crate::{Codec, Transport, TransportError};
64
65const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(50);
67
68pub struct MemoryTransport {
86 sender: Sender<JsonRpcMessage>,
88 receiver: Receiver<JsonRpcMessage>,
90 codec: Codec,
92 closed: bool,
94 poll_interval: Duration,
96}
97
98impl std::fmt::Debug for MemoryTransport {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("MemoryTransport")
101 .field("closed", &self.closed)
102 .field("poll_interval", &self.poll_interval)
103 .finish_non_exhaustive()
104 }
105}
106
107impl MemoryTransport {
108 fn new(sender: Sender<JsonRpcMessage>, receiver: Receiver<JsonRpcMessage>) -> Self {
113 Self {
114 sender,
115 receiver,
116 codec: Codec::new(),
117 closed: false,
118 poll_interval: DEFAULT_POLL_INTERVAL,
119 }
120 }
121
122 #[must_use]
127 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
128 self.poll_interval = interval;
129 self
130 }
131
132 #[must_use]
134 pub fn is_closed(&self) -> bool {
135 self.closed
136 }
137}
138
139impl Transport for MemoryTransport {
140 fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
141 if cx.is_cancel_requested() {
143 return Err(TransportError::Cancelled);
144 }
145
146 if self.closed {
147 return Err(TransportError::Closed);
148 }
149
150 self.sender
152 .send(message.clone())
153 .map_err(|_| TransportError::Closed)
154 }
155
156 fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
157 if cx.is_cancel_requested() {
159 return Err(TransportError::Cancelled);
160 }
161
162 if self.closed {
163 return Err(TransportError::Closed);
164 }
165
166 loop {
168 match self.receiver.recv_timeout(self.poll_interval) {
169 Ok(message) => return Ok(message),
170 Err(RecvTimeoutError::Timeout) => {
171 if cx.is_cancel_requested() {
173 return Err(TransportError::Cancelled);
174 }
175 }
177 Err(RecvTimeoutError::Disconnected) => {
178 self.closed = true;
179 return Err(TransportError::Closed);
180 }
181 }
182 }
183 }
184
185 fn close(&mut self) -> Result<(), TransportError> {
186 self.closed = true;
187 Ok(())
189 }
190}
191
192#[must_use]
226pub fn create_memory_transport_pair() -> (MemoryTransport, MemoryTransport) {
227 create_memory_transport_pair_with_capacity(64)
228}
229
230#[must_use]
246pub fn create_memory_transport_pair_with_capacity(
247 _capacity: usize,
248) -> (MemoryTransport, MemoryTransport) {
249 let (client_to_server_tx, client_to_server_rx) = mpsc::channel();
253 let (server_to_client_tx, server_to_client_rx) = mpsc::channel();
254
255 let client = MemoryTransport::new(client_to_server_tx, server_to_client_rx);
256 let server = MemoryTransport::new(server_to_client_tx, client_to_server_rx);
257
258 (client, server)
259}
260
261#[derive(Debug, Clone)]
274pub struct MemoryTransportBuilder {
275 poll_interval: Duration,
276}
277
278impl Default for MemoryTransportBuilder {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284impl MemoryTransportBuilder {
285 #[must_use]
287 pub fn new() -> Self {
288 Self {
289 poll_interval: DEFAULT_POLL_INTERVAL,
290 }
291 }
292
293 #[must_use]
295 pub fn poll_interval(mut self, interval: Duration) -> Self {
296 self.poll_interval = interval;
297 self
298 }
299
300 #[must_use]
302 pub fn build(self) -> (MemoryTransport, MemoryTransport) {
303 let (mut client, mut server) = create_memory_transport_pair();
304 client.poll_interval = self.poll_interval;
305 server.poll_interval = self.poll_interval;
306 (client, server)
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use fastmcp_protocol::{JsonRpcRequest, JsonRpcResponse, RequestId};
314 use std::thread;
315
316 #[test]
317 fn test_basic_send_receive() {
318 let (mut client, mut server) = create_memory_transport_pair();
319 let cx = Cx::for_testing();
320
321 let request = JsonRpcRequest::new("test/method", None, 1i64);
323 client.send_request(&cx, &request).unwrap();
324
325 let msg = server.recv(&cx).unwrap();
327 match msg {
328 JsonRpcMessage::Request(req) => {
329 assert_eq!(req.method, "test/method");
330 assert_eq!(req.id, Some(RequestId::Number(1)));
331 }
332 _ => panic!("Expected request"),
333 }
334 }
335
336 #[test]
337 fn test_bidirectional_communication() {
338 let (mut client, mut server) = create_memory_transport_pair();
339 let cx = Cx::for_testing();
340
341 let request = JsonRpcRequest::new("ping", None, 1i64);
343 client.send_request(&cx, &request).unwrap();
344
345 let _msg = server.recv(&cx).unwrap();
347 let response =
348 JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"pong": true}));
349 server.send_response(&cx, &response).unwrap();
350
351 let msg = client.recv(&cx).unwrap();
353 match msg {
354 JsonRpcMessage::Response(resp) => {
355 assert!(resp.result.is_some());
356 }
357 _ => panic!("Expected response"),
358 }
359 }
360
361 #[test]
362 fn test_multiple_messages() {
363 let (mut client, mut server) = create_memory_transport_pair();
364 let cx = Cx::for_testing();
365
366 for i in 1..=5 {
368 let request = JsonRpcRequest::new(format!("method_{i}"), None, i as i64);
369 client.send_request(&cx, &request).unwrap();
370 }
371
372 for i in 1..=5 {
374 let msg = server.recv(&cx).unwrap();
375 match msg {
376 JsonRpcMessage::Request(req) => {
377 assert_eq!(req.method, format!("method_{i}"));
378 }
379 _ => panic!("Expected request"),
380 }
381 }
382 }
383
384 #[test]
385 fn test_cancellation_on_recv() {
386 let (client, mut server) = create_memory_transport_pair();
387 let cx = Cx::for_testing();
388
389 cx.set_cancel_requested(true);
393
394 let result = server.recv(&cx);
396 assert!(matches!(result, Err(TransportError::Cancelled)));
397
398 drop(client);
400 }
401
402 #[test]
403 fn test_cancellation_on_send() {
404 let (mut client, _server) = create_memory_transport_pair();
405 let cx = Cx::for_testing();
406
407 cx.set_cancel_requested(true);
408
409 let request = JsonRpcRequest::new("test", None, 1i64);
410 let result = client.send_request(&cx, &request);
411 assert!(matches!(result, Err(TransportError::Cancelled)));
412 }
413
414 #[test]
415 fn test_close_signals_disconnection() {
416 let (mut client, mut server) = create_memory_transport_pair();
417 let cx = Cx::for_testing();
418
419 client.close().unwrap();
421 drop(client);
422
423 let result = server.recv(&cx);
425 assert!(matches!(result, Err(TransportError::Closed)));
426 }
427
428 #[test]
429 fn test_send_after_close_fails() {
430 let (mut client, _server) = create_memory_transport_pair();
431 let cx = Cx::for_testing();
432
433 client.close().unwrap();
434
435 let request = JsonRpcRequest::new("test", None, 1i64);
436 let result = client.send_request(&cx, &request);
437 assert!(matches!(result, Err(TransportError::Closed)));
438 }
439
440 #[test]
441 fn test_recv_after_close_fails() {
442 let (mut client, mut server) = create_memory_transport_pair();
443 let cx = Cx::for_testing();
444
445 let request = JsonRpcRequest::new("test", None, 1i64);
447 client.send_request(&cx, &request).unwrap();
448
449 server.close().unwrap();
451
452 let result = server.recv(&cx);
454 assert!(matches!(result, Err(TransportError::Closed)));
455 }
456
457 #[test]
458 fn test_cross_thread_communication() {
459 let (mut client, mut server) = create_memory_transport_pair();
460
461 let server_handle = thread::spawn(move || {
462 let cx = Cx::for_testing();
463
464 let msg = server.recv(&cx).unwrap();
466 let request_id = match &msg {
467 JsonRpcMessage::Request(req) => req.id.clone().unwrap(),
468 _ => panic!("Expected request"),
469 };
470
471 let response = JsonRpcResponse::success(request_id, serde_json::json!({"ok": true}));
473 server.send_response(&cx, &response).unwrap();
474 });
475
476 let client_handle = thread::spawn(move || {
477 let cx = Cx::for_testing();
478
479 let request = JsonRpcRequest::new("cross_thread_test", None, 42i64);
481 client.send_request(&cx, &request).unwrap();
482
483 let msg = client.recv(&cx).unwrap();
485 match msg {
486 JsonRpcMessage::Response(resp) => {
487 assert!(resp.result.is_some());
488 }
489 _ => panic!("Expected response"),
490 }
491 });
492
493 server_handle.join().unwrap();
494 client_handle.join().unwrap();
495 }
496
497 #[test]
498 fn test_builder_custom_poll_interval() {
499 use std::time::Duration;
500
501 let (client, server) = MemoryTransportBuilder::new()
502 .poll_interval(Duration::from_millis(5))
503 .build();
504
505 assert_eq!(client.poll_interval, Duration::from_millis(5));
506 assert_eq!(server.poll_interval, Duration::from_millis(5));
507 }
508
509 #[test]
510 fn test_is_closed() {
511 let (mut client, server) = create_memory_transport_pair();
512
513 assert!(!client.is_closed());
514 assert!(!server.is_closed());
515
516 client.close().unwrap();
517
518 assert!(client.is_closed());
519 assert!(!server.is_closed());
521 }
522
523 #[test]
524 fn test_with_poll_interval() {
525 use std::time::Duration;
526
527 let (client, _server) = create_memory_transport_pair();
528 let client = client.with_poll_interval(Duration::from_millis(100));
529
530 assert_eq!(client.poll_interval, Duration::from_millis(100));
531 }
532}