compio_io/util/
split.rs

1//! Functionality to split an I/O type into separate read and write halves.
2
3use std::{fmt::Debug, sync::Arc};
4
5use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
6use futures_util::lock::Mutex;
7
8use crate::{AsyncRead, AsyncReadAt, AsyncWrite, AsyncWriteAt, IoResult, util::bilock::BiLock};
9
10/// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
11/// [`AsyncRead`] and [`AsyncWrite`] handles.with internal synchronization.
12pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) {
13    Split::new(stream).split()
14}
15
16/// Splits a single value implementing `AsyncRead + AsyncWrite` into separate
17/// [`AsyncRead`] and [`AsyncWrite`] handles without internal synchronization
18/// (not `Send` and `Sync`).
19pub fn split_unsync<T: AsyncRead + AsyncWrite>(
20    stream: T,
21) -> (UnsyncReadHalf<T>, UnsyncWriteHalf<T>) {
22    UnsyncSplit::new(stream).split()
23}
24
25/// A trait for types that can be split into separate read and write halves.
26///
27/// This trait enables an I/O type to be divided into two separate components:
28/// one for reading and one for writing. This is particularly useful in async
29/// contexts where you might want to perform concurrent read and write
30/// operations from different tasks.
31///
32/// # Implementor
33/// - Any `(R, W)` tuple implements this trait.
34/// - `TcpStream`, `UnixStream` and references to them in `compio::net`
35///   implement this trait without any lock thanks to the underlying sockets'
36///   duplex nature.
37/// - `File` and named pipes in `compio::fs` implement this trait with
38///   [`ReadHalf`] and [`WriteHalf`] being the file itself since it's
39///   reference-counted under the hood.
40/// - For other type to be compatible with this trait, it must be wrapped with
41///   [`UnsyncSplit`] or [`Split`], which wrap the type in a unsynced or synced
42///   lock respectively.
43pub trait Splittable {
44    /// The type of the read half, which normally implements [`AsyncRead`] or
45    /// [`AsyncReadAt`].
46    type ReadHalf;
47
48    /// The type of the write half, which normally implements [`AsyncWrite`] or
49    /// [`AsyncWriteAt`].
50    type WriteHalf;
51
52    /// Consumes `self` and returns a tuple containing separate read and write
53    /// halves.
54    ///
55    /// The returned halves can be used independently to perform read and write
56    /// operations respectively, potentially from different tasks
57    /// concurrently.
58    fn split(self) -> (Self::ReadHalf, Self::WriteHalf);
59}
60
61/// Enables splitting an I/O type into separate read and write halves
62/// without requiring thread-safety.
63///
64/// # Examples
65///
66/// ```ignore
67/// use compio::io::util::UnsyncSplit;
68///
69/// // Create a splittable stream
70/// let stream = /* some stream */;
71/// let unsync = UnsyncSplit::new(stream);
72/// let (read_half, write_half) = unsync.split();
73/// ```
74#[derive(Debug)]
75pub struct UnsyncSplit<T>(BiLock<T>, BiLock<T>);
76
77impl<T> UnsyncSplit<T> {
78    /// Creates a new `UnsyncSplit` from the given stream.
79    pub fn new(stream: T) -> Self {
80        let (r, w) = BiLock::new(stream);
81        UnsyncSplit(r, w)
82    }
83}
84
85impl<T> Splittable for UnsyncSplit<T> {
86    type ReadHalf = UnsyncReadHalf<T>;
87    type WriteHalf = UnsyncWriteHalf<T>;
88
89    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
90        (UnsyncReadHalf(self.0), UnsyncWriteHalf(self.1))
91    }
92}
93
94impl<R, W> Splittable for (R, W) {
95    type ReadHalf = R;
96    type WriteHalf = W;
97
98    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
99        self
100    }
101}
102
103/// The readable half of a value returned from [`split`].
104#[derive(Debug)]
105pub struct UnsyncReadHalf<T>(BiLock<T>);
106
107impl<T> UnsyncReadHalf<T> {
108    /// Reunites with a previously split [`UnsyncWriteHalf`].
109    ///
110    /// # Panics
111    ///
112    /// If this [`UnsyncReadHalf`] and the given [`UnsyncWriteHalf`] do not
113    /// originate from the same [`split_unsync`](super::split_unsync) operation
114    /// this method will panic.
115    #[track_caller]
116    pub fn unsplit(self, other: UnsyncWriteHalf<T>) -> T {
117        self.0.try_join(other.0).expect(
118            "`UnsyncReadHalf` and `UnsyncWriteHalf` must originate from the same `UnsyncSplit`",
119        )
120    }
121}
122
123impl<T: AsyncRead> AsyncRead for UnsyncReadHalf<T> {
124    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
125        self.0.lock().await.read(buf).await
126    }
127
128    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
129        self.0.lock().await.read_vectored(buf).await
130    }
131}
132
133impl<T: AsyncReadAt> AsyncReadAt for UnsyncReadHalf<T> {
134    async fn read_at<B: IoBufMut>(&self, buf: B, pos: u64) -> BufResult<usize, B> {
135        self.0.lock().await.read_at(buf, pos).await
136    }
137}
138
139/// The writable half of a value returned from [`split`](super::split).
140#[derive(Debug)]
141pub struct UnsyncWriteHalf<T>(BiLock<T>);
142
143impl<T: AsyncWrite> AsyncWrite for UnsyncWriteHalf<T> {
144    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
145        self.0.lock().await.write(buf).await
146    }
147
148    async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
149        self.0.lock().await.write_vectored(buf).await
150    }
151
152    async fn flush(&mut self) -> IoResult<()> {
153        self.0.lock().await.flush().await
154    }
155
156    async fn shutdown(&mut self) -> IoResult<()> {
157        self.0.lock().await.shutdown().await
158    }
159}
160
161impl<T: AsyncWriteAt> AsyncWriteAt for UnsyncWriteHalf<T> {
162    async fn write_at<B: IoBuf>(&mut self, buf: B, pos: u64) -> BufResult<usize, B> {
163        self.0.lock().await.write_at(buf, pos).await
164    }
165
166    async fn write_vectored_at<B: IoVectoredBuf>(
167        &mut self,
168        buf: B,
169        pos: u64,
170    ) -> BufResult<usize, B> {
171        self.0.lock().await.write_vectored_at(buf, pos).await
172    }
173}
174
175/// Splitting an I/O type into separate read and write halves
176#[derive(Debug)]
177pub struct Split<T>(Arc<Mutex<T>>);
178
179impl<T> Split<T> {
180    /// Creates a new `Split` from the given stream.
181    pub fn new(stream: T) -> Self {
182        Split(Arc::new(Mutex::new(stream)))
183    }
184}
185
186impl<T: AsyncRead + AsyncWrite> Splittable for Split<T> {
187    type ReadHalf = ReadHalf<T>;
188    type WriteHalf = WriteHalf<T>;
189
190    fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
191        (ReadHalf(self.0.clone()), WriteHalf(self.0))
192    }
193}
194
195/// The readable half of a value returned from [`split`](super::split).
196#[derive(Debug)]
197pub struct ReadHalf<T>(Arc<Mutex<T>>);
198
199impl<T: Unpin> ReadHalf<T> {
200    /// Reunites with a previously split [`WriteHalf`].
201    ///
202    /// # Panics
203    ///
204    /// If this [`ReadHalf`] and the given [`WriteHalf`] do not originate from
205    /// the same [`split`](super::split) operation this method will panic.
206    /// This can be checked ahead of time by comparing the stored pointer
207    /// of the two halves.
208    #[track_caller]
209    pub fn unsplit(self, w: WriteHalf<T>) -> T {
210        if Arc::ptr_eq(&self.0, &w.0) {
211            drop(w);
212            let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed");
213            inner.into_inner()
214        } else {
215            #[cold]
216            fn panic_unrelated() -> ! {
217                panic!("Unrelated `WriteHalf` passed to `ReadHalf::unsplit`.")
218            }
219
220            panic_unrelated()
221        }
222    }
223}
224
225impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
226    async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
227        self.0.lock().await.read(buf).await
228    }
229
230    async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
231        self.0.lock().await.read_vectored(buf).await
232    }
233}
234
235impl<T: AsyncReadAt> AsyncReadAt for ReadHalf<T> {
236    async fn read_at<B: IoBufMut>(&self, buf: B, pos: u64) -> BufResult<usize, B> {
237        self.0.lock().await.read_at(buf, pos).await
238    }
239}
240
241/// The writable half of a value returned from [`split`](super::split).
242#[derive(Debug)]
243pub struct WriteHalf<T>(Arc<Mutex<T>>);
244
245impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
246    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
247        self.0.lock().await.write(buf).await
248    }
249
250    async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
251        self.0.lock().await.write_vectored(buf).await
252    }
253
254    async fn flush(&mut self) -> IoResult<()> {
255        self.0.lock().await.flush().await
256    }
257
258    async fn shutdown(&mut self) -> IoResult<()> {
259        self.0.lock().await.shutdown().await
260    }
261}
262
263impl<T: AsyncWriteAt> AsyncWriteAt for WriteHalf<T> {
264    async fn write_at<B: IoBuf>(&mut self, buf: B, pos: u64) -> BufResult<usize, B> {
265        self.0.lock().await.write_at(buf, pos).await
266    }
267
268    async fn write_vectored_at<B: IoVectoredBuf>(
269        &mut self,
270        buf: B,
271        pos: u64,
272    ) -> BufResult<usize, B> {
273        self.0.lock().await.write_vectored_at(buf, pos).await
274    }
275}