memory_socket/lib.rs
1//! Provides an in-memory socket abstraction.
2//!
3//! The `memory-socket` crate provides the [`MemoryListener`] and [`MemorySocket`] types which can
4//! be thought of as in-memory versions of the standard library `TcpListener` and `TcpStream`
5//! types.
6//!
7//! ## Feature flags
8//!
9//! - `async`: Adds async support for [`MemorySocket`] and [`MemoryListener`]
10//!
11//! [`MemoryListener`]: struct.MemoryListener.html
12//! [`MemorySocket`]: struct.MemorySocket.html
13
14use bytes::{buf::BufExt, Buf, Bytes, BytesMut};
15use flume::{Receiver, Sender};
16use once_cell::sync::Lazy;
17use std::{
18 collections::HashMap,
19 io::{ErrorKind, Read, Result, Write},
20 num::NonZeroU16,
21 sync::Mutex,
22};
23
24#[cfg(feature = "async")]
25mod r#async;
26
27#[cfg(feature = "async")]
28pub use r#async::IncomingStream;
29
30static SWITCHBOARD: Lazy<Mutex<SwitchBoard>> =
31 Lazy::new(|| Mutex::new(SwitchBoard(HashMap::default(), 1)));
32
33struct SwitchBoard(HashMap<NonZeroU16, Sender<MemorySocket>>, u16);
34
35/// An in-memory socket server, listening for connections.
36///
37/// After creating a `MemoryListener` by [`bind`]ing it to a socket address, it listens
38/// for incoming connections. These can be accepted by calling [`accept`] or by iterating over
39/// iterating over the [`Incoming`] iterator returned by [`incoming`][`MemoryListener::incoming`].
40///
41/// The socket will be closed when the value is dropped.
42///
43/// [`accept`]: #method.accept
44/// [`bind`]: #method.bind
45/// [`Incoming`]: struct.Incoming.html
46/// [`MemoryListener::incoming`]: #method.incoming
47///
48/// # Examples
49///
50/// ```no_run
51/// use std::io::{Read, Result, Write};
52///
53/// use memory_socket::{MemoryListener, MemorySocket};
54///
55/// fn write_stormlight(mut stream: MemorySocket) -> Result<()> {
56/// let msg = b"The most important step a person can take is always the next one.";
57/// stream.write_all(msg)?;
58/// stream.flush()
59/// }
60///
61/// fn main() -> Result<()> {
62/// let mut listener = MemoryListener::bind(16)?;
63///
64/// // accept connections and process them serially
65/// for stream in listener.incoming() {
66/// write_stormlight(stream?)?;
67/// }
68/// Ok(())
69/// }
70/// ```
71pub struct MemoryListener {
72 incoming: Receiver<MemorySocket>,
73 port: NonZeroU16,
74}
75
76impl Drop for MemoryListener {
77 fn drop(&mut self) {
78 let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
79 // Remove the Sending side of the channel in the switchboard when
80 // MemoryListener is dropped
81 switchboard.0.remove(&self.port);
82 }
83}
84
85impl MemoryListener {
86 /// Creates a new `MemoryListener` which will be bound to the specified
87 /// port.
88 ///
89 /// The returned listener is ready for accepting connections.
90 ///
91 /// Binding with a port number of `0` will request that a port be assigned
92 /// to this listener. The port allocated can be queried via the
93 /// [`local_addr`] method.
94 ///
95 /// [`local_addr`]: #method.local_addr
96 ///
97 /// # Examples
98 ///
99 /// Create a MemoryListener bound to port 16:
100 ///
101 /// ```no_run
102 /// use memory_socket::MemoryListener;
103 ///
104 /// # fn main () -> ::std::io::Result<()> {
105 /// let listener = MemoryListener::bind(16)?;
106 /// # Ok(())}
107 /// ```
108 pub fn bind(port: u16) -> Result<Self> {
109 let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
110
111 // Get the port we should bind to. If 0 was given, use a random port
112 let port = if let Some(port) = NonZeroU16::new(port) {
113 if switchboard.0.contains_key(&port) {
114 return Err(ErrorKind::AddrInUse.into());
115 }
116
117 port
118 } else {
119 loop {
120 let port = NonZeroU16::new(switchboard.1).unwrap_or_else(|| unreachable!());
121
122 // The switchboard is full and all ports are in use
123 if switchboard.0.len() == (std::u16::MAX - 1) as usize {
124 return Err(ErrorKind::AddrInUse.into());
125 }
126
127 // Instead of overflowing to 0, resume searching at port 1 since port 0 isn't a
128 // valid port to bind to.
129 if switchboard.1 == std::u16::MAX {
130 switchboard.1 = 1;
131 } else {
132 switchboard.1 += 1;
133 }
134
135 if !switchboard.0.contains_key(&port) {
136 break port;
137 }
138 }
139 };
140
141 let (sender, receiver) = flume::unbounded();
142 switchboard.0.insert(port, sender);
143
144 Ok(Self {
145 incoming: receiver,
146 port,
147 })
148 }
149
150 /// Returns the local address that this listener is bound to.
151 ///
152 /// This can be useful, for example, when binding to port `0` to figure out
153 /// which port was actually bound.
154 ///
155 /// # Examples
156 ///
157 /// ```
158 /// use memory_socket::MemoryListener;
159 ///
160 /// # fn main () -> ::std::io::Result<()> {
161 /// let listener = MemoryListener::bind(16)?;
162 ///
163 /// assert_eq!(listener.local_addr(), 16);
164 /// # Ok(())}
165 /// ```
166 pub fn local_addr(&self) -> u16 {
167 self.port.get()
168 }
169
170 /// Returns an iterator over the connections being received on this
171 /// listener.
172 ///
173 /// The returned iterator will never return `None`. Iterating over
174 /// it is equivalent to calling [`accept`] in a loop.
175 ///
176 /// [`accept`]: #method.accept
177 ///
178 /// # Examples
179 ///
180 /// ```no_run
181 /// use memory_socket::MemoryListener;
182 /// use std::io::{Read, Write};
183 ///
184 /// let mut listener = MemoryListener::bind(80).unwrap();
185 ///
186 /// for stream in listener.incoming() {
187 /// match stream {
188 /// Ok(stream) => {
189 /// println!("new client!");
190 /// }
191 /// Err(e) => { /* connection failed */ }
192 /// }
193 /// }
194 /// ```
195 pub fn incoming(&self) -> Incoming<'_> {
196 Incoming { inner: self }
197 }
198
199 /// Accept a new incoming connection from this listener.
200 ///
201 /// This function will block the calling thread until a new connection
202 /// is established. When established, the corresponding [`MemorySocket`]
203 /// will be returned.
204 ///
205 /// [`MemorySocket`]: struct.MemorySocket.html
206 ///
207 /// # Examples
208 ///
209 /// ```no_run
210 /// use std::net::TcpListener;
211 /// use memory_socket::MemoryListener;
212 ///
213 /// let mut listener = MemoryListener::bind(8080).unwrap();
214 /// match listener.accept() {
215 /// Ok(_socket) => println!("new client!"),
216 /// Err(e) => println!("couldn't get client: {:?}", e),
217 /// }
218 /// ```
219 pub fn accept(&self) -> Result<MemorySocket> {
220 self.incoming.iter().next().ok_or_else(|| unreachable!())
221 }
222}
223
224/// An iterator that infinitely [`accept`]s connections on a [`MemoryListener`].
225///
226/// This `struct` is created by the [`incoming`] method on [`MemoryListener`].
227/// See its documentation for more info.
228///
229/// [`accept`]: struct.MemoryListener.html#method.accept
230/// [`incoming`]: struct.MemoryListener.html#method.incoming
231/// [`MemoryListener`]: struct.MemoryListener.html
232pub struct Incoming<'a> {
233 inner: &'a MemoryListener,
234}
235
236impl<'a> Iterator for Incoming<'a> {
237 type Item = Result<MemorySocket>;
238
239 fn next(&mut self) -> Option<Self::Item> {
240 Some(self.inner.accept())
241 }
242}
243
244/// An in-memory stream between two local sockets.
245///
246/// A `MemorySocket` can either be created by connecting to an endpoint, via the
247/// [`connect`] method, or by [accepting] a connection from a [listener].
248/// It can be read or written to using the `Read` and `Write` traits.
249///
250/// # Examples
251///
252/// ```
253/// use std::io::{Read, Result, Write};
254/// use memory_socket::MemorySocket;
255///
256/// # fn main() -> Result<()> {
257/// let (mut socket_a, mut socket_b) = MemorySocket::new_pair();
258///
259/// socket_a.write_all(b"stormlight")?;
260/// socket_a.flush()?;
261///
262/// let mut buf = [0; 10];
263/// socket_b.read_exact(&mut buf)?;
264/// assert_eq!(&buf, b"stormlight");
265///
266/// # Ok(())}
267/// ```
268///
269/// [`connect`]: struct.MemorySocket.html#method.connect
270/// [accepting]: struct.MemoryListener.html#method.accept
271/// [listener]: struct.MemoryListener.html
272pub struct MemorySocket {
273 incoming: Receiver<Bytes>,
274 outgoing: Sender<Bytes>,
275 write_buffer: BytesMut,
276 current_buffer: Option<Bytes>,
277 seen_eof: bool,
278}
279
280impl MemorySocket {
281 fn new(incoming: Receiver<Bytes>, outgoing: Sender<Bytes>) -> Self {
282 Self {
283 incoming,
284 outgoing,
285 write_buffer: BytesMut::new(),
286 current_buffer: None,
287 seen_eof: false,
288 }
289 }
290
291 /// Construct both sides of an in-memory socket.
292 ///
293 /// # Examples
294 ///
295 /// ```
296 /// use memory_socket::MemorySocket;
297 ///
298 /// let (socket_a, socket_b) = MemorySocket::new_pair();
299 /// ```
300 pub fn new_pair() -> (Self, Self) {
301 let (a_tx, a_rx) = flume::unbounded();
302 let (b_tx, b_rx) = flume::unbounded();
303 let a = Self::new(a_rx, b_tx);
304 let b = Self::new(b_rx, a_tx);
305
306 (a, b)
307 }
308
309 /// Create a new in-memory Socket connected to the specified port.
310 ///
311 /// This function will create a new MemorySocket socket and attempt to connect it to
312 /// the `port` provided.
313 ///
314 /// # Examples
315 ///
316 /// ```
317 /// use memory_socket::MemorySocket;
318 ///
319 /// # fn main () -> ::std::io::Result<()> {
320 /// # let _listener = memory_socket::MemoryListener::bind(16)?;
321 /// let socket = MemorySocket::connect(16)?;
322 /// # Ok(())}
323 /// ```
324 pub fn connect(port: u16) -> Result<MemorySocket> {
325 let mut switchboard = (&*SWITCHBOARD).lock().unwrap();
326
327 // Find port to connect to
328 let port = NonZeroU16::new(port).ok_or_else(|| ErrorKind::AddrNotAvailable)?;
329
330 let sender = switchboard
331 .0
332 .get_mut(&port)
333 .ok_or_else(|| ErrorKind::AddrNotAvailable)?;
334
335 let (socket_a, socket_b) = Self::new_pair();
336
337 // Send the socket to the listener
338 sender
339 .send(socket_a)
340 .map_err(|_| ErrorKind::AddrNotAvailable)?;
341
342 Ok(socket_b)
343 }
344}
345
346impl Read for MemorySocket {
347 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
348 let mut bytes_read = 0;
349
350 loop {
351 // If we've already filled up the buffer then we can return
352 if bytes_read == buf.len() {
353 return Ok(bytes_read);
354 }
355
356 match self.current_buffer {
357 // We still have data to copy to `buf`
358 Some(ref mut current_buffer) if current_buffer.has_remaining() => {
359 let bytes_to_read =
360 ::std::cmp::min(buf.len() - bytes_read, current_buffer.remaining());
361 debug_assert!(bytes_to_read > 0);
362
363 current_buffer
364 .take(bytes_to_read)
365 .copy_to_slice(&mut buf[bytes_read..(bytes_read + bytes_to_read)]);
366 bytes_read += bytes_to_read;
367 }
368
369 // Either we've exhausted our current buffer or we don't have one
370 _ => {
371 // If we've read anything up to this point return the bytes read
372 if bytes_read > 0 {
373 return Ok(bytes_read);
374 }
375
376 self.current_buffer = match self.incoming.recv() {
377 Ok(buf) => Some(buf),
378
379 // The remote side hung up, if this is the first time we've seen EOF then
380 // we should return `Ok(0)` otherwise an UnexpectedEof Error
381 Err(_) => {
382 if self.seen_eof {
383 return Err(ErrorKind::UnexpectedEof.into());
384 } else {
385 self.seen_eof = true;
386 return Ok(0);
387 }
388 }
389 }
390 }
391 }
392 }
393 }
394}
395
396impl Write for MemorySocket {
397 fn write(&mut self, buf: &[u8]) -> Result<usize> {
398 self.write_buffer.extend_from_slice(buf);
399 Ok(buf.len())
400 }
401
402 fn flush(&mut self) -> Result<()> {
403 if !self.write_buffer.is_empty() {
404 self.outgoing
405 .send(self.write_buffer.split().freeze())
406 .map_err(|_| ErrorKind::BrokenPipe.into())
407 } else {
408 Ok(())
409 }
410 }
411}