tsyncp/util/
tcp.rs

1//! Light wrappers around [tokio::net::tcp]'s [OwnedReadHalf](tokio::net::tcp::OwnedReadHalf) and [OwnedWriteHalf](tokio::net::tcp::OwnedWriteHalf).
2
3use std::io;
4use std::ops::{Deref, DerefMut};
5use std::pin::Pin;
6use std::task::{Context, Poll};
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8use tokio::net::{tcp, TcpStream};
9
10pub use tokio::net::tcp::ReuniteError;
11
12pub(crate) fn reunite(
13    read: OwnedReadHalf,
14    mut write: OwnedWriteHalf,
15) -> Result<TcpStream, ReuniteError> {
16    let w_inner = write.take_inner();
17
18    // don't need wrapper anymore
19    drop(write);
20
21    // reunite will call .forget() on tcp::OwnedWriteHalf
22    read.inner.reunite(w_inner)
23}
24
25/// Light wrapper around [tokio::net::tcp::OwnedReadHalf].
26#[derive(Debug)]
27pub struct OwnedReadHalf {
28    inner: tcp::OwnedReadHalf,
29}
30
31impl AsyncRead for OwnedReadHalf {
32    fn poll_read(
33        mut self: Pin<&mut Self>,
34        cx: &mut Context<'_>,
35        buf: &mut ReadBuf<'_>,
36    ) -> Poll<io::Result<()>> {
37        Pin::new(&mut self.inner).poll_read(cx, buf)
38    }
39}
40
41impl From<tcp::OwnedReadHalf> for OwnedReadHalf {
42    fn from(r: tcp::OwnedReadHalf) -> Self {
43        Self { inner: r }
44    }
45}
46
47impl Deref for OwnedReadHalf {
48    type Target = tcp::OwnedReadHalf;
49
50    fn deref(&self) -> &Self::Target {
51        &self.inner
52    }
53}
54
55impl DerefMut for OwnedReadHalf {
56    fn deref_mut(&mut self) -> &mut Self::Target {
57        &mut self.inner
58    }
59}
60
61/// Light wrapper around [tokio::net::tcp::OwnedWriteHalf] to stop it from shutting down TCP stream when
62/// it drops.
63#[derive(Debug)]
64pub struct OwnedWriteHalf {
65    inner: Option<tcp::OwnedWriteHalf>,
66    should_forget: bool,
67}
68
69impl OwnedWriteHalf {
70    // Should be only used before dropping the struct.
71    fn take_inner(&mut self) -> tcp::OwnedWriteHalf {
72        let inner = self.inner.take().expect("should exist");
73
74        self.should_forget = false;
75
76        inner
77    }
78}
79
80impl From<tcp::OwnedWriteHalf> for OwnedWriteHalf {
81    fn from(w: tcp::OwnedWriteHalf) -> Self {
82        Self {
83            inner: Some(w),
84            should_forget: true,
85        }
86    }
87}
88
89impl AsyncWrite for OwnedWriteHalf {
90    fn poll_write(
91        mut self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93        buf: &[u8],
94    ) -> Poll<io::Result<usize>> {
95        let inner: &mut tcp::OwnedWriteHalf = self.deref_mut();
96        Pin::new(inner).poll_write(cx, buf)
97    }
98
99    fn poll_write_vectored(
100        mut self: Pin<&mut Self>,
101        cx: &mut Context<'_>,
102        bufs: &[io::IoSlice<'_>],
103    ) -> Poll<io::Result<usize>> {
104        let inner: &mut tcp::OwnedWriteHalf = self.deref_mut();
105        Pin::new(inner).poll_write_vectored(cx, bufs)
106    }
107
108    fn is_write_vectored(&self) -> bool {
109        let inner: &tcp::OwnedWriteHalf = self.deref();
110        inner.is_write_vectored()
111    }
112
113    #[inline]
114    fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<io::Result<()>> {
115        // tcp flush is a no-op
116        Poll::Ready(Ok(()))
117    }
118
119    // `poll_shutdown` on a write half shutdowns the stream in the "write" direction.
120    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
121        let inner: &mut tcp::OwnedWriteHalf = self.deref_mut();
122        Pin::new(inner).poll_shutdown(cx)
123    }
124}
125
126impl Drop for OwnedWriteHalf {
127    fn drop(&mut self) {
128        if self.should_forget {
129            let inner = self.take_inner();
130            inner.forget();
131        }
132    }
133}
134
135impl Deref for OwnedWriteHalf {
136    type Target = tcp::OwnedWriteHalf;
137
138    fn deref(&self) -> &Self::Target {
139        self.inner.as_ref().expect("Should exist")
140    }
141}
142
143impl DerefMut for OwnedWriteHalf {
144    fn deref_mut(&mut self) -> &mut Self::Target {
145        self.inner.as_mut().expect("Should exist")
146    }
147}