use std::net::SocketAddr;
use std::sync::Arc;
use std::time::Duration;
use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use async_trait::async_trait;
use tokio::net::TcpSocket;
use tokio::time::timeout;
use crate::resolve::Target;
#[derive(Debug, Clone)]
pub struct TransportFactory {
connect_timeout: Duration,
reserved_v4: Arc<TcpSocket>,
reserved_v6: Arc<TcpSocket>,
}
impl TransportFactory {
pub fn try_new(connect_timeout: Duration) -> Result<Self> {
let reserved_v4 =
reserve_socket_v4().context("failed reserving local v4 socket address")?;
let reserved_v6 =
reserve_socket_v6().context("failed reserving local v6 socket address")?;
Ok(Self {
connect_timeout,
reserved_v4: reserved_v4.into(),
reserved_v6: reserved_v6.into(),
})
}
}
#[async_trait]
impl super::TransportFactory for TransportFactory {
async fn setup(&self, target: &Target) -> Result<super::Transport> {
let addr = match target {
Target::SocketAddr(addr) => *addr,
unsupported => bail!("TcpTransportFactory: unsupported target: {unsupported}"),
};
let local_addr = match addr {
SocketAddr::V4(_) => self
.reserved_v4
.local_addr()
.context("unable to get local v4 socket addr")?,
SocketAddr::V6(_) => self
.reserved_v6
.local_addr()
.context("unable to get local v6 socket addr")?,
};
let socket =
new_reuseport_socket(local_addr).context("unable to build local reusable socket")?;
let tcp = timeout(self.connect_timeout, socket.connect(addr))
.await
.context("TCP connect timed out")?
.context("TCP connect failed")?;
Ok(super::Transport::Tcp(tcp))
}
}
fn reserve_socket_v4() -> Result<TcpSocket> {
let ip = std::net::Ipv4Addr::UNSPECIFIED;
let unspec = std::net::SocketAddrV4::new(ip, 0);
new_reuseport_socket(SocketAddr::from(unspec))
}
fn reserve_socket_v6() -> Result<TcpSocket> {
let ip = std::net::Ipv6Addr::UNSPECIFIED;
let unspec = std::net::SocketAddrV6::new(ip, 0, 0, 0);
new_reuseport_socket(SocketAddr::from(unspec))
}
fn new_reuseport_socket(local_addr: SocketAddr) -> Result<TcpSocket> {
let socket = match local_addr {
SocketAddr::V4(_addr) => {
TcpSocket::new_v4().context("failed creating new v4 TCP socket")?
}
SocketAddr::V6(_addr) => {
TcpSocket::new_v6().context("failed creating new v6 TCP socket")?
}
};
#[cfg(not(target_os = "windows"))]
socket
.set_reuseport(true)
.context("unable to set SO_REUSEPORT")?;
#[cfg(target_os = "windows")]
socket
.set_reuseaddr(true)
.context("unable to set SO_REUSEADDR")?;
socket
.bind(local_addr)
.context("unable to bind local address")?;
Ok(socket)
}