use std::io::{self, Read, Write};
use std::net::SocketAddr;
use std::os::fd::{AsFd, AsRawFd, BorrowedFd, FromRawFd, RawFd};
use std::pin::Pin;
use std::task::{Context, Poll, Waker};
use std::time::Duration;
use mio::{Interest, Token};
use super::{AsyncRead, AsyncWrite, waker_to_ptr};
use crate::io::IoHandle;
pub struct TcpStream {
inner: mio::net::TcpStream,
io: IoHandle,
token: Option<Token>,
registered_task: *mut u8,
}
impl TcpStream {
pub(crate) fn new(inner: mio::net::TcpStream, io: IoHandle) -> Self {
Self {
inner,
io,
token: None,
registered_task: std::ptr::null_mut(),
}
}
pub fn connect(addr: SocketAddr, io: IoHandle) -> io::Result<Self> {
let inner = mio::net::TcpStream::connect(addr)?;
Ok(Self::new(inner, io))
}
pub fn from_std(stream: std::net::TcpStream, io: IoHandle) -> io::Result<Self> {
let inner = mio::net::TcpStream::from_std(stream);
Ok(Self::new(inner, io))
}
pub fn into_std(mut self) -> io::Result<std::net::TcpStream> {
if let Some(token) = self.token.take() {
let _ = unsafe { self.io.deregister(&mut self.inner, token) };
}
let fd = self.inner.as_raw_fd();
std::mem::forget(self); Ok(unsafe { std::net::TcpStream::from_raw_fd(fd) })
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.peer_addr()
}
fn socket_ref(&self) -> socket2::SockRef<'_> {
socket2::SockRef::from(&self.inner)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.nodelay()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.set_nodelay(nodelay)
}
pub fn ttl(&self) -> io::Result<u32> {
self.socket_ref().ttl()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.socket_ref().set_ttl(ttl)
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.socket_ref().linger()
}
pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
self.socket_ref().set_linger(duration)
}
pub fn keepalive(&self) -> io::Result<bool> {
self.socket_ref().keepalive()
}
pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
self.socket_ref().set_keepalive(keepalive)
}
pub fn send_buffer_size(&self) -> io::Result<usize> {
self.socket_ref().send_buffer_size()
}
pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
self.socket_ref().set_send_buffer_size(size)
}
pub fn recv_buffer_size(&self) -> io::Result<usize> {
self.socket_ref().recv_buffer_size()
}
pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
self.socket_ref().set_recv_buffer_size(size)
}
pub fn take_error(&self) -> io::Result<Option<io::Error>> {
self.socket_ref().take_error()
}
pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
(&self.inner).read(buf)
}
pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
(&self.inner).write(buf)
}
pub fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
let buf = unsafe {
&mut *(buf as *mut [u8] as *mut [std::mem::MaybeUninit<u8>])
};
self.socket_ref().peek(buf)
}
pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_read(cx, buf)).await
}
pub async fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
std::future::poll_fn(|cx| Pin::new(&mut *self).poll_write(cx, buf)).await
}
pub async fn write_all(&mut self, mut buf: &[u8]) -> io::Result<()> {
while !buf.is_empty() {
let n = self.write(buf).await?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"failed to write whole buffer",
));
}
buf = &buf[n..];
}
Ok(())
}
pub fn poll_read_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if let Err(e) = self.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
if let Some(token) = self.token {
if self.io.readiness(token).readable {
return Poll::Ready(Ok(()));
}
}
Poll::Pending
}
pub fn poll_write_ready(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if let Err(e) = self.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
if let Some(token) = self.token {
if self.io.readiness(token).writable {
return Poll::Ready(Ok(()));
}
}
Poll::Pending
}
pub async fn readable(&mut self) -> io::Result<()> {
std::future::poll_fn(|cx| self.poll_read_ready(cx)).await
}
pub async fn writable(&mut self) -> io::Result<()> {
std::future::poll_fn(|cx| self.poll_write_ready(cx)).await
}
pub fn split(&mut self) -> (ReadHalf<'_>, WriteHalf<'_>) {
let ptr = self as *mut TcpStream;
(
ReadHalf { stream: ptr, _marker: std::marker::PhantomData },
WriteHalf { stream: ptr, _marker: std::marker::PhantomData },
)
}
pub fn into_split(self) -> (OwnedReadHalf, OwnedWriteHalf) {
use std::rc::Rc;
let shared = Rc::new(std::cell::UnsafeCell::new(self));
(
OwnedReadHalf {
stream: Rc::clone(&shared),
},
OwnedWriteHalf { stream: shared },
)
}
#[inline(always)]
fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
let task_ptr = waker_to_ptr(cx);
if let Some(token) = self.token {
if task_ptr != self.registered_task {
self.io.set_waker(token, cx.waker().clone());
self.registered_task = task_ptr;
}
return Ok(());
}
self.do_register(task_ptr, cx.waker().clone())
}
#[cold]
fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
let interest = Interest::READABLE | Interest::WRITABLE;
let token = self.io.register(&mut self.inner, interest, waker)?;
self.token = Some(token);
self.registered_task = task_ptr;
Ok(())
}
}
impl AsyncRead for TcpStream {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.read(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(token) = this.token {
this.io.clear_readable(token);
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.write(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(token) = this.token {
this.io.clear_writable(token);
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
if let Err(e) = this.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.inner.flush() {
Ok(()) => Poll::Ready(Ok(())),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Some(token) = this.token {
this.io.clear_writable(token);
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_shutdown(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
let this = self.get_mut();
match this.inner.shutdown(std::net::Shutdown::Write) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) if e.kind() == io::ErrorKind::NotConnected => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl std::fmt::Debug for TcpStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpStream")
.field("fd", &self.inner.as_raw_fd())
.field("registered", &self.token.is_some())
.finish()
}
}
impl AsFd for TcpStream {
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
impl AsRawFd for TcpStream {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl Drop for TcpStream {
fn drop(&mut self) {
if let Some(token) = self.token {
let _ = unsafe { self.io.deregister(&mut self.inner, token) };
}
}
}
pub struct ReadHalf<'a> {
stream: *mut TcpStream,
_marker: std::marker::PhantomData<&'a mut TcpStream>,
}
impl ReadHalf<'_> {
fn stream(&mut self) -> &mut TcpStream {
unsafe { &mut *self.stream }
}
}
impl AsyncRead for ReadHalf<'_> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Pin::new(this.stream()).poll_read(cx, buf)
}
}
pub struct WriteHalf<'a> {
stream: *mut TcpStream,
_marker: std::marker::PhantomData<&'a mut TcpStream>,
}
impl WriteHalf<'_> {
fn stream(&mut self) -> &mut TcpStream {
unsafe { &mut *self.stream }
}
}
impl AsyncWrite for WriteHalf<'_> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let this = self.get_mut();
Pin::new(this.stream()).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(this.stream()).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let this = self.get_mut();
Pin::new(this.stream()).poll_shutdown(cx)
}
}
pub struct OwnedReadHalf {
stream: std::rc::Rc<std::cell::UnsafeCell<TcpStream>>,
}
impl OwnedReadHalf {
pub fn reunite(self, write: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
if std::rc::Rc::ptr_eq(&self.stream, &write.stream) {
drop(write);
let cell = std::rc::Rc::try_unwrap(self.stream)
.map_err(|_| ReuniteError)?;
Ok(cell.into_inner())
} else {
Err(ReuniteError)
}
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
unsafe { &*self.stream.get() }.peer_addr()
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
unsafe { &*self.stream.get() }.local_addr()
}
}
impl AsyncRead for OwnedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let stream = unsafe { &mut *self.stream.get() };
Pin::new(stream).poll_read(cx, buf)
}
}
pub struct OwnedWriteHalf {
stream: std::rc::Rc<std::cell::UnsafeCell<TcpStream>>,
}
impl OwnedWriteHalf {
pub fn reunite(self, read: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
read.reunite(self)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
unsafe { &*self.stream.get() }.peer_addr()
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
unsafe { &*self.stream.get() }.local_addr()
}
}
impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let stream = unsafe { &mut *self.stream.get() };
Pin::new(stream).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream = unsafe { &mut *self.stream.get() };
Pin::new(stream).poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let stream = unsafe { &mut *self.stream.get() };
Pin::new(stream).poll_shutdown(cx)
}
}
#[derive(Debug)]
pub struct ReuniteError;
impl std::fmt::Display for ReuniteError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "halves do not belong to the same TcpStream")
}
}
impl std::error::Error for ReuniteError {}
pub struct TcpListener {
inner: mio::net::TcpListener,
io: IoHandle,
token: Option<Token>,
registered_task: *mut u8,
}
impl TcpListener {
pub fn bind(addr: SocketAddr, io: IoHandle) -> io::Result<Self> {
let inner = mio::net::TcpListener::bind(addr)?;
Ok(Self {
inner,
io,
token: None,
registered_task: std::ptr::null_mut(),
})
}
pub fn from_std(listener: std::net::TcpListener, io: IoHandle) -> io::Result<Self> {
let inner = mio::net::TcpListener::from_std(listener);
Ok(Self {
inner,
io,
token: None,
registered_task: std::ptr::null_mut(),
})
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn ttl(&self) -> io::Result<u32> {
socket2::SockRef::from(&self.inner).ttl()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
socket2::SockRef::from(&self.inner).set_ttl(ttl)
}
pub fn accept(&mut self) -> Accept<'_> {
Accept { listener: self }
}
#[inline(always)]
fn ensure_registered(&mut self, cx: &Context<'_>) -> io::Result<()> {
let task_ptr = waker_to_ptr(cx);
if let Some(token) = self.token {
if task_ptr != self.registered_task {
self.io.set_waker(token, cx.waker().clone());
self.registered_task = task_ptr;
}
return Ok(());
}
self.do_register(task_ptr, cx.waker().clone())
}
#[cold]
fn do_register(&mut self, task_ptr: *mut u8, waker: Waker) -> io::Result<()> {
let token = self.io.register(&mut self.inner, Interest::READABLE, waker)?;
self.token = Some(token);
self.registered_task = task_ptr;
Ok(())
}
}
impl std::fmt::Debug for TcpListener {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpListener")
.field("fd", &self.inner.as_raw_fd())
.field("registered", &self.token.is_some())
.finish()
}
}
impl AsFd for TcpListener {
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
impl AsRawFd for TcpListener {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
impl Drop for TcpListener {
fn drop(&mut self) {
if let Some(token) = self.token {
let _ = unsafe { self.io.deregister(&mut self.inner, token) };
}
}
}
pub struct Accept<'a> {
listener: &'a mut TcpListener,
}
impl std::future::Future for Accept<'_> {
type Output = io::Result<(TcpStream, SocketAddr)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
if let Err(e) = this.listener.ensure_registered(cx) {
return Poll::Ready(Err(e));
}
match this.listener.inner.accept() {
Ok((stream, addr)) => {
let tcp = TcpStream::new(stream, this.listener.io);
Poll::Ready(Ok((tcp, addr)))
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(e) => Poll::Ready(Err(e)),
}
}
}
pub struct TcpSocket {
inner: socket2::Socket,
}
impl TcpSocket {
pub fn new_v4() -> io::Result<Self> {
let inner = socket2::Socket::new(
socket2::Domain::IPV4,
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
inner.set_nonblocking(true)?;
Ok(Self { inner })
}
pub fn new_v6() -> io::Result<Self> {
let inner = socket2::Socket::new(
socket2::Domain::IPV6,
socket2::Type::STREAM,
Some(socket2::Protocol::TCP),
)?;
inner.set_nonblocking(true)?;
Ok(Self { inner })
}
pub fn set_reuseaddr(&self, reuseaddr: bool) -> io::Result<()> {
self.inner.set_reuse_address(reuseaddr)
}
pub fn reuseaddr(&self) -> io::Result<bool> {
self.inner.reuse_address()
}
#[cfg(unix)]
pub fn set_reuseport(&self, reuseport: bool) -> io::Result<()> {
self.inner.set_reuse_port(reuseport)
}
#[cfg(unix)]
pub fn reuseport(&self) -> io::Result<bool> {
self.inner.reuse_port()
}
pub fn set_keepalive(&self, keepalive: bool) -> io::Result<()> {
self.inner.set_keepalive(keepalive)
}
pub fn keepalive(&self) -> io::Result<bool> {
self.inner.keepalive()
}
pub fn set_nodelay(&self, nodelay: bool) -> io::Result<()> {
self.inner.set_nodelay(nodelay)
}
pub fn nodelay(&self) -> io::Result<bool> {
self.inner.nodelay()
}
pub fn set_linger(&self, duration: Option<Duration>) -> io::Result<()> {
self.inner.set_linger(duration)
}
pub fn linger(&self) -> io::Result<Option<Duration>> {
self.inner.linger()
}
pub fn set_send_buffer_size(&self, size: usize) -> io::Result<()> {
self.inner.set_send_buffer_size(size)
}
pub fn send_buffer_size(&self) -> io::Result<usize> {
self.inner.send_buffer_size()
}
pub fn set_recv_buffer_size(&self, size: usize) -> io::Result<()> {
self.inner.set_recv_buffer_size(size)
}
pub fn recv_buffer_size(&self) -> io::Result<usize> {
self.inner.recv_buffer_size()
}
pub fn set_ttl(&self, ttl: u32) -> io::Result<()> {
self.inner.set_ttl(ttl)
}
pub fn ttl(&self) -> io::Result<u32> {
self.inner.ttl()
}
pub fn bind(&self, addr: SocketAddr) -> io::Result<()> {
self.inner.bind(&addr.into())
}
pub fn connect(self, addr: SocketAddr, io: IoHandle) -> io::Result<TcpStream> {
match self.inner.connect(&addr.into()) {
Ok(()) => {}
Err(e)
if e.raw_os_error() == Some(libc::EINPROGRESS)
|| e.raw_os_error() == Some(libc::EALREADY) => {}
Err(e) => return Err(e),
}
let std_stream: std::net::TcpStream = self.inner.into();
let mio_stream = mio::net::TcpStream::from_std(std_stream);
Ok(TcpStream::new(mio_stream, io))
}
pub fn listen(self, backlog: i32, io: IoHandle) -> io::Result<TcpListener> {
self.inner.listen(backlog)?;
let std_listener: std::net::TcpListener = self.inner.into();
let mio_listener = mio::net::TcpListener::from_std(std_listener);
Ok(TcpListener {
inner: mio_listener,
io,
token: None,
registered_task: std::ptr::null_mut(),
})
}
}
impl std::fmt::Debug for TcpSocket {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TcpSocket")
.field("fd", &self.inner.as_raw_fd())
.finish()
}
}
impl AsFd for TcpSocket {
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
impl AsRawFd for TcpSocket {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{Runtime, spawn_boxed};
use nexus_rt::WorldBuilder;
use std::cell::Cell;
use std::rc::Rc;
#[test]
fn tcp_echo() {
let wb = WorldBuilder::new();
let mut world = wb.build();
let mut rt = Runtime::new(&mut world);
let done = Rc::new(Cell::new(false));
let done2 = done.clone();
rt.block_on(async move {
let listener = TcpListener::bind(
"127.0.0.1:0".parse().unwrap(),
crate::context::io(),
).expect("bind failed");
let addr = listener.local_addr().unwrap();
spawn_boxed(async move {
let mut listener = listener;
let (mut stream, _peer) = listener.accept().await.unwrap();
let mut buf = [0u8; 64];
let n = stream.read(&mut buf).await.unwrap();
stream.write_all(&buf[..n]).await.unwrap();
});
let io = crate::context::io();
let flag = done2;
spawn_boxed(async move {
crate::context::sleep(std::time::Duration::from_millis(10)).await;
let mut client = TcpStream::connect(addr, io).unwrap();
client.write_all(b"hello").await.unwrap();
let mut buf = [0u8; 64];
let n = client.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"hello");
flag.set(true);
});
crate::context::sleep(std::time::Duration::from_millis(500)).await;
});
assert!(done.get(), "echo exchange never completed");
}
#[test]
fn tcp_socket_builder() {
let socket = TcpSocket::new_v4().unwrap();
socket.set_reuseaddr(true).unwrap();
assert!(socket.reuseaddr().unwrap());
socket.set_nodelay(true).unwrap();
assert!(socket.nodelay().unwrap());
socket.set_send_buffer_size(65536).unwrap();
assert!(socket.send_buffer_size().unwrap() >= 65536);
}
}