hyperlocal_with_windows/
client_windows.rs

1use hex::FromHex;
2use hyper::{body::Body, rt::ReadBufCursor, Uri};
3use hyper_util::{
4    client::legacy::{
5        connect::{Connected, Connection},
6        Client,
7    },
8    rt::{TokioExecutor, TokioIo},
9};
10use pin_project_lite::pin_project;
11use std::{
12    future::Future,
13    io,
14    io::Error,
15    path::{Path, PathBuf},
16    pin::Pin,
17    task::{Context, Poll},
18};
19use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20use tower_service::Service;
21
22use crate::windows::convert_unix_stream_to_nb_tcp_stream;
23
24pin_project! {
25    /// Wrapper around [`tokio::net::TcpStream`].
26    #[derive(Debug)]
27    pub struct UnixStream {
28        #[pin]
29        unix_stream: tokio::net::TcpStream,
30    }
31}
32
33impl UnixStream {
34    async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
35        let path = path.as_ref().to_owned();
36        let unix_stream =
37            tokio::task::spawn_blocking(move || uds_windows::UnixStream::connect(path)).await??;
38        let unix_stream =
39            tokio::net::TcpStream::from_std(convert_unix_stream_to_nb_tcp_stream(unix_stream))?;
40        Ok(Self { unix_stream })
41    }
42}
43
44impl AsyncWrite for UnixStream {
45    fn poll_write(
46        self: Pin<&mut Self>,
47        cx: &mut Context<'_>,
48        buf: &[u8],
49    ) -> Poll<Result<usize, io::Error>> {
50        self.project().unix_stream.poll_write(cx, buf)
51    }
52
53    fn poll_flush(
54        self: Pin<&mut Self>,
55        cx: &mut Context<'_>,
56    ) -> Poll<Result<(), io::Error>> {
57        self.project().unix_stream.poll_flush(cx)
58    }
59
60    fn poll_shutdown(
61        self: Pin<&mut Self>,
62        cx: &mut Context<'_>,
63    ) -> Poll<Result<(), io::Error>> {
64        self.project().unix_stream.poll_shutdown(cx)
65    }
66
67    fn poll_write_vectored(
68        self: Pin<&mut Self>,
69        cx: &mut Context<'_>,
70        bufs: &[io::IoSlice<'_>],
71    ) -> Poll<Result<usize, Error>> {
72        self.project().unix_stream.poll_write_vectored(cx, bufs)
73    }
74
75    fn is_write_vectored(&self) -> bool {
76        self.unix_stream.is_write_vectored()
77    }
78}
79
80impl hyper::rt::Write for UnixStream {
81    fn poll_write(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84        buf: &[u8],
85    ) -> Poll<Result<usize, Error>> {
86        self.project().unix_stream.poll_write(cx, buf)
87    }
88
89    fn poll_flush(
90        self: Pin<&mut Self>,
91        cx: &mut Context<'_>,
92    ) -> Poll<Result<(), Error>> {
93        self.project().unix_stream.poll_flush(cx)
94    }
95
96    fn poll_shutdown(
97        self: Pin<&mut Self>,
98        cx: &mut Context<'_>,
99    ) -> Poll<Result<(), Error>> {
100        self.project().unix_stream.poll_shutdown(cx)
101    }
102}
103
104impl AsyncRead for UnixStream {
105    fn poll_read(
106        self: Pin<&mut Self>,
107        cx: &mut Context<'_>,
108        buf: &mut ReadBuf<'_>,
109    ) -> Poll<io::Result<()>> {
110        self.project().unix_stream.poll_read(cx, buf)
111    }
112}
113
114impl hyper::rt::Read for UnixStream {
115    fn poll_read(
116        self: Pin<&mut Self>,
117        cx: &mut Context<'_>,
118        buf: ReadBufCursor<'_>,
119    ) -> Poll<Result<(), Error>> {
120        let mut t = TokioIo::new(self.project().unix_stream);
121        Pin::new(&mut t).poll_read(cx, buf)
122    }
123}
124
125/// the `[UnixConnector]` can be used to construct a `[hyper::Client]` which can
126/// speak to a unix domain socket.
127///
128/// # Example
129/// ```
130/// use http_body_util::Full;
131/// use hyper::body::Bytes;
132/// use hyper_util::{client::legacy::Client, rt::TokioExecutor};
133/// use hyperlocal_with_windows::UnixConnector;
134///
135/// let connector = UnixConnector;
136/// let client: Client<UnixConnector, Full<Bytes>> =
137///     Client::builder(TokioExecutor::new()).build(connector);
138/// ```
139///
140/// # Note
141/// If you don't need access to the low-level `[hyper::Client]` builder
142/// interface, consider using the `[UnixClientExt]` trait instead.
143#[derive(Clone, Copy, Debug, Default)]
144pub struct UnixConnector;
145
146impl Unpin for UnixConnector {}
147
148impl Service<Uri> for UnixConnector {
149    type Response = UnixStream;
150    type Error = io::Error;
151    #[allow(clippy::type_complexity)]
152    type Future =
153        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
154
155    fn call(
156        &mut self,
157        req: Uri,
158    ) -> Self::Future {
159        let fut = async move {
160            let path = parse_socket_path(&req)?;
161            UnixStream::connect(path).await
162        };
163
164        Box::pin(fut)
165    }
166
167    fn poll_ready(
168        &mut self,
169        _cx: &mut Context<'_>,
170    ) -> Poll<Result<(), Self::Error>> {
171        Poll::Ready(Ok(()))
172    }
173}
174
175impl Connection for UnixStream {
176    fn connected(&self) -> Connected {
177        Connected::new()
178    }
179}
180
181fn parse_socket_path(uri: &Uri) -> Result<PathBuf, io::Error> {
182    if uri.scheme_str() != Some("unix") {
183        return Err(io::Error::new(
184            io::ErrorKind::InvalidInput,
185            "invalid URL, scheme must be unix",
186        ));
187    }
188
189    if let Some(host) = uri.host() {
190        let bytes = Vec::from_hex(host).map_err(|_| {
191            io::Error::new(
192                io::ErrorKind::InvalidInput,
193                "invalid URL, host must be a hex-encoded path",
194            )
195        })?;
196
197        Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned()))
198    } else {
199        Err(io::Error::new(
200            io::ErrorKind::InvalidInput,
201            "invalid URL, host must be present",
202        ))
203    }
204}
205
206/// Extension trait for constructing a hyper HTTP client over a Unix domain
207/// socket.
208pub trait UnixClientExt<B: Body + Send> {
209    /// Construct a client which speaks HTTP over a Unix domain socket
210    ///
211    /// # Example
212    /// ```
213    /// use http_body_util::Full;
214    /// use hyper::body::Bytes;
215    /// use hyper_util::client::legacy::Client;
216    /// use hyperlocal_with_windows::{UnixClientExt, UnixConnector};
217    ///
218    /// let client: Client<UnixConnector, Full<Bytes>> = Client::unix();
219    /// ```
220    #[must_use]
221    fn unix() -> Client<UnixConnector, B>
222    where
223        B::Data: Send,
224    {
225        Client::builder(TokioExecutor::new()).build(UnixConnector)
226    }
227}
228
229impl<B: Body + Send> UnixClientExt<B> for Client<UnixConnector, B> {}