memory_socket/
lib.rs

1//! Provides an in-memory socket abstraction.
2//!
3//! The `memory-socket` crate provides the [`MemoryListener`] and [`MemorySocket`] types which can
4//! be thought of as in-memory versions of the standard library `TcpListener` and `TcpStream`
5//! types.
6//!
7//! ## Feature flags
8//!
9//! - `async`: Adds async support for [`MemorySocket`] and [`MemoryListener`]
10//!
11//! [`MemoryListener`]: struct.MemoryListener.html
12//! [`MemorySocket`]: struct.MemorySocket.html
13
14use bytes::{buf::BufExt, Buf, Bytes, BytesMut};
15use flume::{Receiver, Sender};
16use once_cell::sync::Lazy;
17use std::{
18    collections::HashMap,
19    io::{ErrorKind, Read, Result, Write},
20    num::NonZeroU16,
21    sync::Mutex,
22};
23
24#[cfg(feature = "async")]
25mod r#async;
26
27#[cfg(feature = "async")]
28pub use r#async::IncomingStream;
29
30static SWITCHBOARD: Lazy<Mutex<SwitchBoard>> =
31    Lazy::new(|| Mutex::new(SwitchBoard(HashMap::default(), 1)));
32
33struct SwitchBoard(HashMap<NonZeroU16, Sender<MemorySocket>>, u16);
34
35/// An in-memory socket server, listening for connections.
36///
37/// After creating a `MemoryListener` by [`bind`]ing it to a socket address, it listens
38/// for incoming connections. These can be accepted by calling [`accept`] or by iterating over
39/// iterating over the [`Incoming`] iterator returned by [`incoming`][`MemoryListener::incoming`].
40///
41/// The socket will be closed when the value is dropped.
42///
43/// [`accept`]: #method.accept
44/// [`bind`]: #method.bind
45/// [`Incoming`]: struct.Incoming.html
46/// [`MemoryListener::incoming`]: #method.incoming
47///
48/// # Examples
49///
50/// ```no_run
51/// use std::io::{Read, Result, Write};
52///
53/// use memory_socket::{MemoryListener, MemorySocket};
54///
55/// fn write_stormlight(mut stream: MemorySocket) -> Result<()> {
56///     let msg = b"The most important step a person can take is always the next one.";
57///     stream.write_all(msg)?;
58///     stream.flush()
59/// }
60///
61/// fn main() -> Result<()> {
62///     let mut listener = MemoryListener::bind(16)?;
63///
64///     // accept connections and process them serially
65///     for stream in listener.incoming() {
66///         write_stormlight(stream?)?;
67///     }
68///     Ok(())
69/// }
70/// ```
71pub struct MemoryListener {
72    incoming: Receiver<MemorySocket>,
73    port: NonZeroU16,
74}
75
76impl Drop for MemoryListener {
77    fn drop(&mut self) {
78        let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
79        // Remove the Sending side of the channel in the switchboard when
80        // MemoryListener is dropped
81        switchboard.0.remove(&self.port);
82    }
83}
84
85impl MemoryListener {
86    /// Creates a new `MemoryListener` which will be bound to the specified
87    /// port.
88    ///
89    /// The returned listener is ready for accepting connections.
90    ///
91    /// Binding with a port number of `0` will request that a port be assigned
92    /// to this listener. The port allocated can be queried via the
93    /// [`local_addr`] method.
94    ///
95    /// [`local_addr`]: #method.local_addr
96    ///
97    /// # Examples
98    ///
99    /// Create a MemoryListener bound to port 16:
100    ///
101    /// ```no_run
102    /// use memory_socket::MemoryListener;
103    ///
104    /// # fn main () -> ::std::io::Result<()> {
105    /// let listener = MemoryListener::bind(16)?;
106    /// # Ok(())}
107    /// ```
108    pub fn bind(port: u16) -> Result<Self> {
109        let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
110
111        // Get the port we should bind to.  If 0 was given, use a random port
112        let port = if let Some(port) = NonZeroU16::new(port) {
113            if switchboard.0.contains_key(&port) {
114                return Err(ErrorKind::AddrInUse.into());
115            }
116
117            port
118        } else {
119            loop {
120                let port = NonZeroU16::new(switchboard.1).unwrap_or_else(|| unreachable!());
121
122                // The switchboard is full and all ports are in use
123                if switchboard.0.len() == (std::u16::MAX - 1) as usize {
124                    return Err(ErrorKind::AddrInUse.into());
125                }
126
127                // Instead of overflowing to 0, resume searching at port 1 since port 0 isn't a
128                // valid port to bind to.
129                if switchboard.1 == std::u16::MAX {
130                    switchboard.1 = 1;
131                } else {
132                    switchboard.1 += 1;
133                }
134
135                if !switchboard.0.contains_key(&port) {
136                    break port;
137                }
138            }
139        };
140
141        let (sender, receiver) = flume::unbounded();
142        switchboard.0.insert(port, sender);
143
144        Ok(Self {
145            incoming: receiver,
146            port,
147        })
148    }
149
150    /// Returns the local address that this listener is bound to.
151    ///
152    /// This can be useful, for example, when binding to port `0` to figure out
153    /// which port was actually bound.
154    ///
155    /// # Examples
156    ///
157    /// ```
158    /// use memory_socket::MemoryListener;
159    ///
160    /// # fn main () -> ::std::io::Result<()> {
161    /// let listener = MemoryListener::bind(16)?;
162    ///
163    /// assert_eq!(listener.local_addr(), 16);
164    /// # Ok(())}
165    /// ```
166    pub fn local_addr(&self) -> u16 {
167        self.port.get()
168    }
169
170    /// Returns an iterator over the connections being received on this
171    /// listener.
172    ///
173    /// The returned iterator will never return `None`. Iterating over
174    /// it is equivalent to calling [`accept`] in a loop.
175    ///
176    /// [`accept`]: #method.accept
177    ///
178    /// # Examples
179    ///
180    /// ```no_run
181    /// use memory_socket::MemoryListener;
182    /// use std::io::{Read, Write};
183    ///
184    /// let mut listener = MemoryListener::bind(80).unwrap();
185    ///
186    /// for stream in listener.incoming() {
187    ///     match stream {
188    ///         Ok(stream) => {
189    ///             println!("new client!");
190    ///         }
191    ///         Err(e) => { /* connection failed */ }
192    ///     }
193    /// }
194    /// ```
195    pub fn incoming(&self) -> Incoming<'_> {
196        Incoming { inner: self }
197    }
198
199    /// Accept a new incoming connection from this listener.
200    ///
201    /// This function will block the calling thread until a new connection
202    /// is established. When established, the corresponding [`MemorySocket`]
203    /// will be returned.
204    ///
205    /// [`MemorySocket`]: struct.MemorySocket.html
206    ///
207    /// # Examples
208    ///
209    /// ```no_run
210    /// use std::net::TcpListener;
211    /// use memory_socket::MemoryListener;
212    ///
213    /// let mut listener = MemoryListener::bind(8080).unwrap();
214    /// match listener.accept() {
215    ///     Ok(_socket) => println!("new client!"),
216    ///     Err(e) => println!("couldn't get client: {:?}", e),
217    /// }
218    /// ```
219    pub fn accept(&self) -> Result<MemorySocket> {
220        self.incoming.iter().next().ok_or_else(|| unreachable!())
221    }
222}
223
224/// An iterator that infinitely [`accept`]s connections on a [`MemoryListener`].
225///
226/// This `struct` is created by the [`incoming`] method on [`MemoryListener`].
227/// See its documentation for more info.
228///
229/// [`accept`]: struct.MemoryListener.html#method.accept
230/// [`incoming`]: struct.MemoryListener.html#method.incoming
231/// [`MemoryListener`]: struct.MemoryListener.html
232pub struct Incoming<'a> {
233    inner: &'a MemoryListener,
234}
235
236impl<'a> Iterator for Incoming<'a> {
237    type Item = Result<MemorySocket>;
238
239    fn next(&mut self) -> Option<Self::Item> {
240        Some(self.inner.accept())
241    }
242}
243
244/// An in-memory stream between two local sockets.
245///
246/// A `MemorySocket` can either be created by connecting to an endpoint, via the
247/// [`connect`] method, or by [accepting] a connection from a [listener].
248/// It can be read or written to using the `Read` and `Write` traits.
249///
250/// # Examples
251///
252/// ```
253/// use std::io::{Read, Result, Write};
254/// use memory_socket::MemorySocket;
255///
256/// # fn main() -> Result<()> {
257/// let (mut socket_a, mut socket_b) = MemorySocket::new_pair();
258///
259/// socket_a.write_all(b"stormlight")?;
260/// socket_a.flush()?;
261///
262/// let mut buf = [0; 10];
263/// socket_b.read_exact(&mut buf)?;
264/// assert_eq!(&buf, b"stormlight");
265///
266/// # Ok(())}
267/// ```
268///
269/// [`connect`]: struct.MemorySocket.html#method.connect
270/// [accepting]: struct.MemoryListener.html#method.accept
271/// [listener]: struct.MemoryListener.html
272pub struct MemorySocket {
273    incoming: Receiver<Bytes>,
274    outgoing: Sender<Bytes>,
275    write_buffer: BytesMut,
276    current_buffer: Option<Bytes>,
277    seen_eof: bool,
278}
279
280impl MemorySocket {
281    fn new(incoming: Receiver<Bytes>, outgoing: Sender<Bytes>) -> Self {
282        Self {
283            incoming,
284            outgoing,
285            write_buffer: BytesMut::new(),
286            current_buffer: None,
287            seen_eof: false,
288        }
289    }
290
291    /// Construct both sides of an in-memory socket.
292    ///
293    /// # Examples
294    ///
295    /// ```
296    /// use memory_socket::MemorySocket;
297    ///
298    /// let (socket_a, socket_b) = MemorySocket::new_pair();
299    /// ```
300    pub fn new_pair() -> (Self, Self) {
301        let (a_tx, a_rx) = flume::unbounded();
302        let (b_tx, b_rx) = flume::unbounded();
303        let a = Self::new(a_rx, b_tx);
304        let b = Self::new(b_rx, a_tx);
305
306        (a, b)
307    }
308
309    /// Create a new in-memory Socket connected to the specified port.
310    ///
311    /// This function will create a new MemorySocket socket and attempt to connect it to
312    /// the `port` provided.
313    ///
314    /// # Examples
315    ///
316    /// ```
317    /// use memory_socket::MemorySocket;
318    ///
319    /// # fn main () -> ::std::io::Result<()> {
320    /// # let _listener = memory_socket::MemoryListener::bind(16)?;
321    /// let socket = MemorySocket::connect(16)?;
322    /// # Ok(())}
323    /// ```
324    pub fn connect(port: u16) -> Result<MemorySocket> {
325        let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
326
327        // Find port to connect to
328        let port = NonZeroU16::new(port).ok_or_else(|| ErrorKind::AddrNotAvailable)?;
329
330        let sender = switchboard
331            .0
332            .get_mut(&port)
333            .ok_or_else(|| ErrorKind::AddrNotAvailable)?;
334
335        let (socket_a, socket_b) = Self::new_pair();
336
337        // Send the socket to the listener
338        sender
339            .send(socket_a)
340            .map_err(|_| ErrorKind::AddrNotAvailable)?;
341
342        Ok(socket_b)
343    }
344}
345
346impl Read for MemorySocket {
347    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
348        let mut bytes_read = 0;
349
350        loop {
351            // If we've already filled up the buffer then we can return
352            if bytes_read == buf.len() {
353                return Ok(bytes_read);
354            }
355
356            match self.current_buffer {
357                // We still have data to copy to `buf`
358                Some(ref mut current_buffer) if current_buffer.has_remaining() => {
359                    let bytes_to_read =
360                        ::std::cmp::min(buf.len() - bytes_read, current_buffer.remaining());
361                    debug_assert!(bytes_to_read > 0);
362
363                    current_buffer
364                        .take(bytes_to_read)
365                        .copy_to_slice(&mut buf[bytes_read..(bytes_read + bytes_to_read)]);
366                    bytes_read += bytes_to_read;
367                }
368
369                // Either we've exhausted our current buffer or we don't have one
370                _ => {
371                    // If we've read anything up to this point return the bytes read
372                    if bytes_read > 0 {
373                        return Ok(bytes_read);
374                    }
375
376                    self.current_buffer = match self.incoming.recv() {
377                        Ok(buf) => Some(buf),
378
379                        // The remote side hung up, if this is the first time we've seen EOF then
380                        // we should return `Ok(0)` otherwise an UnexpectedEof Error
381                        Err(_) => {
382                            if self.seen_eof {
383                                return Err(ErrorKind::UnexpectedEof.into());
384                            } else {
385                                self.seen_eof = true;
386                                return Ok(0);
387                            }
388                        }
389                    }
390                }
391            }
392        }
393    }
394}
395
396impl Write for MemorySocket {
397    fn write(&mut self, buf: &[u8]) -> Result<usize> {
398        self.write_buffer.extend_from_slice(buf);
399        Ok(buf.len())
400    }
401
402    fn flush(&mut self) -> Result<()> {
403        if !self.write_buffer.is_empty() {
404            self.outgoing
405                .send(self.write_buffer.split().freeze())
406                .map_err(|_| ErrorKind::BrokenPipe.into())
407        } else {
408            Ok(())
409        }
410    }
411}