1use crate::{Transport, transport::{TransportError, ListenerEvent}};
22use fnv::FnvHashMap;
23use futures::{future::{self, Ready}, prelude::*, channel::mpsc, task::Context, task::Poll};
24use lazy_static::lazy_static;
25use multiaddr::{Protocol, Multiaddr};
26use parking_lot::Mutex;
27use rw_stream_sink::RwStreamSink;
28use std::{collections::hash_map::Entry, error, fmt, io, num::NonZeroU64, pin::Pin};
29
30lazy_static! {
31 static ref HUB: Hub = Hub(Mutex::new(FnvHashMap::default()));
32}
33
34struct Hub(Mutex<FnvHashMap<NonZeroU64, ChannelSender>>);
35
36type ChannelSender = mpsc::Sender<(Channel<Vec<u8>>, NonZeroU64)>;
39
40type ChannelReceiver = mpsc::Receiver<(Channel<Vec<u8>>, NonZeroU64)>;
43
44impl Hub {
45 fn register_port(&self, port: u64) -> Option<(ChannelReceiver, NonZeroU64)> {
50 let mut hub = self.0.lock();
51
52 let port = if let Some(port) = NonZeroU64::new(port) {
53 port
54 } else {
55 loop {
56 let port = match NonZeroU64::new(rand::random()) {
57 Some(p) => p,
58 None => continue,
59 };
60 if !hub.contains_key(&port) {
61 break port;
62 }
63 }
64 };
65
66 let (tx, rx) = mpsc::channel(2);
67 match hub.entry(port) {
68 Entry::Occupied(_) => return None,
69 Entry::Vacant(e) => e.insert(tx)
70 };
71
72 Some((rx, port))
73 }
74
75 fn unregister_port(&self, port: &NonZeroU64) -> Option<ChannelSender> {
76 self.0.lock().remove(port)
77 }
78
79 fn get(&self, port: &NonZeroU64) -> Option<ChannelSender> {
80 self.0.lock().get(port).cloned()
81 }
82}
83
84#[derive(Debug, Copy, Clone, Default)]
86pub struct MemoryTransport;
87
88pub struct DialFuture {
90 dial_port: NonZeroU64,
97 sender: ChannelSender,
98 channel_to_send: Option<Channel<Vec<u8>>>,
99 channel_to_return: Option<Channel<Vec<u8>>>,
100}
101
102impl DialFuture {
103 fn new(port: NonZeroU64) -> Option<Self> {
104 let sender = HUB.get(&port)?;
105
106 let (_dial_port_channel, dial_port) = HUB.register_port(0)
107 .expect("there to be some random unoccupied port.");
108
109 let (a_tx, a_rx) = mpsc::channel(4096);
110 let (b_tx, b_rx) = mpsc::channel(4096);
111 Some(DialFuture {
112 dial_port,
113 sender,
114 channel_to_send: Some(RwStreamSink::new(Chan {
115 incoming: a_rx,
116 outgoing: b_tx,
117 dial_port: None,
118 })),
119 channel_to_return: Some(RwStreamSink::new(Chan {
120 incoming: b_rx,
121 outgoing: a_tx,
122 dial_port: Some(dial_port),
123 })),
124 })
125 }
126}
127
128impl Future for DialFuture {
129 type Output = Result<Channel<Vec<u8>>, MemoryTransportError>;
130
131 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
132
133 match self.sender.poll_ready(cx) {
134 Poll::Pending => return Poll::Pending,
135 Poll::Ready(Ok(())) => {},
136 Poll::Ready(Err(_)) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
137 }
138
139 let channel_to_send = self.channel_to_send.take()
140 .expect("Future should not be polled again once complete");
141 let dial_port = self.dial_port;
142 match self.sender.start_send((channel_to_send, dial_port)) {
143 Err(_) => return Poll::Ready(Err(MemoryTransportError::Unreachable)),
144 Ok(()) => {}
145 }
146
147 Poll::Ready(Ok(self.channel_to_return.take()
148 .expect("Future should not be polled again once complete")))
149 }
150}
151
152impl Transport for MemoryTransport {
153 type Output = Channel<Vec<u8>>;
154 type Error = MemoryTransportError;
155 type Listener = Listener;
156 type ListenerUpgrade = Ready<Result<Self::Output, Self::Error>>;
157 type Dial = DialFuture;
158
159 fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
160 let port = if let Ok(port) = parse_memory_addr(&addr) {
161 port
162 } else {
163 return Err(TransportError::MultiaddrNotSupported(addr));
164 };
165
166 let (rx, port) = match HUB.register_port(port) {
167 Some((rx, port)) => (rx, port),
168 None => return Err(TransportError::Other(MemoryTransportError::Unreachable)),
169 };
170
171 let listener = Listener {
172 port,
173 addr: Protocol::Memory(port.get()).into(),
174 receiver: rx,
175 tell_listen_addr: true
176 };
177
178 Ok(listener)
179 }
180
181 fn dial(self, addr: Multiaddr) -> Result<DialFuture, TransportError<Self::Error>> {
182 let port = if let Ok(port) = parse_memory_addr(&addr) {
183 if let Some(port) = NonZeroU64::new(port) {
184 port
185 } else {
186 return Err(TransportError::Other(MemoryTransportError::Unreachable));
187 }
188 } else {
189 return Err(TransportError::MultiaddrNotSupported(addr));
190 };
191
192 DialFuture::new(port).ok_or(TransportError::Other(MemoryTransportError::Unreachable))
193 }
194
195 fn address_translation(&self, _server: &Multiaddr, _observed: &Multiaddr) -> Option<Multiaddr> {
196 None
197 }
198}
199
200#[derive(Debug, Copy, Clone)]
202pub enum MemoryTransportError {
203 Unreachable,
205 AlreadyInUse,
207}
208
209impl fmt::Display for MemoryTransportError {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 match *self {
212 MemoryTransportError::Unreachable => write!(f, "No listener on the given port."),
213 MemoryTransportError::AlreadyInUse => write!(f, "Port already occupied."),
214 }
215 }
216}
217
218impl error::Error for MemoryTransportError {}
219
220pub struct Listener {
222 port: NonZeroU64,
224 addr: Multiaddr,
226 receiver: ChannelReceiver,
228 tell_listen_addr: bool
230}
231
232impl Stream for Listener {
233 type Item = Result<ListenerEvent<Ready<Result<Channel<Vec<u8>>, MemoryTransportError>>, MemoryTransportError>, MemoryTransportError>;
234
235 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
236 if self.tell_listen_addr {
237 self.tell_listen_addr = false;
238 return Poll::Ready(Some(Ok(ListenerEvent::NewAddress(self.addr.clone()))))
239 }
240
241 let (channel, dial_port) = match Stream::poll_next(Pin::new(&mut self.receiver), cx) {
242 Poll::Pending => return Poll::Pending,
243 Poll::Ready(None) => panic!("Alive listeners always have a sender."),
244 Poll::Ready(Some(v)) => v,
245 };
246
247 let event = ListenerEvent::Upgrade {
248 upgrade: future::ready(Ok(channel)),
249 local_addr: self.addr.clone(),
250 remote_addr: Protocol::Memory(dial_port.get()).into()
251 };
252
253 Poll::Ready(Some(Ok(event)))
254 }
255}
256
257impl Drop for Listener {
258 fn drop(&mut self) {
259 let val_in = HUB.unregister_port(&self.port);
260 debug_assert!(val_in.is_some());
261 }
262}
263
264fn parse_memory_addr(a: &Multiaddr) -> Result<u64, ()> {
266 let mut iter = a.iter();
267
268 let port = if let Some(Protocol::Memory(port)) = iter.next() {
269 port
270 } else {
271 return Err(());
272 };
273
274 if iter.next().is_some() {
275 return Err(());
276 }
277
278 Ok(port)
279}
280
281pub type Channel<T> = RwStreamSink<Chan<T>>;
285
286pub struct Chan<T = Vec<u8>> {
290 incoming: mpsc::Receiver<T>,
291 outgoing: mpsc::Sender<T>,
292
293 dial_port: Option<NonZeroU64>,
300}
301
302impl<T> Unpin for Chan<T> {
303}
304
305impl<T> Stream for Chan<T> {
306 type Item = Result<T, io::Error>;
307
308 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
309 match Stream::poll_next(Pin::new(&mut self.incoming), cx) {
310 Poll::Pending => Poll::Pending,
311 Poll::Ready(None) => Poll::Ready(None),
312 Poll::Ready(Some(v)) => Poll::Ready(Some(Ok(v))),
313 }
314 }
315}
316
317impl<T> Sink<T> for Chan<T> {
318 type Error = io::Error;
319
320 fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
321 self.outgoing.poll_ready(cx)
322 .map(|v| v.map_err(|_| io::ErrorKind::BrokenPipe.into()))
323 }
324
325 fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
326 self.outgoing.start_send(item).map_err(|_| io::ErrorKind::BrokenPipe.into())
327 }
328
329 fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
330 Poll::Ready(Ok(()))
331 }
332
333 fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
334 Poll::Ready(Ok(()))
335 }
336}
337
338impl<T: AsRef<[u8]>> Into<RwStreamSink<Chan<T>>> for Chan<T> {
339 fn into(self) -> RwStreamSink<Chan<T>> {
340 RwStreamSink::new(self)
341 }
342}
343
344impl<T> Drop for Chan<T> {
345 fn drop(&mut self) {
346 if let Some(port) = self.dial_port {
347 let channel_sender = HUB.unregister_port(&port);
348 debug_assert!(channel_sender.is_some());
349 }
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn parse_memory_addr_works() {
359 assert_eq!(parse_memory_addr(&"/memory/5".parse().unwrap()), Ok(5));
360 assert_eq!(parse_memory_addr(&"/tcp/150".parse().unwrap()), Err(()));
361 assert_eq!(parse_memory_addr(&"/memory/0".parse().unwrap()), Ok(0));
362 assert_eq!(parse_memory_addr(&"/memory/5/tcp/150".parse().unwrap()), Err(()));
363 assert_eq!(parse_memory_addr(&"/tcp/150/memory/5".parse().unwrap()), Err(()));
364 assert_eq!(parse_memory_addr(&"/memory/1234567890".parse().unwrap()), Ok(1_234_567_890));
365 }
366
367 #[test]
368 fn listening_twice() {
369 let transport = MemoryTransport::default();
370 assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok());
371 assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok());
372 let _listener = transport.listen_on("/memory/1639174018481".parse().unwrap()).unwrap();
373 assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_err());
374 assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_err());
375 drop(_listener);
376 assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok());
377 assert!(transport.listen_on("/memory/1639174018481".parse().unwrap()).is_ok());
378 }
379
380 #[test]
381 fn port_not_in_use() {
382 let transport = MemoryTransport::default();
383 assert!(transport.dial("/memory/810172461024613".parse().unwrap()).is_err());
384 let _listener = transport.listen_on("/memory/810172461024613".parse().unwrap()).unwrap();
385 assert!(transport.dial("/memory/810172461024613".parse().unwrap()).is_ok());
386 }
387
388 #[test]
389 fn communicating_between_dialer_and_listener() {
390 let msg = [1, 2, 3];
391
392 let rand_port = rand::random::<u64>().saturating_add(1);
395 let t1_addr: Multiaddr = format!("/memory/{}", rand_port).parse().unwrap();
396 let cloned_t1_addr = t1_addr.clone();
397
398 let t1 = MemoryTransport::default();
399
400 let listener = async move {
401 let listener = t1.listen_on(t1_addr.clone()).unwrap();
402
403 let upgrade = listener.filter_map(|ev| futures::future::ready(
404 ListenerEvent::into_upgrade(ev.unwrap())
405 )).next().await.unwrap();
406
407 let mut socket = upgrade.0.await.unwrap();
408
409 let mut buf = [0; 3];
410 socket.read_exact(&mut buf).await.unwrap();
411
412 assert_eq!(buf, msg);
413 };
414
415 let t2 = MemoryTransport::default();
418 let dialer = async move {
419 let mut socket = t2.dial(cloned_t1_addr).unwrap().await.unwrap();
420 socket.write_all(&msg).await.unwrap();
421 };
422
423 futures::executor::block_on(futures::future::join(listener, dialer));
426 }
427
428 #[test]
429 fn dialer_address_unequal_to_listener_address() {
430 let listener_addr: Multiaddr = Protocol::Memory(
431 rand::random::<u64>().saturating_add(1),
432 ).into();
433 let listener_addr_cloned = listener_addr.clone();
434
435 let listener_transport = MemoryTransport::default();
436
437 let listener = async move {
438 let mut listener = listener_transport.listen_on(listener_addr.clone())
439 .unwrap();
440 while let Some(ev) = listener.next().await {
441 if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
442 assert!(
443 remote_addr != listener_addr,
444 "Expect dialer address not to equal listener address."
445 );
446 return;
447 }
448 }
449 };
450
451 let dialer = async move {
452 MemoryTransport::default().dial(listener_addr_cloned)
453 .unwrap()
454 .await
455 .unwrap();
456 };
457
458 futures::executor::block_on(futures::future::join(listener, dialer));
459 }
460
461 #[test]
462 fn dialer_port_is_deregistered() {
463 let (terminate, should_terminate) = futures::channel::oneshot::channel();
464 let (terminated, is_terminated) = futures::channel::oneshot::channel();
465
466 let listener_addr: Multiaddr = Protocol::Memory(
467 rand::random::<u64>().saturating_add(1),
468 ).into();
469 let listener_addr_cloned = listener_addr.clone();
470
471 let listener_transport = MemoryTransport::default();
472
473 let listener = async move {
474 let mut listener = listener_transport.listen_on(listener_addr.clone())
475 .unwrap();
476 while let Some(ev) = listener.next().await {
477 if let ListenerEvent::Upgrade { remote_addr, .. } = ev.unwrap() {
478 let dialer_port = NonZeroU64::new(
479 parse_memory_addr(&remote_addr).unwrap(),
480 ).unwrap();
481
482 assert!(
483 HUB.get(&dialer_port).is_some(),
484 "Expect dialer port to stay registered while connection is in use.",
485 );
486
487 terminate.send(()).unwrap();
488 is_terminated.await.unwrap();
489
490 assert!(
491 HUB.get(&dialer_port).is_none(),
492 "Expect dialer port to be deregistered once connection is dropped.",
493 );
494
495 return;
496 }
497 }
498 };
499
500 let dialer = async move {
501 let _chan = MemoryTransport::default().dial(listener_addr_cloned)
502 .unwrap()
503 .await
504 .unwrap();
505
506 should_terminate.await.unwrap();
507 drop(_chan);
508 terminated.send(()).unwrap();
509 };
510
511 futures::executor::block_on(futures::future::join(listener, dialer));
512 }
513}