hyper_client_sockets/
async_io.rs

1#[cfg(any(feature = "unix", feature = "firecracker"))]
2use std::path::Path;
3#[cfg(feature = "vsock")]
4use std::{
5    io::{Read, Write},
6    os::fd::{AsRawFd, FromRawFd, IntoRawFd},
7    pin::Pin,
8    task::{Context, Poll},
9};
10
11#[cfg(any(feature = "unix", feature = "vsock", feature = "firecracker"))]
12use async_io::Async;
13#[cfg(feature = "firecracker")]
14use futures_lite::{io::BufReader, AsyncBufReadExt, AsyncWriteExt, StreamExt};
15#[cfg(any(feature = "unix", feature = "firecracker"))]
16use smol_hyper::rt::FuturesIo;
17#[cfg(feature = "vsock")]
18use vsock::VsockAddr;
19
20use crate::Backend;
21
22/// [Backend] for hyper-client-sockets that is implemented via the async-io crate's reactor.
23#[derive(Debug, Clone)]
24pub struct AsyncIoBackend;
25
26impl Backend for AsyncIoBackend {
27    #[cfg(feature = "unix")]
28    #[cfg_attr(docsrs, doc(cfg(feature = "unix")))]
29    type UnixIo = FuturesIo<Async<std::os::unix::net::UnixStream>>;
30
31    #[cfg(feature = "vsock")]
32    #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
33    type VsockIo = AsyncVsockIo;
34
35    #[cfg(feature = "firecracker")]
36    #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))]
37    type FirecrackerIo = FuturesIo<Async<std::os::unix::net::UnixStream>>;
38
39    #[cfg(feature = "unix")]
40    #[cfg_attr(docsrs, doc(cfg(feature = "unix")))]
41    async fn connect_to_unix_socket(socket_path: &Path) -> Result<Self::UnixIo, std::io::Error> {
42        Ok(FuturesIo::new(
43            Async::<std::os::unix::net::UnixStream>::connect(socket_path).await?,
44        ))
45    }
46
47    #[cfg(feature = "vsock")]
48    #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
49    async fn connect_to_vsock_socket(addr: vsock::VsockAddr) -> Result<Self::VsockIo, std::io::Error> {
50        Ok(AsyncVsockIo::connect(addr).await?)
51    }
52
53    #[cfg(feature = "firecracker")]
54    #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))]
55    async fn connect_to_firecracker_socket(
56        host_socket_path: &Path,
57        guest_port: u32,
58    ) -> Result<Self::FirecrackerIo, std::io::Error> {
59        let mut stream = Async::<std::os::unix::net::UnixStream>::connect(host_socket_path).await?;
60        stream.write_all(format!("CONNECT {guest_port}\n").as_bytes()).await?;
61
62        let mut lines = BufReader::new(&mut stream).lines();
63        match lines.next().await {
64            Some(Ok(line)) => {
65                if !line.starts_with("OK") {
66                    return Err(std::io::Error::new(
67                        std::io::ErrorKind::ConnectionRefused,
68                        "Firecracker refused to establish a tunnel to the given guest port",
69                    ));
70                }
71            }
72            _ => {
73                return Err(std::io::Error::new(
74                    std::io::ErrorKind::InvalidInput,
75                    "Could not read Firecracker response",
76                ))
77            }
78        };
79
80        Ok(FuturesIo::new(stream))
81    }
82}
83
84#[cfg(feature = "vsock")]
85#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
86pub struct AsyncVsockIo(Async<std::fs::File>);
87
88#[cfg(feature = "vsock")]
89#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
90impl AsyncVsockIo {
91    async fn connect(addr: VsockAddr) -> Result<Self, std::io::Error> {
92        let socket = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM, 0) };
93        if socket < 0 {
94            return Err(std::io::Error::last_os_error());
95        }
96
97        if unsafe { libc::fcntl(socket, libc::F_SETFL, libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 {
98            let _ = unsafe { libc::close(socket) };
99            return Err(std::io::Error::last_os_error());
100        }
101
102        if unsafe {
103            libc::connect(
104                socket,
105                &addr as *const _ as *const libc::sockaddr,
106                size_of::<libc::sockaddr_vm>() as libc::socklen_t,
107            )
108        } < 0
109        {
110            let err = std::io::Error::last_os_error();
111            if let Some(os_err) = err.raw_os_error() {
112                if os_err != libc::EINPROGRESS {
113                    let _ = unsafe { libc::close(socket) };
114                    return Err(err);
115                }
116            }
117        }
118
119        let async_fd = Async::new(unsafe { std::fs::File::from_raw_fd(socket) })?;
120
121        loop {
122            let connection_check = async_fd.write_with(|fd| {
123                let mut sock_err: libc::c_int = 0;
124                let mut sock_err_len: libc::socklen_t = size_of::<libc::c_int>() as libc::socklen_t;
125                let err = unsafe {
126                    libc::getsockopt(
127                        fd.as_raw_fd(),
128                        libc::SOL_SOCKET,
129                        libc::SO_ERROR,
130                        &mut sock_err as *mut _ as *mut libc::c_void,
131                        &mut sock_err_len as *mut libc::socklen_t,
132                    )
133                };
134
135                if err < 0 {
136                    return Err(std::io::Error::last_os_error());
137                }
138
139                if sock_err == 0 {
140                    Ok(())
141                } else {
142                    Err(std::io::Error::from_raw_os_error(sock_err))
143                }
144            });
145
146            match connection_check.await {
147                Ok(_) => {
148                    return Ok(AsyncVsockIo(Async::new(unsafe {
149                        std::fs::File::from_raw_fd(async_fd.into_inner()?.into_raw_fd())
150                    })?))
151                }
152                Err(err)
153                    if err.kind() == std::io::ErrorKind::WouldBlock
154                        || err.kind() == std::io::ErrorKind::Interrupted =>
155                {
156                    continue
157                }
158                Err(err) => return Err(err),
159            }
160        }
161    }
162}
163
164#[cfg(feature = "vsock")]
165#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
166impl hyper::rt::Write for AsyncVsockIo {
167    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
168        loop {
169            match self.0.poll_writable(cx) {
170                Poll::Ready(Ok(_)) => {}
171                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
172                Poll::Pending => return Poll::Pending,
173            };
174
175            match self.0.get_ref().write(buf) {
176                Ok(amount) => return Poll::Ready(Ok(amount)),
177                Err(ref err)
178                    if err.kind() == std::io::ErrorKind::Interrupted
179                        || err.kind() == std::io::ErrorKind::WouldBlock =>
180                {
181                    continue
182                }
183                Err(err) => return Poll::Ready(Err(err)),
184            }
185        }
186    }
187
188    fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
189        Poll::Ready(Ok(()))
190    }
191
192    fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
193        Poll::Ready(Ok(()))
194    }
195}
196
197#[cfg(feature = "vsock")]
198#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
199impl hyper::rt::Read for AsyncVsockIo {
200    fn poll_read(
201        self: Pin<&mut Self>,
202        cx: &mut Context<'_>,
203        mut buf: hyper::rt::ReadBufCursor<'_>,
204    ) -> Poll<Result<(), std::io::Error>> {
205        let b;
206        unsafe {
207            b = &mut *(buf.as_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
208        };
209
210        loop {
211            match self.0.poll_readable(cx) {
212                Poll::Ready(Ok(_)) => {}
213                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
214                Poll::Pending => return Poll::Pending,
215            };
216
217            match self.0.get_ref().read(b) {
218                Ok(amount) => {
219                    unsafe {
220                        buf.advance(amount);
221                    }
222
223                    return Poll::Ready(Ok(()));
224                }
225                Err(ref err)
226                    if err.kind() == std::io::ErrorKind::Interrupted
227                        || err.kind() == std::io::ErrorKind::WouldBlock =>
228                {
229                    continue
230                }
231                Err(err) => return Poll::Ready(Err(err)),
232            }
233        }
234    }
235}