rama_hyper_util/rt/
tokio.rs

1#![allow(dead_code)]
2//! Tokio IO integration for hyper
3use std::{
4    future::Future,
5    pin::Pin,
6    task::{Context, Poll},
7    time::{Duration, Instant},
8};
9
10use hyper::rt::{Executor, Sleep, Timer};
11use pin_project_lite::pin_project;
12
13/// Future executor that utilises `tokio` threads.
14#[non_exhaustive]
15#[derive(Default, Debug, Clone)]
16pub struct TokioExecutor {}
17
18pin_project! {
19    /// A wrapping implementing hyper IO traits for a type that
20    /// implements Tokio's IO traits.
21    #[derive(Debug)]
22    pub struct TokioIo<T> {
23        #[pin]
24        inner: T,
25    }
26}
27
28/// A Timer that uses the tokio runtime.
29#[non_exhaustive]
30#[derive(Default, Clone, Debug)]
31pub struct TokioTimer;
32
33// Use TokioSleep to get tokio::time::Sleep to implement Unpin.
34// see https://docs.rs/tokio/latest/tokio/time/struct.Sleep.html
35pin_project! {
36    #[derive(Debug)]
37    struct TokioSleep {
38        #[pin]
39        inner: tokio::time::Sleep,
40    }
41}
42
43// ===== impl TokioExecutor =====
44
45impl<Fut> Executor<Fut> for TokioExecutor
46where
47    Fut: Future + Send + 'static,
48    Fut::Output: Send + 'static,
49{
50    fn execute(&self, fut: Fut) {
51        tokio::spawn(fut);
52    }
53}
54
55impl TokioExecutor {
56    /// Create new executor that relies on [`tokio::spawn`] to execute futures.
57    pub fn new() -> Self {
58        Self {}
59    }
60}
61
62// ==== impl TokioIo =====
63
64impl<T> TokioIo<T> {
65    /// Wrap a type implementing Tokio's IO traits.
66    pub fn new(inner: T) -> Self {
67        Self { inner }
68    }
69
70    /// Borrow the inner type.
71    pub fn inner(&self) -> &T {
72        &self.inner
73    }
74
75    /// Consume this wrapper and get the inner type.
76    pub fn into_inner(self) -> T {
77        self.inner
78    }
79}
80
81impl<T> hyper::rt::Read for TokioIo<T>
82where
83    T: tokio::io::AsyncRead,
84{
85    fn poll_read(
86        self: Pin<&mut Self>,
87        cx: &mut Context<'_>,
88        mut buf: hyper::rt::ReadBufCursor<'_>,
89    ) -> Poll<Result<(), std::io::Error>> {
90        let n = unsafe {
91            let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
92            match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
93                Poll::Ready(Ok(())) => tbuf.filled().len(),
94                other => return other,
95            }
96        };
97
98        unsafe {
99            buf.advance(n);
100        }
101        Poll::Ready(Ok(()))
102    }
103}
104
105impl<T> hyper::rt::Write for TokioIo<T>
106where
107    T: tokio::io::AsyncWrite,
108{
109    fn poll_write(
110        self: Pin<&mut Self>,
111        cx: &mut Context<'_>,
112        buf: &[u8],
113    ) -> Poll<Result<usize, std::io::Error>> {
114        tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
115    }
116
117    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
118        tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
119    }
120
121    fn poll_shutdown(
122        self: Pin<&mut Self>,
123        cx: &mut Context<'_>,
124    ) -> Poll<Result<(), std::io::Error>> {
125        tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
126    }
127
128    fn is_write_vectored(&self) -> bool {
129        tokio::io::AsyncWrite::is_write_vectored(&self.inner)
130    }
131
132    fn poll_write_vectored(
133        self: Pin<&mut Self>,
134        cx: &mut Context<'_>,
135        bufs: &[std::io::IoSlice<'_>],
136    ) -> Poll<Result<usize, std::io::Error>> {
137        tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
138    }
139}
140
141impl<T> tokio::io::AsyncRead for TokioIo<T>
142where
143    T: hyper::rt::Read,
144{
145    fn poll_read(
146        self: Pin<&mut Self>,
147        cx: &mut Context<'_>,
148        tbuf: &mut tokio::io::ReadBuf<'_>,
149    ) -> Poll<Result<(), std::io::Error>> {
150        //let init = tbuf.initialized().len();
151        let filled = tbuf.filled().len();
152        let sub_filled = unsafe {
153            let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
154
155            match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
156                Poll::Ready(Ok(())) => buf.filled().len(),
157                other => return other,
158            }
159        };
160
161        let n_filled = filled + sub_filled;
162        // At least sub_filled bytes had to have been initialized.
163        let n_init = sub_filled;
164        unsafe {
165            tbuf.assume_init(n_init);
166            tbuf.set_filled(n_filled);
167        }
168
169        Poll::Ready(Ok(()))
170    }
171}
172
173impl<T> tokio::io::AsyncWrite for TokioIo<T>
174where
175    T: hyper::rt::Write,
176{
177    fn poll_write(
178        self: Pin<&mut Self>,
179        cx: &mut Context<'_>,
180        buf: &[u8],
181    ) -> Poll<Result<usize, std::io::Error>> {
182        hyper::rt::Write::poll_write(self.project().inner, cx, buf)
183    }
184
185    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
186        hyper::rt::Write::poll_flush(self.project().inner, cx)
187    }
188
189    fn poll_shutdown(
190        self: Pin<&mut Self>,
191        cx: &mut Context<'_>,
192    ) -> Poll<Result<(), std::io::Error>> {
193        hyper::rt::Write::poll_shutdown(self.project().inner, cx)
194    }
195
196    fn is_write_vectored(&self) -> bool {
197        hyper::rt::Write::is_write_vectored(&self.inner)
198    }
199
200    fn poll_write_vectored(
201        self: Pin<&mut Self>,
202        cx: &mut Context<'_>,
203        bufs: &[std::io::IoSlice<'_>],
204    ) -> Poll<Result<usize, std::io::Error>> {
205        hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
206    }
207}
208
209// ==== impl TokioTimer =====
210
211impl Timer for TokioTimer {
212    fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
213        Box::pin(TokioSleep {
214            inner: tokio::time::sleep(duration),
215        })
216    }
217
218    fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
219        Box::pin(TokioSleep {
220            inner: tokio::time::sleep_until(deadline.into()),
221        })
222    }
223
224    fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
225        if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
226            sleep.reset(new_deadline)
227        }
228    }
229}
230
231impl TokioTimer {
232    /// Create a new TokioTimer
233    pub fn new() -> Self {
234        Self {}
235    }
236}
237
238impl Future for TokioSleep {
239    type Output = ();
240
241    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
242        self.project().inner.poll(cx)
243    }
244}
245
246impl Sleep for TokioSleep {}
247
248impl TokioSleep {
249    fn reset(self: Pin<&mut Self>, deadline: Instant) {
250        self.project().inner.as_mut().reset(deadline.into());
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use crate::rt::TokioExecutor;
257    use hyper::rt::Executor;
258    use tokio::sync::oneshot;
259
260    #[cfg(not(miri))]
261    #[tokio::test]
262    async fn simple_execute() -> Result<(), Box<dyn std::error::Error>> {
263        let (tx, rx) = oneshot::channel();
264        let executor = TokioExecutor::new();
265        executor.execute(async move {
266            tx.send(()).unwrap();
267        });
268        rx.await.map_err(Into::into)
269    }
270}