hyper_unix_connector/
lib.rs1use anyhow::{anyhow, Error};
9use core::{
10 pin::Pin,
11 task::{Context, Poll},
12};
13use hex::FromHex;
14use pin_project::pin_project;
15use std::borrow::Cow;
16use std::future::Future;
17use std::path::Path;
18use tokio::io::ReadBuf;
19
20#[derive(Debug)]
37pub struct Uri<'a> {
38 encoded: Cow<'a, str>,
40}
41
42impl<'a> Into<hyper::Uri> for Uri<'a> {
43 fn into(self) -> hyper::Uri {
44 self.encoded.as_ref().parse().unwrap()
45 }
46}
47
48impl<'a> Uri<'a> {
49 pub fn new<P>(socket: P, path: &'a str) -> Self
52 where
53 P: AsRef<Path>,
54 {
55 let host = hex::encode(socket.as_ref().to_string_lossy().as_bytes());
56 let host_str = format!("unix://{}:0{}", host, path);
57 Uri {
58 encoded: Cow::Owned(host_str),
59 }
60 }
61
62 fn socket_path(uri: &hyper::Uri) -> Option<String> {
65 uri.host()
66 .iter()
67 .filter_map(|host| {
68 Vec::from_hex(host)
69 .ok()
70 .map(|raw| String::from_utf8_lossy(&raw).into_owned())
71 })
72 .next()
73 }
74}
75
76#[derive(Debug)]
104pub struct UnixConnector(tokio::net::UnixListener);
105
106impl From<tokio::net::UnixListener> for UnixConnector {
107 fn from(u: tokio::net::UnixListener) -> Self {
108 UnixConnector(u)
109 }
110}
111
112impl Into<tokio::net::UnixListener> for UnixConnector {
113 fn into(self) -> tokio::net::UnixListener {
114 self.0
115 }
116}
117
118impl hyper::server::accept::Accept for UnixConnector {
119 type Conn = tokio::net::UnixStream;
120 type Error = Error;
121
122 fn poll_accept(
123 self: Pin<&mut Self>,
124 cx: &mut Context,
125 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
126 self.0
127 .poll_accept(cx)
128 .map_ok(|(stream, _addr)| stream)
129 .map_err(|e| e.into())
130 .map(Some)
131 }
132}
133
134#[pin_project]
136#[derive(Debug)]
137pub struct UDS(#[pin] tokio::net::UnixStream);
138
139impl From<tokio::net::UnixStream> for UDS {
140 fn from(f: tokio::net::UnixStream) -> Self {
141 Self(f)
142 }
143}
144
145impl Into<tokio::net::UnixStream> for UDS {
146 fn into(self) -> tokio::net::UnixStream {
147 self.0
148 }
149}
150
151macro_rules! conn_impl_fn {
152 ($fn: ident |$first_var: ident: $first_typ: ty, $($var: ident: $typ: ty),*| -> $ret: ty ;;) => {
153 fn $fn ($first_var: $first_typ, $( $var: $typ ),* ) -> $ret {
154 let ux: Pin<&mut tokio::net::UnixStream> = $first_var.project().0;
155 ux.$fn($($var),*)
156 }
157 };
158}
159
160impl tokio::io::AsyncRead for UDS {
161 conn_impl_fn!(poll_read |self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>| -> Poll<std::io::Result<()>> ;;);
162}
163
164impl tokio::io::AsyncWrite for UDS {
165 conn_impl_fn!(poll_write |self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]| -> Poll<std::io::Result<usize>> ;;);
166 conn_impl_fn!(poll_flush |self: Pin<&mut Self>, cx: &mut Context<'_>| -> Poll<std::io::Result<()>> ;;);
167 conn_impl_fn!(poll_shutdown |self: Pin<&mut Self>, cx: &mut Context<'_>| -> Poll<std::io::Result<()>> ;;);
168}
169
170#[derive(Clone, Copy, Debug)]
185pub struct UnixClient;
186
187impl hyper::service::Service<hyper::Uri> for UnixClient {
188 type Response = UDS;
189 type Error = Error;
190 type Future =
191 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
192
193 fn poll_ready(&mut self, _cx: &mut Context) -> Poll<Result<(), Self::Error>> {
194 Poll::Ready(Ok(()))
195 }
196
197 fn call(&mut self, dst: hyper::Uri) -> Self::Future {
198 Box::pin(async move {
199 match dst.scheme_str() {
200 Some("unix") => (),
201 _ => return Err(anyhow!("Invalid uri {:?}", dst)),
202 }
203
204 let path = match Uri::socket_path(&dst) {
205 Some(path) => path,
206
207 None => return Err(anyhow!("Invalid uri {:?}", dst)),
208 };
209
210 let st = tokio::net::UnixStream::connect(&path).await?;
211 Ok(st.into())
212 })
213 }
214}
215
216impl hyper::client::connect::Connection for UDS {
217 fn connected(&self) -> hyper::client::connect::Connected {
218 hyper::client::connect::Connected::new()
219 }
220}
221
222#[cfg(test)]
223mod test {
224 use crate::{UnixClient, UnixConnector, Uri};
225 use futures_util::stream::{StreamExt, TryStreamExt};
226 use hyper::service::{make_service_fn, service_fn};
227 use hyper::{Body, Client, Error, Response, Server};
228
229 #[test]
230 fn ping() -> Result<(), anyhow::Error> {
231 const PING_RESPONSE: &str = "Hello, World";
232 const TEST_UNIX_ADDR: &str = "my-unix-socket";
233
234 std::fs::remove_file(TEST_UNIX_ADDR).unwrap_or_else(|_| ());
235
236 let rt = tokio::runtime::Builder::new_multi_thread()
237 .enable_all()
238 .build()?;
239 rt.block_on(async {
240 let uc: UnixConnector = tokio::net::UnixListener::bind(TEST_UNIX_ADDR)
242 .expect("bind unixlistener")
243 .into();
244 let srv_fut = Server::builder(uc).serve(make_service_fn(|_| async move {
245 Ok::<_, Error>(service_fn(|_| async move {
246 Ok::<_, Error>(Response::new(Body::from(PING_RESPONSE)))
247 }))
248 }));
249
250 let client: Client<UnixClient, Body> = Client::builder().build(UnixClient);
252
253 tokio::spawn(async move {
254 if let Err(e) = srv_fut.await {
255 panic!(e);
256 }
257 });
258
259 let addr: hyper::Uri = Uri::new(TEST_UNIX_ADDR, "/").into();
260 let body = client.get(addr).await.unwrap().into_body();
261 let payload: Vec<u8> = body
262 .map(|b| b.map(|v| v.to_vec()))
263 .try_concat()
264 .await
265 .unwrap();
266 let resp = String::from_utf8(payload).expect("body utf8");
267 assert_eq!(resp, PING_RESPONSE);
268 });
269
270 std::fs::remove_file(TEST_UNIX_ADDR).unwrap_or_else(|_| ());
271
272 Ok(())
273 }
274}