hyper_client_sockets/
tokio.rs1#[cfg(feature = "vsock")]
2use std::{
3 io::{Read, Write},
4 os::fd::{AsRawFd, FromRawFd, IntoRawFd, OwnedFd},
5 task::Poll,
6};
7#[cfg(feature = "vsock")]
8use std::{pin::Pin, task::Context};
9
10#[cfg(any(feature = "unix", feature = "firecracker"))]
11use hyper_util::rt::TokioIo;
12#[cfg(any(feature = "unix", feature = "firecracker"))]
13use std::path::Path;
14#[cfg(any(feature = "unix", feature = "firecracker"))]
15use tokio::net::UnixStream;
16
17#[cfg(feature = "firecracker")]
18use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
19
20#[cfg(feature = "vsock")]
21use tokio::io::unix::AsyncFd;
22
23use crate::Backend;
24
25#[derive(Debug, Clone)]
27pub struct TokioBackend;
28
29impl Backend for TokioBackend {
30 #[cfg(feature = "unix")]
31 #[cfg_attr(docsrs, doc(cfg(feature = "unix")))]
32 type UnixIo = TokioIo<UnixStream>;
33
34 #[cfg(feature = "vsock")]
35 #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
36 type VsockIo = TokioVsockIo;
37
38 #[cfg(feature = "firecracker")]
39 #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))]
40 type FirecrackerIo = TokioIo<UnixStream>;
41
42 #[cfg(feature = "unix")]
43 #[cfg_attr(docsrs, doc(cfg(feature = "unix")))]
44 async fn connect_to_unix_socket(socket_path: &Path) -> Result<Self::UnixIo, std::io::Error> {
45 Ok(TokioIo::new(UnixStream::connect(socket_path).await?))
46 }
47
48 #[cfg(feature = "vsock")]
49 #[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
50 async fn connect_to_vsock_socket(addr: vsock::VsockAddr) -> Result<Self::VsockIo, std::io::Error> {
51 TokioVsockIo::connect(addr).await
52 }
53
54 #[cfg(feature = "firecracker")]
55 #[cfg_attr(docsrs, doc(cfg(feature = "firecracker")))]
56 async fn connect_to_firecracker_socket(
57 host_socket_path: &Path,
58 guest_port: u32,
59 ) -> Result<Self::FirecrackerIo, std::io::Error> {
60 let mut stream = UnixStream::connect(host_socket_path).await?;
61 stream.write_all(format!("CONNECT {guest_port}\n").as_bytes()).await?;
62
63 let mut lines = BufReader::new(&mut stream).lines();
64 match lines.next_line().await {
65 Ok(Some(line)) => {
66 if !line.starts_with("OK") {
67 return Err(std::io::Error::new(
68 std::io::ErrorKind::ConnectionRefused,
69 "Firecracker refused to establish a tunnel to the given guest port",
70 ));
71 }
72 }
73 _ => {
74 return Err(std::io::Error::new(
75 std::io::ErrorKind::InvalidInput,
76 "Could not read Firecracker response",
77 ))
78 }
79 };
80
81 Ok(TokioIo::new(stream))
82 }
83}
84
85#[cfg(feature = "vsock")]
88#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
89pub struct TokioVsockIo(AsyncFd<vsock::VsockStream>);
90
91#[cfg(feature = "vsock")]
92#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
93impl TokioVsockIo {
94 async fn connect(addr: vsock::VsockAddr) -> Result<Self, std::io::Error> {
95 let socket = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM, 0) };
96 if socket < 0 {
97 return Err(std::io::Error::last_os_error());
98 }
99
100 if unsafe { libc::fcntl(socket, libc::F_SETFL, libc::O_NONBLOCK | libc::O_CLOEXEC) } < 0 {
101 let _ = unsafe { libc::close(socket) };
102 return Err(std::io::Error::last_os_error());
103 }
104
105 if unsafe {
106 libc::connect(
107 socket,
108 &addr as *const _ as *const libc::sockaddr,
109 size_of::<libc::sockaddr_vm>() as libc::socklen_t,
110 )
111 } < 0
112 {
113 let err = std::io::Error::last_os_error();
114 if let Some(os_err) = err.raw_os_error() {
115 if os_err != libc::EINPROGRESS {
116 let _ = unsafe { libc::close(socket) };
117 return Err(err);
118 }
119 }
120 }
121
122 let async_fd = AsyncFd::new(unsafe { OwnedFd::from_raw_fd(socket) })?;
123
124 loop {
125 let mut guard = async_fd.writable().await?;
126
127 let connection_check = guard.try_io(|fd| {
128 let mut sock_err: libc::c_int = 0;
129 let mut sock_err_len: libc::socklen_t = size_of::<libc::c_int>() as libc::socklen_t;
130 let err = unsafe {
131 libc::getsockopt(
132 fd.as_raw_fd(),
133 libc::SOL_SOCKET,
134 libc::SO_ERROR,
135 &mut sock_err as *mut _ as *mut libc::c_void,
136 &mut sock_err_len as *mut libc::socklen_t,
137 )
138 };
139
140 if err < 0 {
141 return Err(std::io::Error::last_os_error());
142 }
143
144 if sock_err == 0 {
145 Ok(())
146 } else {
147 Err(std::io::Error::from_raw_os_error(sock_err))
148 }
149 });
150
151 match connection_check {
152 Ok(Ok(_)) => {
153 return Ok(TokioVsockIo(AsyncFd::new(unsafe {
154 vsock::VsockStream::from_raw_fd(async_fd.into_inner().into_raw_fd())
155 })?))
156 }
157 Ok(Err(err)) => return Err(err),
158 Err(_would_block) => continue,
159 }
160 }
161 }
162}
163
164#[cfg(feature = "vsock")]
165#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
166impl hyper::rt::Write for TokioVsockIo {
167 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, std::io::Error>> {
168 loop {
169 let mut guard = match self.0.poll_write_ready(cx) {
170 Poll::Ready(Ok(guard)) => guard,
171 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
172 Poll::Pending => return Poll::Pending,
173 };
174
175 match guard.try_io(|inner| inner.get_ref().write(buf)) {
176 Ok(Ok(amount)) => return Ok(amount).into(),
177 Ok(Err(ref err)) if err.kind() == std::io::ErrorKind::Interrupted => continue,
178 Ok(Err(err)) => return Err(err).into(),
179 Err(_would_block) => continue,
180 }
181 }
182 }
183
184 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
185 Poll::Ready(Ok(()))
186 }
187
188 fn poll_shutdown(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
189 Poll::Ready(Ok(()))
190 }
191}
192
193#[cfg(feature = "vsock")]
194#[cfg_attr(docsrs, doc(cfg(feature = "vsock")))]
195impl hyper::rt::Read for TokioVsockIo {
196 fn poll_read(
197 self: Pin<&mut Self>,
198 cx: &mut Context<'_>,
199 mut buf: hyper::rt::ReadBufCursor<'_>,
200 ) -> Poll<Result<(), std::io::Error>> {
201 let b;
202 unsafe {
203 b = &mut *(buf.as_mut() as *mut [std::mem::MaybeUninit<u8>] as *mut [u8]);
204 };
205
206 loop {
207 let mut guard = match self.0.poll_read_ready(cx) {
208 Poll::Ready(Ok(guard)) => guard,
209 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
210 Poll::Pending => return Poll::Pending,
211 };
212
213 match guard.try_io(|inner| inner.get_ref().read(b)) {
214 Ok(Ok(amount)) => {
215 unsafe {
216 buf.advance(amount);
217 }
218
219 return Ok(()).into();
220 }
221 Ok(Err(ref err)) if err.kind() == std::io::ErrorKind::Interrupted => continue,
222 Ok(Err(err)) => return Err(err).into(),
223 Err(_would_block) => {
224 continue;
225 }
226 }
227 }
228 }
229}