hyperlocal_with_windows/
client.rs

1use hex::FromHex;
2use hyper::{body::Body, rt::ReadBufCursor, Uri};
3use hyper_util::{
4    client::legacy::{
5        connect::{Connected, Connection},
6        Client,
7    },
8    rt::{TokioExecutor, TokioIo},
9};
10use pin_project_lite::pin_project;
11use std::{
12    future::Future,
13    io,
14    io::Error,
15    path::{Path, PathBuf},
16    pin::Pin,
17    task::{Context, Poll},
18};
19use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20use tower_service::Service;
21
22pin_project! {
23    /// Wrapper around [`tokio::net::UnixStream`].
24    #[derive(Debug)]
25    pub struct UnixStream {
26        #[pin]
27        unix_stream: tokio::net::UnixStream,
28    }
29}
30
31impl UnixStream {
32    async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
33        let unix_stream = tokio::net::UnixStream::connect(path).await?;
34        Ok(Self { unix_stream })
35    }
36}
37
38impl AsyncWrite for UnixStream {
39    fn poll_write(
40        self: Pin<&mut Self>,
41        cx: &mut Context<'_>,
42        buf: &[u8],
43    ) -> Poll<Result<usize, io::Error>> {
44        self.project().unix_stream.poll_write(cx, buf)
45    }
46
47    fn poll_flush(
48        self: Pin<&mut Self>,
49        cx: &mut Context<'_>,
50    ) -> Poll<Result<(), io::Error>> {
51        self.project().unix_stream.poll_flush(cx)
52    }
53
54    fn poll_shutdown(
55        self: Pin<&mut Self>,
56        cx: &mut Context<'_>,
57    ) -> Poll<Result<(), io::Error>> {
58        self.project().unix_stream.poll_shutdown(cx)
59    }
60
61    fn poll_write_vectored(
62        self: Pin<&mut Self>,
63        cx: &mut Context<'_>,
64        bufs: &[io::IoSlice<'_>],
65    ) -> Poll<Result<usize, Error>> {
66        self.project().unix_stream.poll_write_vectored(cx, bufs)
67    }
68
69    fn is_write_vectored(&self) -> bool {
70        self.unix_stream.is_write_vectored()
71    }
72}
73
74impl hyper::rt::Write for UnixStream {
75    fn poll_write(
76        self: Pin<&mut Self>,
77        cx: &mut Context<'_>,
78        buf: &[u8],
79    ) -> Poll<Result<usize, Error>> {
80        self.project().unix_stream.poll_write(cx, buf)
81    }
82
83    fn poll_flush(
84        self: Pin<&mut Self>,
85        cx: &mut Context<'_>,
86    ) -> Poll<Result<(), Error>> {
87        self.project().unix_stream.poll_flush(cx)
88    }
89
90    fn poll_shutdown(
91        self: Pin<&mut Self>,
92        cx: &mut Context<'_>,
93    ) -> Poll<Result<(), Error>> {
94        self.project().unix_stream.poll_shutdown(cx)
95    }
96}
97
98impl AsyncRead for UnixStream {
99    fn poll_read(
100        self: Pin<&mut Self>,
101        cx: &mut Context<'_>,
102        buf: &mut ReadBuf<'_>,
103    ) -> Poll<io::Result<()>> {
104        self.project().unix_stream.poll_read(cx, buf)
105    }
106}
107
108impl hyper::rt::Read for UnixStream {
109    fn poll_read(
110        self: Pin<&mut Self>,
111        cx: &mut Context<'_>,
112        buf: ReadBufCursor<'_>,
113    ) -> Poll<Result<(), Error>> {
114        let mut t = TokioIo::new(self.project().unix_stream);
115        Pin::new(&mut t).poll_read(cx, buf)
116    }
117}
118
119/// the `[UnixConnector]` can be used to construct a `[hyper::Client]` which can
120/// speak to a unix domain socket.
121///
122/// # Example
123/// ```
124/// use http_body_util::Full;
125/// use hyper::body::Bytes;
126/// use hyper_util::{client::legacy::Client, rt::TokioExecutor};
127/// use hyperlocal_with_windows::UnixConnector;
128///
129/// let connector = UnixConnector;
130/// let client: Client<UnixConnector, Full<Bytes>> =
131///     Client::builder(TokioExecutor::new()).build(connector);
132/// ```
133///
134/// # Note
135/// If you don't need access to the low-level `[hyper::Client]` builder
136/// interface, consider using the `[UnixClientExt]` trait instead.
137#[derive(Clone, Copy, Debug, Default)]
138pub struct UnixConnector;
139
140impl Unpin for UnixConnector {}
141
142impl Service<Uri> for UnixConnector {
143    type Response = UnixStream;
144    type Error = io::Error;
145    #[allow(clippy::type_complexity)]
146    type Future =
147        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
148
149    fn call(
150        &mut self,
151        req: Uri,
152    ) -> Self::Future {
153        let fut = async move {
154            let path = parse_socket_path(&req)?;
155            UnixStream::connect(path).await
156        };
157
158        Box::pin(fut)
159    }
160
161    fn poll_ready(
162        &mut self,
163        _cx: &mut Context<'_>,
164    ) -> Poll<Result<(), Self::Error>> {
165        Poll::Ready(Ok(()))
166    }
167}
168
169impl Connection for UnixStream {
170    fn connected(&self) -> Connected {
171        Connected::new()
172    }
173}
174
175fn parse_socket_path(uri: &Uri) -> Result<PathBuf, io::Error> {
176    if uri.scheme_str() != Some("unix") {
177        return Err(io::Error::new(
178            io::ErrorKind::InvalidInput,
179            "invalid URL, scheme must be unix",
180        ));
181    }
182
183    if let Some(host) = uri.host() {
184        let bytes = Vec::from_hex(host).map_err(|_| {
185            io::Error::new(
186                io::ErrorKind::InvalidInput,
187                "invalid URL, host must be a hex-encoded path",
188            )
189        })?;
190
191        Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned()))
192    } else {
193        Err(io::Error::new(
194            io::ErrorKind::InvalidInput,
195            "invalid URL, host must be present",
196        ))
197    }
198}
199
200/// Extension trait for constructing a hyper HTTP client over a Unix domain
201/// socket.
202pub trait UnixClientExt<B: Body + Send> {
203    /// Construct a client which speaks HTTP over a Unix domain socket
204    ///
205    /// # Example
206    /// ```
207    /// use http_body_util::Full;
208    /// use hyper::body::Bytes;
209    /// use hyper_util::client::legacy::Client;
210    /// use hyperlocal_with_windows::{UnixClientExt, UnixConnector};
211    ///
212    /// let client: Client<UnixConnector, Full<Bytes>> = Client::unix();
213    /// ```
214    #[must_use]
215    fn unix() -> Client<UnixConnector, B>
216    where
217        B::Data: Send,
218    {
219        Client::builder(TokioExecutor::new()).build(UnixConnector)
220    }
221}
222
223impl<B: Body + Send> UnixClientExt<B> for Client<UnixConnector, B> {}