Skip to main content

fastmcp_transport/
memory.rs

1//! In-memory transport for testing MCP servers without subprocess spawning.
2//!
3//! This module provides a channel-based transport for direct client-server
4//! communication within the same process. Essential for unit testing MCP
5//! servers without network/IO overhead.
6//!
7//! # Overview
8//!
9//! The [`MemoryTransport`] uses crossbeam channels to enable bidirectional
10//! message passing between client and server. Create a pair using
11//! [`create_memory_transport_pair`] which returns connected client and server
12//! transports.
13//!
14//! # Example
15//!
16//! ```ignore
17//! use fastmcp_transport::memory::create_memory_transport_pair;
18//! use fastmcp_transport::Transport;
19//! use asupersync::Cx;
20//!
21//! // Create connected pair
22//! let (client_transport, server_transport) = create_memory_transport_pair();
23//!
24//! // Use in separate threads/tasks
25//! // Client sends, server receives (and vice versa)
26//! let cx = Cx::for_testing();
27//! let request = JsonRpcRequest::new("test", None, 1i64);
28//! client_transport.send_request(&cx, &request)?;
29//!
30//! // Server receives the message
31//! let msg = server_transport.recv(&cx)?;
32//! ```
33//!
34//! # Testing Servers
35//!
36//! The primary use case is testing servers without subprocess spawning:
37//!
38//! ```ignore
39//! use fastmcp_transport::memory::{create_memory_transport_pair, MemoryTransport};
40//! use std::thread;
41//!
42//! let (mut client, mut server) = create_memory_transport_pair();
43//!
44//! // Spawn server handler in a thread
45//! let server_handle = thread::spawn(move || {
46//!     // Pass server transport to your server's run loop
47//!     run_server_with_transport(server);
48//! });
49//!
50//! // Use client to test
51//! let cx = Cx::for_testing();
52//! client.send_request(&cx, &init_request)?;
53//! let response = client.recv(&cx)?;
54//! assert!(matches!(response, JsonRpcMessage::Response(_)));
55//! ```
56
57use std::sync::mpsc::{self, Receiver, RecvTimeoutError, Sender};
58use std::time::Duration;
59
60use asupersync::Cx;
61use fastmcp_protocol::JsonRpcMessage;
62
63use crate::{Codec, Transport, TransportError};
64
65/// Default timeout for recv operations when polling for cancellation.
66const DEFAULT_POLL_INTERVAL: Duration = Duration::from_millis(50);
67
68/// In-memory transport using channels for message passing.
69///
70/// This transport enables direct communication between a client and server
71/// without any network or I/O overhead. Messages are passed through
72/// bounded MPSC channels.
73///
74/// # Thread Safety
75///
76/// The transport is `Send` and can be passed to other threads, but it is
77/// not `Sync`. Each endpoint (client/server) should be used from a single
78/// thread at a time.
79///
80/// # Cancellation
81///
82/// Recv operations poll the channel with a timeout, checking for cancellation
83/// between polls. This ensures proper integration with asupersync's
84/// cancellation mechanism.
85pub struct MemoryTransport {
86    /// Channel for sending messages to the peer.
87    sender: Sender<JsonRpcMessage>,
88    /// Channel for receiving messages from the peer.
89    receiver: Receiver<JsonRpcMessage>,
90    /// Codec for validation (not used for serialization in memory transport).
91    codec: Codec,
92    /// Whether the transport has been closed.
93    closed: bool,
94    /// Poll interval for cancellation checks during recv.
95    poll_interval: Duration,
96}
97
98impl std::fmt::Debug for MemoryTransport {
99    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
100        f.debug_struct("MemoryTransport")
101            .field("closed", &self.closed)
102            .field("poll_interval", &self.poll_interval)
103            .finish_non_exhaustive()
104    }
105}
106
107impl MemoryTransport {
108    /// Creates a new memory transport from channel endpoints.
109    ///
110    /// This is an internal constructor. Use [`create_memory_transport_pair`]
111    /// to create a connected pair of transports.
112    fn new(sender: Sender<JsonRpcMessage>, receiver: Receiver<JsonRpcMessage>) -> Self {
113        Self {
114            sender,
115            receiver,
116            codec: Codec::new(),
117            closed: false,
118            poll_interval: DEFAULT_POLL_INTERVAL,
119        }
120    }
121
122    /// Sets the poll interval for cancellation checks during recv.
123    ///
124    /// Lower values provide faster cancellation response but use more CPU.
125    /// Default is 50ms.
126    #[must_use]
127    pub fn with_poll_interval(mut self, interval: Duration) -> Self {
128        self.poll_interval = interval;
129        self
130    }
131
132    /// Returns whether this transport has been closed.
133    #[must_use]
134    pub fn is_closed(&self) -> bool {
135        self.closed
136    }
137}
138
139impl Transport for MemoryTransport {
140    fn send(&mut self, cx: &Cx, message: &JsonRpcMessage) -> Result<(), TransportError> {
141        // Check for cancellation before send
142        if cx.is_cancel_requested() {
143            return Err(TransportError::Cancelled);
144        }
145
146        if self.closed {
147            return Err(TransportError::Closed);
148        }
149
150        // Clone and send the message through the channel
151        self.sender
152            .send(message.clone())
153            .map_err(|_| TransportError::Closed)
154    }
155
156    fn recv(&mut self, cx: &Cx) -> Result<JsonRpcMessage, TransportError> {
157        // Check for cancellation before receive
158        if cx.is_cancel_requested() {
159            return Err(TransportError::Cancelled);
160        }
161
162        if self.closed {
163            return Err(TransportError::Closed);
164        }
165
166        // Poll with timeout to allow cancellation checks
167        loop {
168            match self.receiver.recv_timeout(self.poll_interval) {
169                Ok(message) => return Ok(message),
170                Err(RecvTimeoutError::Timeout) => {
171                    // Check for cancellation between polls
172                    if cx.is_cancel_requested() {
173                        return Err(TransportError::Cancelled);
174                    }
175                    // Continue polling
176                }
177                Err(RecvTimeoutError::Disconnected) => {
178                    self.closed = true;
179                    return Err(TransportError::Closed);
180                }
181            }
182        }
183    }
184
185    fn close(&mut self) -> Result<(), TransportError> {
186        self.closed = true;
187        // Dropping sender will signal disconnection to the peer
188        Ok(())
189    }
190}
191
192/// Creates a connected pair of memory transports.
193///
194/// Returns `(client, server)` transports where:
195/// - Messages sent on `client` are received on `server`
196/// - Messages sent on `server` are received on `client`
197///
198/// # Channel Capacity
199///
200/// Uses bounded channels with a default capacity of 64 messages.
201/// This prevents unbounded memory growth if one side is slower.
202///
203/// # Example
204///
205/// ```
206/// use fastmcp_transport::memory::create_memory_transport_pair;
207/// use fastmcp_transport::Transport;
208/// use fastmcp_protocol::{JsonRpcMessage, JsonRpcRequest};
209/// use asupersync::Cx;
210///
211/// let (mut client, mut server) = create_memory_transport_pair();
212/// let cx = Cx::for_testing();
213///
214/// // Client sends a request
215/// let request = JsonRpcRequest::new("test/method", None, 1i64);
216/// client.send_request(&cx, &request).unwrap();
217///
218/// // Server receives it
219/// let msg = server.recv(&cx).unwrap();
220/// match msg {
221///     JsonRpcMessage::Request(req) => assert_eq!(req.method, "test/method"),
222///     _ => panic!("Expected request"),
223/// }
224/// ```
225#[must_use]
226pub fn create_memory_transport_pair() -> (MemoryTransport, MemoryTransport) {
227    create_memory_transport_pair_with_capacity(64)
228}
229
230/// Creates a connected pair of memory transports with specified channel capacity.
231///
232/// # Arguments
233///
234/// * `capacity` - Maximum number of messages that can be buffered in each direction.
235///   If 0, creates unbounded channels (not recommended for production use).
236///
237/// # Example
238///
239/// ```
240/// use fastmcp_transport::memory::create_memory_transport_pair_with_capacity;
241///
242/// // Small buffer for testing backpressure
243/// let (client, server) = create_memory_transport_pair_with_capacity(4);
244/// ```
245#[must_use]
246pub fn create_memory_transport_pair_with_capacity(
247    _capacity: usize,
248) -> (MemoryTransport, MemoryTransport) {
249    // Note: std::sync::mpsc doesn't have bounded channels, so we use unbounded.
250    // For bounded behavior, users should use the crossbeam crate.
251    // This is a simplification for the initial implementation.
252    let (client_to_server_tx, client_to_server_rx) = mpsc::channel();
253    let (server_to_client_tx, server_to_client_rx) = mpsc::channel();
254
255    let client = MemoryTransport::new(client_to_server_tx, server_to_client_rx);
256    let server = MemoryTransport::new(server_to_client_tx, client_to_server_rx);
257
258    (client, server)
259}
260
261/// Builder for creating memory transport pairs with custom configuration.
262///
263/// # Example
264///
265/// ```
266/// use fastmcp_transport::memory::MemoryTransportBuilder;
267/// use std::time::Duration;
268///
269/// let (client, server) = MemoryTransportBuilder::new()
270///     .poll_interval(Duration::from_millis(10))
271///     .build();
272/// ```
273#[derive(Debug, Clone)]
274pub struct MemoryTransportBuilder {
275    poll_interval: Duration,
276}
277
278impl Default for MemoryTransportBuilder {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284impl MemoryTransportBuilder {
285    /// Creates a new builder with default settings.
286    #[must_use]
287    pub fn new() -> Self {
288        Self {
289            poll_interval: DEFAULT_POLL_INTERVAL,
290        }
291    }
292
293    /// Sets the poll interval for cancellation checks during recv.
294    #[must_use]
295    pub fn poll_interval(mut self, interval: Duration) -> Self {
296        self.poll_interval = interval;
297        self
298    }
299
300    /// Builds the transport pair with the configured settings.
301    #[must_use]
302    pub fn build(self) -> (MemoryTransport, MemoryTransport) {
303        let (mut client, mut server) = create_memory_transport_pair();
304        client.poll_interval = self.poll_interval;
305        server.poll_interval = self.poll_interval;
306        (client, server)
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use fastmcp_protocol::{JsonRpcRequest, JsonRpcResponse, RequestId};
314    use std::thread;
315
316    #[test]
317    fn test_basic_send_receive() {
318        let (mut client, mut server) = create_memory_transport_pair();
319        let cx = Cx::for_testing();
320
321        // Client sends request
322        let request = JsonRpcRequest::new("test/method", None, 1i64);
323        client.send_request(&cx, &request).unwrap();
324
325        // Server receives it
326        let msg = server.recv(&cx).unwrap();
327        match msg {
328            JsonRpcMessage::Request(req) => {
329                assert_eq!(req.method, "test/method");
330                assert_eq!(req.id, Some(RequestId::Number(1)));
331            }
332            _ => panic!("Expected request"),
333        }
334    }
335
336    #[test]
337    fn test_bidirectional_communication() {
338        let (mut client, mut server) = create_memory_transport_pair();
339        let cx = Cx::for_testing();
340
341        // Client sends request
342        let request = JsonRpcRequest::new("ping", None, 1i64);
343        client.send_request(&cx, &request).unwrap();
344
345        // Server receives and responds
346        let _msg = server.recv(&cx).unwrap();
347        let response =
348            JsonRpcResponse::success(RequestId::Number(1), serde_json::json!({"pong": true}));
349        server.send_response(&cx, &response).unwrap();
350
351        // Client receives response
352        let msg = client.recv(&cx).unwrap();
353        match msg {
354            JsonRpcMessage::Response(resp) => {
355                assert!(resp.result.is_some());
356            }
357            _ => panic!("Expected response"),
358        }
359    }
360
361    #[test]
362    fn test_multiple_messages() {
363        let (mut client, mut server) = create_memory_transport_pair();
364        let cx = Cx::for_testing();
365
366        // Send multiple messages
367        for i in 1..=5 {
368            let request = JsonRpcRequest::new(format!("method_{i}"), None, i as i64);
369            client.send_request(&cx, &request).unwrap();
370        }
371
372        // Receive all messages
373        for i in 1..=5 {
374            let msg = server.recv(&cx).unwrap();
375            match msg {
376                JsonRpcMessage::Request(req) => {
377                    assert_eq!(req.method, format!("method_{i}"));
378                }
379                _ => panic!("Expected request"),
380            }
381        }
382    }
383
384    #[test]
385    fn test_cancellation_on_recv() {
386        let (client, mut server) = create_memory_transport_pair();
387        let cx = Cx::for_testing();
388
389        // Don't send anything, so recv will block
390
391        // Set up cancellation
392        cx.set_cancel_requested(true);
393
394        // Recv should return cancelled immediately
395        let result = server.recv(&cx);
396        assert!(matches!(result, Err(TransportError::Cancelled)));
397
398        // Keep client alive to prevent disconnection error
399        drop(client);
400    }
401
402    #[test]
403    fn test_cancellation_on_send() {
404        let (mut client, _server) = create_memory_transport_pair();
405        let cx = Cx::for_testing();
406
407        cx.set_cancel_requested(true);
408
409        let request = JsonRpcRequest::new("test", None, 1i64);
410        let result = client.send_request(&cx, &request);
411        assert!(matches!(result, Err(TransportError::Cancelled)));
412    }
413
414    #[test]
415    fn test_close_signals_disconnection() {
416        let (mut client, mut server) = create_memory_transport_pair();
417        let cx = Cx::for_testing();
418
419        // Close client
420        client.close().unwrap();
421        drop(client);
422
423        // Server should get closed error on recv
424        let result = server.recv(&cx);
425        assert!(matches!(result, Err(TransportError::Closed)));
426    }
427
428    #[test]
429    fn test_send_after_close_fails() {
430        let (mut client, _server) = create_memory_transport_pair();
431        let cx = Cx::for_testing();
432
433        client.close().unwrap();
434
435        let request = JsonRpcRequest::new("test", None, 1i64);
436        let result = client.send_request(&cx, &request);
437        assert!(matches!(result, Err(TransportError::Closed)));
438    }
439
440    #[test]
441    fn test_recv_after_close_fails() {
442        let (mut client, mut server) = create_memory_transport_pair();
443        let cx = Cx::for_testing();
444
445        // Send a message before closing
446        let request = JsonRpcRequest::new("test", None, 1i64);
447        client.send_request(&cx, &request).unwrap();
448
449        // Close server
450        server.close().unwrap();
451
452        // Recv should fail
453        let result = server.recv(&cx);
454        assert!(matches!(result, Err(TransportError::Closed)));
455    }
456
457    #[test]
458    fn test_cross_thread_communication() {
459        let (mut client, mut server) = create_memory_transport_pair();
460
461        let server_handle = thread::spawn(move || {
462            let cx = Cx::for_testing();
463
464            // Receive request
465            let msg = server.recv(&cx).unwrap();
466            let request_id = match &msg {
467                JsonRpcMessage::Request(req) => req.id.clone().unwrap(),
468                _ => panic!("Expected request"),
469            };
470
471            // Send response
472            let response = JsonRpcResponse::success(request_id, serde_json::json!({"ok": true}));
473            server.send_response(&cx, &response).unwrap();
474        });
475
476        let client_handle = thread::spawn(move || {
477            let cx = Cx::for_testing();
478
479            // Send request
480            let request = JsonRpcRequest::new("cross_thread_test", None, 42i64);
481            client.send_request(&cx, &request).unwrap();
482
483            // Receive response
484            let msg = client.recv(&cx).unwrap();
485            match msg {
486                JsonRpcMessage::Response(resp) => {
487                    assert!(resp.result.is_some());
488                }
489                _ => panic!("Expected response"),
490            }
491        });
492
493        server_handle.join().unwrap();
494        client_handle.join().unwrap();
495    }
496
497    #[test]
498    fn test_builder_custom_poll_interval() {
499        use std::time::Duration;
500
501        let (client, server) = MemoryTransportBuilder::new()
502            .poll_interval(Duration::from_millis(5))
503            .build();
504
505        assert_eq!(client.poll_interval, Duration::from_millis(5));
506        assert_eq!(server.poll_interval, Duration::from_millis(5));
507    }
508
509    #[test]
510    fn test_is_closed() {
511        let (mut client, server) = create_memory_transport_pair();
512
513        assert!(!client.is_closed());
514        assert!(!server.is_closed());
515
516        client.close().unwrap();
517
518        assert!(client.is_closed());
519        // Server doesn't know yet until recv fails
520        assert!(!server.is_closed());
521    }
522
523    #[test]
524    fn test_with_poll_interval() {
525        use std::time::Duration;
526
527        let (client, _server) = create_memory_transport_pair();
528        let client = client.with_poll_interval(Duration::from_millis(100));
529
530        assert_eq!(client.poll_interval, Duration::from_millis(100));
531    }
532}