1use futures_util::future::BoxFuture;
2use hex::FromHex;
3use hyper::{
4 client::connect::{Connected, Connection},
5 service::Service,
6 Body, Client, Uri,
7};
8use pin_project::pin_project;
9use std::{
10 io,
11 path::{Path, PathBuf},
12 pin::Pin,
13 task::{Context, Poll},
14};
15
16#[pin_project]
17#[derive(Debug)]
18pub struct UnixStream {
19 #[pin]
20 unix_stream: tokio::net::UnixStream,
21}
22
23impl UnixStream {
24 async fn connect<P>(path: P) -> std::io::Result<Self>
25 where
26 P: AsRef<Path>,
27 {
28 let unix_stream = tokio::net::UnixStream::connect(path).await?;
29 Ok(Self { unix_stream })
30 }
31}
32
33impl tokio::io::AsyncWrite for UnixStream {
34 fn poll_write(
35 self: Pin<&mut Self>,
36 cx: &mut Context<'_>,
37 buf: &[u8],
38 ) -> Poll<Result<usize, io::Error>> {
39 self.project().unix_stream.poll_write(cx, buf)
40 }
41 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
42 self.project().unix_stream.poll_flush(cx)
43 }
44 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
45 self.project().unix_stream.poll_shutdown(cx)
46 }
47}
48
49impl tokio::io::AsyncRead for UnixStream {
50 fn poll_read(
51 self: Pin<&mut Self>,
52 cx: &mut Context<'_>,
53 buf: &mut [u8],
54 ) -> Poll<io::Result<usize>> {
55 self.project().unix_stream.poll_read(cx, buf)
56 }
57}
58
59#[derive(Clone, Copy, Debug, Default)]
75pub struct UnixConnector;
76
77impl Unpin for UnixConnector {}
78
79impl Service<Uri> for UnixConnector {
80 type Response = UnixStream;
81 type Error = std::io::Error;
82 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
83 fn call(&mut self, req: Uri) -> Self::Future {
84 let fut = async move {
85 let path = parse_socket_path(req)?;
86 UnixStream::connect(path).await
87 };
88
89 Box::pin(fut)
90 }
91 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
92 Poll::Ready(Ok(()))
93 }
94}
95
96impl Connection for UnixStream {
97 fn connected(&self) -> Connected {
98 Connected::new()
99 }
100}
101
102fn parse_socket_path(uri: Uri) -> Result<std::path::PathBuf, io::Error> {
103 if uri.scheme_str() != Some("unix") {
104 return Err(io::Error::new(
105 io::ErrorKind::InvalidInput,
106 "invalid URL, scheme must be unix",
107 ));
108 }
109
110 if let Some(host) = uri.host() {
111 let bytes = Vec::from_hex(host).map_err(|_| {
112 io::Error::new(
113 io::ErrorKind::InvalidInput,
114 "invalid URL, host must be a hex-encoded path",
115 )
116 })?;
117
118 Ok(PathBuf::from(String::from_utf8_lossy(&bytes).into_owned()))
119 } else {
120 Err(io::Error::new(
121 io::ErrorKind::InvalidInput,
122 "invalid URL, host must be present",
123 ))
124 }
125}
126
127pub trait UnixClientExt {
130 fn unix() -> Client<UnixConnector, Body> {
140 Client::builder().build(UnixConnector)
141 }
142}
143
144impl UnixClientExt for Client<UnixConnector> {}