sfo_split/
split.rs

1#![allow(unused)]
2
3use std::io;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::sync::Mutex;
7use std::task::{Context, Poll};
8
9pub struct ReadHalf<T> {
10    inner: Arc<Inner<T>>,
11}
12
13pub struct WriteHalf<T> {
14    inner: Arc<Inner<T>>,
15}
16
17pub fn split<T>(stream: T) -> (ReadHalf<T>, WriteHalf<T>)
18{
19    let inner = Arc::new(Inner {
20        stream: Mutex::new(stream),
21    });
22
23    let rd = ReadHalf {
24        inner: inner.clone(),
25    };
26
27    let wr = WriteHalf { inner };
28
29    (rd, wr)
30}
31
32struct Inner<T> {
33    stream: Mutex<T>,
34}
35
36impl<T> Inner<T> {
37    fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
38        let mut guard = self.stream.lock().unwrap();
39
40        let stream = unsafe { Pin::new_unchecked(&mut *guard) };
41
42        f(stream)
43    }
44}
45
46impl<T> ReadHalf<T> {
47    pub fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
48        self.inner.with_lock(f)
49    }
50
51    pub fn is_pair_of(&self, other: &WriteHalf<T>) -> bool {
52        other.is_pair_of(self)
53    }
54
55    #[track_caller]
56    pub fn unsplit(self, wr: WriteHalf<T>) -> T
57    where
58        T: Unpin,
59    {
60        if self.is_pair_of(&wr) {
61            drop(wr);
62
63            let inner = Arc::try_unwrap(self.inner)
64                .ok()
65                .expect("`Arc::try_unwrap` failed");
66
67            inner.stream.into_inner().unwrap()
68        } else {
69            panic!("Unrelated `split::Write` passed to `split::Read::unsplit`.")
70        }
71    }
72}
73
74impl<T> WriteHalf<T> {
75    pub fn with_lock<R>(&self, f: impl FnOnce(Pin<&mut T>) -> R) -> R {
76        self.inner.with_lock(f)
77    }
78
79    pub fn is_pair_of(&self, other: &ReadHalf<T>) -> bool {
80        Arc::ptr_eq(&self.inner, &other.inner)
81    }
82}
83
84#[cfg(feature = "io")]
85impl<T: tokio::io::AsyncRead> tokio::io::AsyncRead for ReadHalf<T> {
86    fn poll_read(
87        self: Pin<&mut Self>,
88        cx: &mut Context<'_>,
89        buf: &mut tokio::io::ReadBuf<'_>,
90    ) -> Poll<io::Result<()>> {
91        self.inner.with_lock(|stream| stream.poll_read(cx, buf))
92    }
93}
94
95#[cfg(feature = "io")]
96impl<T: tokio::io::AsyncWrite> tokio::io::AsyncWrite for WriteHalf<T> {
97    fn poll_write(
98        self: Pin<&mut Self>,
99        cx: &mut Context<'_>,
100        buf: &[u8],
101    ) -> Poll<Result<usize, io::Error>> {
102        self.inner.with_lock(|stream| stream.poll_write(cx, buf))
103    }
104
105    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
106        self.inner.with_lock(|stream| stream.poll_flush(cx))
107    }
108
109    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
110        self.inner.with_lock(|stream| stream.poll_shutdown(cx))
111    }
112
113    fn poll_write_vectored(
114        self: Pin<&mut Self>,
115        cx: &mut Context<'_>,
116        bufs: &[io::IoSlice<'_>],
117    ) -> Poll<Result<usize, io::Error>> {
118        self.inner
119            .with_lock(|stream| stream.poll_write_vectored(cx, bufs))
120    }
121
122    fn is_write_vectored(&self) -> bool {
123        self.inner.with_lock(|stream| stream.is_write_vectored())
124    }
125}
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    // 创建一个简单的模拟类型用于测试
131    #[derive(Debug, PartialEq)]
132    struct MockStream {
133        id: usize,
134        data: Vec<u8>,
135    }
136
137    impl MockStream {
138        fn new(id: usize) -> Self {
139            Self {
140                id,
141                data: vec![],
142            }
143        }
144    }
145
146    #[test]
147    fn test_split_creates_paired_halves() {
148        let stream = MockStream::new(1);
149        let (read_half, write_half) = split(stream);
150
151        // 验证两个半部分是配对的
152        assert!(read_half.is_pair_of(&write_half));
153        assert!(write_half.is_pair_of(&read_half));
154    }
155
156    #[test]
157    fn test_halves_from_different_streams_are_not_paired() {
158        let stream1 = MockStream::new(1);
159        let stream2 = MockStream::new(2);
160
161        let (read_half1, write_half1) = split(stream1);
162        let (_read_half2, write_half2) = split(stream2);
163
164        // 验证来自不同流的半部分不是配对的
165        assert!(!read_half1.is_pair_of(&write_half2));
166        assert!(!write_half2.is_pair_of(&read_half1));
167    }
168
169    #[test]
170    #[should_panic(expected = "Unrelated `split::Write` passed to `split::Read::unsplit`.")]
171    fn test_unsplit_panics_when_halves_are_not_paired() {
172        let stream1 = MockStream::new(1);
173        let stream2 = MockStream::new(2);
174
175        let (read_half, _write_half) = split(stream1);
176        let (_read_half2, write_half2) = split(stream2);
177
178        // 尝试合并不配对的读写半部分应该panic
179        let _ = read_half.unsplit(write_half2);
180    }
181
182    #[test]
183    fn test_with_lock_functionality() {
184        // 这个测试主要验证内部锁机制是否正常工作
185        let stream = MockStream::new(1);
186        let (read_half, write_half) = split(stream);
187
188        // 验证我们可以通过读半部分访问内部流
189        let read_inner_ptr = {
190            let guard = read_half.inner.stream.lock().unwrap();
191            &*guard as *const _ as usize
192        };
193
194        // 验证我们可以通过写半部分访问内部流
195        let write_inner_ptr = {
196            let guard = write_half.inner.stream.lock().unwrap();
197            &*guard as *const _ as usize
198        };
199
200        // 验证两个半部分确实共享同一个内部流实例
201        assert_eq!(read_inner_ptr, write_inner_ptr);
202    }
203}