use core::{
net::{SocketAddr, SocketAddrV4, SocketAddrV6},
pin::Pin,
sync::atomic::{AtomicBool, Ordering},
task::{Context, Poll},
};
use pin_project_lite::pin_project;
use std::io::{Error, ErrorKind, Result};
use tokio::net::{self, TcpListener, TcpSocket, TcpStream, ToSocketAddrs};
mod private {
#[expect(unnameable_types, reason = "want Tcp to be 'sealed'")]
pub trait Sealed {}
}
use private::Sealed;
pub trait Tcp: Sealed + Sized {
fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>>;
fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync;
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>>;
}
impl Sealed for TcpListener {}
impl Tcp for TcpListener {
#[inline]
fn bind<A: ToSocketAddrs>(addr: A) -> impl Future<Output = Result<Self>> {
Self::bind(addr)
}
#[inline]
fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
self.accept()
}
#[inline]
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
self.poll_accept(cx)
}
}
#[derive(Debug)]
pub struct DualStackTcpListener {
ip6: TcpListener,
ip4: TcpListener,
ip6_first: AtomicBool,
}
impl DualStackTcpListener {
#[inline]
pub fn from_sockets(
(socket_1, backlog_1): (TcpSocket, u32),
(socket_2, backlog_2): (TcpSocket, u32),
) -> Result<Self> {
socket_1.local_addr().and_then(|sock| {
socket_2.local_addr().and_then(|sock_2| {
if sock.is_ipv6() {
if sock_2.is_ipv4() {
socket_1.listen(backlog_1).and_then(|ip6| {
socket_2.listen(backlog_2).map(|ip4| Self {
ip6,
ip4,
ip6_first: AtomicBool::new(true),
})
})
} else {
Err(Error::new(
ErrorKind::InvalidData,
"TcpSockets are the same IP version",
))
}
} else if sock_2.is_ipv6() {
socket_1.listen(backlog_1).and_then(|ip4| {
socket_2.listen(backlog_2).map(|ip6| Self {
ip6,
ip4,
ip6_first: AtomicBool::new(true),
})
})
} else {
Err(Error::new(
ErrorKind::InvalidData,
"TcpSockets are the same IP version",
))
}
})
})
}
#[expect(clippy::unreachable, reason = "we want to crash when there is a bug")]
#[inline]
pub fn local_addr(&self) -> Result<(SocketAddrV6, SocketAddrV4)> {
self.ip6.local_addr().and_then(|ip6| {
self.ip4.local_addr().map(|ip4| {
(
if let SocketAddr::V6(sock6) = ip6 {
sock6
} else {
unreachable!("there is a bug in DualStackTcpListener::bind")
},
if let SocketAddr::V4(sock4) = ip4 {
sock4
} else {
unreachable!("there is a bug in DualStackTcpListener::bind")
},
)
})
})
}
#[inline]
pub fn set_ttl(&self, ttl_ip6: u32, ttl_ip4: u32) -> Result<()> {
self.ip6
.set_ttl(ttl_ip6)
.and_then(|()| self.ip4.set_ttl(ttl_ip4))
}
#[inline]
pub fn ttl(&self) -> Result<(u32, u32)> {
self.ip6
.ttl()
.and_then(|ip6| self.ip4.ttl().map(|ip4| (ip6, ip4)))
}
}
pin_project! {
struct AcceptFut<
F: Future<Output = Result<(TcpStream, SocketAddr)>>,
F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
> {
#[pin]
fut_1: F,
#[pin]
fut_2: F2,
}
}
impl<
F: Future<Output = Result<(TcpStream, SocketAddr)>>,
F2: Future<Output = Result<(TcpStream, SocketAddr)>>,
> Future for AcceptFut<F, F2>
{
type Output = Result<(TcpStream, SocketAddr)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
match this.fut_1.poll(cx) {
Poll::Ready(res) => Poll::Ready(res),
Poll::Pending => this.fut_2.poll(cx),
}
}
}
impl Sealed for DualStackTcpListener {}
impl Tcp for DualStackTcpListener {
#[inline]
async fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
match net::lookup_host(addr).await {
Ok(socks) => {
let mut last_err = None;
let mut ip6_opt = None;
let mut ip4_opt = None;
for sock in socks {
match ip6_opt {
None => match ip4_opt {
None => {
let is_ip6 = sock.is_ipv6();
match TcpListener::bind(sock).await {
Ok(ip) => {
if is_ip6 {
ip6_opt = Some(ip);
} else {
ip4_opt = Some(ip);
}
}
Err(err) => last_err = Some(err),
}
}
Some(ip4) => {
if sock.is_ipv6() {
match TcpListener::bind(sock).await {
Ok(ip6) => {
return Ok(Self {
ip6,
ip4,
ip6_first: AtomicBool::new(true),
});
}
Err(err) => last_err = Some(err),
}
}
ip4_opt = Some(ip4);
}
},
Some(ip6) => {
if sock.is_ipv4() {
match TcpListener::bind(sock).await {
Ok(ip4) => {
return Ok(Self {
ip6,
ip4,
ip6_first: AtomicBool::new(true),
});
}
Err(err) => last_err = Some(err),
}
}
ip6_opt = Some(ip6);
}
}
}
Err(last_err.unwrap_or_else(|| {
Error::new(
ErrorKind::InvalidInput,
"could not resolve to an IPv6 and IPv4 address",
)
}))
}
Err(err) => Err(err),
}
}
#[inline]
fn accept(&self) -> impl Future<Output = Result<(TcpStream, SocketAddr)>> + Send + Sync {
if self.ip6_first.swap(false, Ordering::Relaxed) {
AcceptFut {
fut_1: self.ip6.accept(),
fut_2: self.ip4.accept(),
}
} else {
self.ip6_first.store(true, Ordering::Relaxed);
AcceptFut {
fut_1: self.ip4.accept(),
fut_2: self.ip6.accept(),
}
}
}
#[inline]
fn poll_accept(&self, cx: &mut Context<'_>) -> Poll<Result<(TcpStream, SocketAddr)>> {
if self.ip6_first.swap(false, Ordering::Relaxed) {
self.ip6.poll_accept(cx)
} else {
self.ip6_first.store(true, Ordering::Relaxed);
self.ip4.poll_accept(cx)
}
}
}