hyper_client_sockets/
tokio.rs

1#[cfg(feature = "vsock")]
2use std::{
3    io::{Read, Write},
4    os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd},
5    task::Poll,
6};
7#[cfg(feature = "vsock")]
8use std::{pin::Pin, task::Context};
9
10#[cfg(any(feature = "unix", feature = "firecracker"))]
11use hyper_util::rt::TokioIo;
12#[cfg(any(feature = "unix", feature = "firecracker"))]
13use std::path::Path;
14#[cfg(any(feature = "unix", feature = "firecracker"))]
15use tokio::net::UnixStream;
16
17#[cfg(feature = "firecracker")]
18use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
19
20#[cfg(feature = "vsock")]
21use tokio::io::unix::AsyncFd;
22
23use crate::Backend;
24
25/// [Backend] for hyper-client-sockets that is implemented via the Tokio reactor.
26#[derive(Debug, Clone)]
27pub struct TokioBackend;
28
29impl Backend for TokioBackend {
30    #[cfg(feature = "unix")]
31    #[cfg_attr(docsrs, doc(cfg(feature = "unix")))]
32    type UnixIo = TokioIo<UnixStream>;
33
34    #[cfg(feature = "vsock")]
35    #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
36    type VsockIo = TokioVsockIo;
37
38    #[cfg(feature = "firecracker")]
39    #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))]
40    type FirecrackerIo = TokioIo<UnixStream>;
41
42    #[cfg(feature = "unix")]
43    #[cfg_attr(docsrs, doc(cfg(feature = "unix")))]
44    async fn connect_to_unix_socket(socket_path: &Path) -> Result<Self::UnixIo, std::io::Error> {
45        Ok(TokioIo::new(UnixStream::connect(socket_path).await?))
46    }
47
48    #[cfg(feature = "vsock")]
49    #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
50    async fn connect_to_vsock_socket(addr: vsock::VsockAddr) -> Result<Self::VsockIo, std::io::Error> {
51        TokioVsockIo::connect(addr).await
52    }
53
54    #[cfg(feature = "firecracker")]
55    #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))]
56    async fn connect_to_firecracker_socket(
57        host_socket_path: &Path,
58        guest_port: u32,
59    ) -> Result<Self::FirecrackerIo, std::io::Error> {
60        let mut stream = UnixStream::connect(host_socket_path).await?;
61        stream.write_all(format!("CONNECT {guest_port}\n").as_bytes()).await?;
62
63        let mut lines = BufReader::new(&mut stream).lines();
64        match lines.next_line().await {
65            Ok(Some(line)) => {
66                if !line.starts_with("OK") {
67                    return Err(std::io::Error::new(
68                        std::io::ErrorKind::ConnectionRefused,
69                        "Firecracker refused to establish a tunnel to the given guest port",
70                    ));
71                }
72            }
73            _ => {
74                return Err(std::io::Error::new(
75                    std::io::ErrorKind::InvalidInput,
76                    "Could not read Firecracker response",
77                ))
78            }
79        };
80
81        Ok(TokioIo::new(stream))
82    }
83}
84
85/// IO object representing an active vsock connection controlled via a Tokio [AsyncFd].
86/// This is internally a reimplementation of a relevant part of the tokio-vsock crate.
87#[cfg(feature = "vsock")]
88#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
89pub struct TokioVsockIo(AsyncFd<vsock::VsockStream>);
90
91#[cfg(feature = "vsock")]
92#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
93impl TokioVsockIo {
94    async fn connect(addr: vsock::VsockAddr) -> Result<Self, std::io::Error> {
95        let socket = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM, 0) };
96        if socket < 0 {
97            return Err(std::io::Error::last_os_error());
98        }
99
100        if unsafe { libc::fcntl(socket, libc::F_SETFL, libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 {
101            let _ = unsafe { libc::close(socket) };
102            return Err(std::io::Error::last_os_error());
103        }
104
105        if unsafe {
106            libc::connect(
107                socket,
108                &addr as *const _ as *const libc::sockaddr,
109                size_of::<libc::sockaddr_vm>() as libc::socklen_t,
110            )
111        } < 0
112        {
113            let err = std::io::Error::last_os_error();
114            if let Some(os_err) = err.raw_os_error() {
115                if os_err != libc::EINPROGRESS {
116                    let _ = unsafe { libc::close(socket) };
117                    return Err(err);
118                }
119            }
120        }
121
122        let async_fd = AsyncFd::new(unsafe { OwnedFd::from_raw_fd(socket) })?;
123
124        loop {
125            let mut guard = async_fd.writable().await?;
126
127            let connection_check = guard.try_io(|fd| {
128                let mut sock_err: libc::c_int = 0;
129                let mut sock_err_len: libc::socklen_t = size_of::<libc::c_int>() as libc::socklen_t;
130                let err = unsafe {
131                    libc::getsockopt(
132                        fd.as_raw_fd(),
133                        libc::SOL_SOCKET,
134                        libc::SO_ERROR,
135                        &mut sock_err as *mut _ as *mut libc::c_void,
136                        &mut sock_err_len as *mut libc::socklen_t,
137                    )
138                };
139
140                if err < 0 {
141                    return Err(std::io::Error::last_os_error());
142                }
143
144                if sock_err == 0 {
145                    Ok(())
146                } else {
147                    Err(std::io::Error::from_raw_os_error(sock_err))
148                }
149            });
150
151            match connection_check {
152                Ok(Ok(_)) => {
153                    return Ok(TokioVsockIo(AsyncFd::new(unsafe {
154                        vsock::VsockStream::from_raw_fd(async_fd.into_inner().into_raw_fd())
155                    })?))
156                }
157                Ok(Err(err)) => return Err(err),
158                Err(_would_block) => continue,
159            }
160        }
161    }
162}
163
164#[cfg(feature = "vsock")]
165#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
166impl hyper::rt::Write for TokioVsockIo {
167    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
168        loop {
169            let mut guard = match self.0.poll_write_ready(cx) {
170                Poll::Ready(Ok(guard)) => guard,
171                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
172                Poll::Pending => return Poll::Pending,
173            };
174
175            match guard.try_io(|inner| inner.get_ref().write(buf)) {
176                Ok(Ok(amount)) => return Ok(amount).into(),
177                Ok(Err(ref err)) if err.kind() == std::io::ErrorKind::Interrupted => continue,
178                Ok(Err(err)) => return Err(err).into(),
179                Err(_would_block) => continue,
180            }
181        }
182    }
183
184    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
185        Poll::Ready(Ok(()))
186    }
187
188    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
189        Poll::Ready(Ok(()))
190    }
191}
192
193#[cfg(feature = "vsock")]
194#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
195impl hyper::rt::Read for TokioVsockIo {
196    fn poll_read(
197        self: Pin<&mut Self>,
198        cx: &mut Context<'_>,
199        mut buf: hyper::rt::ReadBufCursor<'_>,
200    ) -> Poll<Result<(), std::io::Error>> {
201        let b;
202        unsafe {
203            b = &mut *(buf.as_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
204        };
205
206        loop {
207            let mut guard = match self.0.poll_read_ready(cx) {
208                Poll::Ready(Ok(guard)) => guard,
209                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
210                Poll::Pending => return Poll::Pending,
211            };
212
213            match guard.try_io(|inner| inner.get_ref().read(b)) {
214                Ok(Ok(amount)) => {
215                    unsafe {
216                        buf.advance(amount);
217                    }
218
219                    return Ok(()).into();
220                }
221                Ok(Err(ref err)) if err.kind() == std::io::ErrorKind::Interrupted => continue,
222                Ok(Err(err)) => return Err(err).into(),
223                Err(_would_block) => {
224                    continue;
225                }
226            }
227        }
228    }
229}