hyperlocal_next/
client.rs1use hex::FromHex;
2use hyper::{body::Body, rt::ReadBufCursor, Uri};
3use hyper_util::{
4 client::legacy::{
5 connect::{Connected, Connection},
6 Client,
7 },
8 rt::{TokioExecutor, TokioIo},
9};
10use pin_project_lite::pin_project;
11use std::{
12 future::Future,
13 io,
14 io::Error,
15 path::{Path, PathBuf},
16 pin::Pin,
17 task::{Context, Poll},
18};
19use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
20use tower_service::Service;
21
22pin_project! {
23 #[derive(Debug)]
24 pub struct UnixStream {
25 #[pin]
26 unix_stream: tokio::net::UnixStream,
27 }
28}
29
30impl UnixStream {
31 async fn connect(path: impl AsRef<Path>) -> io::Result<Self> {
32 let unix_stream = tokio::net::UnixStream::connect(path).await?;
33 Ok(Self { unix_stream })
34 }
35}
36
37impl AsyncWrite for UnixStream {
38 fn poll_write(
39 self: Pin<&mut Self>,
40 cx: &mut Context<'_>,
41 buf: &[u8],
42 ) -> Poll<Result<usize, io::Error>> {
43 self.project().unix_stream.poll_write(cx, buf)
44 }
45
46 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
47 self.project().unix_stream.poll_flush(cx)
48 }
49
50 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
51 self.project().unix_stream.poll_shutdown(cx)
52 }
53}
54
55impl hyper::rt::Write for UnixStream {
56 fn poll_write(
57 self: Pin<&mut Self>,
58 cx: &mut Context<'_>,
59 buf: &[u8],
60 ) -> Poll<Result<usize, Error>> {
61 self.project().unix_stream.poll_write(cx, buf)
62 }
63
64 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
65 self.project().unix_stream.poll_flush(cx)
66 }
67
68 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
69 self.project().unix_stream.poll_shutdown(cx)
70 }
71}
72
73impl AsyncRead for UnixStream {
74 fn poll_read(
75 self: Pin<&mut Self>,
76 cx: &mut Context<'_>,
77 buf: &mut ReadBuf<'_>,
78 ) -> Poll<io::Result<()>> {
79 self.project().unix_stream.poll_read(cx, buf)
80 }
81}
82
83impl hyper::rt::Read for UnixStream {
84 fn poll_read(
85 self: Pin<&mut Self>,
86 cx: &mut Context<'_>,
87 buf: ReadBufCursor<'_>,
88 ) -> Poll<Result<(), Error>> {
89 let mut t = TokioIo::new(self.project().unix_stream);
90 Pin::new(&mut t).poll_read(cx, buf)
91 }
92}
93
94#[derive(Clone, Copy, Debug, Default)]
113pub struct UnixConnector;
114
115impl Unpin for UnixConnector {}
116
117impl Service<Uri> for UnixConnector {
118 type Response = UnixStream;
119 type Error = io::Error;
120 #[allow(clippy::type_complexity)]
121 type Future =
122 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
123
124 fn call(&mut self, req: Uri) -> Self::Future {
125 let fut = async move {
126 let path = parse_socket_path(&req)?;
127 UnixStream::connect(path).await
128 };
129
130 Box::pin(fut)
131 }
132
133 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
134 Poll::Ready(Ok(()))
135 }
136}
137
138impl Connection for UnixStream {
139 fn connected(&self) -> Connected {
140 Connected::new()
141 }
142}
143
144fn parse_socket_path(uri: &Uri) -> Result<PathBuf, io::Error> {
145 if uri.scheme_str() != Some("unix") {
146 return Err(io::Error::new(
147 io::ErrorKind::InvalidInput,
148 "invalid URL, scheme must be unix",
149 ));
150 }
151
152 if let Some(host) = uri.host() {
153 let bytes = Vec::from_hex(host).map_err(|_| {
154 io::Error::new(
155 io::ErrorKind::InvalidInput,
156 "invalid URL, host must be a hex-encoded path",
157 )
158 })?;
159
160 Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned()))
161 } else {
162 Err(io::Error::new(
163 io::ErrorKind::InvalidInput,
164 "invalid URL, host must be present",
165 ))
166 }
167}
168
169pub trait UnixClientExt<B: Body + Send> {
172 #[must_use]
184 fn unix() -> Client<UnixConnector, B>
185 where
186 B::Data: Send,
187 {
188 Client::builder(TokioExecutor::new()).build(UnixConnector)
189 }
190}
191
192impl<B: Body + Send> UnixClientExt<B> for Client<UnixConnector, B> {}