1#[cfg(all(feature = "tls", feature = "rustls"))]
33compile_error!(
34 "`tls` and `rustls` features are mutually exclusive. You should enable only one of them"
35);
36
37use async_socks5::AddrKind;
38use futures::{
39 ready,
40 task::{Context, Poll},
41};
42use http::uri::Scheme;
43use hyper::{service::Service, Uri};
44#[cfg(feature = "rustls")]
45use hyper_rustls::HttpsConnector;
46#[cfg(feature = "tls")]
47use hyper_tls::HttpsConnector;
48use std::{future::Future, io, pin::Pin};
49use tokio::io::{AsyncRead, AsyncWrite, BufStream};
50
51pub use async_socks5::Auth;
52
53#[cfg(feature = "tls")]
54pub use hyper_tls::native_tls::Error as TlsError;
55
56#[derive(Debug, thiserror::Error)]
57pub enum Error {
58 #[error("{0}")]
59 Socks(
60 #[from]
61 #[source]
62 async_socks5::Error,
63 ),
64 #[error("{0}")]
65 Io(
66 #[from]
67 #[source]
68 io::Error,
69 ),
70 #[error("{0}")]
71 Connector(
72 #[from]
73 #[source]
74 BoxedError,
75 ),
76 #[error("Missing host")]
77 MissingHost,
78}
79
80pub type SocksFuture<R> = Pin<Box<dyn Future<Output = Result<R, Error>> + Send>>;
84
85pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
86
87#[derive(Debug, Clone, PartialEq, Eq, Hash)]
89pub struct SocksConnector<C> {
90 pub proxy_addr: Uri,
91 pub auth: Option<Auth>,
92 pub connector: C,
93}
94
95impl<C> SocksConnector<C> {
96 #[cfg(feature = "tls")]
98 pub fn with_tls(self) -> Result<HttpsConnector<Self>, TlsError> {
99 let args = (self, hyper_tls::native_tls::TlsConnector::new()?.into());
100 Ok(HttpsConnector::from(args))
101 }
102
103 #[cfg(feature = "rustls")]
105 pub fn with_tls(self) -> Result<HttpsConnector<Self>, io::Error> {
106 use rusttls::ClientConfig;
107 use std::sync::Arc;
108
109 let mut config = ClientConfig::new();
110 config.root_store = match rustls_native_certs::load_native_certs() {
111 Ok(store) => store,
112 Err((_, err)) => return Err(err),
113 };
114
115 let config = Arc::new(config);
116
117 let args = (self, config);
118 Ok(HttpsConnector::from(args))
119 }
120}
121
122impl<C> SocksConnector<C>
123where
124 C: Service<Uri>,
125 C::Response: AsyncRead + AsyncWrite + Send + Unpin,
126 C::Error: Into<BoxedError>,
127{
128 async fn call_async(mut self, target_addr: Uri) -> Result<C::Response, Error> {
129 let host = target_addr
130 .host()
131 .map(str::to_string)
132 .ok_or(Error::MissingHost)?;
133 let port =
134 target_addr
135 .port_u16()
136 .unwrap_or(if target_addr.scheme() == Some(&Scheme::HTTPS) {
137 443
138 } else {
139 80
140 });
141 let target_addr = AddrKind::Domain(host, port);
142
143 let stream = self
144 .connector
145 .call(self.proxy_addr)
146 .await
147 .map_err(Into::<BoxedError>::into)?;
148 let mut buf_stream = BufStream::new(stream);
149 let _ = async_socks5::connect(&mut buf_stream, target_addr, self.auth).await?;
150 Ok(buf_stream.into_inner())
151 }
152}
153
154impl<C> Service<Uri> for SocksConnector<C>
155where
156 C: Service<Uri> + Clone + Send + 'static,
157 C::Response: AsyncRead + AsyncWrite + Send + Unpin,
158 C::Error: Into<BoxedError>,
159 C::Future: Send,
160{
161 type Response = C::Response;
162 type Error = Error;
163 type Future = SocksFuture<C::Response>;
164
165 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
166 ready!(self.connector.poll_ready(cx)).map_err(Into::<BoxedError>::into)?;
167 Poll::Ready(Ok(()))
168 }
169
170 fn call(&mut self, req: Uri) -> Self::Future {
171 let this = self.clone();
172 Box::pin(async move { this.call_async(req).await })
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use hyper::{client::HttpConnector, Body, Client};
180
181 const PROXY_ADDR: &str = "socks5://127.0.0.1:1080";
182 const PROXY_USERNAME: &str = "hyper";
183 const PROXY_PASSWORD: &str = "proxy";
184 const HTTP_ADDR: &str = "http://google.com";
185 const HTTPS_ADDR: &str = "https://google.com";
186
187 struct Tester {
188 uri: Uri,
189 auth: Option<Auth>,
190 swap_connector: bool,
191 }
192
193 impl Tester {
194 fn uri(uri: Uri) -> Tester {
195 Self {
196 uri,
197 auth: None,
198 swap_connector: false,
199 }
200 }
201
202 fn http() -> Self {
203 Self::uri(Uri::from_static(HTTP_ADDR))
204 }
205
206 fn https() -> Self {
207 Self::uri(Uri::from_static(HTTPS_ADDR))
208 }
209
210 fn with_auth(mut self) -> Self {
211 self.auth = Some(Auth {
212 username: PROXY_USERNAME.to_string(),
213 password: PROXY_PASSWORD.to_string(),
214 });
215 self
216 }
217
218 fn swap_connector(mut self) -> Self {
219 self.swap_connector = true;
220 self
221 }
222
223 async fn test(self) {
224 let mut connector = HttpConnector::new();
225 connector.enforce_http(false);
226 let socks = SocksConnector {
227 proxy_addr: Uri::from_static(PROXY_ADDR),
228 auth: self.auth,
229 connector,
230 };
231
232 let fut = if (self.uri.scheme() == Some(&Scheme::HTTP)) ^ self.swap_connector {
233 Client::builder().build::<_, Body>(socks).get(self.uri)
234 } else {
235 Client::builder()
236 .build::<_, Body>(socks.with_tls().unwrap())
237 .get(self.uri)
238 };
239 let _ = fut.await.unwrap();
240 }
241 }
242
243 #[tokio::test]
244 async fn http_no_auth() {
245 Tester::http().test().await
246 }
247
248 #[tokio::test]
249 async fn https_no_auth() {
250 Tester::https().test().await
251 }
252
253 #[tokio::test]
254 async fn http_auth() {
255 Tester::http().with_auth().test().await
256 }
257
258 #[tokio::test]
259 async fn https_auth() {
260 Tester::https().with_auth().test().await
261 }
262
263 #[tokio::test]
264 async fn http_no_auth_swap() {
265 Tester::http().swap_connector().test().await
266 }
267
268 #[should_panic = "IncompleteMessage"]
269 #[tokio::test]
270 async fn https_no_auth_swap() {
271 Tester::https().swap_connector().test().await
272 }
273
274 #[tokio::test]
275 async fn http_auth_swap() {
276 Tester::http().with_auth().swap_connector().test().await
277 }
278
279 #[should_panic = "IncompleteMessage"]
280 #[tokio::test]
281 async fn https_auth_swap() {
282 Tester::https().with_auth().swap_connector().test().await
283 }
284}