use futures::{future, future::FutureResult, prelude::*, Async, Poll};
use libp2p_core as swarm;
use log::{debug, error};
use multiaddr::{Protocol, Multiaddr, ToMultiaddr};
use std::fmt;
use std::io::{self, Read, Write};
use std::net::SocketAddr;
use std::time::Duration;
use swarm::{Transport, transport::TransportError};
use tk_listen::{ListenExt, SleepOnError};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_tcp::{ConnectFuture, Incoming, TcpListener, TcpStream};
#[derive(Debug, Clone, Default)]
pub struct TcpConfig {
sleep_on_error: Duration,
recv_buffer_size: Option<usize>,
send_buffer_size: Option<usize>,
ttl: Option<u32>,
keepalive: Option<Option<Duration>>,
nodelay: Option<bool>,
}
impl TcpConfig {
#[inline]
pub fn new() -> TcpConfig {
TcpConfig {
sleep_on_error: Duration::from_millis(100),
recv_buffer_size: None,
send_buffer_size: None,
ttl: None,
keepalive: None,
nodelay: None,
}
}
#[inline]
pub fn recv_buffer_size(mut self, value: usize) -> Self {
self.recv_buffer_size = Some(value);
self
}
#[inline]
pub fn send_buffer_size(mut self, value: usize) -> Self {
self.send_buffer_size = Some(value);
self
}
#[inline]
pub fn ttl(mut self, value: u32) -> Self {
self.ttl = Some(value);
self
}
#[inline]
pub fn keepalive(mut self, value: Option<Duration>) -> Self {
self.keepalive = Some(value);
self
}
#[inline]
pub fn nodelay(mut self, value: bool) -> Self {
self.nodelay = Some(value);
self
}
}
impl Transport for TcpConfig {
type Output = TcpTransStream;
type Error = io::Error;
type Listener = TcpListenStream;
type ListenerUpgrade = FutureResult<Self::Output, io::Error>;
type Dial = TcpDialFut;
fn listen_on(self, addr: Multiaddr) -> Result<(Self::Listener, Multiaddr), TransportError<Self::Error>> {
if let Ok(socket_addr) = multiaddr_to_socketaddr(&addr) {
let listener = TcpListener::bind(&socket_addr);
let new_addr = match listener {
Ok(ref l) => if let Ok(new_s_addr) = l.local_addr() {
new_s_addr.to_multiaddr().expect(
"multiaddr generated from socket addr is \
always valid",
)
} else {
addr
},
Err(_) => addr,
};
debug!("Now listening on {}", new_addr);
let sleep_on_error = self.sleep_on_error;
let inner = listener
.map_err(TransportError::Other)?
.incoming()
.sleep_on_error(sleep_on_error);
Ok((
TcpListenStream {
inner: Ok(inner),
config: self,
},
new_addr,
))
} else {
Err(TransportError::MultiaddrNotSupported(addr))
}
}
fn dial(self, addr: Multiaddr) -> Result<Self::Dial, TransportError<Self::Error>> {
if let Ok(socket_addr) = multiaddr_to_socketaddr(&addr) {
if socket_addr.port() != 0 && !socket_addr.ip().is_unspecified() {
debug!("Dialing {}", addr);
Ok(TcpDialFut {
inner: TcpStream::connect(&socket_addr),
config: self,
})
} else {
debug!("Instantly refusing dialing {}, as it is invalid", addr);
Err(TransportError::Other(io::ErrorKind::ConnectionRefused.into()))
}
} else {
Err(TransportError::MultiaddrNotSupported(addr))
}
}
fn nat_traversal(&self, server: &Multiaddr, observed: &Multiaddr) -> Option<Multiaddr> {
let mut address = Multiaddr::empty();
match server.iter().zip(observed.iter()).next() {
Some((Protocol::Ip4(_), x @ Protocol::Ip4(_))) => address.append(x),
Some((Protocol::Ip6(_), x @ Protocol::Ip4(_))) => address.append(x),
Some((Protocol::Ip4(_), x @ Protocol::Ip6(_))) => address.append(x),
Some((Protocol::Ip6(_), x @ Protocol::Ip6(_))) => address.append(x),
_ => return None
}
for proto in server.iter().skip(1) {
address.append(proto)
}
Some(address)
}
}
fn multiaddr_to_socketaddr(addr: &Multiaddr) -> Result<SocketAddr, ()> {
let mut iter = addr.iter();
let proto1 = iter.next().ok_or(())?;
let proto2 = iter.next().ok_or(())?;
if iter.next().is_some() {
return Err(());
}
match (proto1, proto2) {
(Protocol::Ip4(ip), Protocol::Tcp(port)) => Ok(SocketAddr::new(ip.into(), port)),
(Protocol::Ip6(ip), Protocol::Tcp(port)) => Ok(SocketAddr::new(ip.into(), port)),
_ => Err(()),
}
}
fn apply_config(config: &TcpConfig, socket: &TcpStream) -> Result<(), io::Error> {
if let Some(recv_buffer_size) = config.recv_buffer_size {
socket.set_recv_buffer_size(recv_buffer_size)?;
}
if let Some(send_buffer_size) = config.send_buffer_size {
socket.set_send_buffer_size(send_buffer_size)?;
}
if let Some(ttl) = config.ttl {
socket.set_ttl(ttl)?;
}
if let Some(keepalive) = config.keepalive {
socket.set_keepalive(keepalive)?;
}
if let Some(nodelay) = config.nodelay {
socket.set_nodelay(nodelay)?;
}
Ok(())
}
#[derive(Debug)]
#[must_use = "futures do nothing unless polled"]
pub struct TcpDialFut {
inner: ConnectFuture,
config: TcpConfig,
}
impl Future for TcpDialFut {
type Item = TcpTransStream;
type Error = io::Error;
fn poll(&mut self) -> Poll<TcpTransStream, io::Error> {
match self.inner.poll() {
Ok(Async::Ready(stream)) => {
apply_config(&self.config, &stream)?;
Ok(Async::Ready(TcpTransStream { inner: stream }))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(err) => {
debug!("Error while dialing => {:?}", err);
Err(err)
}
}
}
}
pub struct TcpListenStream {
inner: Result<SleepOnError<Incoming>, Option<io::Error>>,
config: TcpConfig,
}
impl Stream for TcpListenStream {
type Item = (FutureResult<TcpTransStream, io::Error>, Multiaddr);
type Error = io::Error;
fn poll(
&mut self,
) -> Poll<
Option<(FutureResult<TcpTransStream, io::Error>, Multiaddr)>,
io::Error,
> {
let inner = match self.inner {
Ok(ref mut inc) => inc,
Err(ref mut err) => {
return Err(err.take().expect("poll called again after error"));
}
};
loop {
match inner.poll() {
Ok(Async::Ready(Some(sock))) => {
let addr = match sock.peer_addr() {
Ok(addr) => addr
.to_multiaddr()
.expect("generating a multiaddr from a socket addr never fails"),
Err(err) => {
error!("Ignored incoming because could't determine its \
address: {:?}", err);
continue
},
};
match apply_config(&self.config, &sock) {
Ok(()) => (),
Err(err) => return Ok(Async::Ready(Some((future::err(err), addr)))),
};
debug!("Incoming connection from {}", addr);
let ret = future::ok(TcpTransStream { inner: sock });
break Ok(Async::Ready(Some((ret, addr))))
}
Ok(Async::Ready(None)) => break Ok(Async::Ready(None)),
Ok(Async::NotReady) => break Ok(Async::NotReady),
Err(()) => unreachable!("sleep_on_error never produces an error"),
}
}
}
}
impl fmt::Debug for TcpListenStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.inner {
Ok(_) => write!(f, "TcpListenStream"),
Err(None) => write!(f, "TcpListenStream(Errored)"),
Err(Some(ref err)) => write!(f, "TcpListenStream({:?})", err),
}
}
}
#[derive(Debug)]
pub struct TcpTransStream {
inner: TcpStream,
}
impl Read for TcpTransStream {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
self.inner.read(buf)
}
}
impl AsyncRead for TcpTransStream {}
impl Write for TcpTransStream {
#[inline]
fn write(&mut self, buf: &[u8]) -> Result<usize, io::Error> {
self.inner.write(buf)
}
#[inline]
fn flush(&mut self) -> Result<(), io::Error> {
self.inner.flush()
}
}
impl AsyncWrite for TcpTransStream {
#[inline]
fn shutdown(&mut self) -> Poll<(), io::Error> {
AsyncWrite::shutdown(&mut self.inner)
}
}
impl Drop for TcpTransStream {
#[inline]
fn drop(&mut self) {
if let Ok(addr) = self.inner.peer_addr() {
debug!("Dropped TCP connection to {:?}", addr);
} else {
debug!("Dropped TCP connection to undeterminate peer");
}
}
}
#[cfg(test)]
mod tests {
use tokio::runtime::current_thread::Runtime;
use super::{multiaddr_to_socketaddr, TcpConfig};
use futures::stream::Stream;
use futures::Future;
use multiaddr::Multiaddr;
use std;
use std::net::{IpAddr, Ipv4Addr, SocketAddr};
use super::swarm::Transport;
use tokio_io;
#[test]
fn multiaddr_to_tcp_conversion() {
use std::net::Ipv6Addr;
assert!(
multiaddr_to_socketaddr(&"/ip4/127.0.0.1/udp/1234".parse::<Multiaddr>().unwrap())
.is_err()
);
assert_eq!(
multiaddr_to_socketaddr(&"/ip4/127.0.0.1/tcp/12345".parse::<Multiaddr>().unwrap()),
Ok(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
12345,
))
);
assert_eq!(
multiaddr_to_socketaddr(
&"/ip4/255.255.255.255/tcp/8080"
.parse::<Multiaddr>()
.unwrap()
),
Ok(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(255, 255, 255, 255)),
8080,
))
);
assert_eq!(
multiaddr_to_socketaddr(&"/ip6/::1/tcp/12345".parse::<Multiaddr>().unwrap()),
Ok(SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
12345,
))
);
assert_eq!(
multiaddr_to_socketaddr(
&"/ip6/ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/tcp/8080"
.parse::<Multiaddr>()
.unwrap()
),
Ok(SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(
65535, 65535, 65535, 65535, 65535, 65535, 65535, 65535,
)),
8080,
))
);
}
#[test]
fn communicating_between_dialer_and_listener() {
use std::io::Write;
std::thread::spawn(move || {
let addr = "/ip4/127.0.0.1/tcp/12345".parse::<Multiaddr>().unwrap();
let tcp = TcpConfig::new();
let mut rt = Runtime::new().unwrap();
let handle = rt.handle();
let listener = tcp.listen_on(addr).unwrap().0.for_each(|(sock, _)| {
sock.and_then(|sock| {
let handle_conn = tokio_io::io::read_exact(sock, [0; 3])
.map(|(_, buf)| assert_eq!(buf, [1, 2, 3]))
.map_err(|err| panic!("IO error {:?}", err));
handle.spawn(handle_conn).unwrap();
Ok(())
})
});
rt.block_on(listener).unwrap();
rt.run().unwrap();
});
std::thread::sleep(std::time::Duration::from_millis(100));
let addr = "/ip4/127.0.0.1/tcp/12345".parse::<Multiaddr>().unwrap();
let tcp = TcpConfig::new();
let socket = tcp.dial(addr.clone()).unwrap();
let action = socket.then(|sock| -> Result<(), ()> {
sock.unwrap().write(&[0x1, 0x2, 0x3]).unwrap();
Ok(())
});
let mut rt = Runtime::new().unwrap();
let _ = rt.block_on(action).unwrap();
}
#[test]
fn replace_port_0_in_returned_multiaddr_ipv4() {
let tcp = TcpConfig::new();
let addr = "/ip4/127.0.0.1/tcp/0".parse::<Multiaddr>().unwrap();
assert!(addr.to_string().contains("tcp/0"));
let (_, new_addr) = tcp.listen_on(addr).unwrap();
assert!(!new_addr.to_string().contains("tcp/0"));
}
#[test]
fn replace_port_0_in_returned_multiaddr_ipv6() {
let tcp = TcpConfig::new();
let addr: Multiaddr = "/ip6/::1/tcp/0".parse().unwrap();
assert!(addr.to_string().contains("tcp/0"));
let (_, new_addr) = tcp.listen_on(addr).unwrap();
assert!(!new_addr.to_string().contains("tcp/0"));
}
#[test]
fn larger_addr_denied() {
let tcp = TcpConfig::new();
let addr = "/ip4/127.0.0.1/tcp/12345/tcp/12345"
.parse::<Multiaddr>()
.unwrap();
assert!(tcp.listen_on(addr).is_err());
}
#[test]
fn nat_traversal() {
let tcp = TcpConfig::new();
let server = "/ip4/127.0.0.1/tcp/10000".parse::<Multiaddr>().unwrap();
let observed = "/ip4/80.81.82.83/tcp/25000".parse::<Multiaddr>().unwrap();
let out = tcp.nat_traversal(&server, &observed);
assert_eq!(
out.unwrap(),
"/ip4/80.81.82.83/tcp/10000".parse::<Multiaddr>().unwrap()
);
}
#[test]
fn nat_traversal_ipv6_to_ipv4() {
let tcp = TcpConfig::new();
let server = "/ip6/::1/tcp/10000".parse::<Multiaddr>().unwrap();
let observed = "/ip4/80.81.82.83/tcp/25000".parse::<Multiaddr>().unwrap();
let out = tcp.nat_traversal(&server, &observed);
assert_eq!(
out.unwrap(),
"/ip4/80.81.82.83/tcp/10000".parse::<Multiaddr>().unwrap()
);
}
#[test]
fn nat_traversal_ipv4_to_ipv6() {
let tcp = TcpConfig::new();
let server = "/ip4/127.0.0.1/tcp/10000".parse::<Multiaddr>().unwrap();
let observed = "/ip6/2001:db8::1/tcp/25000".parse::<Multiaddr>().unwrap();
let out = tcp.nat_traversal(&server, &observed);
assert_eq!(
out.unwrap(),
"/ip6/2001:db8::1/tcp/10000".parse::<Multiaddr>().unwrap()
);
}
}