tokio_vsock/
split.rs

1//! Split a single value implementing `AsyncRead + AsyncWrite` into separate
2//! `AsyncRead` and `AsyncWrite` handles.
3//!
4//! To restore this read/write object from its `split::ReadHalf` and
5//! `split::WriteHalf` use `unsplit`.
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9use crate::VsockStream;
10use futures::ready;
11use std::fmt;
12use std::io;
13use std::pin::Pin;
14use std::sync::Arc;
15use std::task::{Context, Poll};
16
17/// Splits a ``VsockStream`` into a readable half and a writeable half
18pub fn split(stream: &mut VsockStream) -> (ReadHalf<'_>, WriteHalf<'_>) {
19    // Safety: we have an exclusive reference to the stream so we can safely get a readonly and
20    // write only reference to it.
21    (ReadHalf(stream), WriteHalf(stream))
22}
23
24/// The readable half of a value returned from [`split`](split()).
25pub struct ReadHalf<'a>(&'a VsockStream);
26
27/// The writable half of a value returned from [`split`](split()).
28pub struct WriteHalf<'a>(&'a VsockStream);
29
30impl AsyncRead for ReadHalf<'_> {
31    fn poll_read(
32        self: Pin<&mut Self>,
33        cx: &mut Context<'_>,
34        buf: &mut ReadBuf<'_>,
35    ) -> Poll<io::Result<()>> {
36        self.0.poll_read_priv(cx, buf)
37    }
38}
39
40impl AsyncWrite for WriteHalf<'_> {
41    fn poll_write(
42        self: Pin<&mut Self>,
43        cx: &mut Context<'_>,
44        buf: &[u8],
45    ) -> Poll<Result<usize, io::Error>> {
46        self.0.poll_write_priv(cx, buf)
47    }
48
49    #[inline]
50    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
51        // Not buffered so flush is a No-op
52        Poll::Ready(Ok(()))
53    }
54
55    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
56        // TODO: This could maybe block?
57        self.0.shutdown(std::net::Shutdown::Write)?;
58        Poll::Ready(Ok(()))
59    }
60}
61
62pub fn split_owned(stream: VsockStream) -> (OwnedReadHalf, OwnedWriteHalf) {
63    let inner = Arc::new(Inner::new(stream));
64
65    let rd = OwnedReadHalf {
66        inner: inner.clone(),
67    };
68
69    let wr = OwnedWriteHalf { inner };
70
71    (rd, wr)
72}
73
74/// The readable half of a value returned from [`split_owned`](split_owned()).
75pub struct OwnedReadHalf {
76    inner: Arc<Inner>,
77}
78
79/// The writable half of a value returned from [`split_owned`](split_owned()).
80pub struct OwnedWriteHalf {
81    inner: Arc<Inner>,
82}
83
84struct Inner(tokio::sync::Mutex<VsockStream>);
85
86impl Inner {
87    fn new(stream: VsockStream) -> Self {
88        Self(tokio::sync::Mutex::new(stream))
89    }
90}
91
92struct Guard<'a>(tokio::sync::MutexGuard<'a, VsockStream>);
93
94impl OwnedReadHalf {
95    /// Checks if this `ReadHalf` and some `WriteHalf` were split from the same
96    /// stream.
97    pub fn is_pair_of(&self, other: &OwnedWriteHalf) -> bool {
98        other.is_pair_of(self)
99    }
100
101    /// Reunites with a previously split `WriteHalf`.
102    ///
103    /// # Panics
104    ///
105    /// If this `ReadHalf` and the given `WriteHalf` do not originate from the
106    /// same `split` operation this method will panic.
107    /// This can be checked ahead of time by comparing the stream ID
108    /// of the two halves.
109    #[track_caller]
110    pub fn unsplit(self, wr: OwnedWriteHalf) -> VsockStream {
111        if self.is_pair_of(&wr) {
112            drop(wr);
113
114            let inner = Arc::try_unwrap(self.inner)
115                .ok()
116                .expect("`Arc::try_unwrap` failed");
117
118            inner.0.into_inner()
119        } else {
120            panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.")
121        }
122    }
123}
124
125impl OwnedWriteHalf {
126    /// Checks if this `WriteHalf` and some `ReadHalf` were split from the same
127    /// stream.
128    pub fn is_pair_of(&self, other: &OwnedReadHalf) -> bool {
129        Arc::ptr_eq(&self.inner, &other.inner)
130    }
131}
132
133impl AsyncRead for OwnedReadHalf {
134    fn poll_read(
135        self: Pin<&mut Self>,
136        cx: &mut Context<'_>,
137        buf: &mut ReadBuf<'_>,
138    ) -> Poll<io::Result<()>> {
139        let mut inner = ready!(self.inner.poll_lock(cx));
140        inner.stream_pin().poll_read(cx, buf)
141    }
142}
143
144impl AsyncWrite for OwnedWriteHalf {
145    fn poll_write(
146        self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148        buf: &[u8],
149    ) -> Poll<Result<usize, io::Error>> {
150        let mut inner = ready!(self.inner.poll_lock(cx));
151        inner.stream_pin().poll_write(cx, buf)
152    }
153
154    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
155        let mut inner = ready!(self.inner.poll_lock(cx));
156        inner.stream_pin().poll_flush(cx)
157    }
158
159    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
160        let mut inner = ready!(self.inner.poll_lock(cx));
161        inner.stream_pin().poll_shutdown(cx)
162    }
163}
164
165impl Inner {
166    fn poll_lock(&self, cx: &mut Context<'_>) -> Poll<Guard<'_>> {
167        if let Ok(guard) = self.0.try_lock() {
168            Poll::Ready(Guard(guard))
169        } else {
170            // Spin... but investigate a better strategy
171
172            std::thread::yield_now();
173            cx.waker().wake_by_ref();
174
175            Poll::Pending
176        }
177    }
178}
179
180impl Guard<'_> {
181    fn stream_pin(&mut self) -> Pin<&mut VsockStream> {
182        // safety: the stream is pinned in `Arc` and the `Guard` ensures mutual
183        // exclusion.
184        unsafe { Pin::new_unchecked(&mut *self.0) }
185    }
186}
187
188impl fmt::Debug for OwnedReadHalf {
189    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
190        fmt.debug_struct("split::OwnedReadHalf").finish()
191    }
192}
193
194impl fmt::Debug for OwnedWriteHalf {
195    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
196        fmt.debug_struct("split::OwnedWriteHalf").finish()
197    }
198}
199
200impl fmt::Debug for ReadHalf<'_> {
201    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
202        f.debug_tuple("split::ReadHalf").finish()
203    }
204}
205
206impl fmt::Debug for WriteHalf<'_> {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        f.debug_tuple("split::WriteHalf").finish()
209    }
210}