use std::{
future::Future,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::{FutureExt, ready};
use multiaddr::Multiaddr;
use tokio::net::{TcpListener, TcpStream};
use tokio_stream::Stream;
use super::{Transport, dns::DnsResolver};
use crate::{
transports::dns::{DnsResolverRef, SystemDnsResolver},
types::TransportProtocol,
utils::{multiaddr::socketaddr_to_multiaddr, network::supports_ipv6},
};
#[derive(Clone)]
pub struct TcpTransport {
ttl: Option<u32>,
nodelay: Option<bool>,
dns_resolver: DnsResolverRef,
supported_protocols: Vec<TransportProtocol>,
}
impl TcpTransport {
setter_mut!(set_ttl, ttl, Option<u32>);
setter_mut!(set_nodelay, nodelay, Option<bool>);
pub fn new() -> Self {
let mut supported_protocols = vec![TransportProtocol::Ipv4];
if supports_ipv6() {
supported_protocols.push(TransportProtocol::Ipv6);
}
Self {
ttl: None,
nodelay: None,
dns_resolver: Arc::new(SystemDnsResolver),
supported_protocols,
}
}
pub fn set_dns_resolver<T: DnsResolver>(&mut self, dns_resolver: T) -> &mut Self {
self.dns_resolver = Arc::new(dns_resolver);
self
}
fn configure(&self, socket: &TcpStream) -> io::Result<()> {
if let Some(ttl) = self.ttl {
socket.set_ttl(ttl)?;
}
if let Some(nodelay) = self.nodelay {
socket.set_nodelay(nodelay)?;
}
Ok(())
}
}
impl Default for TcpTransport {
fn default() -> Self {
Self::new()
}
}
#[crate::async_trait]
impl Transport for TcpTransport {
type Error = io::Error;
type Listener = TcpInbound;
type Output = TcpStream;
async fn listen(&self, addr: &Multiaddr) -> Result<(Self::Listener, Multiaddr), Self::Error> {
let socket_addr = self
.dns_resolver
.resolve(addr.clone())
.await
.map_err(|err| io::Error::other(format!("Failed to resolve address: {err}")))?;
let listener = TcpListener::bind(&socket_addr).await?;
let local_addr = socketaddr_to_multiaddr(&listener.local_addr()?);
Ok((TcpInbound::new(self.clone(), listener), local_addr))
}
async fn dial(&self, addr: &Multiaddr) -> Result<Self::Output, Self::Error> {
let socket_addr = self
.dns_resolver
.resolve(addr.clone())
.await
.map_err(|err| io::Error::other(format!("Address resolution failed: {err}")))?;
let socket = TcpOutbound::new(TcpStream::connect(socket_addr).boxed(), self.clone()).await?;
Ok(socket)
}
fn supported_protocols(&self) -> Vec<TransportProtocol> {
self.supported_protocols.clone()
}
}
pub struct TcpOutbound<F> {
future: F,
config: TcpTransport,
}
impl<F> TcpOutbound<F> {
pub fn new(future: F, config: TcpTransport) -> Self {
Self { future, config }
}
}
impl<F> Future for TcpOutbound<F>
where F: Future<Output = io::Result<TcpStream>> + Unpin
{
type Output = io::Result<TcpStream>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let stream = ready!(Pin::new(&mut self.future).poll(cx))?;
self.config.configure(&stream)?;
Poll::Ready(Ok(stream))
}
}
pub struct TcpInbound {
listener: TcpListener,
config: TcpTransport,
}
impl TcpInbound {
pub fn new(config: TcpTransport, listener: TcpListener) -> Self {
Self { listener, config }
}
}
impl Stream for TcpInbound {
type Item = io::Result<(TcpStream, Multiaddr)>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let (socket, addr) = ready!(self.listener.poll_accept(cx))?;
self.config.configure(&socket)?;
let peer_addr = socketaddr_to_multiaddr(&addr);
Poll::Ready(Some(Ok((socket, peer_addr))))
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn configure() {
let mut tcp = TcpTransport::new();
tcp.set_nodelay(true).set_ttl(789);
assert_eq!(tcp.nodelay, Some(true));
assert_eq!(tcp.ttl, Some(789));
}
}