use std::{
error::Error,
fmt, io,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::net::TcpStream;
use super::stream::{ReadHalf, WriteHalf};
#[derive(Debug)]
pub struct OwnedReadHalf {
pub(crate) inner: ReadHalf,
}
impl OwnedReadHalf {
pub fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.inner.pair.local)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(self.inner.pair.remote)
}
pub fn reunite(self, other: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
reunite(self, other)
}
pub fn poll_peek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_peek(cx, buf)
}
pub async fn peek(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.inner.peek(buf).await
}
}
#[derive(Debug)]
pub struct OwnedWriteHalf {
pub(crate) inner: WriteHalf,
}
impl OwnedWriteHalf {
pub fn local_addr(&self) -> io::Result<SocketAddr> {
Ok(self.inner.pair.local)
}
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
Ok(self.inner.pair.remote)
}
pub fn reunite(self, other: OwnedReadHalf) -> Result<TcpStream, ReuniteError> {
reunite(other, self)
}
}
fn reunite(read: OwnedReadHalf, write: OwnedWriteHalf) -> Result<TcpStream, ReuniteError> {
if Arc::ptr_eq(&read.inner.pair, &write.inner.pair) {
Ok(TcpStream::reunite(read.inner, write.inner))
} else {
Err(ReuniteError(read, write))
}
}
#[derive(Debug)]
pub struct ReuniteError(pub OwnedReadHalf, pub OwnedWriteHalf);
impl fmt::Display for ReuniteError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"tried to reunite halves that are not from the same socket"
)
}
}
impl Error for ReuniteError {}
impl AsyncRead for OwnedReadHalf {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for OwnedWriteHalf {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}