1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4#[cfg(feature = "rustls_byoc")]
5use std::{convert::TryFrom, sync::Arc};
6
7#[cfg(feature = "proxies")]
8mod socks5;
9#[cfg(feature = "proxies")]
10use socks5::connect_via_socks_prx;
11#[cfg(feature = "proxies")]
12mod http;
13#[cfg(feature = "proxies")]
14use http::connect_via_http_prx;
15
16#[cfg(feature = "use_async_h1")]
17use async_std::{
18 io::{Read, Write},
19 net::TcpStream,
20};
21#[cfg(all(feature = "use_async_h1", feature = "proxies"))]
22use http_types::Url as Uri;
23#[cfg(all(feature = "use_hyper", feature = "proxies"))]
24use hyper::http::uri::Uri;
25#[cfg(feature = "use_hyper")]
26use hyper::rt::{Read, ReadBufCursor, Write};
27#[cfg(feature = "use_hyper")]
28use tokio::{
29 io::{AsyncRead as _, AsyncWrite as _},
30 net::TcpStream,
31};
32
33#[cfg(any(feature = "async_native_tls", feature = "hyper_native_tls"))]
34use async_native_tls::{TlsConnector, TlsStream};
35#[cfg(all(feature = "rustls_byoc", feature = "use_async_h1"))]
36use futures_rustls::{
37 client::TlsStream,
38 rustls::{pki_types::ServerName, ClientConfig, RootCertStore},
39 TlsConnector,
40};
41#[cfg(all(feature = "rustls_byoc", feature = "use_hyper"))]
42use tokio_rustls::{
43 client::TlsStream,
44 rustls::{pki_types::ServerName, ClientConfig, RootCertStore},
45 TlsConnector,
46};
47#[cfg(feature = "rustls_byoc")]
48use webpki_roots::TLS_SERVER_ROOTS;
49
50pub struct Stream {
51 state: State,
52}
53enum State {
54 #[cfg(any(
55 feature = "rustls_byoc",
56 feature = "hyper_native_tls",
57 feature = "async_native_tls"
58 ))]
59 Tls(TlsStream<TcpStream>),
60 Plain(TcpStream),
61}
62
63#[cfg(feature = "proxies")]
64pub mod proxy {
65 use super::*;
66 use async_trait::async_trait;
67
68 pub fn set_proxy(proxy: &'static dyn Proxy) {
70 unsafe {
71 GLOBAL_PROXY = proxy;
72 }
73 }
74 pub fn set_boxed_proxy(proxy: Box<dyn Proxy>) {
80 set_proxy(Box::leak(proxy))
81 }
82 pub fn proxy() -> &'static dyn Proxy {
84 unsafe { GLOBAL_PROXY }
85 }
86 static mut GLOBAL_PROXY: &dyn Proxy = &EnvProxy;
87
88 #[async_trait]
90 pub trait Proxy: Sync + Send {
91 async fn connect_w_proxy(&self, host: &str, port: u16, tls: bool) -> io::Result<TcpStream>;
93 }
94 pub struct NoProxy;
96 #[async_trait]
97 impl Proxy for NoProxy {
98 async fn connect_w_proxy(
99 &self,
100 host: &str,
101 port: u16,
102 _tls: bool,
103 ) -> io::Result<TcpStream> {
104 TcpStream::connect((host, port)).await
105 }
106 }
107 pub struct EnvProxy;
115 #[async_trait]
116 impl Proxy for EnvProxy {
117 async fn connect_w_proxy(&self, host: &str, port: u16, tls: bool) -> io::Result<TcpStream> {
118 let mut prx = std::env::var("ALL_PROXY")
119 .or_else(|_| std::env::var("all_proxy"))
120 .ok();
121 if prx.is_none() && tls {
122 prx = std::env::var("HTTPS_PROXY")
123 .or_else(|_| std::env::var("https_proxy"))
124 .ok();
125 }
126 if prx.is_none() && !tls {
127 prx = std::env::var("HTTP_PROXY")
128 .or_else(|_| std::env::var("http_proxy"))
129 .ok();
130 }
131 if let Ok(no_proxy) = std::env::var("NO_PROXY").or_else(|_| std::env::var("no_proxy")) {
132 for h in no_proxy.split(',') {
133 match h.trim() {
134 a if a == host => {}
135 "*" => {}
136 _ => continue,
137 }
138 log::debug!("using no proxy due to env NO_PROXY");
139 prx = None;
140 break;
141 }
142 }
143 match prx {
144 None => TcpStream::connect((host, port)).await,
145 Some(proxy) => {
146 let url = proxy
147 .parse::<Uri>()
148 .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
149 #[cfg(feature = "use_hyper")]
150 let (phost, scheme) = (url.host(), url.scheme_str());
151 #[cfg(feature = "use_async_h1")]
152 let (phost, scheme) = (url.host_str(), Some(url.scheme()));
153
154 let phost = match phost {
155 Some(s) => s,
156 None => {
157 return Err(io::Error::new(
158 io::ErrorKind::InvalidInput,
159 "missing proxy host",
160 ));
161 }
162 };
163 #[cfg(feature = "use_hyper")]
164 let pport = url.port().map(|p| p.as_u16());
165 #[cfg(feature = "use_async_h1")]
166 let pport = url.port();
167
168 let pport = match pport {
169 Some(port) => port,
170 None => match scheme {
171 Some("https") => 443,
172 Some("http") => 80,
173 Some("socks5") => 1080,
174 Some("socks5h") => 1080,
175 _ => {
176 return Err(io::Error::new(
177 io::ErrorKind::InvalidInput,
178 "missing proxy port",
179 ))
180 }
181 },
182 };
183 log::info!("using proxy {}:{}", phost, pport);
184 match scheme {
185 Some("http") => connect_via_http_prx(host, port, phost, pport).await,
186 Some(socks5) if socks5 == "socks5" || socks5 == "socks5h" => {
187 connect_via_socks_prx(host, port, phost, pport, socks5 == "socks5h")
188 .await
189 }
190 _ => {
191 return Err(io::Error::new(
192 io::ErrorKind::InvalidInput,
193 "unsupported proxy scheme",
194 ))
195 }
196 }
197 }
198 }
199 }
200 }
201
202 #[cfg(test)]
203 mod tests {
204 use crate::tests::{
205 assert_stream, block_on, listen_somewhere, spawn, TcpListener, WriteExt,
206 };
207 #[test]
208 fn prx_from_env() {
209 async fn server(listener: TcpListener) -> std::io::Result<bool> {
210 let (mut stream, _) = listener.accept().await?;
211
212 assert_stream(
213 &mut stream,
214 format!("CONNECT whatever:80 HTTP/1.1\r\nHost: whatever:80\r\n\r\n").as_bytes(),
215 )
216 .await?;
217 stream.write_all(b"HTTP/1.1 200 Connected\r\n\r\n").await?;
218
219 assert_stream(
220 &mut stream,
221 format!("GET /bla HTTP/1.1\r\nhost: whatever\r\ncontent-length: 0\r\n\r\n")
222 .as_bytes(),
223 )
224 .await?;
225 stream
226 .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 3\r\n\r\nabc")
227 .await?;
228
229 Ok(true)
230 }
231 block_on(async {
232 let (listener, pport, phost) = listen_somewhere().await?;
233 std::env::set_var("HTTP_PROXY", format!("http://{phost}:{pport}/"));
234 std::env::set_var("NO_PROXY", &phost);
235 let t = spawn(server(listener));
236
237 let r = crate::Request::get("http://whatever/bla");
238 let mut aw = r.exec().await?;
239
240 assert_eq!(aw.status_code(), 200, "wrong status");
241 assert_eq!(aw.text().await?, "abc", "wrong text");
242 assert!(t.await?, "not cool");
243 Ok(())
244 })
245 .unwrap();
246 }
247 }
248}
249
250#[cfg(any(
251 feature = "rustls_byoc",
252 feature = "hyper_native_tls",
253 feature = "async_native_tls"
254))]
255fn get_tls_connector() -> io::Result<TlsConnector> {
256 #[cfg(feature = "rustls_byoc")]
257 {
258 let mut root_store = RootCertStore::empty();
259 root_store.extend(TLS_SERVER_ROOTS.iter().cloned());
260
261 let mut config = ClientConfig::builder()
262 .with_root_certificates(root_store)
263 .with_no_client_auth();
264
265 #[cfg(all(feature = "use_hyper", feature = "http2"))]
266 config.alpn_protocols.push(b"h2".to_vec());
267 config.alpn_protocols.push(b"http/1.1".to_vec());
268
269 Ok(TlsConnector::from(Arc::new(config)))
270 }
271 #[cfg(any(feature = "async_native_tls", feature = "hyper_native_tls"))]
272 return Ok(TlsConnector::new());
273}
274
275impl Stream {
276 pub async fn connect(host: &str, port: u16, tls: bool) -> io::Result<Stream> {
277 #[cfg(feature = "proxies")]
278 let tcp = proxy::proxy().connect_w_proxy(host, port, tls).await?;
279 #[cfg(not(feature = "proxies"))]
280 let tcp = TcpStream::connect((host, port)).await?;
281 log::trace!("connected to {}:{}", host, port);
282
283 if tls {
284 #[cfg(any(
285 feature = "hyper_native_tls",
286 feature = "async_native_tls",
287 feature = "rustls_byoc"
288 ))]
289 {
290 #[cfg(feature = "rustls_byoc")]
291 let host = ServerName::try_from(host)
292 .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Invalid DNS name"))?
293 .to_owned();
294 let tlsc = get_tls_connector()?;
295
296 let tls = tlsc.connect(host, tcp).await;
297 return match tls {
298 Ok(stream) => {
299 log::trace!("wrapped TLS");
300 Ok(Stream {
301 state: State::Tls(stream),
302 })
303 }
304 Err(e) => {
305 log::error!("TLS Handshake: {}", e);
306 #[cfg(feature = "rustls_byoc")]
307 {
308 Err(e)
309 }
310 #[cfg(any(feature = "hyper_native_tls", feature = "async_native_tls"))]
311 Err(io::Error::new(io::ErrorKind::InvalidInput, e))
312 }
313 };
314 }
315 #[cfg(not(any(
316 feature = "rustls_byoc",
317 feature = "hyper_native_tls",
318 feature = "async_native_tls"
319 )))]
320 return Err(io::Error::new(
321 io::ErrorKind::InvalidInput,
322 "no TLS backend available",
323 ));
324 } else {
325 return Ok(Stream {
326 state: State::Plain(tcp),
327 });
328 }
329 }
330}
331
332#[cfg(feature = "use_hyper")]
333impl Stream {
334 pub fn get_proto(&self) -> hyper::Version {
335 #[cfg(feature = "rustls_byoc")]
336 if let State::Tls(ref t) = self.state {
337 let (_, s) = t.get_ref();
338 if Some(&b"h2"[..]) == s.alpn_protocol() {
339 return hyper::Version::HTTP_2;
340 }
341 }
342 hyper::Version::HTTP_11
343 }
344}
345
346impl Write for Stream {
347 fn poll_write(
348 self: Pin<&mut Self>,
349 cx: &mut Context<'_>,
350 buf: &[u8],
351 ) -> Poll<io::Result<usize>> {
352 let pin = self.get_mut();
353 match pin.state {
354 #[cfg(any(
355 feature = "rustls_byoc",
356 feature = "hyper_native_tls",
357 feature = "async_native_tls"
358 ))]
359 State::Tls(ref mut t) => Pin::new(t).poll_write(cx, buf),
360 State::Plain(ref mut t) => Pin::new(t).poll_write(cx, buf),
361 }
362 }
363
364 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
365 let pin = self.get_mut();
366 match pin.state {
367 #[cfg(any(
368 feature = "rustls_byoc",
369 feature = "hyper_native_tls",
370 feature = "async_native_tls"
371 ))]
372 State::Tls(ref mut t) => Pin::new(t).poll_flush(cx),
373 State::Plain(ref mut t) => Pin::new(t).poll_flush(cx),
374 }
375 }
376
377 #[cfg(feature = "use_async_h1")]
378 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
379 let pin = self.get_mut();
380 match pin.state {
381 #[cfg(any(
382 feature = "rustls_byoc",
383 feature = "hyper_native_tls",
384 feature = "async_native_tls"
385 ))]
386 State::Tls(ref mut t) => Pin::new(t).poll_close(cx),
387 State::Plain(ref mut t) => Pin::new(t).poll_close(cx),
388 }
389 }
390
391 #[cfg(feature = "use_hyper")]
392 fn poll_shutdown(
393 self: Pin<&mut Self>,
394 cx: &mut Context<'_>,
395 ) -> Poll<std::result::Result<(), std::io::Error>> {
396 let pin = self.get_mut();
397 match pin.state {
398 #[cfg(any(
399 feature = "rustls_byoc",
400 feature = "hyper_native_tls",
401 feature = "async_native_tls"
402 ))]
403 State::Tls(ref mut t) => Pin::new(t).poll_shutdown(cx),
404 State::Plain(ref mut t) => Pin::new(t).poll_shutdown(cx),
405 }
406 }
407}
408impl Read for Stream {
409 #[cfg(feature = "use_async_h1")]
410 fn poll_read(
411 self: Pin<&mut Self>,
412 cx: &mut Context<'_>,
413 buf: &mut [u8],
414 ) -> Poll<io::Result<usize>> {
415 let pin = self.get_mut();
416 match pin.state {
417 #[cfg(any(
418 feature = "rustls_byoc",
419 feature = "hyper_native_tls",
420 feature = "async_native_tls"
421 ))]
422 State::Tls(ref mut t) => Pin::new(t).poll_read(cx, buf),
423 State::Plain(ref mut t) => Pin::new(t).poll_read(cx, buf),
424 }
425 }
426 #[cfg(feature = "use_hyper")]
427 fn poll_read(
428 self: Pin<&mut Self>,
429 cx: &mut Context<'_>,
430 mut buf: ReadBufCursor<'_>,
431 ) -> Poll<io::Result<()>> {
432 let pin = self.get_mut();
433 let f = {
434 let mut tbuf = tokio::io::ReadBuf::uninit(unsafe { buf.as_mut() });
435 let p = match pin.state {
436 #[cfg(any(
437 feature = "rustls_byoc",
438 feature = "hyper_native_tls",
439 feature = "async_native_tls"
440 ))]
441 State::Tls(ref mut t) => Pin::new(t).poll_read(cx, &mut tbuf),
442 State::Plain(ref mut t) => Pin::new(t).poll_read(cx, &mut tbuf),
443 };
444 match p {
445 Poll::Ready(Ok(())) => tbuf.filled().len(),
446 o => return o,
447 }
448 };
449 unsafe {
450 buf.advance(f);
451 }
452 Poll::Ready(Ok(()))
453 }
454}