1use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender};
58use std::time::Duration;
59
60use asupersync::Cx;
61use fastmcp_protocol::JsonRpcMessage;
62
63use crate::{Codec, Transport, TransportError};
64
65const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(50);
67
68pub struct MemoryTransport {
86 sender: Sender<JsonRpcMessage>,
88 receiver: Receiver<JsonRpcMessage>,
90 codec: Codec,
92 closed: bool,
94 poll_interval: Duration,
96}
97
98impl std::fmt::Debug for MemoryTransport {
99 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100 f.debug_struct("MemoryTransport")
101 .field("closed", &self.closed)
102 .field("poll_interval", &self.poll_interval)
103 .finish_non_exhaustive()
104 }
105}
106
107impl MemoryTransport {
108 fn new(sender: Sender<JsonRpcMessage>, receiver: Receiver<JsonRpcMessage>) -> Self {
113 Self {
114 sender,
115 receiver,
116 codec: Codec::new(),
117 closed: false,
118 poll_interval: DEFAULT_POLL_INTERVAL,
119 }
120 }
121
122 #[must_use]
127 pub fn with_poll_interval(mut self, interval: Duration) -> Self {
128 self.poll_interval = interval;
129 self
130 }
131
132 #[must_use]
134 pub fn is_closed(&self) -> bool {
135 self.closed
136 }
137}
138
139impl Transport for MemoryTransport {
140 fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
141 if cx.is_cancel_requested() {
143 return Err(TransportError::Cancelled);
144 }
145
146 if self.closed {
147 return Err(TransportError::Closed);
148 }
149
150 self.sender
152 .send(message.clone())
153 .map_err(|_| TransportError::Closed)
154 }
155
156 fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
157 if cx.is_cancel_requested() {
159 return Err(TransportError::Cancelled);
160 }
161
162 if self.closed {
163 return Err(TransportError::Closed);
164 }
165
166 loop {
168 match self.receiver.recv_timeout(self.poll_interval) {
169 Ok(message) => return Ok(message),
170 Err(RecvTimeoutError::Timeout) => {
171 if cx.is_cancel_requested() {
173 return Err(TransportError::Cancelled);
174 }
175 }
177 Err(RecvTimeoutError::Disconnected) => {
178 self.closed = true;
179 return Err(TransportError::Closed);
180 }
181 }
182 }
183 }
184
185 fn close(&mut self) -> Result<(), TransportError> {
186 self.closed = true;
187 Ok(())
189 }
190}
191
192#[must_use]
226pub fn create_memory_transport_pair() -> (MemoryTransport, MemoryTransport) {
227 create_memory_transport_pair_with_capacity(64)
228}
229
230#[must_use]
246pub fn create_memory_transport_pair_with_capacity(
247 _capacity: usize,
248) -> (MemoryTransport, MemoryTransport) {
249 let (client_to_server_tx, client_to_server_rx) = mpsc::channel();
253 let (server_to_client_tx, server_to_client_rx) = mpsc::channel();
254
255 let client = MemoryTransport::new(client_to_server_tx, server_to_client_rx);
256 let server = MemoryTransport::new(server_to_client_tx, client_to_server_rx);
257
258 (client, server)
259}
260
261#[derive(Debug, Clone)]
274pub struct MemoryTransportBuilder {
275 poll_interval: Duration,
276}
277
278impl Default for MemoryTransportBuilder {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284impl MemoryTransportBuilder {
285 #[must_use]
287 pub fn new() -> Self {
288 Self {
289 poll_interval: DEFAULT_POLL_INTERVAL,
290 }
291 }
292
293 #[must_use]
295 pub fn poll_interval(mut self, interval: Duration) -> Self {
296 self.poll_interval = interval;
297 self
298 }
299
300 #[must_use]
302 pub fn build(self) -> (MemoryTransport, MemoryTransport) {
303 let (mut client, mut server) = create_memory_transport_pair();
304 client.poll_interval = self.poll_interval;
305 server.poll_interval = self.poll_interval;
306 (client, server)
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313 use fastmcp_protocol::{JsonRpcRequest, JsonRpcResponse, RequestId};
314 use std::thread;
315
316 #[test]
317 fn test_basic_send_receive() {
318 let (mut client, mut server) = create_memory_transport_pair();
319 let cx = Cx::for_testing();
320
321 let request = JsonRpcRequest::new("test/method", None, 1i64);
323 client.send_request(&cx, &request).unwrap();
324
325 let msg = server.recv(&cx).unwrap();
327 assert!(
328 matches!(msg, JsonRpcMessage::Request(_)),
329 "Expected request"
330 );
331 let JsonRpcMessage::Request(req) = msg else {
332 return;
333 };
334 assert_eq!(req.method, "test/method");
335 assert_eq!(req.id, Some(RequestId::Number(1)));
336 }
337
338 #[test]
339 fn test_bidirectional_communication() {
340 let (mut client, mut server) = create_memory_transport_pair();
341 let cx = Cx::for_testing();
342
343 let request = JsonRpcRequest::new("ping", None, 1i64);
345 client.send_request(&cx, &request).unwrap();
346
347 let _msg = server.recv(&cx).unwrap();
349 let response =
350 JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"pong": true}));
351 server.send_response(&cx, &response).unwrap();
352
353 let msg = client.recv(&cx).unwrap();
355 assert!(
356 matches!(msg, JsonRpcMessage::Response(_)),
357 "Expected response"
358 );
359 let JsonRpcMessage::Response(resp) = msg else {
360 return;
361 };
362 assert!(resp.result.is_some());
363 }
364
365 #[test]
366 fn test_multiple_messages() {
367 let (mut client, mut server) = create_memory_transport_pair();
368 let cx = Cx::for_testing();
369
370 for i in 1..=5 {
372 let request = JsonRpcRequest::new(format!("method_{i}"), None, i as i64);
373 client.send_request(&cx, &request).unwrap();
374 }
375
376 for i in 1..=5 {
378 let msg = server.recv(&cx).unwrap();
379 assert!(
380 matches!(msg, JsonRpcMessage::Request(_)),
381 "Expected request"
382 );
383 let JsonRpcMessage::Request(req) = msg else {
384 return;
385 };
386 assert_eq!(req.method, format!("method_{i}"));
387 }
388 }
389
390 #[test]
391 fn test_cancellation_on_recv() {
392 let (client, mut server) = create_memory_transport_pair();
393 let cx = Cx::for_testing();
394
395 cx.set_cancel_requested(true);
399
400 let result = server.recv(&cx);
402 assert!(matches!(result, Err(TransportError::Cancelled)));
403
404 drop(client);
406 }
407
408 #[test]
409 fn test_cancellation_on_send() {
410 let (mut client, _server) = create_memory_transport_pair();
411 let cx = Cx::for_testing();
412
413 cx.set_cancel_requested(true);
414
415 let request = JsonRpcRequest::new("test", None, 1i64);
416 let result = client.send_request(&cx, &request);
417 assert!(matches!(result, Err(TransportError::Cancelled)));
418 }
419
420 #[test]
421 fn test_close_signals_disconnection() {
422 let (mut client, mut server) = create_memory_transport_pair();
423 let cx = Cx::for_testing();
424
425 client.close().unwrap();
427 drop(client);
428
429 let result = server.recv(&cx);
431 assert!(matches!(result, Err(TransportError::Closed)));
432 }
433
434 #[test]
435 fn test_send_after_close_fails() {
436 let (mut client, _server) = create_memory_transport_pair();
437 let cx = Cx::for_testing();
438
439 client.close().unwrap();
440
441 let request = JsonRpcRequest::new("test", None, 1i64);
442 let result = client.send_request(&cx, &request);
443 assert!(matches!(result, Err(TransportError::Closed)));
444 }
445
446 #[test]
447 fn test_recv_after_close_fails() {
448 let (mut client, mut server) = create_memory_transport_pair();
449 let cx = Cx::for_testing();
450
451 let request = JsonRpcRequest::new("test", None, 1i64);
453 client.send_request(&cx, &request).unwrap();
454
455 server.close().unwrap();
457
458 let result = server.recv(&cx);
460 assert!(matches!(result, Err(TransportError::Closed)));
461 }
462
463 #[test]
464 fn test_cross_thread_communication() {
465 let (mut client, mut server) = create_memory_transport_pair();
466
467 let server_handle = thread::spawn(move || {
468 let cx = Cx::for_testing();
469
470 let msg = server.recv(&cx).unwrap();
472 assert!(
473 matches!(msg, JsonRpcMessage::Request(_)),
474 "Expected request"
475 );
476 let JsonRpcMessage::Request(req) = msg else {
477 return;
478 };
479 let request_id = req.id.clone().unwrap();
480
481 let response = JsonRpcResponse::success(request_id, serde_json::json!({"ok": true}));
483 server.send_response(&cx, &response).unwrap();
484 });
485
486 let client_handle = thread::spawn(move || {
487 let cx = Cx::for_testing();
488
489 let request = JsonRpcRequest::new("cross_thread_test", None, 42i64);
491 client.send_request(&cx, &request).unwrap();
492
493 let msg = client.recv(&cx).unwrap();
495 assert!(
496 matches!(msg, JsonRpcMessage::Response(_)),
497 "Expected response"
498 );
499 let JsonRpcMessage::Response(resp) = msg else {
500 return;
501 };
502 assert!(resp.result.is_some());
503 });
504
505 server_handle.join().unwrap();
506 client_handle.join().unwrap();
507 }
508
509 #[test]
510 fn test_builder_custom_poll_interval() {
511 use std::time::Duration;
512
513 let (client, server) = MemoryTransportBuilder::new()
514 .poll_interval(Duration::from_millis(5))
515 .build();
516
517 assert_eq!(client.poll_interval, Duration::from_millis(5));
518 assert_eq!(server.poll_interval, Duration::from_millis(5));
519 }
520
521 #[test]
522 fn test_is_closed() {
523 let (mut client, server) = create_memory_transport_pair();
524
525 assert!(!client.is_closed());
526 assert!(!server.is_closed());
527
528 client.close().unwrap();
529
530 assert!(client.is_closed());
531 assert!(!server.is_closed());
533 }
534
535 #[test]
536 fn test_with_poll_interval() {
537 use std::time::Duration;
538
539 let (client, _server) = create_memory_transport_pair();
540 let client = client.with_poll_interval(Duration::from_millis(100));
541
542 assert_eq!(client.poll_interval, Duration::from_millis(100));
543 }
544
545 #[test]
546 fn test_debug_format() {
547 let (client, _server) = create_memory_transport_pair();
548 let debug = format!("{client:?}");
549 assert!(debug.contains("MemoryTransport"));
550 assert!(debug.contains("closed: false"));
551 }
552
553 #[test]
554 fn test_debug_format_closed() {
555 let (mut client, _server) = create_memory_transport_pair();
556 client.close().unwrap();
557 let debug = format!("{client:?}");
558 assert!(debug.contains("closed: true"));
559 }
560
561 #[test]
562 fn test_send_response_and_receive() {
563 let (mut client, mut server) = create_memory_transport_pair();
564 let cx = Cx::for_testing();
565
566 let response =
567 JsonRpcResponse::success(RequestId::Number(99), serde_json::json!({"val": 42}));
568 server.send_response(&cx, &response).unwrap();
569
570 let msg = client.recv(&cx).unwrap();
571 let JsonRpcMessage::Response(resp) = msg else {
572 panic!("expected response");
573 };
574 assert_eq!(resp.id, Some(RequestId::Number(99)));
575 }
576
577 #[test]
578 fn test_send_to_dropped_peer_fails() {
579 let (mut client, server) = create_memory_transport_pair();
580 let cx = Cx::for_testing();
581
582 drop(server);
584
585 let request = JsonRpcRequest::new("test", None, 1i64);
586 let result = client.send_request(&cx, &request);
587 assert!(matches!(result, Err(TransportError::Closed)));
588 }
589
590 #[test]
591 fn test_recv_from_dropped_peer_returns_closed() {
592 let (client, mut server) = create_memory_transport_pair();
593 let cx = Cx::for_testing();
594
595 drop(client);
597
598 let result = server.recv(&cx);
599 assert!(matches!(result, Err(TransportError::Closed)));
600 assert!(server.is_closed());
601 }
602
603 #[test]
604 fn test_create_pair_with_capacity() {
605 let (mut client, mut server) = create_memory_transport_pair_with_capacity(2);
606 let cx = Cx::for_testing();
607
608 let request = JsonRpcRequest::new("test", None, 1i64);
610 client.send_request(&cx, &request).unwrap();
611 let msg = server.recv(&cx).unwrap();
612 assert!(matches!(msg, JsonRpcMessage::Request(_)));
613 }
614
615 #[test]
616 fn test_builder_default() {
617 let builder = MemoryTransportBuilder::default();
618 let (client, server) = builder.build();
619 assert_eq!(client.poll_interval, DEFAULT_POLL_INTERVAL);
620 assert_eq!(server.poll_interval, DEFAULT_POLL_INTERVAL);
621 }
622
623 #[test]
624 fn test_close_is_idempotent() {
625 let (mut client, _server) = create_memory_transport_pair();
626 client.close().unwrap();
627 assert!(client.is_closed());
628 client.close().unwrap();
630 assert!(client.is_closed());
631 }
632
633 #[test]
634 fn test_message_ordering() {
635 let (mut client, mut server) = create_memory_transport_pair();
636 let cx = Cx::for_testing();
637
638 for i in 0..10 {
640 let request = JsonRpcRequest::new(format!("msg_{i}"), None, i as i64);
641 client.send_request(&cx, &request).unwrap();
642 }
643
644 for i in 0..10 {
646 let msg = server.recv(&cx).unwrap();
647 let JsonRpcMessage::Request(req) = msg else {
648 panic!("expected request");
649 };
650 assert_eq!(req.method, format!("msg_{i}"));
651 }
652 }
653
654 #[test]
655 fn test_cancellation_during_poll() {
656 let (_client, mut server) = MemoryTransportBuilder::new()
657 .poll_interval(Duration::from_millis(5))
658 .build();
659
660 let cx = Cx::for_testing();
661
662 let cx_clone = cx.clone();
664 let handle = thread::spawn(move || {
665 thread::sleep(Duration::from_millis(20));
666 cx_clone.set_cancel_requested(true);
667 });
668
669 let result = server.recv(&cx);
671 assert!(matches!(result, Err(TransportError::Cancelled)));
672
673 handle.join().unwrap();
674 }
675}