1use std::{fmt::Debug, sync::Arc};
4
5use compio_buf::{BufResult, IoBuf, IoBufMut, IoVectoredBuf, IoVectoredBufMut};
6use futures_util::lock::Mutex;
7
8use crate::{AsyncRead, AsyncReadAt, AsyncWrite, AsyncWriteAt, IoResult, util::bilock::BiLock};
9
10pub fn split<T: AsyncRead + AsyncWrite>(stream: T) -> (ReadHalf<T>, WriteHalf<T>) {
13 Split::new(stream).split()
14}
15
16pub fn split_unsync<T: AsyncRead + AsyncWrite>(
20 stream: T,
21) -> (UnsyncReadHalf<T>, UnsyncWriteHalf<T>) {
22 UnsyncSplit::new(stream).split()
23}
24
25pub trait Splittable {
44 type ReadHalf;
47
48 type WriteHalf;
51
52 fn split(self) -> (Self::ReadHalf, Self::WriteHalf);
59}
60
61#[derive(Debug)]
75pub struct UnsyncSplit<T>(BiLock<T>, BiLock<T>);
76
77impl<T> UnsyncSplit<T> {
78 pub fn new(stream: T) -> Self {
80 let (r, w) = BiLock::new(stream);
81 UnsyncSplit(r, w)
82 }
83}
84
85impl<T> Splittable for UnsyncSplit<T> {
86 type ReadHalf = UnsyncReadHalf<T>;
87 type WriteHalf = UnsyncWriteHalf<T>;
88
89 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
90 (UnsyncReadHalf(self.0), UnsyncWriteHalf(self.1))
91 }
92}
93
94impl<R, W> Splittable for (R, W) {
95 type ReadHalf = R;
96 type WriteHalf = W;
97
98 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
99 self
100 }
101}
102
103#[derive(Debug)]
105pub struct UnsyncReadHalf<T>(BiLock<T>);
106
107impl<T> UnsyncReadHalf<T> {
108 #[track_caller]
116 pub fn unsplit(self, other: UnsyncWriteHalf<T>) -> T {
117 self.0.try_join(other.0).expect(
118 "`UnsyncReadHalf` and `UnsyncWriteHalf` must originate from the same `UnsyncSplit`",
119 )
120 }
121}
122
123impl<T: AsyncRead> AsyncRead for UnsyncReadHalf<T> {
124 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
125 self.0.lock().await.read(buf).await
126 }
127
128 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
129 self.0.lock().await.read_vectored(buf).await
130 }
131}
132
133impl<T: AsyncReadAt> AsyncReadAt for UnsyncReadHalf<T> {
134 async fn read_at<B: IoBufMut>(&self, buf: B, pos: u64) -> BufResult<usize, B> {
135 self.0.lock().await.read_at(buf, pos).await
136 }
137}
138
139#[derive(Debug)]
141pub struct UnsyncWriteHalf<T>(BiLock<T>);
142
143impl<T: AsyncWrite> AsyncWrite for UnsyncWriteHalf<T> {
144 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
145 self.0.lock().await.write(buf).await
146 }
147
148 async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
149 self.0.lock().await.write_vectored(buf).await
150 }
151
152 async fn flush(&mut self) -> IoResult<()> {
153 self.0.lock().await.flush().await
154 }
155
156 async fn shutdown(&mut self) -> IoResult<()> {
157 self.0.lock().await.shutdown().await
158 }
159}
160
161impl<T: AsyncWriteAt> AsyncWriteAt for UnsyncWriteHalf<T> {
162 async fn write_at<B: IoBuf>(&mut self, buf: B, pos: u64) -> BufResult<usize, B> {
163 self.0.lock().await.write_at(buf, pos).await
164 }
165
166 async fn write_vectored_at<B: IoVectoredBuf>(
167 &mut self,
168 buf: B,
169 pos: u64,
170 ) -> BufResult<usize, B> {
171 self.0.lock().await.write_vectored_at(buf, pos).await
172 }
173}
174
175#[derive(Debug)]
177pub struct Split<T>(Arc<Mutex<T>>);
178
179impl<T> Split<T> {
180 pub fn new(stream: T) -> Self {
182 Split(Arc::new(Mutex::new(stream)))
183 }
184}
185
186impl<T: AsyncRead + AsyncWrite> Splittable for Split<T> {
187 type ReadHalf = ReadHalf<T>;
188 type WriteHalf = WriteHalf<T>;
189
190 fn split(self) -> (Self::ReadHalf, Self::WriteHalf) {
191 (ReadHalf(self.0.clone()), WriteHalf(self.0))
192 }
193}
194
195#[derive(Debug)]
197pub struct ReadHalf<T>(Arc<Mutex<T>>);
198
199impl<T: Unpin> ReadHalf<T> {
200 #[track_caller]
209 pub fn unsplit(self, w: WriteHalf<T>) -> T {
210 if Arc::ptr_eq(&self.0, &w.0) {
211 drop(w);
212 let inner = Arc::try_unwrap(self.0).expect("`Arc::try_unwrap` failed");
213 inner.into_inner()
214 } else {
215 #[cold]
216 fn panic_unrelated() -> ! {
217 panic!("Unrelated `WriteHalf` passed to `ReadHalf::unsplit`.")
218 }
219
220 panic_unrelated()
221 }
222 }
223}
224
225impl<T: AsyncRead> AsyncRead for ReadHalf<T> {
226 async fn read<B: IoBufMut>(&mut self, buf: B) -> BufResult<usize, B> {
227 self.0.lock().await.read(buf).await
228 }
229
230 async fn read_vectored<V: IoVectoredBufMut>(&mut self, buf: V) -> BufResult<usize, V> {
231 self.0.lock().await.read_vectored(buf).await
232 }
233}
234
235impl<T: AsyncReadAt> AsyncReadAt for ReadHalf<T> {
236 async fn read_at<B: IoBufMut>(&self, buf: B, pos: u64) -> BufResult<usize, B> {
237 self.0.lock().await.read_at(buf, pos).await
238 }
239}
240
241#[derive(Debug)]
243pub struct WriteHalf<T>(Arc<Mutex<T>>);
244
245impl<T: AsyncWrite> AsyncWrite for WriteHalf<T> {
246 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
247 self.0.lock().await.write(buf).await
248 }
249
250 async fn write_vectored<B: IoVectoredBuf>(&mut self, buf: B) -> BufResult<usize, B> {
251 self.0.lock().await.write_vectored(buf).await
252 }
253
254 async fn flush(&mut self) -> IoResult<()> {
255 self.0.lock().await.flush().await
256 }
257
258 async fn shutdown(&mut self) -> IoResult<()> {
259 self.0.lock().await.shutdown().await
260 }
261}
262
263impl<T: AsyncWriteAt> AsyncWriteAt for WriteHalf<T> {
264 async fn write_at<B: IoBuf>(&mut self, buf: B, pos: u64) -> BufResult<usize, B> {
265 self.0.lock().await.write_at(buf, pos).await
266 }
267
268 async fn write_vectored_at<B: IoVectoredBuf>(
269 &mut self,
270 buf: B,
271 pos: u64,
272 ) -> BufResult<usize, B> {
273 self.0.lock().await.write_vectored_at(buf, pos).await
274 }
275}