hyper_client_sockets/
async_io.rs1#[cfg(any(feature = "unix", feature = "firecracker"))]
2use std::path::Path;
3#[cfg(feature = "vsock")]
4use std::{
5 io::{Read, Write},
6 os::fd::{AsRawFd, FromRawFd, IntoRawFd},
7 pin::Pin,
8 task::{Context, Poll},
9};
10
11#[cfg(any(feature = "unix", feature = "vsock", feature = "firecracker"))]
12use async_io::Async;
13#[cfg(feature = "firecracker")]
14use futures_lite::{io::BufReader, AsyncBufReadExt, AsyncWriteExt, StreamExt};
15#[cfg(any(feature = "unix", feature = "firecracker"))]
16use smol_hyper::rt::FuturesIo;
17#[cfg(feature = "vsock")]
18use vsock::VsockAddr;
19
20use crate::Backend;
21
22#[derive(Debug, Clone)]
24pub struct AsyncIoBackend;
25
26impl Backend for AsyncIoBackend {
27 #[cfg(feature = "unix")]
28 #[cfg_attr(docsrs, doc(cfg(feature = "unix")))]
29 type UnixIo = FuturesIo<Async<std::os::unix::net::UnixStream>>;
30
31 #[cfg(feature = "vsock")]
32 #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
33 type VsockIo = AsyncVsockIo;
34
35 #[cfg(feature = "firecracker")]
36 #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))]
37 type FirecrackerIo = FuturesIo<Async<std::os::unix::net::UnixStream>>;
38
39 #[cfg(feature = "unix")]
40 #[cfg_attr(docsrs, doc(cfg(feature = "unix")))]
41 async fn connect_to_unix_socket(socket_path: &Path) -> Result<Self::UnixIo, std::io::Error> {
42 Ok(FuturesIo::new(
43 Async::<std::os::unix::net::UnixStream>::connect(socket_path).await?,
44 ))
45 }
46
47 #[cfg(feature = "vsock")]
48 #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
49 async fn connect_to_vsock_socket(addr: vsock::VsockAddr) -> Result<Self::VsockIo, std::io::Error> {
50 Ok(AsyncVsockIo::connect(addr).await?)
51 }
52
53 #[cfg(feature = "firecracker")]
54 #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))]
55 async fn connect_to_firecracker_socket(
56 host_socket_path: &Path,
57 guest_port: u32,
58 ) -> Result<Self::FirecrackerIo, std::io::Error> {
59 let mut stream = Async::<std::os::unix::net::UnixStream>::connect(host_socket_path).await?;
60 stream.write_all(format!("CONNECT {guest_port}\n").as_bytes()).await?;
61
62 let mut lines = BufReader::new(&mut stream).lines();
63 match lines.next().await {
64 Some(Ok(line)) => {
65 if !line.starts_with("OK") {
66 return Err(std::io::Error::new(
67 std::io::ErrorKind::ConnectionRefused,
68 "Firecracker refused to establish a tunnel to the given guest port",
69 ));
70 }
71 }
72 _ => {
73 return Err(std::io::Error::new(
74 std::io::ErrorKind::InvalidInput,
75 "Could not read Firecracker response",
76 ))
77 }
78 };
79
80 Ok(FuturesIo::new(stream))
81 }
82}
83
84#[cfg(feature = "vsock")]
85#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
86pub struct AsyncVsockIo(Async<std::fs::File>);
87
88#[cfg(feature = "vsock")]
89#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
90impl AsyncVsockIo {
91 async fn connect(addr: VsockAddr) -> Result<Self, std::io::Error> {
92 let socket = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM, 0) };
93 if socket < 0 {
94 return Err(std::io::Error::last_os_error());
95 }
96
97 if unsafe { libc::fcntl(socket, libc::F_SETFL, libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 {
98 let _ = unsafe { libc::close(socket) };
99 return Err(std::io::Error::last_os_error());
100 }
101
102 if unsafe {
103 libc::connect(
104 socket,
105 &addr as *const _ as *const libc::sockaddr,
106 size_of::<libc::sockaddr_vm>() as libc::socklen_t,
107 )
108 } < 0
109 {
110 let err = std::io::Error::last_os_error();
111 if let Some(os_err) = err.raw_os_error() {
112 if os_err != libc::EINPROGRESS {
113 let _ = unsafe { libc::close(socket) };
114 return Err(err);
115 }
116 }
117 }
118
119 let async_fd = Async::new(unsafe { std::fs::File::from_raw_fd(socket) })?;
120
121 loop {
122 let connection_check = async_fd.write_with(|fd| {
123 let mut sock_err: libc::c_int = 0;
124 let mut sock_err_len: libc::socklen_t = size_of::<libc::c_int>() as libc::socklen_t;
125 let err = unsafe {
126 libc::getsockopt(
127 fd.as_raw_fd(),
128 libc::SOL_SOCKET,
129 libc::SO_ERROR,
130 &mut sock_err as *mut _ as *mut libc::c_void,
131 &mut sock_err_len as *mut libc::socklen_t,
132 )
133 };
134
135 if err < 0 {
136 return Err(std::io::Error::last_os_error());
137 }
138
139 if sock_err == 0 {
140 Ok(())
141 } else {
142 Err(std::io::Error::from_raw_os_error(sock_err))
143 }
144 });
145
146 match connection_check.await {
147 Ok(_) => {
148 return Ok(AsyncVsockIo(Async::new(unsafe {
149 std::fs::File::from_raw_fd(async_fd.into_inner()?.into_raw_fd())
150 })?))
151 }
152 Err(err)
153 if err.kind() == std::io::ErrorKind::WouldBlock
154 || err.kind() == std::io::ErrorKind::Interrupted =>
155 {
156 continue
157 }
158 Err(err) => return Err(err),
159 }
160 }
161 }
162}
163
164#[cfg(feature = "vsock")]
165#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
166impl hyper::rt::Write for AsyncVsockIo {
167 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
168 loop {
169 match self.0.poll_writable(cx) {
170 Poll::Ready(Ok(_)) => {}
171 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
172 Poll::Pending => return Poll::Pending,
173 };
174
175 match self.0.get_ref().write(buf) {
176 Ok(amount) => return Poll::Ready(Ok(amount)),
177 Err(ref err)
178 if err.kind() == std::io::ErrorKind::Interrupted
179 || err.kind() == std::io::ErrorKind::WouldBlock =>
180 {
181 continue
182 }
183 Err(err) => return Poll::Ready(Err(err)),
184 }
185 }
186 }
187
188 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
189 Poll::Ready(Ok(()))
190 }
191
192 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
193 Poll::Ready(Ok(()))
194 }
195}
196
197#[cfg(feature = "vsock")]
198#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
199impl hyper::rt::Read for AsyncVsockIo {
200 fn poll_read(
201 self: Pin<&mut Self>,
202 cx: &mut Context<'_>,
203 mut buf: hyper::rt::ReadBufCursor<'_>,
204 ) -> Poll<Result<(), std::io::Error>> {
205 let b;
206 unsafe {
207 b = &mut *(buf.as_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
208 };
209
210 loop {
211 match self.0.poll_readable(cx) {
212 Poll::Ready(Ok(_)) => {}
213 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
214 Poll::Pending => return Poll::Pending,
215 };
216
217 match self.0.get_ref().read(b) {
218 Ok(amount) => {
219 unsafe {
220 buf.advance(amount);
221 }
222
223 return Poll::Ready(Ok(()));
224 }
225 Err(ref err)
226 if err.kind() == std::io::ErrorKind::Interrupted
227 || err.kind() == std::io::ErrorKind::WouldBlock =>
228 {
229 continue
230 }
231 Err(err) => return Poll::Ready(Err(err)),
232 }
233 }
234 }
235}