use std::net::SocketAddr;
use anyhow::Context;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use tokio::net::{TcpListener, TcpStream};
use tracing::info;
const TFO_QUEUE_LEN: i32 = 256;
pub async fn create_listener(addr: &str, fast_open: bool) -> anyhow::Result<TcpListener> {
let sock_addr: SocketAddr = addr
.parse()
.with_context(|| format!("invalid listen address: {addr}"))?;
let domain = if sock_addr.is_ipv6() {
Domain::IPV6
} else {
Domain::IPV4
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.context("failed to create socket")?;
socket.set_reuse_address(true)?;
if fast_open {
set_tcp_fastopen(&socket, TFO_QUEUE_LEN)?;
info!("TCP Fast Open enabled on listener (queue={TFO_QUEUE_LEN})");
}
socket
.bind(&SockAddr::from(sock_addr))
.with_context(|| format!("bind {addr}"))?;
socket
.listen(1024)
.with_context(|| format!("listen {addr}"))?;
socket.set_nonblocking(true)?;
TcpListener::from_std(socket.into()).context("convert to tokio TcpListener")
}
pub async fn connect(addr: &str, fast_open: bool) -> anyhow::Result<TcpStream> {
if !fast_open {
return TcpStream::connect(addr)
.await
.with_context(|| format!("connect to {addr}"));
}
let sock_addr: SocketAddr = tokio::net::lookup_host(addr)
.await
.with_context(|| format!("resolve {addr}"))?
.next()
.with_context(|| format!("no addresses for {addr}"))?;
let domain = if sock_addr.is_ipv6() {
Domain::IPV6
} else {
Domain::IPV4
};
let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))
.context("failed to create socket")?;
set_tcp_fastopen_connect(&socket)?;
socket.set_nonblocking(true)?;
match socket.connect(&SockAddr::from(sock_addr)) {
Ok(()) => {}
Err(e) if e.raw_os_error() == Some(libc::EINPROGRESS) => {}
Err(e) => return Err(e).with_context(|| format!("connect to {addr}")),
}
let std_stream: std::net::TcpStream = socket.into();
let stream = TcpStream::from_std(std_stream).context("convert to tokio TcpStream")?;
stream
.writable()
.await
.with_context(|| format!("connect to {addr}"))?;
if let Some(e) = stream.take_error()? {
return Err(e).with_context(|| format!("connect to {addr}"));
}
Ok(stream)
}
fn set_tcp_fastopen(socket: &Socket, queue_len: i32) -> anyhow::Result<()> {
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
let fd = socket.as_raw_fd();
let val = queue_len;
let ret = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
libc::TCP_FASTOPEN,
&val as *const _ as *const libc::c_void,
std::mem::size_of_val(&val) as libc::socklen_t,
)
};
if ret != 0 {
return Err(std::io::Error::last_os_error()).context("setsockopt TCP_FASTOPEN");
}
}
#[cfg(target_os = "macos")]
{
use std::os::unix::io::AsRawFd;
let fd = socket.as_raw_fd();
let val: i32 = 1; let ret = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
0x105, &val as *const _ as *const libc::c_void,
std::mem::size_of_val(&val) as libc::socklen_t,
)
};
if ret != 0 {
return Err(std::io::Error::last_os_error()).context("setsockopt TCP_FASTOPEN");
}
}
let _ = queue_len;
Ok(())
}
fn set_tcp_fastopen_connect(socket: &Socket) -> anyhow::Result<()> {
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
let fd = socket.as_raw_fd();
let val: i32 = 1;
let ret = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
30, &val as *const _ as *const libc::c_void,
std::mem::size_of_val(&val) as libc::socklen_t,
)
};
if ret != 0 {
return Err(std::io::Error::last_os_error()).context("setsockopt TCP_FASTOPEN_CONNECT");
}
}
#[cfg(target_os = "macos")]
{
use std::os::unix::io::AsRawFd;
let fd = socket.as_raw_fd();
let val: i32 = 1;
let ret = unsafe {
libc::setsockopt(
fd,
libc::IPPROTO_TCP,
0x105, &val as *const _ as *const libc::c_void,
std::mem::size_of_val(&val) as libc::socklen_t,
)
};
if ret != 0 {
return Err(std::io::Error::last_os_error()).context("setsockopt TCP_FASTOPEN");
}
}
Ok(())
}