hyperlocal_next/
client.rs

1use hex::FromHex;
2use hyper::{body::Body, rt::ReadBufCursor, Uri};
3use hyper_util::{
4    client::legacy::{
5        connect::{Connected, Connection},
6        Client,
7    },
8    rt::{TokioExecutor, TokioIo},
9};
10use pin_project_lite::pin_project;
11use std::{
12    future::Future,
13    io,
14    io::Error,
15    path::{Path, PathBuf},
16    pin::Pin,
17    task::{Context, Poll},
18};
19use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20use tower_service::Service;
21
22pin_project! {
23    #[derive(Debug)]
24    pub struct UnixStream {
25        #[pin]
26        unix_stream: tokio::net::UnixStream,
27    }
28}
29
30impl UnixStream {
31    async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
32        let unix_stream = tokio::net::UnixStream::connect(path).await?;
33        Ok(Self { unix_stream })
34    }
35}
36
37impl AsyncWrite for UnixStream {
38    fn poll_write(
39        self: Pin<&mut Self>,
40        cx: &mut Context<'_>,
41        buf: &[u8],
42    ) -> Poll<Result<usize, io::Error>> {
43        self.project().unix_stream.poll_write(cx, buf)
44    }
45
46    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
47        self.project().unix_stream.poll_flush(cx)
48    }
49
50    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
51        self.project().unix_stream.poll_shutdown(cx)
52    }
53}
54
55impl hyper::rt::Write for UnixStream {
56    fn poll_write(
57        self: Pin<&mut Self>,
58        cx: &mut Context<'_>,
59        buf: &[u8],
60    ) -> Poll<Result<usize, Error>> {
61        self.project().unix_stream.poll_write(cx, buf)
62    }
63
64    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
65        self.project().unix_stream.poll_flush(cx)
66    }
67
68    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
69        self.project().unix_stream.poll_shutdown(cx)
70    }
71}
72
73impl AsyncRead for UnixStream {
74    fn poll_read(
75        self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77        buf: &mut ReadBuf<'_>,
78    ) -> Poll<io::Result<()>> {
79        self.project().unix_stream.poll_read(cx, buf)
80    }
81}
82
83impl hyper::rt::Read for UnixStream {
84    fn poll_read(
85        self: Pin<&mut Self>,
86        cx: &mut Context<'_>,
87        buf: ReadBufCursor<'_>,
88    ) -> Poll<Result<(), Error>> {
89        let mut t = TokioIo::new(self.project().unix_stream);
90        Pin::new(&mut t).poll_read(cx, buf)
91    }
92}
93
94/// the `[UnixConnector]` can be used to construct a `[hyper::Client]` which can
95/// speak to a unix domain socket.
96///
97/// # Example
98/// ```
99/// use http_body_util::Full;
100/// use hyper::body::Bytes;
101/// use hyper_util::client::legacy::Client;
102/// use hyper_util::rt::TokioExecutor;
103/// use hyperlocal::UnixConnector;
104///
105/// let connector = UnixConnector;
106/// let client: Client<UnixConnector, Full<Bytes>> = Client::builder(TokioExecutor::new()).build(connector);
107/// ```
108///
109/// # Note
110/// If you don't need access to the low-level `[hyper::Client]` builder
111/// interface, consider using the `[UnixClientExt]` trait instead.
112#[derive(Clone, Copy, Debug, Default)]
113pub struct UnixConnector;
114
115impl Unpin for UnixConnector {}
116
117impl Service<Uri> for UnixConnector {
118    type Response = UnixStream;
119    type Error = io::Error;
120    #[allow(clippy::type_complexity)]
121    type Future =
122        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
123
124    fn call(&mut self, req: Uri) -> Self::Future {
125        let fut = async move {
126            let path = parse_socket_path(&req)?;
127            UnixStream::connect(path).await
128        };
129
130        Box::pin(fut)
131    }
132
133    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134        Poll::Ready(Ok(()))
135    }
136}
137
138impl Connection for UnixStream {
139    fn connected(&self) -> Connected {
140        Connected::new()
141    }
142}
143
144fn parse_socket_path(uri: &Uri) -> Result<PathBuf, io::Error> {
145    if uri.scheme_str() != Some("unix") {
146        return Err(io::Error::new(
147            io::ErrorKind::InvalidInput,
148            "invalid URL, scheme must be unix",
149        ));
150    }
151
152    if let Some(host) = uri.host() {
153        let bytes = Vec::from_hex(host).map_err(|_| {
154            io::Error::new(
155                io::ErrorKind::InvalidInput,
156                "invalid URL, host must be a hex-encoded path",
157            )
158        })?;
159
160        Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned()))
161    } else {
162        Err(io::Error::new(
163            io::ErrorKind::InvalidInput,
164            "invalid URL, host must be present",
165        ))
166    }
167}
168
169/// Extension trait for constructing a hyper HTTP client over a Unix domain
170/// socket.
171pub trait UnixClientExt<B: Body + Send> {
172    /// Construct a client which speaks HTTP over a Unix domain socket
173    ///
174    /// # Example
175    /// ```
176    /// use http_body_util::Full;
177    /// use hyper::body::Bytes;
178    /// use hyper_util::client::legacy::Client;
179    /// use hyperlocal::{UnixClientExt, UnixConnector};
180    ///
181    /// let client: Client<UnixConnector, Full<Bytes>> = Client::unix();
182    /// ```
183    #[must_use]
184    fn unix() -> Client<UnixConnector, B>
185    where
186        B::Data: Send,
187    {
188        Client::builder(TokioExecutor::new()).build(UnixConnector)
189    }
190}
191
192impl<B: Body + Send> UnixClientExt<B> for Client<UnixConnector, B> {}