hyper_unix_connector/
lib.rs

1//! Connect hyper servers and clients to Unix-domain sockets.
2//!
3//! Most of this crate's functionality is borrowed from [hyperlocal](https://github.com/softprops/hyperlocal).
4//! This crate supports async/await, while hyperlocal does not (yet).
5//!
6//! See [`UnixClient`] and [`UnixConnector`] for examples.
7
8use anyhow::{anyhow, Error};
9use core::{
10    pin::Pin,
11    task::{Context, Poll},
12};
13use hex::FromHex;
14use pin_project::pin_project;
15use std::borrow::Cow;
16use std::future::Future;
17use std::path::Path;
18use tokio::io::ReadBuf;
19
20/// A type which implements `Into` for hyper's  [`hyper::Uri`] type
21/// targetting unix domain sockets.
22///
23/// You can use this with any of
24/// the HTTP factory methods on hyper's Client interface
25/// and for creating requests.
26///
27/// ```no_run
28/// extern crate hyper;
29/// extern crate hyper_unix_connector;
30///
31/// let url: hyper::Uri = hyper_unix_connector::Uri::new(
32///   "/path/to/socket", "/urlpath?key=value"
33///  ).into();
34///  let req = hyper::Request::get(url).body(()).unwrap();
35/// ```
36#[derive(Debug)]
37pub struct Uri<'a> {
38    /// url path including leading slash, path, and query string
39    encoded: Cow<'a, str>,
40}
41
42impl<'a> Into<hyper::Uri> for Uri<'a> {
43    fn into(self) -> hyper::Uri {
44        self.encoded.as_ref().parse().unwrap()
45    }
46}
47
48impl<'a> Uri<'a> {
49    /// Productes a new `Uri` from path to domain socket and request path.
50    /// request path should include a leading slash
51    pub fn new<P>(socket: P, path: &'a str) -> Self
52    where
53        P: AsRef<Path>,
54    {
55        let host = hex::encode(socket.as_ref().to_string_lossy().as_bytes());
56        let host_str = format!("unix://{}:0{}", host, path);
57        Uri {
58            encoded: Cow::Owned(host_str),
59        }
60    }
61
62    // fixme: would like to just use hyper::Result and hyper::error::UriError here
63    // but UriError its not exposed for external use
64    fn socket_path(uri: &hyper::Uri) -> Option<String> {
65        uri.host()
66            .iter()
67            .filter_map(|host| {
68                Vec::from_hex(host)
69                    .ok()
70                    .map(|raw| String::from_utf8_lossy(&raw).into_owned())
71            })
72            .next()
73    }
74}
75
76/// Wrapper around [`tokio::net::UnixListener`] that works with [`hyper`] servers.
77///
78/// Useful for making [`hyper`] servers listen on Unix sockets. For the client side, see
79/// [`UnixClient`].
80///
81/// # Example
82/// ```rust
83/// # std::fs::remove_file("./my-unix-socket").unwrap_or_else(|_| ());
84/// # let mut rt = tokio::runtime::Runtime::new().unwrap();
85/// # rt.block_on(async {
86/// use hyper::service::{make_service_fn, service_fn};
87/// use hyper::{Body, Error, Response, Server};
88/// use hyper_unix_connector::UnixConnector;
89///
90/// let uc: UnixConnector = tokio::net::UnixListener::bind("./my-unix-socket")
91///     .unwrap()
92///     .into();
93/// Server::builder(uc).serve(make_service_fn(|_| {
94///     async move {
95///         Ok::<_, Error>(service_fn(|_| {
96///             async move { Ok::<_, Error>(Response::new(Body::from("Hello, World"))) }
97///         }))
98///     }
99/// }));
100/// # });
101/// # std::fs::remove_file("./my-unix-socket").unwrap_or_else(|_| ());
102/// ```
103#[derive(Debug)]
104pub struct UnixConnector(tokio::net::UnixListener);
105
106impl From<tokio::net::UnixListener> for UnixConnector {
107    fn from(u: tokio::net::UnixListener) -> Self {
108        UnixConnector(u)
109    }
110}
111
112impl Into<tokio::net::UnixListener> for UnixConnector {
113    fn into(self) -> tokio::net::UnixListener {
114        self.0
115    }
116}
117
118impl hyper::server::accept::Accept for UnixConnector {
119    type Conn = tokio::net::UnixStream;
120    type Error = Error;
121
122    fn poll_accept(
123        self: Pin<&mut Self>,
124        cx: &mut Context,
125    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
126        self.0
127            .poll_accept(cx)
128            .map_ok(|(stream, _addr)| stream)
129            .map_err(|e| e.into())
130            .map(Some)
131    }
132}
133
134/// Newtype for [`tokio::net::UnixStream`] so that it can work with hyper's `Client`.
135#[pin_project]
136#[derive(Debug)]
137pub struct UDS(#[pin] tokio::net::UnixStream);
138
139impl From<tokio::net::UnixStream> for UDS {
140    fn from(f: tokio::net::UnixStream) -> Self {
141        Self(f)
142    }
143}
144
145impl Into<tokio::net::UnixStream> for UDS {
146    fn into(self) -> tokio::net::UnixStream {
147        self.0
148    }
149}
150
151macro_rules! conn_impl_fn {
152    ($fn: ident |$first_var: ident: $first_typ: ty, $($var: ident: $typ: ty),*| -> $ret: ty ;;) => {
153        fn $fn ($first_var: $first_typ, $( $var: $typ ),* ) -> $ret {
154            let ux: Pin<&mut tokio::net::UnixStream> = $first_var.project().0;
155            ux.$fn($($var),*)
156        }
157    };
158}
159
160impl tokio::io::AsyncRead for UDS {
161    conn_impl_fn!(poll_read |self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>| -> Poll<std::io::Result<()>> ;;);
162}
163
164impl tokio::io::AsyncWrite for UDS {
165    conn_impl_fn!(poll_write    |self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]| -> Poll<std::io::Result<usize>> ;;);
166    conn_impl_fn!(poll_flush    |self: Pin<&mut Self>, cx: &mut Context<'_>| -> Poll<std::io::Result<()>> ;;);
167    conn_impl_fn!(poll_shutdown |self: Pin<&mut Self>, cx: &mut Context<'_>| -> Poll<std::io::Result<()>> ;;);
168}
169
170/// Converts [`Uri`] to [`tokio::net::UnixStream`].
171///
172/// Useful for making [`hyper`] clients connect to Unix-domain addresses. For the server side, see
173/// [`UnixConnector`].
174///
175/// # Example
176/// ```rust
177/// use hyper_unix_connector::{Uri, UnixClient};
178/// use hyper::{Body, Client};
179///
180/// let client: Client<UnixClient, Body> = Client::builder().build(UnixClient);
181/// let addr: hyper::Uri = Uri::new("./my_unix_socket", "/").into();
182/// client.get(addr);
183/// ```
184#[derive(Clone, Copy, Debug)]
185pub struct UnixClient;
186
187impl hyper::service::Service<hyper::Uri> for UnixClient {
188    type Response = UDS;
189    type Error = Error;
190    type Future =
191        Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
192
193    fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
194        Poll::Ready(Ok(()))
195    }
196
197    fn call(&mut self, dst: hyper::Uri) -> Self::Future {
198        Box::pin(async move {
199            match dst.scheme_str() {
200                Some("unix") => (),
201                _ => return Err(anyhow!("Invalid uri {:?}", dst)),
202            }
203
204            let path = match Uri::socket_path(&dst) {
205                Some(path) => path,
206
207                None => return Err(anyhow!("Invalid uri {:?}", dst)),
208            };
209
210            let st = tokio::net::UnixStream::connect(&path).await?;
211            Ok(st.into())
212        })
213    }
214}
215
216impl hyper::client::connect::Connection for UDS {
217    fn connected(&self) -> hyper::client::connect::Connected {
218        hyper::client::connect::Connected::new()
219    }
220}
221
222#[cfg(test)]
223mod test {
224    use crate::{UnixClient, UnixConnector, Uri};
225    use futures_util::stream::{StreamExt, TryStreamExt};
226    use hyper::service::{make_service_fn, service_fn};
227    use hyper::{Body, Client, Error, Response, Server};
228
229    #[test]
230    fn ping() -> Result<(), anyhow::Error> {
231        const PING_RESPONSE: &str = "Hello, World";
232        const TEST_UNIX_ADDR: &str = "my-unix-socket";
233
234        std::fs::remove_file(TEST_UNIX_ADDR).unwrap_or_else(|_| ());
235
236        let rt = tokio::runtime::Builder::new_multi_thread()
237            .enable_all()
238            .build()?;
239        rt.block_on(async {
240            // server
241            let uc: UnixConnector = tokio::net::UnixListener::bind(TEST_UNIX_ADDR)
242                .expect("bind unixlistener")
243                .into();
244            let srv_fut = Server::builder(uc).serve(make_service_fn(|_| async move {
245                Ok::<_, Error>(service_fn(|_| async move {
246                    Ok::<_, Error>(Response::new(Body::from(PING_RESPONSE)))
247                }))
248            }));
249
250            // client
251            let client: Client<UnixClient, Body> = Client::builder().build(UnixClient);
252
253            tokio::spawn(async move {
254                if let Err(e) = srv_fut.await {
255                    panic!(e);
256                }
257            });
258
259            let addr: hyper::Uri = Uri::new(TEST_UNIX_ADDR, "/").into();
260            let body = client.get(addr).await.unwrap().into_body();
261            let payload: Vec<u8> = body
262                .map(|b| b.map(|v| v.to_vec()))
263                .try_concat()
264                .await
265                .unwrap();
266            let resp = String::from_utf8(payload).expect("body utf8");
267            assert_eq!(resp, PING_RESPONSE);
268        });
269
270        std::fs::remove_file(TEST_UNIX_ADDR).unwrap_or_else(|_| ());
271
272        Ok(())
273    }
274}