use std::error::Error;
use std::fmt;
use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::shim::tokio::net::tcp::stream::{noop_cx, TcpStream};
use crate::sys;
#[derive(Debug)]
pub struct OwnedReadHalf {
inner: Arc<TcpStream>,
}
#[derive(Debug)]
pub struct OwnedWriteHalf {
inner: Arc<TcpStream>,
shutdown_on_drop: bool,
}
pub(crate) fn split_owned(stream: TcpStream) -> (OwnedReadHalf, OwnedWriteHalf) {
let arc = Arc::new(stream);
(
OwnedReadHalf { inner: arc.clone() },
OwnedWriteHalf {
inner: arc,
shutdown_on_drop: true,
},
)
}
impl OwnedReadHalf {
pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
reunite(self, other)
}
pub fn try_read(&self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.try_read(buf)
}
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf).await
}
pub fn poll_peek(
&mut self,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<usize>> {
self.inner.poll_peek(cx, buf)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.peer_addr()
}
}
impl OwnedWriteHalf {
pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
reunite(other, self)
}
pub fn forget(mut self) {
self.shutdown_on_drop = false;
drop(self);
}
pub fn try_write(&self, buf: &[u8]) -> io::Result<usize> {
self.inner.try_write(buf)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.inner.local_addr()
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
self.inner.peer_addr()
}
}
fn reunite(read: OwnedReadHalf, write: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
if Arc::ptr_eq(&read.inner, &write.inner) {
let mut write = write;
write.shutdown_on_drop = false;
drop(write);
Arc::try_unwrap(read.inner).map_err(|inner| {
ReuniteError(
OwnedReadHalf {
inner: inner.clone(),
},
OwnedWriteHalf {
inner,
shutdown_on_drop: true,
},
)
})
} else {
Err(ReuniteError(
read,
OwnedWriteHalf {
inner: write.inner.clone(),
shutdown_on_drop: write.shutdown_on_drop,
},
))
}
}
#[derive(Debug)]
pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
impl fmt::Display for ReuniteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}
impl Error for ReuniteError {}
impl Drop for OwnedWriteHalf {
fn drop(&mut self) {
if self.shutdown_on_drop {
let fd = self.inner.fd();
let _ = sys(|k| k.poll_shutdown_write(fd, &mut noop_cx()));
}
}
}
impl AsyncRead for OwnedReadHalf {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
self.inner.poll_read_priv(cx, buf)
}
}
impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.inner.poll_write_priv(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let res = self.inner.poll_shutdown_priv(cx);
if let Poll::Ready(Ok(())) = &res {
self.shutdown_on_drop = false;
}
res
}
}
impl AsRef<TcpStream> for OwnedReadHalf {
fn as_ref(&self) -> &TcpStream {
&self.inner
}
}
impl AsRef<TcpStream> for OwnedWriteHalf {
fn as_ref(&self) -> &TcpStream {
&self.inner
}
}