1#![allow(unused)]
2
3use std::io;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::Mutex;
7use std::task::{Context, Poll};
8
9pub struct ReadHalf<T> {
10 inner: Arc<Inner<T>>,
11}
12
13pub struct WriteHalf<T> {
14 inner: Arc<Inner<T>>,
15}
16
17pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>)
18{
19 let inner = Arc::new(Inner {
20 stream: Mutex::new(stream),
21 });
22
23 let rd = ReadHalf {
24 inner: inner.clone(),
25 };
26
27 let wr = WriteHalf { inner };
28
29 (rd, wr)
30}
31
32struct Inner<T> {
33 stream: Mutex<T>,
34}
35
36impl<T> Inner<T> {
37 fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
38 let mut guard = self.stream.lock().unwrap();
39
40 let stream = unsafe { Pin::new_unchecked(&mut *guard) };
41
42 f(stream)
43 }
44}
45
46impl<T> ReadHalf<T> {
47 fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
48 self.inner.with_lock(f)
49 }
50
51 pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool {
52 other.is_pair_of(self)
53 }
54
55 #[track_caller]
56 pub fn unsplit(self, wr: WriteHalf<T>) -> T
57 where
58 T: Unpin,
59 {
60 if self.is_pair_of(&wr) {
61 drop(wr);
62
63 let inner = Arc::try_unwrap(self.inner)
64 .ok()
65 .expect("`Arc::try_unwrap` failed");
66
67 inner.stream.into_inner().unwrap()
68 } else {
69 panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.")
70 }
71 }
72}
73
74impl<T> WriteHalf<T> {
75 fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
76 self.inner.with_lock(f)
77 }
78
79 pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool {
80 Arc::ptr_eq(&self.inner, &other.inner)
81 }
82}
83
84#[cfg(feature = "io")]
85impl<T: tokio::io::AsyncRead> tokio::io::AsyncRead for ReadHalf<T> {
86 fn poll_read(
87 self: Pin<&mut Self>,
88 cx: &mut Context<'_>,
89 buf: &mut tokio::io::ReadBuf<'_>,
90 ) -> Poll<io::Result<()>> {
91 self.inner.with_lock(|stream| stream.poll_read(cx, buf))
92 }
93}
94
95#[cfg(feature = "io")]
96impl<T: tokio::io::AsyncWrite> tokio::io::AsyncWrite for WriteHalf<T> {
97 fn poll_write(
98 self: Pin<&mut Self>,
99 cx: &mut Context<'_>,
100 buf: &[u8],
101 ) -> Poll<Result<usize, io::Error>> {
102 self.inner.with_lock(|stream| stream.poll_write(cx, buf))
103 }
104
105 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
106 self.inner.with_lock(|stream| stream.poll_flush(cx))
107 }
108
109 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
110 self.inner.with_lock(|stream| stream.poll_shutdown(cx))
111 }
112
113 fn poll_write_vectored(
114 self: Pin<&mut Self>,
115 cx: &mut Context<'_>,
116 bufs: &[io::IoSlice<'_>],
117 ) -> Poll<Result<usize, io::Error>> {
118 self.inner
119 .with_lock(|stream| stream.poll_write_vectored(cx, bufs))
120 }
121
122 fn is_write_vectored(&self) -> bool {
123 self.inner.with_lock(|stream| stream.is_write_vectored())
124 }
125}
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 #[derive(Debug, PartialEq)]
132 struct MockStream {
133 id: usize,
134 data: Vec<u8>,
135 }
136
137 impl MockStream {
138 fn new(id: usize) -> Self {
139 Self {
140 id,
141 data: vec![],
142 }
143 }
144 }
145
146 #[test]
147 fn test_split_creates_paired_halves() {
148 let stream = MockStream::new(1);
149 let (read_half, write_half) = split(stream);
150
151 assert!(read_half.is_pair_of(&write_half));
153 assert!(write_half.is_pair_of(&read_half));
154 }
155
156 #[test]
157 fn test_halves_from_different_streams_are_not_paired() {
158 let stream1 = MockStream::new(1);
159 let stream2 = MockStream::new(2);
160
161 let (read_half1, write_half1) = split(stream1);
162 let (_read_half2, write_half2) = split(stream2);
163
164 assert!(!read_half1.is_pair_of(&write_half2));
166 assert!(!write_half2.is_pair_of(&read_half1));
167 }
168
169 #[test]
170 #[should_panic(expected = "Unrelated `split::Write` passed to `split::Read::unsplit`.")]
171 fn test_unsplit_panics_when_halves_are_not_paired() {
172 let stream1 = MockStream::new(1);
173 let stream2 = MockStream::new(2);
174
175 let (read_half, _write_half) = split(stream1);
176 let (_read_half2, write_half2) = split(stream2);
177
178 let _ = read_half.unsplit(write_half2);
180 }
181
182 #[test]
183 fn test_with_lock_functionality() {
184 let stream = MockStream::new(1);
186 let (read_half, write_half) = split(stream);
187
188 let read_inner_ptr = {
190 let guard = read_half.inner.stream.lock().unwrap();
191 &*guard as *const _ as usize
192 };
193
194 let write_inner_ptr = {
196 let guard = write_half.inner.stream.lock().unwrap();
197 &*guard as *const _ as usize
198 };
199
200 assert_eq!(read_inner_ptr, write_inner_ptr);
202 }
203}