use futures::{
future::{self, Either, FutureResult},
prelude::*,
stream::{self, Chain, IterOk, Once}
};
use get_if_addrs::{IfAddr, get_if_addrs};
use ipnet::{IpNet, Ipv4Net, Ipv6Net};
use libp2p_core::{
Transport,
multiaddr::{Protocol, Multiaddr},
transport::{ListenerEvent, TransportError}
};
use log::{debug, trace};
use std::{
collections::VecDeque,
io::{self, Read, Write},
iter::{self, FromIterator},
net::{IpAddr, SocketAddr},
time::{Duration, Instant},
vec::IntoIter
};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_timer::Delay;
use tokio_tcp::{ConnectFuture, Incoming, TcpStream};
#[derive(Debug, Clone, Default)]
pub struct TcpConfig {
sleep_on_error: Duration,
recv_buffer_size: Option<usize>,
send_buffer_size: Option<usize>,
ttl: Option<u32>,
keepalive: Option<Option<Duration>>,
nodelay: Option<bool>,
}
impl TcpConfig {
pub fn new() -> TcpConfig {
TcpConfig {
sleep_on_error: Duration::from_millis(100),
recv_buffer_size: None,
send_buffer_size: None,
ttl: None,
keepalive: None,
nodelay: None,
}
}
pub fn recv_buffer_size(mut self, value: usize) -> Self {
self.recv_buffer_size = Some(value);
self
}
pub fn send_buffer_size(mut self, value: usize) -> Self {
self.send_buffer_size = Some(value);
self
}
pub fn ttl(mut self, value: u32) -> Self {
self.ttl = Some(value);
self
}
pub fn keepalive(mut self, value: Option<Duration>) -> Self {
self.keepalive = Some(value);
self
}
pub fn nodelay(mut self, value: bool) -> Self {
self.nodelay = Some(value);
self
}
}
impl Transport for TcpConfig {
type Output = TcpTransStream;
type Error = io::Error;
type Listener = TcpListener;
type ListenerUpgrade = FutureResult<Self::Output, Self::Error>;
type Dial = TcpDialFut;
fn listen_on(self, addr: Multiaddr) -> Result<Self::Listener, TransportError<Self::Error>> {
let socket_addr =
if let Ok(sa) = multiaddr_to_socketaddr(&addr) {
sa
} else {
return Err(TransportError::MultiaddrNotSupported(addr))
};
let listener = tokio_tcp::TcpListener::bind(&socket_addr).map_err(TransportError::Other)?;
let local_addr = listener.local_addr().map_err(TransportError::Other)?;
let port = local_addr.port();
let addrs =
if socket_addr.ip().is_unspecified() {
let addrs = host_addresses(port).map_err(TransportError::Other)?;
debug!("Listening on {:?}", addrs.iter().map(|(_, _, ma)| ma).collect::<Vec<_>>());
Addresses::Many(addrs)
} else {
let ma = ip_to_multiaddr(local_addr.ip(), port);
debug!("Listening on {:?}", ma);
Addresses::One(ma)
};
let events = match addrs {
Addresses::One(ref ma) => {
let event = ListenerEvent::NewAddress(ma.clone());
Either::A(stream::once(Ok(event)))
}
Addresses::Many(ref aa) => {
let events = aa.iter()
.map(|(_, _, ma)| ma)
.cloned()
.map(ListenerEvent::NewAddress)
.collect::<Vec<_>>();
Either::B(stream::iter_ok(events))
}
};
let stream = TcpListenStream {
inner: Listener::new(listener.incoming(), self.sleep_on_error),
port,
addrs,
pending: VecDeque::new(),
config: self
};
Ok(TcpListener {
inner: match events {
Either::A(e) => Either::A(e.chain(stream)),
Either::B(e) => Either::B(e.chain(stream))
}
})
}
fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
let socket_addr =
if let Ok(socket_addr) = multiaddr_to_socketaddr(&addr) {
if socket_addr.port() == 0 || socket_addr.ip().is_unspecified() {
debug!("Instantly refusing dialing {}, as it is invalid", addr);
return Err(TransportError::Other(io::ErrorKind::ConnectionRefused.into()))
}
socket_addr
} else {
return Err(TransportError::MultiaddrNotSupported(addr))
};
debug!("Dialing {}", addr);
let future = TcpDialFut {
inner: TcpStream::connect(&socket_addr),
config: self
};
Ok(future)
}
}
fn multiaddr_to_socketaddr(addr: &Multiaddr) -> Result<SocketAddr, ()> {
let mut iter = addr.iter();
let proto1 = iter.next().ok_or(())?;
let proto2 = iter.next().ok_or(())?;
if iter.next().is_some() {
return Err(());
}
match (proto1, proto2) {
(Protocol::Ip4(ip), Protocol::Tcp(port)) => Ok(SocketAddr::new(ip.into(), port)),
(Protocol::Ip6(ip), Protocol::Tcp(port)) => Ok(SocketAddr::new(ip.into(), port)),
_ => Err(()),
}
}
fn ip_to_multiaddr(ip: IpAddr, port: u16) -> Multiaddr {
let proto = match ip {
IpAddr::V4(ip) => Protocol::Ip4(ip),
IpAddr::V6(ip) => Protocol::Ip6(ip)
};
let it = iter::once(proto).chain(iter::once(Protocol::Tcp(port)));
Multiaddr::from_iter(it)
}
fn host_addresses(port: u16) -> io::Result<Vec<(IpAddr, IpNet, Multiaddr)>> {
let mut addrs = Vec::new();
for iface in get_if_addrs()? {
let ip = iface.ip();
let ma = ip_to_multiaddr(ip, port);
let ipn = match iface.addr {
IfAddr::V4(ip4) => {
let prefix_len = (!u32::from_be_bytes(ip4.netmask.octets())).leading_zeros();
let ipnet = Ipv4Net::new(ip4.ip, prefix_len as u8)
.expect("prefix_len is the number of bits in a u32, so can not exceed 32");
IpNet::V4(ipnet)
}
IfAddr::V6(ip6) => {
let prefix_len = (!u128::from_be_bytes(ip6.netmask.octets())).leading_zeros();
let ipnet = Ipv6Net::new(ip6.ip, prefix_len as u8)
.expect("prefix_len is the number of bits in a u128, so can not exceed 128");
IpNet::V6(ipnet)
}
};
addrs.push((ip, ipn, ma))
}
Ok(addrs)
}
fn apply_config(config: &TcpConfig, socket: &TcpStream) -> Result<(), io::Error> {
if let Some(recv_buffer_size) = config.recv_buffer_size {
socket.set_recv_buffer_size(recv_buffer_size)?;
}
if let Some(send_buffer_size) = config.send_buffer_size {
socket.set_send_buffer_size(send_buffer_size)?;
}
if let Some(ttl) = config.ttl {
socket.set_ttl(ttl)?;
}
if let Some(keepalive) = config.keepalive {
socket.set_keepalive(keepalive)?;
}
if let Some(nodelay) = config.nodelay {
socket.set_nodelay(nodelay)?;
}
Ok(())
}
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct TcpDialFut {
inner: ConnectFuture,
config: TcpConfig,
}
impl Future for TcpDialFut {
type Item = TcpTransStream;
type Error = io::Error;
fn poll(&mut self) -> Poll<TcpTransStream, io::Error> {
match self.inner.poll() {
Ok(Async::Ready(stream)) => {
apply_config(&self.config, &stream)?;
Ok(Async::Ready(TcpTransStream { inner: stream }))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
debug!("Error while dialing => {:?}", err);
Err(err)
}
}
}
}
#[derive(Debug)]
pub struct TcpListener {
inner: Either<
Chain<Once<ListenerEvent<FutureResult<TcpTransStream, io::Error>>, io::Error>, TcpListenStream>,
Chain<IterOk<IntoIter<ListenerEvent<FutureResult<TcpTransStream, io::Error>>>, io::Error>, TcpListenStream>
>
}
impl Stream for TcpListener {
type Item = ListenerEvent<FutureResult<TcpTransStream, io::Error>>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match self.inner {
Either::A(ref mut it) => it.poll(),
Either::B(ref mut it) => it.poll()
}
}
}
#[derive(Debug)]
enum Addresses {
One(Multiaddr),
Many(Vec<(IpAddr, IpNet, Multiaddr)>)
}
type Buffer = VecDeque<ListenerEvent<FutureResult<TcpTransStream, io::Error>>>;
#[derive(Debug)]
struct Listener<S> {
stream: S,
pause: Option<Delay>,
pause_duration: Duration
}
impl<S> Listener<S>
where
S: Stream,
S::Error: std::fmt::Display
{
fn new(stream: S, duration: Duration) -> Self {
Listener { stream, pause: None, pause_duration: duration }
}
}
impl<S> Stream for Listener<S>
where
S: Stream,
S::Error: std::fmt::Display
{
type Item = S::Item;
type Error = S::Error;
fn poll(&mut self) -> Poll<Option<S::Item>, S::Error> {
match self.pause.as_mut().map(|p| p.poll()) {
Some(Ok(Async::NotReady)) => return Ok(Async::NotReady),
Some(Ok(Async::Ready(()))) | Some(Err(_)) => { self.pause.take(); }
None => ()
}
match self.stream.poll() {
Ok(x) => Ok(x),
Err(e) => {
debug!("error accepting incoming connection: {}", e);
self.pause = Some(Delay::new(Instant::now() + self.pause_duration));
Err(e)
}
}
}
}
#[derive(Debug)]
pub struct TcpListenStream {
inner: Listener<Incoming>,
port: u16,
addrs: Addresses,
pending: Buffer,
config: TcpConfig
}
fn check_for_interface_changes(
socket_addr: &SocketAddr,
listen_port: u16,
listen_addrs: &mut Vec<(IpAddr, IpNet, Multiaddr)>,
pending: &mut Buffer
) -> Result<(), io::Error> {
if listen_addrs.iter().find(|(ip, ..)| ip == &socket_addr.ip()).is_some() {
return Ok(())
}
if listen_addrs.iter().find(|(_, net, _)| net.contains(&socket_addr.ip())).is_some() {
return Ok(())
}
let old_listen_addrs = std::mem::replace(listen_addrs, host_addresses(listen_port)?);
for (ip, _, ma) in old_listen_addrs.iter() {
if listen_addrs.iter().find(|(i, ..)| i == ip).is_none() {
debug!("Expired listen address: {}", ma);
pending.push_back(ListenerEvent::AddressExpired(ma.clone()));
}
}
for (ip, _, ma) in listen_addrs.iter() {
if old_listen_addrs.iter().find(|(i, ..)| i == ip).is_none() {
debug!("New listen address: {}", ma);
pending.push_back(ListenerEvent::NewAddress(ma.clone()));
}
}
if listen_addrs.iter()
.find(|(ip, net, _)| ip == &socket_addr.ip() || net.contains(&socket_addr.ip()))
.is_none()
{
let msg = format!("{} does not match any listen address", socket_addr.ip());
return Err(io::Error::new(io::ErrorKind::Other, msg))
}
Ok(())
}
impl Stream for TcpListenStream {
type Item = ListenerEvent<FutureResult<TcpTransStream, io::Error>>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, io::Error> {
loop {
if let Some(event) = self.pending.pop_front() {
return Ok(Async::Ready(Some(event)))
}
let sock = match self.inner.poll() {
Ok(Async::Ready(Some(sock))) => sock,
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => return Err(e)
};
let sock_addr = match sock.peer_addr() {
Ok(addr) => addr,
Err(err) => {
debug!("Failed to get peer address: {:?}", err);
continue
}
};
let local_addr = match sock.local_addr() {
Ok(sock_addr) => {
if let Addresses::Many(ref mut addrs) = self.addrs {
check_for_interface_changes(&sock_addr, self.port, addrs, &mut self.pending)?
}
ip_to_multiaddr(sock_addr.ip(), sock_addr.port())
}
Err(err) => {
debug!("Failed to get local address of incoming socket: {:?}", err);
continue
}
};
let remote_addr = ip_to_multiaddr(sock_addr.ip(), sock_addr.port());
match apply_config(&self.config, &sock) {
Ok(()) => {
trace!("Incoming connection from {} at {}", remote_addr, local_addr);
self.pending.push_back(ListenerEvent::Upgrade {
upgrade: future::ok(TcpTransStream { inner: sock }),
local_addr,
remote_addr
})
}
Err(err) => {
debug!("Error upgrading incoming connection from {}: {:?}", remote_addr, err);
self.pending.push_back(ListenerEvent::Upgrade {
upgrade: future::err(err),
local_addr,
remote_addr
})
}
}
}
}
}
#[derive(Debug)]
pub struct TcpTransStream {
inner: TcpStream,
}
impl Read for TcpTransStream {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
self.inner.read(buf)
}
}
impl AsyncRead for TcpTransStream {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.inner.prepare_uninitialized_buffer(buf)
}
fn read_buf<B: bytes::BufMut>(&mut self, buf: &mut B) -> Poll<usize, io::Error> {
self.inner.read_buf(buf)
}
}
impl Write for TcpTransStream {
fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
self.inner.write(buf)
}
fn flush(&mut self) -> Result<(), io::Error> {
self.inner.flush()
}
}
impl AsyncWrite for TcpTransStream {
fn shutdown(&mut self) -> Poll<(), io::Error> {
AsyncWrite::shutdown(&mut self.inner)
}
}
impl Drop for TcpTransStream {
fn drop(&mut self) {
if let Ok(addr) = self.inner.peer_addr() {
debug!("Dropped TCP connection to {:?}", addr);
} else {
debug!("Dropped TCP connection to undeterminate peer");
}
}
}
#[cfg(test)]
mod tests {
use futures::{prelude::*, future::{self, Loop}, stream};
use libp2p_core::{Transport, multiaddr::{Multiaddr, Protocol}, transport::ListenerEvent};
use std::{net::{IpAddr, Ipv4Addr, SocketAddr}, time::Duration};
use super::{multiaddr_to_socketaddr, TcpConfig, Listener};
use tokio::runtime::current_thread::{self, Runtime};
use tokio_io;
#[test]
fn pause_on_error() {
let rs = stream::iter_result(vec![Ok(1), Err(1), Ok(1), Err(1)]);
let ls = Listener::new(rs, Duration::from_secs(1));
let sum = future::loop_fn((0, ls), |(acc, ls)| {
ls.into_future().then(move |item| {
match item {
Ok((None, _)) => Ok::<_, std::convert::Infallible>(Loop::Break(acc)),
Ok((Some(n), rest)) => Ok(Loop::Continue((acc + n, rest))),
Err((n, rest)) => Ok(Loop::Continue((acc + n, rest)))
}
})
});
assert_eq!(4, current_thread::block_on_all(sum).unwrap())
}
#[test]
fn wildcard_expansion() {
let mut listener = TcpConfig::new()
.listen_on("/ip4/0.0.0.0/tcp/0".parse().unwrap())
.expect("listener");
let addr = listener.by_ref()
.wait()
.next()
.expect("some event")
.expect("no error")
.into_new_address()
.expect("listen address");
let server = listener
.take_while(|event| match event {
ListenerEvent::NewAddress(a) => {
let mut iter = a.iter();
match iter.next().expect("ip address") {
Protocol::Ip4(ip) => assert!(!ip.is_unspecified()),
Protocol::Ip6(ip) => assert!(!ip.is_unspecified()),
other => panic!("Unexpected protocol: {}", other)
}
if let Protocol::Tcp(port) = iter.next().expect("port") {
assert_ne!(0, port)
} else {
panic!("No TCP port in address: {}", a)
}
Ok(true)
}
_ => Ok(false)
})
.for_each(|_| Ok(()));
let client = TcpConfig::new().dial(addr).expect("dialer");
tokio::run(server.join(client).map(|_| ()).map_err(|e| panic!("error: {}", e)))
}
#[test]
fn multiaddr_to_tcp_conversion() {
use std::net::Ipv6Addr;
assert!(
multiaddr_to_socketaddr(&"/ip4/127.0.0.1/udp/1234".parse::<Multiaddr>().unwrap())
.is_err()
);
assert_eq!(
multiaddr_to_socketaddr(&"/ip4/127.0.0.1/tcp/12345".parse::<Multiaddr>().unwrap()),
Ok(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
))
);
assert_eq!(
multiaddr_to_socketaddr(
&"/ip4/255.255.255.255/tcp/8080"
.parse::<Multiaddr>()
.unwrap()
),
Ok(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)),
8080,
))
);
assert_eq!(
multiaddr_to_socketaddr(&"/ip6/::1/tcp/12345".parse::<Multiaddr>().unwrap()),
Ok(SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
12345,
))
);
assert_eq!(
multiaddr_to_socketaddr(
&"/ip6/ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/tcp/8080"
.parse::<Multiaddr>()
.unwrap()
),
Ok(SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(
65535, 65535, 65535, 65535, 65535, 65535, 65535, 65535,
)),
8080,
))
);
}
#[test]
fn communicating_between_dialer_and_listener() {
use std::io::Write;
std::thread::spawn(move || {
let addr = "/ip4/127.0.0.1/tcp/12345".parse::<Multiaddr>().unwrap();
let tcp = TcpConfig::new();
let mut rt = Runtime::new().unwrap();
let handle = rt.handle();
let listener = tcp.listen_on(addr).unwrap()
.filter_map(ListenerEvent::into_upgrade)
.for_each(|(sock, _)| {
sock.and_then(|sock| {
let handle_conn = tokio_io::io::read_exact(sock, [0; 3])
.map(|(_, buf)| assert_eq!(buf, [1, 2, 3]))
.map_err(|err| panic!("IO error {:?}", err));
handle.spawn(handle_conn).unwrap();
Ok(())
})
});
rt.block_on(listener).unwrap();
rt.run().unwrap();
});
std::thread::sleep(std::time::Duration::from_millis(100));
let addr = "/ip4/127.0.0.1/tcp/12345".parse::<Multiaddr>().unwrap();
let tcp = TcpConfig::new();
let socket = tcp.dial(addr.clone()).unwrap();
let action = socket.then(|sock| -> Result<(), ()> {
sock.unwrap().write(&[0x1, 0x2, 0x3]).unwrap();
Ok(())
});
let mut rt = Runtime::new().unwrap();
let _ = rt.block_on(action).unwrap();
}
#[test]
fn replace_port_0_in_returned_multiaddr_ipv4() {
let tcp = TcpConfig::new();
let addr = "/ip4/127.0.0.1/tcp/0".parse::<Multiaddr>().unwrap();
assert!(addr.to_string().contains("tcp/0"));
let new_addr = tcp.listen_on(addr).unwrap().wait()
.next()
.expect("some event")
.expect("no error")
.into_new_address()
.expect("listen address");
assert!(!new_addr.to_string().contains("tcp/0"));
}
#[test]
fn replace_port_0_in_returned_multiaddr_ipv6() {
let tcp = TcpConfig::new();
let addr: Multiaddr = "/ip6/::1/tcp/0".parse().unwrap();
assert!(addr.to_string().contains("tcp/0"));
let new_addr = tcp.listen_on(addr).unwrap().wait()
.next()
.expect("some event")
.expect("no error")
.into_new_address()
.expect("listen address");
assert!(!new_addr.to_string().contains("tcp/0"));
}
#[test]
fn larger_addr_denied() {
let tcp = TcpConfig::new();
let addr = "/ip4/127.0.0.1/tcp/12345/tcp/12345"
.parse::<Multiaddr>()
.unwrap();
assert!(tcp.listen_on(addr).is_err());
}
}