rama_hyper_util/rt/
tokio.rs1#![allow(dead_code)]
2use std::{
4 future::Future,
5 pin::Pin,
6 task::{Context, Poll},
7 time::{Duration, Instant},
8};
9
10use hyper::rt::{Executor, Sleep, Timer};
11use pin_project_lite::pin_project;
12
13#[non_exhaustive]
15#[derive(Default, Debug, Clone)]
16pub struct TokioExecutor {}
17
18pin_project! {
19 #[derive(Debug)]
22 pub struct TokioIo<T> {
23 #[pin]
24 inner: T,
25 }
26}
27
28#[non_exhaustive]
30#[derive(Default, Clone, Debug)]
31pub struct TokioTimer;
32
33pin_project! {
36 #[derive(Debug)]
37 struct TokioSleep {
38 #[pin]
39 inner: tokio::time::Sleep,
40 }
41}
42
43impl<Fut> Executor<Fut> for TokioExecutor
46where
47 Fut: Future + Send + 'static,
48 Fut::Output: Send + 'static,
49{
50 fn execute(&self, fut: Fut) {
51 tokio::spawn(fut);
52 }
53}
54
55impl TokioExecutor {
56 pub fn new() -> Self {
58 Self {}
59 }
60}
61
62impl<T> TokioIo<T> {
65 pub fn new(inner: T) -> Self {
67 Self { inner }
68 }
69
70 pub fn inner(&self) -> &T {
72 &self.inner
73 }
74
75 pub fn into_inner(self) -> T {
77 self.inner
78 }
79}
80
81impl<T> hyper::rt::Read for TokioIo<T>
82where
83 T: tokio::io::AsyncRead,
84{
85 fn poll_read(
86 self: Pin<&mut Self>,
87 cx: &mut Context<'_>,
88 mut buf: hyper::rt::ReadBufCursor<'_>,
89 ) -> Poll<Result<(), std::io::Error>> {
90 let n = unsafe {
91 let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
92 match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut tbuf) {
93 Poll::Ready(Ok(())) => tbuf.filled().len(),
94 other => return other,
95 }
96 };
97
98 unsafe {
99 buf.advance(n);
100 }
101 Poll::Ready(Ok(()))
102 }
103}
104
105impl<T> hyper::rt::Write for TokioIo<T>
106where
107 T: tokio::io::AsyncWrite,
108{
109 fn poll_write(
110 self: Pin<&mut Self>,
111 cx: &mut Context<'_>,
112 buf: &[u8],
113 ) -> Poll<Result<usize, std::io::Error>> {
114 tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
115 }
116
117 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
118 tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
119 }
120
121 fn poll_shutdown(
122 self: Pin<&mut Self>,
123 cx: &mut Context<'_>,
124 ) -> Poll<Result<(), std::io::Error>> {
125 tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
126 }
127
128 fn is_write_vectored(&self) -> bool {
129 tokio::io::AsyncWrite::is_write_vectored(&self.inner)
130 }
131
132 fn poll_write_vectored(
133 self: Pin<&mut Self>,
134 cx: &mut Context<'_>,
135 bufs: &[std::io::IoSlice<'_>],
136 ) -> Poll<Result<usize, std::io::Error>> {
137 tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
138 }
139}
140
141impl<T> tokio::io::AsyncRead for TokioIo<T>
142where
143 T: hyper::rt::Read,
144{
145 fn poll_read(
146 self: Pin<&mut Self>,
147 cx: &mut Context<'_>,
148 tbuf: &mut tokio::io::ReadBuf<'_>,
149 ) -> Poll<Result<(), std::io::Error>> {
150 let filled = tbuf.filled().len();
152 let sub_filled = unsafe {
153 let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
154
155 match hyper::rt::Read::poll_read(self.project().inner, cx, buf.unfilled()) {
156 Poll::Ready(Ok(())) => buf.filled().len(),
157 other => return other,
158 }
159 };
160
161 let n_filled = filled + sub_filled;
162 let n_init = sub_filled;
164 unsafe {
165 tbuf.assume_init(n_init);
166 tbuf.set_filled(n_filled);
167 }
168
169 Poll::Ready(Ok(()))
170 }
171}
172
173impl<T> tokio::io::AsyncWrite for TokioIo<T>
174where
175 T: hyper::rt::Write,
176{
177 fn poll_write(
178 self: Pin<&mut Self>,
179 cx: &mut Context<'_>,
180 buf: &[u8],
181 ) -> Poll<Result<usize, std::io::Error>> {
182 hyper::rt::Write::poll_write(self.project().inner, cx, buf)
183 }
184
185 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
186 hyper::rt::Write::poll_flush(self.project().inner, cx)
187 }
188
189 fn poll_shutdown(
190 self: Pin<&mut Self>,
191 cx: &mut Context<'_>,
192 ) -> Poll<Result<(), std::io::Error>> {
193 hyper::rt::Write::poll_shutdown(self.project().inner, cx)
194 }
195
196 fn is_write_vectored(&self) -> bool {
197 hyper::rt::Write::is_write_vectored(&self.inner)
198 }
199
200 fn poll_write_vectored(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 bufs: &[std::io::IoSlice<'_>],
204 ) -> Poll<Result<usize, std::io::Error>> {
205 hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
206 }
207}
208
209impl Timer for TokioTimer {
212 fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
213 Box::pin(TokioSleep {
214 inner: tokio::time::sleep(duration),
215 })
216 }
217
218 fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
219 Box::pin(TokioSleep {
220 inner: tokio::time::sleep_until(deadline.into()),
221 })
222 }
223
224 fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
225 if let Some(sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
226 sleep.reset(new_deadline)
227 }
228 }
229}
230
231impl TokioTimer {
232 pub fn new() -> Self {
234 Self {}
235 }
236}
237
238impl Future for TokioSleep {
239 type Output = ();
240
241 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
242 self.project().inner.poll(cx)
243 }
244}
245
246impl Sleep for TokioSleep {}
247
248impl TokioSleep {
249 fn reset(self: Pin<&mut Self>, deadline: Instant) {
250 self.project().inner.as_mut().reset(deadline.into());
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use crate::rt::TokioExecutor;
257 use hyper::rt::Executor;
258 use tokio::sync::oneshot;
259
260 #[cfg(not(miri))]
261 #[tokio::test]
262 async fn simple_execute() -> Result<(), Box<dyn std::error::Error>> {
263 let (tx, rx) = oneshot::channel();
264 let executor = TokioExecutor::new();
265 executor.execute(async move {
266 tx.send(()).unwrap();
267 });
268 rx.await.map_err(Into::into)
269 }
270}