#[cfg(unix)]
use std::os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::prelude::{AsRawHandle, FromRawSocket, RawHandle};
use std::{
cell::UnsafeCell,
future::Future,
io,
net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
};
use super::stream::TcpStream;
use crate::{
driver::{op::Op, shared_fd::SharedFd},
io::stream::Stream,
net::ListenerConfig,
};
pub struct TcpListener {
fd: SharedFd,
sys_listener: Option<std::net::TcpListener>,
meta: UnsafeCell<ListenerMeta>,
}
impl TcpListener {
pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
#[cfg(unix)]
let sys_listener = unsafe { std::net::TcpListener::from_raw_fd(fd.raw_fd()) };
#[cfg(windows)]
let sys_listener = unsafe { std::net::TcpListener::from_raw_socket(todo!()) };
Self {
fd,
sys_listener: Some(sys_listener),
meta: UnsafeCell::new(ListenerMeta::default()),
}
}
pub fn bind_with_config<A: ToSocketAddrs>(
addr: A,
config: &ListenerConfig,
) -> io::Result<Self> {
let addr = addr
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?;
let domain = if addr.is_ipv6() {
socket2::Domain::IPV6
} else {
socket2::Domain::IPV4
};
let sys_listener =
socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;
#[cfg(all(unix, feature = "legacy"))]
Self::set_non_blocking(&sys_listener)?;
let addr = socket2::SockAddr::from(addr);
#[cfg(unix)]
if config.reuse_port {
sys_listener.set_reuse_port(true)?;
}
if config.reuse_addr {
sys_listener.set_reuse_address(true)?;
}
if let Some(send_buf_size) = config.send_buf_size {
sys_listener.set_send_buffer_size(send_buf_size)?;
}
if let Some(recv_buf_size) = config.recv_buf_size {
sys_listener.set_recv_buffer_size(recv_buf_size)?;
}
sys_listener.bind(&addr)?;
sys_listener.listen(config.backlog)?;
#[cfg(unix)]
let fd = SharedFd::new(sys_listener.into_raw_fd())?;
#[cfg(windows)]
let fd = unimplemented!();
Ok(Self::from_shared_fd(fd))
}
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let cfg = ListenerConfig::default();
Self::bind_with_config(addr, &cfg)
}
#[cfg(unix)]
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
let op = Op::accept(&self.fd)?;
let completion = op.await;
let fd = completion.meta.result?;
let stream = TcpStream::from_shared_fd(SharedFd::new(fd as _)?);
let storage = completion.data.addr.0.as_ptr() as *const _ as *const libc::sockaddr_storage;
let addr = unsafe {
match (*storage).ss_family as libc::c_int {
libc::AF_INET => {
let addr: &libc::sockaddr_in = &*(storage as *const libc::sockaddr_in);
let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
let port = u16::from_be(addr.sin_port);
SocketAddr::V4(SocketAddrV4::new(ip, port))
}
libc::AF_INET6 => {
let addr: &libc::sockaddr_in6 = &*(storage as *const libc::sockaddr_in6);
let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
let port = u16::from_be(addr.sin6_port);
SocketAddr::V6(SocketAddrV6::new(
ip,
port,
addr.sin6_flowinfo,
addr.sin6_scope_id,
))
}
_ => {
return Err(io::ErrorKind::InvalidInput.into());
}
}
};
Ok((stream, addr))
}
#[cfg(windows)]
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
unimplemented!()
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
let meta = self.meta.get();
if let Some(addr) = unsafe { &*meta }.local_addr {
return Ok(addr);
}
self.sys_listener
.as_ref()
.unwrap()
.local_addr()
.map(|addr| {
unsafe { &mut *meta }.local_addr = Some(addr);
addr
})
}
#[cfg(all(unix, feature = "legacy"))]
fn set_non_blocking(_socket: &socket2::Socket) -> io::Result<()> {
crate::driver::CURRENT.with(|x| match x {
#[cfg(all(target_os = "linux", feature = "iouring"))]
crate::driver::Inner::Uring(_) => Ok(()),
crate::driver::Inner::Legacy(_) => _socket.set_nonblocking(true),
})
}
}
impl Stream for TcpListener {
type Item = io::Result<(TcpStream, SocketAddr)>;
type NextFuture<'a> = impl Future<Output = Option<Self::Item>> + 'a;
fn next(&mut self) -> Self::NextFuture<'_> {
async move { Some(self.accept().await) }
}
}
impl std::fmt::Debug for TcpListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpListener").field("fd", &self.fd).finish()
}
}
#[cfg(unix)]
impl AsRawFd for TcpListener {
fn as_raw_fd(&self) -> RawFd {
self.fd.raw_fd()
}
}
#[cfg(windows)]
impl AsRawHandle for TcpListener {
fn as_raw_handle(&self) -> RawHandle {
self.fd.raw_handle()
}
}
impl Drop for TcpListener {
fn drop(&mut self) {
#[cfg(unix)]
self.sys_listener.take().unwrap().into_raw_fd();
#[cfg(windows)]
unimplemented!()
}
}
#[derive(Debug, Default, Clone)]
struct ListenerMeta {
local_addr: Option<SocketAddr>,
}