hyperlocal_with_windows/
client_windows.rs1use hex::FromHex;
2use hyper::{body::Body, rt::ReadBufCursor, Uri};
3use hyper_util::{
4 client::legacy::{
5 connect::{Connected, Connection},
6 Client,
7 },
8 rt::{TokioExecutor, TokioIo},
9};
10use pin_project_lite::pin_project;
11use std::{
12 future::Future,
13 io,
14 io::Error,
15 path::{Path, PathBuf},
16 pin::Pin,
17 task::{Context, Poll},
18};
19use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20use tower_service::Service;
21
22use crate::windows::convert_unix_stream_to_nb_tcp_stream;
23
24pin_project! {
25 #[derive(Debug)]
27 pub struct UnixStream {
28 #[pin]
29 unix_stream: tokio::net::TcpStream,
30 }
31}
32
33impl UnixStream {
34 async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
35 let path = path.as_ref().to_owned();
36 let unix_stream =
37 tokio::task::spawn_blocking(move || uds_windows::UnixStream::connect(path)).await??;
38 let unix_stream =
39 tokio::net::TcpStream::from_std(convert_unix_stream_to_nb_tcp_stream(unix_stream))?;
40 Ok(Self { unix_stream })
41 }
42}
43
44impl AsyncWrite for UnixStream {
45 fn poll_write(
46 self: Pin<&mut Self>,
47 cx: &mut Context<'_>,
48 buf: &[u8],
49 ) -> Poll<Result<usize, io::Error>> {
50 self.project().unix_stream.poll_write(cx, buf)
51 }
52
53 fn poll_flush(
54 self: Pin<&mut Self>,
55 cx: &mut Context<'_>,
56 ) -> Poll<Result<(), io::Error>> {
57 self.project().unix_stream.poll_flush(cx)
58 }
59
60 fn poll_shutdown(
61 self: Pin<&mut Self>,
62 cx: &mut Context<'_>,
63 ) -> Poll<Result<(), io::Error>> {
64 self.project().unix_stream.poll_shutdown(cx)
65 }
66
67 fn poll_write_vectored(
68 self: Pin<&mut Self>,
69 cx: &mut Context<'_>,
70 bufs: &[io::IoSlice<'_>],
71 ) -> Poll<Result<usize, Error>> {
72 self.project().unix_stream.poll_write_vectored(cx, bufs)
73 }
74
75 fn is_write_vectored(&self) -> bool {
76 self.unix_stream.is_write_vectored()
77 }
78}
79
80impl hyper::rt::Write for UnixStream {
81 fn poll_write(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 buf: &[u8],
85 ) -> Poll<Result<usize, Error>> {
86 self.project().unix_stream.poll_write(cx, buf)
87 }
88
89 fn poll_flush(
90 self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 ) -> Poll<Result<(), Error>> {
93 self.project().unix_stream.poll_flush(cx)
94 }
95
96 fn poll_shutdown(
97 self: Pin<&mut Self>,
98 cx: &mut Context<'_>,
99 ) -> Poll<Result<(), Error>> {
100 self.project().unix_stream.poll_shutdown(cx)
101 }
102}
103
104impl AsyncRead for UnixStream {
105 fn poll_read(
106 self: Pin<&mut Self>,
107 cx: &mut Context<'_>,
108 buf: &mut ReadBuf<'_>,
109 ) -> Poll<io::Result<()>> {
110 self.project().unix_stream.poll_read(cx, buf)
111 }
112}
113
114impl hyper::rt::Read for UnixStream {
115 fn poll_read(
116 self: Pin<&mut Self>,
117 cx: &mut Context<'_>,
118 buf: ReadBufCursor<'_>,
119 ) -> Poll<Result<(), Error>> {
120 let mut t = TokioIo::new(self.project().unix_stream);
121 Pin::new(&mut t).poll_read(cx, buf)
122 }
123}
124
125#[derive(Clone, Copy, Debug, Default)]
144pub struct UnixConnector;
145
146impl Unpin for UnixConnector {}
147
148impl Service<Uri> for UnixConnector {
149 type Response = UnixStream;
150 type Error = io::Error;
151 #[allow(clippy::type_complexity)]
152 type Future =
153 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
154
155 fn call(
156 &mut self,
157 req: Uri,
158 ) -> Self::Future {
159 let fut = async move {
160 let path = parse_socket_path(&req)?;
161 UnixStream::connect(path).await
162 };
163
164 Box::pin(fut)
165 }
166
167 fn poll_ready(
168 &mut self,
169 _cx: &mut Context<'_>,
170 ) -> Poll<Result<(), Self::Error>> {
171 Poll::Ready(Ok(()))
172 }
173}
174
175impl Connection for UnixStream {
176 fn connected(&self) -> Connected {
177 Connected::new()
178 }
179}
180
181fn parse_socket_path(uri: &Uri) -> Result<PathBuf, io::Error> {
182 if uri.scheme_str() != Some("unix") {
183 return Err(io::Error::new(
184 io::ErrorKind::InvalidInput,
185 "invalid URL, scheme must be unix",
186 ));
187 }
188
189 if let Some(host) = uri.host() {
190 let bytes = Vec::from_hex(host).map_err(|_| {
191 io::Error::new(
192 io::ErrorKind::InvalidInput,
193 "invalid URL, host must be a hex-encoded path",
194 )
195 })?;
196
197 Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned()))
198 } else {
199 Err(io::Error::new(
200 io::ErrorKind::InvalidInput,
201 "invalid URL, host must be present",
202 ))
203 }
204}
205
206pub trait UnixClientExt<B: Body + Send> {
209 #[must_use]
221 fn unix() -> Client<UnixConnector, B>
222 where
223 B::Data: Send,
224 {
225 Client::builder(TokioExecutor::new()).build(UnixConnector)
226 }
227}
228
229impl<B: Body + Send> UnixClientExt<B> for Client<UnixConnector, B> {}