use std::{
convert::TryFrom,
io,
num::NonZeroU16,
pin::Pin,
task::{Context, Poll},
};
use futures::stream::Stream;
use multiaddr::{Multiaddr, Protocol};
use crate::{
memsocket::{self, MemoryListener, MemorySocket},
transports::Transport,
types::TransportProtocol,
};
#[derive(Debug, Default, Clone)]
pub struct MemoryTransport;
impl MemoryTransport {
pub fn acquire_next_memsocket_port() -> NonZeroU16 {
memsocket::acquire_next_memsocket_port()
}
pub fn release_next_memsocket_port(port: NonZeroU16) {
memsocket::release_memsocket_port(port);
}
}
#[crate::async_trait]
impl Transport for MemoryTransport {
type Error = io::Error;
type Listener = Listener;
type Output = MemorySocket;
async fn listen(&self, addr: &Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> {
let port = parse_addr(addr)?;
let listener = MemoryListener::bind(port)?;
let actual_port = listener.local_addr();
let mut actual_addr = Multiaddr::empty();
actual_addr.push(Protocol::Memory(u64::from(actual_port)));
Ok((Listener { inner: listener }, actual_addr))
}
async fn dial(&self, addr: &Multiaddr) -> Result<Self::Output, Self::Error> {
let port = parse_addr(addr)?;
Ok(MemorySocket::connect(port)?)
}
fn supported_protocols(&self) -> Vec<TransportProtocol> {
vec![TransportProtocol::Memory]
}
}
fn parse_addr(addr: &Multiaddr) -> io::Result<u16> {
let mut iter = addr.iter();
let port = if let Some(Protocol::Memory(port)) = iter.next() {
port
} else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid Multiaddr '{addr:?}'"),
));
};
if iter.next().is_some() {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
format!("Invalid Multiaddr '{addr:?}'"),
));
}
Ok(u16::try_from(port).unwrap())
}
#[must_use = "streams do nothing unless polled"]
#[derive(Debug)]
pub struct Listener {
inner: MemoryListener,
}
impl Stream for Listener {
type Item = io::Result<(MemorySocket, Multiaddr)>;
fn poll_next(mut self: Pin<&mut Self>, context: &mut Context) -> Poll<Option<Self::Item>> {
let mut incoming = self.inner.incoming();
match Pin::new(&mut incoming).poll_next(context) {
Poll::Ready(Some(Ok(socket))) => {
let dialer_addr = Protocol::Memory(0).into();
Poll::Ready(Some(Ok((socket, dialer_addr))))
},
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => Poll::Ready(None),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod test {
use futures::{future::join, stream::StreamExt};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use super::*;
#[tokio::test]
async fn simple_listen_and_dial() -> Result<(), ::std::io::Error> {
let t = MemoryTransport;
let (listener, addr) = t.listen(&"/memory/0".parse().unwrap()).await?;
let listener = async move {
let (item, _listener) = listener.into_future().await;
let (mut socket, _addr) = item.unwrap().unwrap();
let mut buf = Vec::new();
socket.read_to_end(&mut buf).await.unwrap();
assert_eq!(buf, b"hello world");
};
let mut outbound = t.dial(&addr).await?;
let dialer = async move {
outbound.write_all(b"hello world").await.unwrap();
outbound.flush().await.unwrap();
};
join(dialer, listener).await;
Ok(())
}
#[tokio::test]
async fn unsupported_multiaddrs() {
let t = MemoryTransport;
let err = t.listen(&"/ip4/127.0.0.1/tcp/0".parse().unwrap()).await.unwrap_err();
assert!(matches!(err.kind(), io::ErrorKind::InvalidInput));
let err = t.dial(&"/ip4/127.0.0.1/tcp/22".parse().unwrap()).await.unwrap_err();
assert!(matches!(err.kind(), io::ErrorKind::InvalidInput));
}
#[test]
fn acquire_release_memsocket_port() {
let port1 = MemoryTransport::acquire_next_memsocket_port();
let port2 = MemoryTransport::acquire_next_memsocket_port();
assert_ne!(port1, port2);
MemoryTransport::release_next_memsocket_port(port1);
MemoryTransport::release_next_memsocket_port(port2);
}
}