rpc_it/
transports.rs

1#[cfg(feature = "tokio")]
2mod tokio_io {
3    use std::{pin::Pin, task::Poll};
4
5    use tokio::io::ReadBuf;
6
7    use crate::transport::FrameReader;
8
9    pub struct TokioWriteFrameWrapper<T> {
10        inner: T,
11    }
12
13    pub trait ToWriteFrame<T> {
14        fn to_write_frame(self) -> TokioWriteFrameWrapper<T>;
15    }
16
17    impl<T> ToWriteFrame<T> for T
18    where
19        T: tokio::io::AsyncWrite + Unpin + Send + 'static,
20    {
21        fn to_write_frame(self) -> TokioWriteFrameWrapper<T> {
22            TokioWriteFrameWrapper { inner: self }
23        }
24    }
25
26    impl<T> crate::transport::AsyncFrameWrite for TokioWriteFrameWrapper<T>
27    where
28        T: tokio::io::AsyncWrite + Unpin + Send + 'static,
29    {
30        fn poll_write(
31            mut self: std::pin::Pin<&mut Self>,
32            cx: &mut std::task::Context<'_>,
33            buf: &mut FrameReader,
34        ) -> std::task::Poll<std::io::Result<()>> {
35            if let Poll::Ready(n) = Pin::new(&mut self.inner).poll_write(cx, buf.as_slice())? {
36                buf.advance(n);
37                Poll::Ready(Ok(()))
38            } else {
39                Poll::Pending
40            }
41        }
42
43        fn poll_flush(
44            mut self: std::pin::Pin<&mut Self>,
45            cx: &mut std::task::Context<'_>,
46        ) -> std::task::Poll<std::io::Result<()>> {
47            Pin::new(&mut self.inner).poll_flush(cx)
48        }
49
50        fn poll_close(
51            mut self: std::pin::Pin<&mut Self>,
52            cx: &mut std::task::Context<'_>,
53        ) -> std::task::Poll<std::io::Result<()>> {
54            Pin::new(&mut self.inner).poll_shutdown(cx)
55        }
56    }
57
58    pub struct TokioAsyncReadWrapper<T> {
59        inner: T,
60    }
61
62    pub trait ToAsyncRead<T> {
63        fn to_async_read(self) -> TokioAsyncReadWrapper<T>;
64    }
65
66    impl<T> ToAsyncRead<T> for T
67    where
68        T: tokio::io::AsyncRead + Unpin + Send + 'static,
69    {
70        fn to_async_read(self) -> TokioAsyncReadWrapper<T> {
71            TokioAsyncReadWrapper { inner: self }
72        }
73    }
74
75    impl<T> futures_util::AsyncRead for TokioAsyncReadWrapper<T>
76    where
77        T: tokio::io::AsyncRead + Unpin + Send + 'static,
78    {
79        fn poll_read(
80            mut self: std::pin::Pin<&mut Self>,
81            cx: &mut std::task::Context<'_>,
82            buf: &mut [u8],
83        ) -> std::task::Poll<std::io::Result<usize>> {
84            let mut buf = ReadBuf::new(buf);
85            if Pin::new(&mut self.inner).poll_read(cx, &mut buf)?.is_pending() {
86                Poll::Pending
87            } else {
88                Poll::Ready(Ok(buf.filled().len()))
89            }
90        }
91    }
92}
93
94#[cfg(feature = "in-memory")]
95pub use in_memory_::*;
96#[cfg(feature = "in-memory")]
97mod in_memory_ {
98    use std::{
99        collections::VecDeque,
100        pin::Pin,
101        sync::Arc,
102        task::{Context, Poll},
103    };
104
105    use bytes::Bytes;
106    use futures_util::task::AtomicWaker;
107    use parking_lot::Mutex;
108
109    use crate::transport::{AsyncFrameRead, AsyncFrameWrite, FrameReader};
110
111    struct InMemoryInner {
112        chunks: VecDeque<Bytes>,
113        waker: AtomicWaker,
114
115        writer_dropped: bool,
116        reader_dropped: bool,
117    }
118
119    pub struct InMemoryWriter(Arc<Mutex<InMemoryInner>>);
120    pub struct InMemoryReader(Arc<Mutex<InMemoryInner>>);
121
122    pub fn new_in_memory() -> (InMemoryWriter, InMemoryReader) {
123        let inner = Arc::new(Mutex::new(InMemoryInner {
124            chunks: VecDeque::new(),
125            waker: AtomicWaker::new(),
126
127            writer_dropped: false,
128            reader_dropped: false,
129        }));
130
131        (InMemoryWriter(inner.clone()), InMemoryReader(inner.clone()))
132    }
133
134    impl AsyncFrameWrite for InMemoryWriter {
135        fn poll_write(
136            self: Pin<&mut Self>,
137            _cx: &mut Context<'_>,
138            buf: &mut FrameReader,
139        ) -> Poll<std::io::Result<()>> {
140            let mut inner = self.0.lock();
141            if inner.reader_dropped {
142                return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
143            }
144
145            inner.chunks.push_back(buf.take().freeze());
146            inner.waker.wake();
147            Poll::Ready(Ok(()))
148        }
149
150        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
151            Poll::Ready(Ok(()))
152        }
153
154        fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
155            let mut inner = self.0.lock();
156            inner.writer_dropped = true;
157            Poll::Ready(Ok(()))
158        }
159    }
160
161    impl Drop for InMemoryWriter {
162        fn drop(&mut self) {
163            let mut inner = self.0.lock();
164            inner.writer_dropped = true;
165            inner.waker.wake();
166        }
167    }
168
169    impl AsyncFrameRead for InMemoryReader {
170        fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<Bytes>> {
171            let mut inner = self.0.lock();
172            if inner.chunks.is_empty() {
173                if inner.writer_dropped {
174                    return Poll::Ready(Err(std::io::ErrorKind::BrokenPipe.into()));
175                }
176
177                inner.waker.register(cx.waker());
178                Poll::Pending
179            } else {
180                Poll::Ready(Ok(inner.chunks.pop_front().unwrap()))
181            }
182        }
183    }
184
185    impl Drop for InMemoryReader {
186        fn drop(&mut self) {
187            let mut inner = self.0.lock();
188            inner.reader_dropped = true;
189        }
190    }
191}