use alloc::boxed::Box;
use alloc::format;
use alloc::sync::Arc;
use alloc::vec;
use alloc::vec::Vec;
use core::cell::UnsafeCell;
use core::mem;
use core::net::SocketAddr;
use core::pin::Pin;
use core::task::{Context, Poll};
use futures::{AsyncRead, AsyncWrite, Future, ready};
use crate::linux::io_uring::ffi::{SOCK_CLOEXEC, SOCK_NONBLOCK, SocketDomain, SocketType};
use crate::linux::io_uring::{
Close, Fd, IoUring, IoVec, MsgHdr, Read, RecvMsg, SendMsg, Write, socket_addr_to_dual_stack,
};
use crate::linux::net::{NetworkError, Result, SocketBufferAllocation, get_buffer_pool};
use crate::linux::sys::{self, Errno};
use crate::{linux, net};
pub struct Listener {
ring: Arc<IoUring>,
fd: Fd,
local_addr: SocketAddr,
}
struct SendFuture {
ring: Arc<IoUring>,
fd: Fd,
buf: Vec<u8>,
state: UnsafeCell<
Option<Pin<Box<dyn Future<Output = crate::linux::io_uring::Result<usize>> + Send>>>,
>,
}
struct RecvFuture {
ring: Arc<IoUring>,
fd: Fd,
buf_len: usize,
state: UnsafeCell<
Option<(
Pin<Box<dyn Future<Output = crate::linux::io_uring::Result<usize>> + Send>>,
Vec<u8>,
)>,
>,
}
struct CloseFuture {
ring: Arc<IoUring>,
fd: Fd,
state: UnsafeCell<
Option<Pin<Box<dyn Future<Output = crate::linux::io_uring::Result<()>> + Send>>>,
>,
}
pub struct Stream {
ring: Arc<IoUring>,
fd: Fd,
local_addr: SocketAddr,
peer_addr: SocketAddr,
buffer_allocation: Option<SocketBufferAllocation>,
current_send: UnsafeCell<Option<SendFuture>>,
current_recv: UnsafeCell<Option<RecvFuture>>,
current_close: UnsafeCell<Option<CloseFuture>>,
}
unsafe impl Send for Stream {}
unsafe impl Sync for Stream {}
impl AsyncWrite for Stream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<futures::io::Result<usize>> {
unsafe {
let this = self.get_unchecked_mut();
let current_send = &mut *this.current_send.get();
let send_future = SendFuture {
ring: this.ring.clone(),
fd: this.fd,
buf: buf.to_vec(),
state: UnsafeCell::new(None),
};
*current_send = Some(send_future);
if let Some(send_op) = current_send {
let state = &mut *send_op.state.get();
if state.is_none() {
let buf_ptr = send_op.buf.as_ptr();
let buf_len = send_op.buf.len();
let ring = send_op.ring.clone();
let fd = send_op.fd;
let mut iovec = IoVec {
base: buf_ptr as *mut u8,
len: buf_len,
};
let mut msghdr = MsgHdr {
name: core::ptr::null_mut(),
namelen: 0,
iov: &mut iovec as *mut IoVec,
iovlen: 1,
control: core::ptr::null_mut(),
controllen: 0,
flags: 0,
};
let fut = Box::pin(async move {
ring.sendmsg(fd, &msghdr, 0).await.await
});
*state = Some(fut);
}
match state.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready(Ok(n)) => {
*current_send = None; Poll::Ready(Ok(n))
}
Poll::Ready(Err(e)) => {
*current_send = None; Poll::Ready(Err(futures::io::Error::new(
futures::io::ErrorKind::Other,
format!("io_uring sendmsg error: {}", e),
)))
}
Poll::Pending => Poll::Pending,
}
} else {
unreachable!("Just created send operation");
}
}
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<futures::io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<futures::io::Result<()>> {
unsafe {
let this = self.get_unchecked_mut();
let current_close = &mut *this.current_close.get();
if current_close.is_none() {
*current_close = Some(CloseFuture {
ring: this.ring.clone(),
fd: this.fd,
state: UnsafeCell::new(None),
});
}
if let Some(close_op) = current_close {
let state = &mut *close_op.state.get();
if state.is_none() {
let ring = close_op.ring.clone();
let fd = close_op.fd;
let fut = Box::pin(async move { ring.close(fd).await.await });
*state = Some(fut);
}
match state.as_mut().unwrap().as_mut().poll(cx) {
Poll::Ready(Ok(())) => {
*current_close = None;
Poll::Ready(Ok(()))
}
Poll::Ready(Err(e)) => {
*current_close = None;
Poll::Ready(Err(futures::io::Error::new(
futures::io::ErrorKind::Other,
format!("io_uring close error: {}", e),
)))
}
Poll::Pending => Poll::Pending,
}
} else {
unreachable!("Just created close operation");
}
}
}
}
impl AsyncRead for Stream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<futures::io::Result<usize>> {
unsafe {
let this = self.get_unchecked_mut();
let current_recv = &mut *this.current_recv.get();
let recv_future = RecvFuture {
ring: this.ring.clone(),
fd: this.fd,
buf_len: buf.len(),
state: UnsafeCell::new(None),
};
*current_recv = Some(recv_future);
if let Some(recv_op) = current_recv {
let state = &mut *recv_op.state.get();
if state.is_none() {
let buf_len = recv_op.buf_len;
let ring = recv_op.ring.clone();
let fd = recv_op.fd;
let mut recv_buf = vec![0u8; buf_len];
let buf_ptr = recv_buf.as_mut_ptr();
let mut iovec = IoVec {
base: buf_ptr,
len: buf_len,
};
let mut msghdr = MsgHdr {
name: core::ptr::null_mut(),
namelen: 0,
iov: &mut iovec as *mut IoVec,
iovlen: 1,
control: core::ptr::null_mut(),
controllen: 0,
flags: 0,
};
let fut = Box::pin(async move {
ring.recvmsg(fd, &mut msghdr).await.await
});
*state = Some((fut, recv_buf));
}
let (fut, recv_buf) = state.as_mut().unwrap();
match fut.as_mut().poll(cx) {
Poll::Ready(Ok(n)) => {
buf[..n].copy_from_slice(&recv_buf[..n]);
*current_recv = None; Poll::Ready(Ok(n))
}
Poll::Ready(Err(e)) => {
*current_recv = None; Poll::Ready(Err(futures::io::Error::new(
futures::io::ErrorKind::Other,
format!("io_uring recvmsg error: {}", e),
)))
}
Poll::Pending => Poll::Pending,
}
} else {
unreachable!("Just created recv operation");
}
}
}
}
impl net::tcp::Stream<linux::runtime::Runtime, linux::runtime::Share> for Stream {
fn connect(addr: core::net::SocketAddr) -> impl Future<Output = net::Result<Self>>
where
Self: Sized,
{
async move {
let ring = Arc::new(IoUring::with_capacity(256).map_err(|e| {
NetworkError::Internal(format!("Failed to create io_uring: {}", e))
})?);
let domain = match addr {
SocketAddr::V4(_) => SocketDomain::Inet as i32,
SocketAddr::V6(_) => SocketDomain::Inet6 as i32,
};
let fd = ring
.socket(
domain,
SocketType::Stream as i32 | SOCK_NONBLOCK | SOCK_CLOEXEC,
0,
)
.await
.await
.map_err(|e| NetworkError::Internal(format!("Failed to create socket: {}", e)))?;
let (sock_addr, size) = socket_addr_to_dual_stack(addr);
ring.connect(fd, sock_addr)
.await
.await
.map_err(|e| NetworkError::ConnectionRefused)?;
let local_addr = addr;
Ok(Stream {
ring,
fd,
local_addr,
peer_addr: addr,
buffer_allocation: None,
current_send: UnsafeCell::new(None),
current_recv: UnsafeCell::new(None),
current_close: UnsafeCell::new(None),
})
}
}
fn local_addr(&self) -> net::Result<core::net::SocketAddr> {
Ok(self.local_addr)
}
fn peer_addr(&self) -> net::Result<core::net::SocketAddr> {
Ok(self.peer_addr)
}
}
impl net::tcp::Listener<linux::runtime::Runtime, linux::runtime::Share> for Listener {
fn bind(addr: core::net::SocketAddr) -> impl Future<Output = net::Result<Self>>
where
Self: Sized,
{
async move {
let ring = Arc::new(IoUring::with_capacity(256).map_err(|e| {
NetworkError::Internal(format!("Failed to create io_uring: {}", e))
})?);
let domain = match addr {
SocketAddr::V4(_) => SocketDomain::Inet as i32,
SocketAddr::V6(_) => SocketDomain::Inet6 as i32,
};
let fd = ring
.socket(
domain,
SocketType::Stream as i32 | SOCK_NONBLOCK | SOCK_CLOEXEC,
0,
)
.await
.await
.map_err(|e| NetworkError::Internal(format!("Failed to create socket: {}", e)))?;
let (sock_addr, addr_len) = socket_addr_to_dual_stack(addr);
unsafe {
sys::bind(
*fd,
&sock_addr as *const _ as *const sys::SockAddr,
addr_len as u32,
)
.map_err(|e| NetworkError::AddressInUse)?;
}
unsafe {
sys::listen(*fd, 128)
.map_err(|e| NetworkError::Internal(format!("Listen failed: {}", e)))?;
}
Ok(Listener {
ring,
fd,
local_addr: addr,
})
}
}
fn accept(&self) -> impl Future<Output = net::Result<(Stream, core::net::SocketAddr)>> {
async move {
let (client_fd, sock_addr) = self
.ring
.accept(self.fd)
.await
.await
.map_err(|e| NetworkError::Internal(format!("Accept failed: {}", e)))?;
let peer_addr = self.local_addr;
let stream = Stream {
ring: self.ring.clone(),
fd: client_fd,
local_addr: self.local_addr,
peer_addr,
buffer_allocation: None,
current_send: UnsafeCell::new(None),
current_recv: UnsafeCell::new(None),
current_close: UnsafeCell::new(None),
};
Ok((stream, peer_addr))
}
}
fn local_addr(&self) -> net::Result<core::net::SocketAddr> {
Ok(self.local_addr)
}
}