Skip to main content

openwire_tokio/
lib.rs

1use std::future::Future;
2use std::io;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use hyper::rt::{Executor, Sleep, Timer};
9use openwire_core::{
10    next_connection_id, BoxConnection, BoxFuture, BoxTaskHandle, CallContext, Connected,
11    Connection, ConnectionInfo, DnsResolver, EstablishmentStage, TaskHandle, TcpConnector,
12    WireError, WireErrorKind, WireExecutor,
13};
14use pin_project_lite::pin_project;
15use tracing::instrument::WithSubscriber;
16
17#[derive(Debug)]
18struct TokioTaskHandle(tokio::task::JoinHandle<()>);
19
20impl TaskHandle for TokioTaskHandle {
21    fn abort(&self) {
22        self.0.abort();
23    }
24}
25
26#[non_exhaustive]
27#[derive(Clone, Debug, Default)]
28pub struct TokioExecutor;
29
30impl TokioExecutor {
31    pub fn new() -> Self {
32        Self
33    }
34}
35
36impl<Fut> Executor<Fut> for TokioExecutor
37where
38    Fut: Future + Send + 'static,
39    Fut::Output: Send + 'static,
40{
41    fn execute(&self, future: Fut) {
42        tokio::spawn(future.with_current_subscriber());
43    }
44}
45
46impl WireExecutor for TokioExecutor {
47    fn spawn(&self, future: BoxFuture<()>) -> Result<BoxTaskHandle, WireError> {
48        Ok(Box::new(TokioTaskHandle(tokio::spawn(
49            future.with_current_subscriber(),
50        ))))
51    }
52}
53
54pin_project! {
55    #[derive(Debug)]
56    pub struct TokioIo<T> {
57        #[pin]
58        inner: T,
59    }
60}
61
62impl<T> TokioIo<T> {
63    pub fn new(inner: T) -> Self {
64        Self { inner }
65    }
66
67    pub fn inner(&self) -> &T {
68        &self.inner
69    }
70
71    pub fn inner_mut(&mut self) -> &mut T {
72        &mut self.inner
73    }
74
75    pub fn into_inner(self) -> T {
76        self.inner
77    }
78}
79
80impl<T> hyper::rt::Read for TokioIo<T>
81where
82    T: tokio::io::AsyncRead,
83{
84    fn poll_read(
85        self: Pin<&mut Self>,
86        cx: &mut Context<'_>,
87        mut buf: hyper::rt::ReadBufCursor<'_>,
88    ) -> Poll<Result<(), std::io::Error>> {
89        let filled = unsafe {
90            let mut read_buf = tokio::io::ReadBuf::uninit(buf.as_mut());
91            match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut read_buf) {
92                Poll::Ready(Ok(())) => read_buf.filled().len(),
93                other => return other,
94            }
95        };
96
97        unsafe {
98            buf.advance(filled);
99        }
100        Poll::Ready(Ok(()))
101    }
102}
103
104impl<T> hyper::rt::Write for TokioIo<T>
105where
106    T: tokio::io::AsyncWrite,
107{
108    fn poll_write(
109        self: Pin<&mut Self>,
110        cx: &mut Context<'_>,
111        buf: &[u8],
112    ) -> Poll<Result<usize, std::io::Error>> {
113        tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
114    }
115
116    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
117        tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
118    }
119
120    fn poll_shutdown(
121        self: Pin<&mut Self>,
122        cx: &mut Context<'_>,
123    ) -> Poll<Result<(), std::io::Error>> {
124        tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
125    }
126
127    fn is_write_vectored(&self) -> bool {
128        tokio::io::AsyncWrite::is_write_vectored(&self.inner)
129    }
130
131    fn poll_write_vectored(
132        self: Pin<&mut Self>,
133        cx: &mut Context<'_>,
134        bufs: &[std::io::IoSlice<'_>],
135    ) -> Poll<Result<usize, std::io::Error>> {
136        tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
137    }
138}
139
140impl<T> tokio::io::AsyncRead for TokioIo<T>
141where
142    T: hyper::rt::Read,
143{
144    fn poll_read(
145        self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        read_buf: &mut tokio::io::ReadBuf<'_>,
148    ) -> Poll<Result<(), std::io::Error>> {
149        let filled = read_buf.filled().len();
150        let newly_filled = unsafe {
151            let mut hyper_buf = hyper::rt::ReadBuf::uninit(read_buf.unfilled_mut());
152            match hyper::rt::Read::poll_read(self.project().inner, cx, hyper_buf.unfilled()) {
153                Poll::Ready(Ok(())) => hyper_buf.filled().len(),
154                other => return other,
155            }
156        };
157
158        unsafe {
159            read_buf.assume_init(newly_filled);
160            read_buf.set_filled(filled + newly_filled);
161        }
162
163        Poll::Ready(Ok(()))
164    }
165}
166
167impl<T> tokio::io::AsyncWrite for TokioIo<T>
168where
169    T: hyper::rt::Write,
170{
171    fn poll_write(
172        self: Pin<&mut Self>,
173        cx: &mut Context<'_>,
174        buf: &[u8],
175    ) -> Poll<Result<usize, std::io::Error>> {
176        hyper::rt::Write::poll_write(self.project().inner, cx, buf)
177    }
178
179    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
180        hyper::rt::Write::poll_flush(self.project().inner, cx)
181    }
182
183    fn poll_shutdown(
184        self: Pin<&mut Self>,
185        cx: &mut Context<'_>,
186    ) -> Poll<Result<(), std::io::Error>> {
187        hyper::rt::Write::poll_shutdown(self.project().inner, cx)
188    }
189
190    fn is_write_vectored(&self) -> bool {
191        hyper::rt::Write::is_write_vectored(&self.inner)
192    }
193
194    fn poll_write_vectored(
195        self: Pin<&mut Self>,
196        cx: &mut Context<'_>,
197        bufs: &[std::io::IoSlice<'_>],
198    ) -> Poll<Result<usize, std::io::Error>> {
199        hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
200    }
201}
202
203#[non_exhaustive]
204#[derive(Clone, Debug, Default)]
205pub struct TokioTimer;
206
207impl TokioTimer {
208    pub fn new() -> Self {
209        Self
210    }
211}
212
213pin_project! {
214    #[derive(Debug)]
215    struct TokioSleep {
216        #[pin]
217        inner: tokio::time::Sleep,
218    }
219}
220
221impl Timer for TokioTimer {
222    fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
223        Box::pin(TokioSleep {
224            inner: tokio::time::sleep(duration),
225        })
226    }
227
228    fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
229        Box::pin(TokioSleep {
230            inner: tokio::time::sleep_until(deadline.into()),
231        })
232    }
233
234    fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
235        if let Some(tokio_sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
236            tokio_sleep.reset(new_deadline);
237        }
238    }
239
240    fn now(&self) -> Instant {
241        tokio::time::Instant::now().into()
242    }
243}
244
245impl Future for TokioSleep {
246    type Output = ();
247
248    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
249        self.project().inner.poll(cx)
250    }
251}
252
253impl Sleep for TokioSleep {}
254
255impl TokioSleep {
256    fn reset(self: Pin<&mut Self>, deadline: Instant) {
257        self.project().inner.as_mut().reset(deadline.into());
258    }
259}
260
261#[derive(Clone, Debug, Default)]
262pub struct SystemDnsResolver;
263
264impl DnsResolver for SystemDnsResolver {
265    fn resolve(
266        &self,
267        ctx: CallContext,
268        host: String,
269        port: u16,
270    ) -> BoxFuture<Result<Vec<SocketAddr>, WireError>> {
271        Box::pin(async move {
272            ctx.listener().dns_start(&ctx, &host, port);
273            match tokio::net::lookup_host((host.as_str(), port)).await {
274                Ok(addrs) => {
275                    let addrs: Vec<_> = addrs.collect();
276                    if addrs.is_empty() {
277                        let error = WireError::dns(
278                            "DNS resolution returned no socket addresses",
279                            io::Error::new(io::ErrorKind::NotFound, "empty DNS result"),
280                        );
281                        ctx.listener().dns_failed(&ctx, &host, &error);
282                        return Err(error);
283                    }
284                    ctx.listener().dns_end(&ctx, &host, &addrs);
285                    Ok(addrs)
286                }
287                Err(error) => {
288                    let error = WireError::dns("DNS resolution failed", error);
289                    ctx.listener().dns_failed(&ctx, &host, &error);
290                    Err(error)
291                }
292            }
293        })
294    }
295}
296
297#[derive(Clone, Debug, Default)]
298pub struct TokioTcpConnector;
299
300impl TcpConnector for TokioTcpConnector {
301    fn connect(
302        &self,
303        ctx: CallContext,
304        addr: SocketAddr,
305        timeout: Option<Duration>,
306    ) -> BoxFuture<Result<BoxConnection, WireError>> {
307        Box::pin(async move {
308            ctx.listener().connect_start(&ctx, addr);
309            let connect = tokio::net::TcpStream::connect(addr);
310            let stream = match timeout {
311                Some(timeout) => match tokio::time::timeout(timeout, connect).await {
312                    Ok(result) => {
313                        result.map_err(|error| WireError::tcp_connect("TCP connect failed", error))
314                    }
315                    Err(error) => Err(WireError::with_source(
316                        WireErrorKind::Timeout,
317                        format!("connection timed out after {timeout:?}"),
318                        error,
319                    )
320                    .with_establishment(EstablishmentStage::Tcp, true)
321                    .with_connect_timeout()),
322                },
323                None => connect
324                    .await
325                    .map_err(|error| WireError::tcp_connect("TCP connect failed", error)),
326            };
327            let stream = match stream {
328                Ok(stream) => stream,
329                Err(error) => {
330                    ctx.listener().connect_failed(&ctx, addr, &error);
331                    return Err(error);
332                }
333            };
334
335            stream
336                .set_nodelay(true)
337                .map_err(|error| WireError::tcp_connect("failed to configure TCP_NODELAY", error))
338                .inspect_err(|error| {
339                    ctx.listener().connect_failed(&ctx, addr, error);
340                })?;
341
342            let info = ConnectionInfo {
343                id: next_connection_id(),
344                remote_addr: stream.peer_addr().ok(),
345                local_addr: stream.local_addr().ok(),
346                tls: false,
347            };
348
349            ctx.mark_connection_established();
350            ctx.listener().connect_end(&ctx, info.id, addr);
351
352            Ok(Box::new(TcpConnection {
353                inner: TokioIo::new(stream),
354                info,
355            }) as BoxConnection)
356        })
357    }
358}
359
360struct TcpConnection {
361    inner: TokioIo<tokio::net::TcpStream>,
362    info: ConnectionInfo,
363}
364
365impl Connection for TcpConnection {
366    fn connected(&self) -> Connected {
367        Connected::new().info(self.info.clone())
368    }
369}
370
371impl hyper::rt::Read for TcpConnection {
372    fn poll_read(
373        self: Pin<&mut Self>,
374        cx: &mut Context<'_>,
375        buf: hyper::rt::ReadBufCursor<'_>,
376    ) -> Poll<Result<(), io::Error>> {
377        Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
378    }
379}
380
381impl hyper::rt::Write for TcpConnection {
382    fn poll_write(
383        self: Pin<&mut Self>,
384        cx: &mut Context<'_>,
385        buf: &[u8],
386    ) -> Poll<Result<usize, io::Error>> {
387        Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
388    }
389
390    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
391        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
392    }
393
394    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
395        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
396    }
397
398    fn is_write_vectored(&self) -> bool {
399        self.inner.is_write_vectored()
400    }
401
402    fn poll_write_vectored(
403        self: Pin<&mut Self>,
404        cx: &mut Context<'_>,
405        bufs: &[io::IoSlice<'_>],
406    ) -> Poll<Result<usize, io::Error>> {
407        Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)
408    }
409}
410
411#[cfg(test)]
412mod tests {
413    use std::time::Duration;
414
415    use hyper::rt::Executor;
416    use hyper::rt::Timer;
417    use tokio::sync::oneshot;
418
419    use super::{TokioExecutor, TokioTimer};
420
421    #[tokio::test]
422    async fn tokio_executor_spawns_background_future() {
423        let (tx, rx) = oneshot::channel();
424        TokioExecutor::new().execute(async move {
425            let _ = tx.send(());
426        });
427        rx.await.expect("executor future should complete");
428    }
429
430    #[tokio::test]
431    async fn tokio_timer_reset_moves_sleep_deadline() {
432        let timer = TokioTimer::new();
433        let mut sleep = hyper::rt::Timer::sleep(&timer, Duration::from_secs(5));
434        timer.reset(&mut sleep, timer.now() + Duration::from_millis(1));
435        sleep.await;
436    }
437}