use std::{
future::Future,
io::{self, ErrorKind},
net::{SocketAddr, ToSocketAddrs},
os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, OwnedFd},
pin::Pin,
task::{Context, Poll},
};
use io_uring::{opcode, types};
use libc::{AF_INET, AF_INET6, SOCK_STREAM};
use tokio_stream::Stream;
use crate::reactor::{MultishotReactorIo, Reactor, ReactorIo};
use super::{
read::{AsyncRead, AsyncReader},
sock_addr::CSockAddr,
write::{AsyncWrite, AsyncWriter},
};
pub struct TcpListener {
inner: OwnedFd,
io: MultishotReactorIo,
}
fn mk_sock(addr: &SocketAddr) -> std::io::Result<OwnedFd> {
let family = if addr.is_ipv4() { AF_INET } else { AF_INET6 };
let sock = unsafe { libc::socket(family, SOCK_STREAM, 0) };
if sock == -1 {
Err(std::io::Error::last_os_error())?;
}
let sock = unsafe { OwnedFd::from_raw_fd(sock) };
Ok(sock)
}
impl TcpListener {
pub fn bind(addrs: impl ToSocketAddrs) -> std::io::Result<Self> {
let addrs = addrs.to_socket_addrs()?;
let mut last_err = ErrorKind::NotFound.into();
for addr in addrs {
let sock = mk_sock(&addr)?;
let caddr: CSockAddr = addr.into();
if unsafe { libc::bind(sock.as_raw_fd(), caddr.as_ptr(), caddr.len as _) } == -1 {
last_err = std::io::Error::last_os_error();
continue;
}
match unsafe { libc::listen(sock.as_raw_fd(), 1024) } {
-1 => last_err = std::io::Error::last_os_error(),
0 => {
return Ok(Self {
inner: sock,
io: Reactor::new_multishot_io(),
})
}
_ => unreachable!("listen() cannot return a value other than 0 or -1"),
}
}
Err(last_err)
}
}
impl Stream for TcpListener {
type Item = std::io::Result<TcpStream>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
this.io
.submit_or_get_result(|| {
(
opcode::AcceptMulti::new(types::Fd(this.inner.as_raw_fd())).build(),
cx.waker().clone(),
)
})
.map(|x| {
x.map(|x| {
x.map(|fd| TcpStream {
inner: unsafe { OwnedFd::from_raw_fd(fd) },
})
})
})
}
}
pub struct TcpStream {
inner: OwnedFd,
}
impl TcpStream {
pub async fn connect<A: ToSocketAddrs>(addrs: A) -> std::io::Result<Self> {
let addrs = addrs.to_socket_addrs()?;
let mut last_err: std::io::Error = ErrorKind::InvalidData.into();
for addr in addrs {
let sock = mk_sock(&addr)?;
let connect = SockConnect {
fd: sock.as_fd(),
io: Reactor::new_io(),
addr,
};
match connect.await {
Ok(()) => return Ok(Self { inner: sock }),
Err(e) => last_err = e,
}
}
Err(last_err)
}
}
struct SockConnect<'fd> {
fd: BorrowedFd<'fd>,
io: ReactorIo,
addr: SocketAddr,
}
impl Future for SockConnect<'_> {
type Output = io::Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let addr: CSockAddr = self.addr.into();
let entry =
opcode::Connect::new(types::Fd(self.fd.as_raw_fd()), addr.as_ptr(), addr.len as _);
self.io
.submit_or_get_result(|| (entry.build(), cx.waker().clone()))
.map(|x| x.map(|_| ()))
}
}
impl AsyncRead for TcpStream {
fn read(&mut self, buf: &mut [u8]) -> impl Future<Output = io::Result<usize>> {
AsyncReader {
fd: self.inner.as_fd(),
io: Reactor::new_io(),
buf,
seekable: false,
}
}
}
impl AsyncWrite for TcpStream {
fn write(&mut self, buf: &[u8]) -> impl Future<Output = io::Result<usize>> {
AsyncWriter {
fd: self.inner.as_fd(),
io: Reactor::new_io(),
buf,
seekable: false,
}
}
}