use crate::cx::Cx;
use crate::io::{AsyncRead, AsyncReadVectored, AsyncWrite, ReadBuf};
use crate::runtime::io_driver::IoRegistration;
use crate::runtime::reactor::Interest;
use parking_lot::Mutex;
use std::io::{self, IoSliceMut, Read, Write};
use std::net::Shutdown;
use std::os::unix::net;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll, Waker};
#[derive(Debug)]
pub struct ReadHalf<'a> {
inner: &'a net::UnixStream,
}
impl<'a> ReadHalf<'a> {
pub(crate) fn new(inner: &'a net::UnixStream) -> Self {
Self { inner }
}
}
impl AsyncRead for ReadHalf<'_> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let mut inner = self.inner;
match inner.read(buf.unfilled()) {
Ok(n) => {
buf.advance(n);
Poll::Ready(Ok(()))
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
crate::net::tcp::stream::fallback_rewake(cx);
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncReadVectored for ReadHalf<'_> {
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let mut inner = self.inner;
match inner.read_vectored(bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
crate::net::tcp::stream::fallback_rewake(cx);
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
#[derive(Debug)]
pub struct WriteHalf<'a> {
inner: &'a net::UnixStream,
}
impl<'a> WriteHalf<'a> {
pub(crate) fn new(inner: &'a net::UnixStream) -> Self {
Self { inner }
}
}
impl AsyncWrite for WriteHalf<'_> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let mut inner = self.inner;
match inner.write(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
crate::net::tcp::stream::fallback_rewake(cx);
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let mut inner = self.inner;
match inner.write_vectored(bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
crate::net::tcp::stream::fallback_rewake(cx);
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn is_write_vectored(&self) -> bool {
true
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let mut inner = self.inner;
match inner.flush() {
Ok(()) => Poll::Ready(Ok(())),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
crate::net::tcp::stream::fallback_rewake(cx);
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
match self.inner.shutdown(Shutdown::Write) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) if e.kind() == io::ErrorKind::NotConnected => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(e)),
}
}
}
struct CombinedWaker {
read: Option<Waker>,
write: Option<Waker>,
}
use std::task::Wake;
impl Wake for CombinedWaker {
fn wake(self: Arc<Self>) {
self.wake_by_ref();
}
fn wake_by_ref(self: &Arc<Self>) {
if let Some(w) = &self.read {
w.wake_by_ref();
}
if let Some(w) = &self.write {
w.wake_by_ref();
}
}
}
fn combined_waker(read: Option<&Waker>, write: Option<&Waker>) -> Waker {
Waker::from(Arc::new(CombinedWaker {
read: read.cloned(),
write: write.cloned(),
}))
}
#[inline]
fn registration_interest(read_waiter: bool, write_waiter: bool, fallback: Interest) -> Interest {
let mut interest = Interest::empty();
if read_waiter {
interest |= Interest::READABLE;
}
if write_waiter {
interest |= Interest::WRITABLE;
}
if interest.is_empty() {
fallback
} else {
interest
}
}
struct SplitIoState {
registration: Option<IoRegistration>,
read_waker: Option<Waker>,
write_waker: Option<Waker>,
combined_waker: Option<Waker>,
}
pub(crate) struct UnixStreamInner {
state: Mutex<SplitIoState>,
stream: Arc<net::UnixStream>,
}
impl std::fmt::Debug for UnixStreamInner {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnixStreamInner")
.field("stream", &self.stream)
.field("state", &"...")
.finish()
}
}
impl UnixStreamInner {
#[allow(clippy::too_many_lines)]
fn register_interest(&self, cx: &Context<'_>, interest: Interest) -> io::Result<()> {
let mut guard = self.state.lock();
let mut wakers_changed = false;
if interest.is_readable() {
if !guard
.read_waker
.as_ref()
.is_some_and(|w| w.will_wake(cx.waker()))
{
guard.read_waker = Some(cx.waker().clone());
wakers_changed = true;
}
}
if interest.is_writable() {
if !guard
.write_waker
.as_ref()
.is_some_and(|w| w.will_wake(cx.waker()))
{
guard.write_waker = Some(cx.waker().clone());
wakers_changed = true;
}
}
if wakers_changed || guard.combined_waker.is_none() {
guard.combined_waker = Some(combined_waker(
guard.read_waker.as_ref(),
guard.write_waker.as_ref(),
));
}
let mut dropped_reg = None;
let mut early_return = None;
let mut wakers_to_wake = None;
{
let SplitIoState {
registration,
read_waker,
write_waker,
combined_waker: cached_combined_waker,
} = &mut *guard;
if let Some(reg) = registration.as_mut() {
let combined_interest =
registration_interest(read_waker.is_some(), write_waker.is_some(), interest);
let waker = cached_combined_waker
.as_ref()
.expect("combined waker initialized");
match reg.rearm(combined_interest, waker) {
Ok(true) => early_return = Some(Ok(())),
Ok(false) => {
dropped_reg = registration.take();
}
Err(err) if err.kind() == io::ErrorKind::NotConnected => {
dropped_reg = registration.take();
wakers_to_wake = Some((read_waker.clone(), write_waker.clone()));
early_return = Some(Ok(()));
}
Err(err) => early_return = Some(Err(err)),
}
}
}
if let Some(res) = early_return {
drop(guard);
drop(dropped_reg);
if let Some((rw, ww)) = wakers_to_wake {
if let Some(w) = rw {
w.wake();
}
if let Some(w) = ww {
w.wake();
}
}
return res;
}
let waker = guard
.combined_waker
.as_ref()
.expect("combined waker initialized")
.clone();
let register_interest = registration_interest(
guard.read_waker.is_some(),
guard.write_waker.is_some(),
interest,
);
let Some(current) = Cx::current() else {
crate::net::tcp::stream::fallback_rewake(cx);
return Ok(());
};
let Some(driver) = current.io_driver_handle() else {
crate::net::tcp::stream::fallback_rewake(cx);
return Ok(());
};
let result = match driver.register(&*self.stream, register_interest, waker) {
Ok(registration) => {
guard.registration = Some(registration);
Ok(())
}
Err(err) if err.kind() == io::ErrorKind::Unsupported => {
crate::net::tcp::stream::fallback_rewake(cx);
Ok(())
}
Err(err) if err.kind() == io::ErrorKind::NotConnected => {
crate::net::tcp::stream::fallback_rewake(cx);
Ok(())
}
Err(err) => Err(err),
};
drop(guard);
result
}
fn clear_waiter_on_drop(&self, interest: Interest) {
let mut guard = self.state.lock();
let mut wakers_changed = interest.is_readable();
if wakers_changed {
guard.read_waker = None;
}
if interest.is_writable() {
guard.write_waker = None;
wakers_changed = true;
}
if wakers_changed || guard.combined_waker.is_none() {
guard.combined_waker = Some(combined_waker(
guard.read_waker.as_ref(),
guard.write_waker.as_ref(),
));
}
let desired_interest = registration_interest(
guard.read_waker.is_some(),
guard.write_waker.is_some(),
Interest::empty(),
);
let mut clear_registration = desired_interest.is_empty();
let mut wakers_to_wake = None;
if !clear_registration {
let combined = guard
.combined_waker
.as_ref()
.expect("combined waker initialized")
.clone();
let is_some = guard.registration.is_some();
let rearm_ok = guard
.registration
.as_mut()
.is_some_and(|reg| matches!(reg.rearm(desired_interest, &combined), Ok(true)));
if is_some {
if !rearm_ok {
clear_registration = true;
wakers_to_wake = Some((guard.read_waker.clone(), guard.write_waker.clone()));
}
} else {
wakers_to_wake = Some((guard.read_waker.clone(), guard.write_waker.clone()));
}
}
let dropped_reg = if clear_registration {
guard.registration.take()
} else {
None
};
drop(guard);
drop(dropped_reg);
if let Some((rw, ww)) = wakers_to_wake {
if let Some(w) = rw {
w.wake();
}
if let Some(w) = ww {
w.wake();
}
}
}
}
impl Drop for OwnedReadHalf {
fn drop(&mut self) {
self.inner.clear_waiter_on_drop(Interest::READABLE);
}
}
#[derive(Debug)]
pub struct OwnedReadHalf {
inner: Arc<UnixStreamInner>,
}
impl OwnedReadHalf {
pub(crate) fn new_pair(
stream: Arc<net::UnixStream>,
registration: Option<IoRegistration>,
) -> (Self, OwnedWriteHalf) {
let inner = Arc::new(UnixStreamInner {
stream,
state: Mutex::new(SplitIoState {
registration,
read_waker: None,
write_waker: None,
combined_waker: None,
}),
});
(
Self {
inner: inner.clone(),
},
OwnedWriteHalf {
inner,
shutdown_on_drop: true,
},
)
}
pub fn reunite(self, other: OwnedWriteHalf) -> Result<super::UnixStream, ReuniteError> {
if Arc::ptr_eq(&self.inner, &other.inner) {
let mut other = other;
other.shutdown_on_drop = false;
let registration = self.inner.state.lock().registration.take();
Ok(super::UnixStream::from_parts(
self.inner.stream.clone(),
registration,
))
} else {
Err(ReuniteError(self, other))
}
}
}
impl AsyncRead for OwnedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let inner: &net::UnixStream = &self.inner.stream;
match (&*inner).read(buf.unfilled()) {
Ok(n) => {
buf.advance(n);
Poll::Ready(Ok(()))
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Err(e) = self.inner.register_interest(cx, Interest::READABLE) {
return Poll::Ready(Err(e));
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncReadVectored for OwnedReadHalf {
fn poll_read_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &mut [IoSliceMut<'_>],
) -> Poll<io::Result<usize>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let inner: &net::UnixStream = &self.inner.stream;
match (&*inner).read_vectored(bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Err(err) = self.inner.register_interest(cx, Interest::READABLE) {
return Poll::Ready(Err(err));
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
#[derive(Debug)]
pub struct OwnedWriteHalf {
inner: Arc<UnixStreamInner>,
shutdown_on_drop: bool,
}
impl OwnedWriteHalf {
pub fn shutdown(&self) -> io::Result<()> {
self.inner.stream.shutdown(Shutdown::Write)
}
pub fn set_shutdown_on_drop(&mut self, shutdown: bool) {
self.shutdown_on_drop = shutdown;
}
}
impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let inner: &net::UnixStream = &self.inner.stream;
match (&*inner).write(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Err(err) = self.inner.register_interest(cx, Interest::WRITABLE) {
return Poll::Ready(Err(err));
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let inner: &net::UnixStream = &self.inner.stream;
match (&*inner).write_vectored(bufs) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Err(err) = self.inner.register_interest(cx, Interest::WRITABLE) {
return Poll::Ready(Err(err));
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn is_write_vectored(&self) -> bool {
true
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
let inner: &net::UnixStream = &self.inner.stream;
match (&*inner).flush() {
Ok(()) => Poll::Ready(Ok(())),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
if let Err(e) = self.inner.register_interest(cx, Interest::WRITABLE) {
return Poll::Ready(Err(e));
}
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
if Cx::current().is_some_and(|cx| cx.checkpoint().is_err()) {
return Poll::Ready(Err(io::Error::new(io::ErrorKind::Interrupted, "cancelled")));
}
match self.inner.stream.shutdown(Shutdown::Write) {
Ok(()) => Poll::Ready(Ok(())),
Err(e) if e.kind() == io::ErrorKind::NotConnected => Poll::Ready(Ok(())),
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl Drop for OwnedWriteHalf {
fn drop(&mut self) {
self.inner.clear_waiter_on_drop(Interest::WRITABLE);
if self.shutdown_on_drop {
let _ = self.inner.stream.shutdown(Shutdown::Write);
}
}
}
#[derive(Debug)]
pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
impl std::fmt::Display for ReuniteError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}
impl std::error::Error for ReuniteError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::cx::Cx;
use std::task::{Context, Poll, Waker};
fn noop_waker() -> Waker {
std::task::Waker::noop().clone()
}
#[test]
fn test_borrowed_halves() {
let (s1, _s2) = net::UnixStream::pair().expect("pair failed");
s1.set_nonblocking(true).expect("set_nonblocking failed");
let _read = ReadHalf::new(&s1);
let _write = WriteHalf::new(&s1);
}
#[test]
fn test_owned_halves() {
let (s1, _s2) = net::UnixStream::pair().expect("pair failed");
s1.set_nonblocking(true).expect("set_nonblocking failed");
let stream = super::super::UnixStream::from_std(s1).expect("wrap stream");
let (_read, _write) = stream.into_split();
}
#[test]
fn test_reunite_success() {
let (s1, _s2) = net::UnixStream::pair().expect("pair failed");
s1.set_nonblocking(true).expect("set_nonblocking failed");
let stream = super::super::UnixStream::from_std(s1).expect("wrap stream");
let (read, write) = stream.into_split();
let _reunited = read.reunite(write).expect("reunite should succeed");
}
#[test]
fn test_reunite_failure() {
let (s1, _s2a) = net::UnixStream::pair().expect("pair failed");
let (s2, _s2b) = net::UnixStream::pair().expect("pair failed");
s1.set_nonblocking(true).expect("set_nonblocking failed");
s2.set_nonblocking(true).expect("set_nonblocking failed");
let stream1 = super::super::UnixStream::from_std(s1).expect("wrap stream1");
let stream2 = super::super::UnixStream::from_std(s2).expect("wrap stream2");
let (read1, _write1) = stream1.into_split();
let (_read2, write2) = stream2.into_split();
let err = read1.reunite(write2).expect_err("reunite should fail");
assert!(err.to_string().contains("not from the same socket"));
}
#[test]
fn registration_interest_prefers_waiter_union() {
let both = registration_interest(true, true, Interest::READABLE);
assert_eq!(both, Interest::READABLE | Interest::WRITABLE);
let write_only = registration_interest(false, true, Interest::READABLE);
assert_eq!(write_only, Interest::WRITABLE);
let fallback = registration_interest(false, false, Interest::READABLE);
assert_eq!(fallback, Interest::READABLE);
}
#[test]
fn borrowed_split_halves_return_interrupted_when_cancel_requested() {
let (s1, _s2) = net::UnixStream::pair().expect("pair failed");
s1.set_nonblocking(true).expect("set_nonblocking failed");
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let _guard = Cx::set_current(Some(cx));
let mut read_half = ReadHalf::new(&s1);
let mut write_half = WriteHalf::new(&s1);
let waker = noop_waker();
let mut task_cx = Context::from_waker(&waker);
let mut buf = [0u8; 8];
let mut read_buf = crate::io::ReadBuf::new(&mut buf);
let read =
crate::io::AsyncRead::poll_read(Pin::new(&mut read_half), &mut task_cx, &mut read_buf);
assert!(matches!(
read,
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::Interrupted
));
let write =
crate::io::AsyncWrite::poll_write(Pin::new(&mut write_half), &mut task_cx, b"hello");
assert!(matches!(
write,
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::Interrupted
));
let flush = crate::io::AsyncWrite::poll_flush(Pin::new(&mut write_half), &mut task_cx);
assert!(matches!(
flush,
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::Interrupted
));
let shutdown =
crate::io::AsyncWrite::poll_shutdown(Pin::new(&mut write_half), &mut task_cx);
assert!(matches!(
shutdown,
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::Interrupted
));
}
#[test]
fn owned_split_halves_return_interrupted_when_cancel_requested() {
let (s1, _s2) = net::UnixStream::pair().expect("pair failed");
s1.set_nonblocking(true).expect("set_nonblocking failed");
let cx = Cx::for_testing();
cx.set_cancel_requested(true);
let _guard = Cx::set_current(Some(cx));
let stream = super::super::UnixStream::from_std(s1).expect("wrap stream");
let (mut read_half, mut write_half) = stream.into_split();
let waker = noop_waker();
let mut task_cx = Context::from_waker(&waker);
let mut buf = [0u8; 8];
let mut read_buf = crate::io::ReadBuf::new(&mut buf);
let read =
crate::io::AsyncRead::poll_read(Pin::new(&mut read_half), &mut task_cx, &mut read_buf);
assert!(matches!(
read,
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::Interrupted
));
let write =
crate::io::AsyncWrite::poll_write(Pin::new(&mut write_half), &mut task_cx, b"hello");
assert!(matches!(
write,
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::Interrupted
));
let flush = crate::io::AsyncWrite::poll_flush(Pin::new(&mut write_half), &mut task_cx);
assert!(matches!(
flush,
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::Interrupted
));
let shutdown =
crate::io::AsyncWrite::poll_shutdown(Pin::new(&mut write_half), &mut task_cx);
assert!(matches!(
shutdown,
Poll::Ready(Err(ref err)) if err.kind() == io::ErrorKind::Interrupted
));
}
}