libp2p_core/transport/
memory.rs

1// Copyright 2018 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use crate::{Transport, transport::{TransportError, ListenerEvent}};
22use fnv::FnvHashMap;
23use futures::{future::{self, Ready}, prelude::*, channel::mpsc, task::Context, task::Poll};
24use lazy_static::lazy_static;
25use multiaddr::{Protocol, Multiaddr};
26use parking_lot::Mutex;
27use rw_stream_sink::RwStreamSink;
28use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin};
29
30lazy_static! {
31    static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default()));
32}
33
34struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
35
36/// A [`mpsc::Sender`] enabling a [`DialFuture`] to send a [`Channel`] and the
37/// port of the dialer to a [`Listener`].
38type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
39
40/// A [`mpsc::Receiver`] enabling a [`Listener`] to receive a [`Channel`] and
41/// the port of the dialer from a [`DialFuture`].
42type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
43
44impl Hub {
45    /// Registers the given port on the hub.
46    ///
47    /// Randomizes port when given port is `0`. Returns [`None`] when given port
48    /// is already occupied.
49    fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
50        let mut hub = self.0.lock();
51
52        let port = if let Some(port) = NonZeroU64::new(port) {
53            port
54        } else {
55            loop {
56                let port = match NonZeroU64::new(rand::random()) {
57                    Some(p) => p,
58                    None => continue,
59                };
60                if !hub.contains_key(&port) {
61                    break port;
62                }
63            }
64        };
65
66        let (tx, rx) = mpsc::channel(2);
67        match hub.entry(port) {
68            Entry::Occupied(_) => return None,
69            Entry::Vacant(e) => e.insert(tx)
70        };
71
72        Some((rx, port))
73    }
74
75    fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
76        self.0.lock().remove(port)
77    }
78
79    fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
80        self.0.lock().get(port).cloned()
81    }
82}
83
84/// Transport that supports `/memory/N` multiaddresses.
85#[derive(Debug, Copy, Clone, Default)]
86pub struct MemoryTransport;
87
88/// Connection to a `MemoryTransport` currently being opened.
89pub struct DialFuture {
90    /// Ephemeral source port.
91    ///
92    /// These ports mimic TCP ephemeral source ports but are not actually used
93    /// by the memory transport due to the direct use of channels. They merely
94    /// ensure that every connection has a unique address for each dialer, which
95    /// is not at the same time a listen address (analogous to TCP).
96    dial_port: NonZeroU64,
97    sender: ChannelSender,
98    channel_to_send: Option<Channel<Vec<u8>>>,
99    channel_to_return: Option<Channel<Vec<u8>>>,
100}
101
102impl DialFuture {
103    fn new(port: NonZeroU64) -> Option<Self> {
104        let sender = HUB.get(&port)?;
105
106        let (_dial_port_channel, dial_port) = HUB.register_port(0)
107            .expect("there to be some random unoccupied port.");
108
109        let (a_tx, a_rx) = mpsc::channel(4096);
110        let (b_tx, b_rx) = mpsc::channel(4096);
111        Some(DialFuture {
112            dial_port,
113            sender,
114            channel_to_send: Some(RwStreamSink::new(Chan {
115                incoming: a_rx,
116                outgoing: b_tx,
117                dial_port: None,
118            })),
119            channel_to_return: Some(RwStreamSink::new(Chan {
120                incoming: b_rx,
121                outgoing: a_tx,
122                dial_port: Some(dial_port),
123            })),
124        })
125    }
126}
127
128impl Future for DialFuture {
129    type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
130
131    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132
133        match self.sender.poll_ready(cx) {
134            Poll::Pending => return Poll::Pending,
135            Poll::Ready(Ok(())) => {},
136            Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
137        }
138
139        let channel_to_send = self.channel_to_send.take()
140            .expect("Future should not be polled again once complete");
141        let dial_port = self.dial_port;
142        match self.sender.start_send((channel_to_send, dial_port)) {
143            Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
144            Ok(()) => {}
145        }
146
147        Poll::Ready(Ok(self.channel_to_return.take()
148                .expect("Future should not be polled again once complete")))
149    }
150}
151
152impl Transport for MemoryTransport {
153    type Output = Channel<Vec<u8>>;
154    type Error = MemoryTransportError;
155    type Listener = Listener;
156    type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
157    type Dial = DialFuture;
158
159    fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
160        let port = if let Ok(port) = parse_memory_addr(&addr) {
161            port
162        } else {
163            return Err(TransportError::MultiaddrNotSupported(addr));
164        };
165
166        let (rx, port) = match HUB.register_port(port) {
167            Some((rx, port)) => (rx, port),
168            None => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
169        };
170
171        let listener = Listener {
172            port,
173            addr: Protocol::Memory(port.get()).into(),
174            receiver: rx,
175            tell_listen_addr: true
176        };
177
178        Ok(listener)
179    }
180
181    fn dial(self, addr: Multiaddr) -> Result<DialFuture, TransportError<Self::Error>> {
182        let port = if let Ok(port) = parse_memory_addr(&addr) {
183            if let Some(port) = NonZeroU64::new(port) {
184                port
185            } else {
186                return Err(TransportError::Other(MemoryTransportError::Unreachable));
187            }
188        } else {
189            return Err(TransportError::MultiaddrNotSupported(addr));
190        };
191
192        DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
193    }
194
195    fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option<Multiaddr> {
196        None
197    }
198}
199
200/// Error that can be produced from the `MemoryTransport`.
201#[derive(Debug, Copy, Clone)]
202pub enum MemoryTransportError {
203    /// There's no listener on the given port.
204    Unreachable,
205    /// Tries to listen on a port that is already in use.
206    AlreadyInUse,
207}
208
209impl fmt::Display for MemoryTransportError {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        match *self {
212            MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
213            MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
214        }
215    }
216}
217
218impl error::Error for MemoryTransportError {}
219
220/// Listener for memory connections.
221pub struct Listener {
222    /// Port we're listening on.
223    port: NonZeroU64,
224    /// The address we are listening on.
225    addr: Multiaddr,
226    /// Receives incoming connections.
227    receiver: ChannelReceiver,
228    /// Generate `ListenerEvent::NewAddress` to inform about our listen address.
229    tell_listen_addr: bool
230}
231
232impl Stream for Listener {
233    type Item = Result<ListenerEvent<Ready<Result<Channel<Vec<u8>>, MemoryTransportError>>, MemoryTransportError>, MemoryTransportError>;
234
235    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
236        if self.tell_listen_addr {
237            self.tell_listen_addr = false;
238            return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone()))))
239        }
240
241        let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
242            Poll::Pending => return Poll::Pending,
243            Poll::Ready(None) => panic!("Alive listeners always have a sender."),
244            Poll::Ready(Some(v)) => v,
245        };
246
247        let event = ListenerEvent::Upgrade {
248            upgrade: future::ready(Ok(channel)),
249            local_addr: self.addr.clone(),
250            remote_addr: Protocol::Memory(dial_port.get()).into()
251        };
252
253        Poll::Ready(Some(Ok(event)))
254    }
255}
256
257impl Drop for Listener {
258    fn drop(&mut self) {
259        let val_in = HUB.unregister_port(&self.port);
260        debug_assert!(val_in.is_some());
261    }
262}
263
264/// If the address is `/memory/n`, returns the value of `n`.
265fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
266    let mut iter = a.iter();
267
268    let port = if let Some(Protocol::Memory(port)) = iter.next() {
269        port
270    } else {
271        return Err(());
272    };
273
274    if iter.next().is_some() {
275        return Err(());
276    }
277
278    Ok(port)
279}
280
281/// A channel represents an established, in-memory, logical connection between two endpoints.
282///
283/// Implements `AsyncRead` and `AsyncWrite`.
284pub type Channel<T> = RwStreamSink<Chan<T>>;
285
286/// A channel represents an established, in-memory, logical connection between two endpoints.
287///
288/// Implements `Sink` and `Stream`.
289pub struct Chan<T = Vec<u8>> {
290    incoming: mpsc::Receiver<T>,
291    outgoing: mpsc::Sender<T>,
292
293    // Needed in [`Drop`] implementation of [`Chan`] to unregister the dialing
294    // port with the global [`HUB`]. Is [`Some`] when [`Chan`] of dialer and
295    // [`None`] when [`Chan`] of listener.
296    //
297    // Note: Listening port is unregistered in [`Drop`] implementation of
298    // [`Listener`].
299    dial_port: Option<NonZeroU64>,
300}
301
302impl<T> Unpin for Chan<T> {
303}
304
305impl<T> Stream for Chan<T> {
306    type Item = Result<T, io::Error>;
307
308    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309        match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
310            Poll::Pending => Poll::Pending,
311            Poll::Ready(None) => Poll::Ready(None),
312            Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
313        }
314    }
315}
316
317impl<T> Sink<T> for Chan<T> {
318    type Error = io::Error;
319
320    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
321        self.outgoing.poll_ready(cx)
322            .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
323    }
324
325    fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
326        self.outgoing.start_send(item).map_err(|_| io::ErrorKind::BrokenPipe.into())
327    }
328
329    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330        Poll::Ready(Ok(()))
331    }
332
333    fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334        Poll::Ready(Ok(()))
335    }
336}
337
338impl<T: AsRef<[u8]>> Into<RwStreamSink<Chan<T>>> for Chan<T> {
339    fn into(self) -> RwStreamSink<Chan<T>> {
340        RwStreamSink::new(self)
341    }
342}
343
344impl<T> Drop for Chan<T> {
345    fn drop(&mut self) {
346        if let Some(port) = self.dial_port {
347            let channel_sender = HUB.unregister_port(&port);
348            debug_assert!(channel_sender.is_some());
349        }
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn parse_memory_addr_works() {
359        assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
360        assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
361        assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
362        assert_eq!(parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()), Err(()));
363        assert_eq!(parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()), Err(()));
364        assert_eq!(parse_memory_addr(&"/memory/1234567890".parse().unwrap()), Ok(1_234_567_890));
365    }
366
367    #[test]
368    fn listening_twice() {
369        let transport = MemoryTransport::default();
370        assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok());
371        assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok());
372        let _listener = transport.listen_on("/memory/1639174018481".parse().unwrap()).unwrap();
373        assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_err());
374        assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_err());
375        drop(_listener);
376        assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok());
377        assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok());
378    }
379
380    #[test]
381    fn port_not_in_use() {
382        let transport = MemoryTransport::default();
383        assert!(transport.dial("/memory/810172461024613".parse().unwrap()).is_err());
384        let _listener = transport.listen_on("/memory/810172461024613".parse().unwrap()).unwrap();
385        assert!(transport.dial("/memory/810172461024613".parse().unwrap()).is_ok());
386    }
387
388    #[test]
389    fn communicating_between_dialer_and_listener() {
390        let msg = [1, 2, 3];
391
392        // Setup listener.
393
394        let rand_port = rand::random::<u64>().saturating_add(1);
395        let t1_addr: Multiaddr = format!("/memory/{}", rand_port).parse().unwrap();
396        let cloned_t1_addr = t1_addr.clone();
397
398        let t1 = MemoryTransport::default();
399
400        let listener = async move {
401            let listener = t1.listen_on(t1_addr.clone()).unwrap();
402
403            let upgrade = listener.filter_map(|ev| futures::future::ready(
404                ListenerEvent::into_upgrade(ev.unwrap())
405            )).next().await.unwrap();
406
407            let mut socket = upgrade.0.await.unwrap();
408
409            let mut buf = [0; 3];
410            socket.read_exact(&mut buf).await.unwrap();
411
412            assert_eq!(buf, msg);
413        };
414
415        // Setup dialer.
416
417        let t2 = MemoryTransport::default();
418        let dialer = async move {
419            let mut socket = t2.dial(cloned_t1_addr).unwrap().await.unwrap();
420            socket.write_all(&msg).await.unwrap();
421        };
422
423        // Wait for both to finish.
424
425        futures::executor::block_on(futures::future::join(listener, dialer));
426    }
427
428    #[test]
429    fn dialer_address_unequal_to_listener_address() {
430        let listener_addr: Multiaddr = Protocol::Memory(
431            rand::random::<u64>().saturating_add(1),
432        ).into();
433        let listener_addr_cloned = listener_addr.clone();
434
435        let listener_transport = MemoryTransport::default();
436
437        let listener = async move {
438            let mut listener = listener_transport.listen_on(listener_addr.clone())
439                .unwrap();
440            while let Some(ev) = listener.next().await {
441                if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
442                    assert!(
443                        remote_addr != listener_addr,
444                        "Expect dialer address not to equal listener address."
445                    );
446                    return;
447                }
448            }
449        };
450
451        let dialer = async move {
452            MemoryTransport::default().dial(listener_addr_cloned)
453                .unwrap()
454                .await
455                .unwrap();
456        };
457
458        futures::executor::block_on(futures::future::join(listener, dialer));
459    }
460
461    #[test]
462    fn dialer_port_is_deregistered() {
463        let (terminate, should_terminate) = futures::channel::oneshot::channel();
464        let (terminated, is_terminated) = futures::channel::oneshot::channel();
465
466        let listener_addr: Multiaddr = Protocol::Memory(
467            rand::random::<u64>().saturating_add(1),
468        ).into();
469        let listener_addr_cloned = listener_addr.clone();
470
471        let listener_transport = MemoryTransport::default();
472
473        let listener = async move {
474            let mut listener = listener_transport.listen_on(listener_addr.clone())
475                .unwrap();
476            while let Some(ev) = listener.next().await {
477                if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
478                    let dialer_port = NonZeroU64::new(
479                        parse_memory_addr(&remote_addr).unwrap(),
480                    ).unwrap();
481
482                    assert!(
483                        HUB.get(&dialer_port).is_some(),
484                        "Expect dialer port to stay registered while connection is in use.",
485                    );
486
487                    terminate.send(()).unwrap();
488                    is_terminated.await.unwrap();
489
490                    assert!(
491                        HUB.get(&dialer_port).is_none(),
492                        "Expect dialer port to be deregistered once connection is dropped.",
493                    );
494
495                    return;
496                }
497            }
498        };
499
500        let dialer = async move {
501            let _chan = MemoryTransport::default().dial(listener_addr_cloned)
502                .unwrap()
503                .await
504                .unwrap();
505
506            should_terminate.await.unwrap();
507            drop(_chan);
508            terminated.send(()).unwrap();
509        };
510
511        futures::executor::block_on(futures::future::join(listener, dialer));
512    }
513}