1use std::future::Future;
2use std::io;
3use std::net::SocketAddr;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use hyper::rt::{Executor, Sleep, Timer};
9use openwire_core::{
10 next_connection_id, BoxConnection, BoxFuture, BoxTaskHandle, CallContext, Connected,
11 Connection, ConnectionInfo, DnsResolver, EstablishmentStage, TaskHandle, TcpConnector,
12 WireError, WireErrorKind, WireExecutor,
13};
14use pin_project_lite::pin_project;
15use tracing::instrument::WithSubscriber;
16
17#[derive(Debug)]
18struct TokioTaskHandle(tokio::task::JoinHandle<()>);
19
20impl TaskHandle for TokioTaskHandle {
21 fn abort(&self) {
22 self.0.abort();
23 }
24}
25
26#[non_exhaustive]
27#[derive(Clone, Debug, Default)]
28pub struct TokioExecutor;
29
30impl TokioExecutor {
31 pub fn new() -> Self {
32 Self
33 }
34}
35
36impl<Fut> Executor<Fut> for TokioExecutor
37where
38 Fut: Future + Send + 'static,
39 Fut::Output: Send + 'static,
40{
41 fn execute(&self, future: Fut) {
42 tokio::spawn(future.with_current_subscriber());
43 }
44}
45
46impl WireExecutor for TokioExecutor {
47 fn spawn(&self, future: BoxFuture<()>) -> Result<BoxTaskHandle, WireError> {
48 Ok(Box::new(TokioTaskHandle(tokio::spawn(
49 future.with_current_subscriber(),
50 ))))
51 }
52}
53
54pin_project! {
55 #[derive(Debug)]
56 pub struct TokioIo<T> {
57 #[pin]
58 inner: T,
59 }
60}
61
62impl<T> TokioIo<T> {
63 pub fn new(inner: T) -> Self {
64 Self { inner }
65 }
66
67 pub fn inner(&self) -> &T {
68 &self.inner
69 }
70
71 pub fn inner_mut(&mut self) -> &mut T {
72 &mut self.inner
73 }
74
75 pub fn into_inner(self) -> T {
76 self.inner
77 }
78}
79
80impl<T> hyper::rt::Read for TokioIo<T>
81where
82 T: tokio::io::AsyncRead,
83{
84 fn poll_read(
85 self: Pin<&mut Self>,
86 cx: &mut Context<'_>,
87 mut buf: hyper::rt::ReadBufCursor<'_>,
88 ) -> Poll<Result<(), std::io::Error>> {
89 let filled = unsafe {
90 let mut read_buf = tokio::io::ReadBuf::uninit(buf.as_mut());
91 match tokio::io::AsyncRead::poll_read(self.project().inner, cx, &mut read_buf) {
92 Poll::Ready(Ok(())) => read_buf.filled().len(),
93 other => return other,
94 }
95 };
96
97 unsafe {
98 buf.advance(filled);
99 }
100 Poll::Ready(Ok(()))
101 }
102}
103
104impl<T> hyper::rt::Write for TokioIo<T>
105where
106 T: tokio::io::AsyncWrite,
107{
108 fn poll_write(
109 self: Pin<&mut Self>,
110 cx: &mut Context<'_>,
111 buf: &[u8],
112 ) -> Poll<Result<usize, std::io::Error>> {
113 tokio::io::AsyncWrite::poll_write(self.project().inner, cx, buf)
114 }
115
116 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
117 tokio::io::AsyncWrite::poll_flush(self.project().inner, cx)
118 }
119
120 fn poll_shutdown(
121 self: Pin<&mut Self>,
122 cx: &mut Context<'_>,
123 ) -> Poll<Result<(), std::io::Error>> {
124 tokio::io::AsyncWrite::poll_shutdown(self.project().inner, cx)
125 }
126
127 fn is_write_vectored(&self) -> bool {
128 tokio::io::AsyncWrite::is_write_vectored(&self.inner)
129 }
130
131 fn poll_write_vectored(
132 self: Pin<&mut Self>,
133 cx: &mut Context<'_>,
134 bufs: &[std::io::IoSlice<'_>],
135 ) -> Poll<Result<usize, std::io::Error>> {
136 tokio::io::AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs)
137 }
138}
139
140impl<T> tokio::io::AsyncRead for TokioIo<T>
141where
142 T: hyper::rt::Read,
143{
144 fn poll_read(
145 self: Pin<&mut Self>,
146 cx: &mut Context<'_>,
147 read_buf: &mut tokio::io::ReadBuf<'_>,
148 ) -> Poll<Result<(), std::io::Error>> {
149 let filled = read_buf.filled().len();
150 let newly_filled = unsafe {
151 let mut hyper_buf = hyper::rt::ReadBuf::uninit(read_buf.unfilled_mut());
152 match hyper::rt::Read::poll_read(self.project().inner, cx, hyper_buf.unfilled()) {
153 Poll::Ready(Ok(())) => hyper_buf.filled().len(),
154 other => return other,
155 }
156 };
157
158 unsafe {
159 read_buf.assume_init(newly_filled);
160 read_buf.set_filled(filled + newly_filled);
161 }
162
163 Poll::Ready(Ok(()))
164 }
165}
166
167impl<T> tokio::io::AsyncWrite for TokioIo<T>
168where
169 T: hyper::rt::Write,
170{
171 fn poll_write(
172 self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 buf: &[u8],
175 ) -> Poll<Result<usize, std::io::Error>> {
176 hyper::rt::Write::poll_write(self.project().inner, cx, buf)
177 }
178
179 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
180 hyper::rt::Write::poll_flush(self.project().inner, cx)
181 }
182
183 fn poll_shutdown(
184 self: Pin<&mut Self>,
185 cx: &mut Context<'_>,
186 ) -> Poll<Result<(), std::io::Error>> {
187 hyper::rt::Write::poll_shutdown(self.project().inner, cx)
188 }
189
190 fn is_write_vectored(&self) -> bool {
191 hyper::rt::Write::is_write_vectored(&self.inner)
192 }
193
194 fn poll_write_vectored(
195 self: Pin<&mut Self>,
196 cx: &mut Context<'_>,
197 bufs: &[std::io::IoSlice<'_>],
198 ) -> Poll<Result<usize, std::io::Error>> {
199 hyper::rt::Write::poll_write_vectored(self.project().inner, cx, bufs)
200 }
201}
202
203#[non_exhaustive]
204#[derive(Clone, Debug, Default)]
205pub struct TokioTimer;
206
207impl TokioTimer {
208 pub fn new() -> Self {
209 Self
210 }
211}
212
213pin_project! {
214 #[derive(Debug)]
215 struct TokioSleep {
216 #[pin]
217 inner: tokio::time::Sleep,
218 }
219}
220
221impl Timer for TokioTimer {
222 fn sleep(&self, duration: Duration) -> Pin<Box<dyn Sleep>> {
223 Box::pin(TokioSleep {
224 inner: tokio::time::sleep(duration),
225 })
226 }
227
228 fn sleep_until(&self, deadline: Instant) -> Pin<Box<dyn Sleep>> {
229 Box::pin(TokioSleep {
230 inner: tokio::time::sleep_until(deadline.into()),
231 })
232 }
233
234 fn reset(&self, sleep: &mut Pin<Box<dyn Sleep>>, new_deadline: Instant) {
235 if let Some(tokio_sleep) = sleep.as_mut().downcast_mut_pin::<TokioSleep>() {
236 tokio_sleep.reset(new_deadline);
237 }
238 }
239
240 fn now(&self) -> Instant {
241 tokio::time::Instant::now().into()
242 }
243}
244
245impl Future for TokioSleep {
246 type Output = ();
247
248 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
249 self.project().inner.poll(cx)
250 }
251}
252
253impl Sleep for TokioSleep {}
254
255impl TokioSleep {
256 fn reset(self: Pin<&mut Self>, deadline: Instant) {
257 self.project().inner.as_mut().reset(deadline.into());
258 }
259}
260
261#[derive(Clone, Debug, Default)]
262pub struct SystemDnsResolver;
263
264impl DnsResolver for SystemDnsResolver {
265 fn resolve(
266 &self,
267 ctx: CallContext,
268 host: String,
269 port: u16,
270 ) -> BoxFuture<Result<Vec<SocketAddr>, WireError>> {
271 Box::pin(async move {
272 ctx.listener().dns_start(&ctx, &host, port);
273 match tokio::net::lookup_host((host.as_str(), port)).await {
274 Ok(addrs) => {
275 let addrs: Vec<_> = addrs.collect();
276 if addrs.is_empty() {
277 let error = WireError::dns(
278 "DNS resolution returned no socket addresses",
279 io::Error::new(io::ErrorKind::NotFound, "empty DNS result"),
280 );
281 ctx.listener().dns_failed(&ctx, &host, &error);
282 return Err(error);
283 }
284 ctx.listener().dns_end(&ctx, &host, &addrs);
285 Ok(addrs)
286 }
287 Err(error) => {
288 let error = WireError::dns("DNS resolution failed", error);
289 ctx.listener().dns_failed(&ctx, &host, &error);
290 Err(error)
291 }
292 }
293 })
294 }
295}
296
297#[derive(Clone, Debug, Default)]
298pub struct TokioTcpConnector;
299
300impl TcpConnector for TokioTcpConnector {
301 fn connect(
302 &self,
303 ctx: CallContext,
304 addr: SocketAddr,
305 timeout: Option<Duration>,
306 ) -> BoxFuture<Result<BoxConnection, WireError>> {
307 Box::pin(async move {
308 ctx.listener().connect_start(&ctx, addr);
309 let connect = tokio::net::TcpStream::connect(addr);
310 let stream = match timeout {
311 Some(timeout) => match tokio::time::timeout(timeout, connect).await {
312 Ok(result) => {
313 result.map_err(|error| WireError::tcp_connect("TCP connect failed", error))
314 }
315 Err(error) => Err(WireError::with_source(
316 WireErrorKind::Timeout,
317 format!("connection timed out after {timeout:?}"),
318 error,
319 )
320 .with_establishment(EstablishmentStage::Tcp, true)
321 .with_connect_timeout()),
322 },
323 None => connect
324 .await
325 .map_err(|error| WireError::tcp_connect("TCP connect failed", error)),
326 };
327 let stream = match stream {
328 Ok(stream) => stream,
329 Err(error) => {
330 ctx.listener().connect_failed(&ctx, addr, &error);
331 return Err(error);
332 }
333 };
334
335 stream
336 .set_nodelay(true)
337 .map_err(|error| WireError::tcp_connect("failed to configure TCP_NODELAY", error))
338 .inspect_err(|error| {
339 ctx.listener().connect_failed(&ctx, addr, error);
340 })?;
341
342 let info = ConnectionInfo {
343 id: next_connection_id(),
344 remote_addr: stream.peer_addr().ok(),
345 local_addr: stream.local_addr().ok(),
346 tls: false,
347 };
348
349 ctx.mark_connection_established();
350 ctx.listener().connect_end(&ctx, info.id, addr);
351
352 Ok(Box::new(TcpConnection {
353 inner: TokioIo::new(stream),
354 info,
355 }) as BoxConnection)
356 })
357 }
358}
359
360struct TcpConnection {
361 inner: TokioIo<tokio::net::TcpStream>,
362 info: ConnectionInfo,
363}
364
365impl Connection for TcpConnection {
366 fn connected(&self) -> Connected {
367 Connected::new().info(self.info.clone())
368 }
369}
370
371impl hyper::rt::Read for TcpConnection {
372 fn poll_read(
373 self: Pin<&mut Self>,
374 cx: &mut Context<'_>,
375 buf: hyper::rt::ReadBufCursor<'_>,
376 ) -> Poll<Result<(), io::Error>> {
377 Pin::new(&mut self.get_mut().inner).poll_read(cx, buf)
378 }
379}
380
381impl hyper::rt::Write for TcpConnection {
382 fn poll_write(
383 self: Pin<&mut Self>,
384 cx: &mut Context<'_>,
385 buf: &[u8],
386 ) -> Poll<Result<usize, io::Error>> {
387 Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)
388 }
389
390 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
391 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
392 }
393
394 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
395 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
396 }
397
398 fn is_write_vectored(&self) -> bool {
399 self.inner.is_write_vectored()
400 }
401
402 fn poll_write_vectored(
403 self: Pin<&mut Self>,
404 cx: &mut Context<'_>,
405 bufs: &[io::IoSlice<'_>],
406 ) -> Poll<Result<usize, io::Error>> {
407 Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)
408 }
409}
410
411#[cfg(test)]
412mod tests {
413 use std::time::Duration;
414
415 use hyper::rt::Executor;
416 use hyper::rt::Timer;
417 use tokio::sync::oneshot;
418
419 use super::{TokioExecutor, TokioTimer};
420
421 #[tokio::test]
422 async fn tokio_executor_spawns_background_future() {
423 let (tx, rx) = oneshot::channel();
424 TokioExecutor::new().execute(async move {
425 let _ = tx.send(());
426 });
427 rx.await.expect("executor future should complete");
428 }
429
430 #[tokio::test]
431 async fn tokio_timer_reset_moves_sleep_deadline() {
432 let timer = TokioTimer::new();
433 let mut sleep = hyper::rt::Timer::sleep(&timer, Duration::from_secs(5));
434 timer.reset(&mut sleep, timer.now() + Duration::from_millis(1));
435 sleep.await;
436 }
437}