generic_async_http_client/tcp/
mod.rs

1use std::io;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4#[cfg(feature = "rustls_byoc")]
5use std::{convert::TryFrom, sync::Arc};
6
7#[cfg(feature = "proxies")]
8mod socks5;
9#[cfg(feature = "proxies")]
10use socks5::connect_via_socks_prx;
11#[cfg(feature = "proxies")]
12mod http;
13#[cfg(feature = "proxies")]
14use http::connect_via_http_prx;
15
16#[cfg(feature = "use_async_h1")]
17use async_std::{
18    io::{Read, Write},
19    net::TcpStream,
20};
21#[cfg(all(feature = "use_async_h1", feature = "proxies"))]
22use http_types::Url as Uri;
23#[cfg(all(feature = "use_hyper", feature = "proxies"))]
24use hyper::http::uri::Uri;
25#[cfg(feature = "use_hyper")]
26use hyper::rt::{Read, ReadBufCursor, Write};
27#[cfg(feature = "use_hyper")]
28use tokio::{
29    io::{AsyncRead as _, AsyncWrite as _},
30    net::TcpStream,
31};
32
33#[cfg(any(feature = "async_native_tls", feature = "hyper_native_tls"))]
34use async_native_tls::{TlsConnector, TlsStream};
35#[cfg(all(feature = "rustls_byoc", feature = "use_async_h1"))]
36use futures_rustls::{
37    client::TlsStream,
38    rustls::{pki_types::ServerName, ClientConfig, RootCertStore},
39    TlsConnector,
40};
41#[cfg(all(feature = "rustls_byoc", feature = "use_hyper"))]
42use tokio_rustls::{
43    client::TlsStream,
44    rustls::{pki_types::ServerName, ClientConfig, RootCertStore},
45    TlsConnector,
46};
47#[cfg(feature = "rustls_byoc")]
48use webpki_roots::TLS_SERVER_ROOTS;
49
50pub struct Stream {
51    state: State,
52}
53enum State {
54    #[cfg(any(
55        feature = "rustls_byoc",
56        feature = "hyper_native_tls",
57        feature = "async_native_tls"
58    ))]
59    Tls(TlsStream<TcpStream>),
60    Plain(TcpStream),
61}
62
63#[cfg(feature = "proxies")]
64pub mod proxy {
65    use super::*;
66    use async_trait::async_trait;
67
68    /// Sets the global proxy to a `&'static Proxy`.
69    pub fn set_proxy(proxy: &'static dyn Proxy) {
70        unsafe {
71            GLOBAL_PROXY = proxy;
72        }
73    }
74    /// Sets the global proxy to a `Box<Proxy>`.
75    ///
76    /// This is a simple convenience wrapper over `set_proxy`, which takes a
77    /// `Box<Proxy>` rather than a `&'static Proxy`. See the documentation for
78    /// [`set_proxy`] for more details.
79    pub fn set_boxed_proxy(proxy: Box<dyn Proxy>) {
80        set_proxy(Box::leak(proxy))
81    }
82    /// Returns a reference to the proxy.
83    pub fn proxy() -> &'static dyn Proxy {
84        unsafe { GLOBAL_PROXY }
85    }
86    static mut GLOBAL_PROXY: &dyn Proxy = &EnvProxy;
87
88    /// Trait to implement custom proxies
89    #[async_trait]
90    pub trait Proxy: Sync + Send {
91        /// create a new TCP connection to the target
92        async fn connect_w_proxy(&self, host: &str, port: u16, tls: bool) -> io::Result<TcpStream>;
93    }
94    /// Use a direct connection
95    pub struct NoProxy;
96    #[async_trait]
97    impl Proxy for NoProxy {
98        async fn connect_w_proxy(
99            &self,
100            host: &str,
101            port: u16,
102            _tls: bool,
103        ) -> io::Result<TcpStream> {
104            TcpStream::connect((host, port)).await
105        }
106    }
107    /// Default Proxy. Performs auto detection from ENV. Supports the schemes `http(s)://` and `socks5(h)://`.
108    /// 
109    /// `http_proxy`, `HTTPS_PROXY` should be set for protocol-specific proxies.
110    /// General proxy should be set with `ALL_PROXY`
111    ///
112    /// A comma-separated list of host names that shouldn't go through any proxy is
113    /// set in (only an asterisk, '*' matches all hosts) `NO_PROXY`
114    pub struct EnvProxy;
115    #[async_trait]
116    impl Proxy for EnvProxy {
117        async fn connect_w_proxy(&self, host: &str, port: u16, tls: bool) -> io::Result<TcpStream> {
118            let mut prx = std::env::var("ALL_PROXY")
119                .or_else(|_| std::env::var("all_proxy"))
120                .ok();
121            if prx.is_none() && tls {
122                prx = std::env::var("HTTPS_PROXY")
123                    .or_else(|_| std::env::var("https_proxy"))
124                    .ok();
125            }
126            if prx.is_none() && !tls {
127                prx = std::env::var("HTTP_PROXY")
128                    .or_else(|_| std::env::var("http_proxy"))
129                    .ok();
130            }
131            if let Ok(no_proxy) = std::env::var("NO_PROXY").or_else(|_| std::env::var("no_proxy")) {
132                for h in no_proxy.split(',') {
133                    match h.trim() {
134                        a if a == host => {}
135                        "*" => {}
136                        _ => continue,
137                    }
138                    log::debug!("using no proxy due to env NO_PROXY");
139                    prx = None;
140                    break;
141                }
142            }
143            match prx {
144                None => TcpStream::connect((host, port)).await,
145                Some(proxy) => {
146                    let url = proxy
147                        .parse::<Uri>()
148                        .map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e))?;
149                    #[cfg(feature = "use_hyper")]
150                    let (phost, scheme) = (url.host(), url.scheme_str());
151                    #[cfg(feature = "use_async_h1")]
152                    let (phost, scheme) = (url.host_str(), Some(url.scheme()));
153
154                    let phost = match phost {
155                        Some(s) => s,
156                        None => {
157                            return Err(io::Error::new(
158                                io::ErrorKind::InvalidInput,
159                                "missing proxy host",
160                            ));
161                        }
162                    };
163                    #[cfg(feature = "use_hyper")]
164                    let pport = url.port().map(|p| p.as_u16());
165                    #[cfg(feature = "use_async_h1")]
166                    let pport = url.port();
167
168                    let pport = match pport {
169                        Some(port) => port,
170                        None => match scheme {
171                            Some("https") => 443,
172                            Some("http") => 80,
173                            Some("socks5") => 1080,
174                            Some("socks5h") => 1080,
175                            _ => {
176                                return Err(io::Error::new(
177                                    io::ErrorKind::InvalidInput,
178                                    "missing proxy port",
179                                ))
180                            }
181                        },
182                    };
183                    log::info!("using proxy {}:{}", phost, pport);
184                    match scheme {
185                        Some("http") => connect_via_http_prx(host, port, phost, pport).await,
186                        Some(socks5) if socks5 == "socks5" || socks5 == "socks5h" => {
187                            connect_via_socks_prx(host, port, phost, pport, socks5 == "socks5h")
188                                .await
189                        }
190                        _ => {
191                            return Err(io::Error::new(
192                                io::ErrorKind::InvalidInput,
193                                "unsupported proxy scheme",
194                            ))
195                        }
196                    }
197                }
198            }
199        }
200    }
201
202    #[cfg(test)]
203    mod tests {
204        use crate::tests::{
205            assert_stream, block_on, listen_somewhere, spawn, TcpListener, WriteExt,
206        };
207        #[test]
208        fn prx_from_env() {
209            async fn server(listener: TcpListener) -> std::io::Result<bool> {
210                let (mut stream, _) = listener.accept().await?;
211
212                assert_stream(
213                    &mut stream,
214                    format!("CONNECT whatever:80 HTTP/1.1\r\nHost: whatever:80\r\n\r\n").as_bytes(),
215                )
216                .await?;
217                stream.write_all(b"HTTP/1.1 200 Connected\r\n\r\n").await?;
218
219                assert_stream(
220                    &mut stream,
221                    format!("GET /bla HTTP/1.1\r\nhost: whatever\r\ncontent-length: 0\r\n\r\n")
222                        .as_bytes(),
223                )
224                .await?;
225                stream
226                    .write_all(b"HTTP/1.1 200 OK\r\ncontent-length: 3\r\n\r\nabc")
227                    .await?;
228
229                Ok(true)
230            }
231            block_on(async {
232                let (listener, pport, phost) = listen_somewhere().await?;
233                std::env::set_var("HTTP_PROXY", format!("http://{phost}:{pport}/"));
234                std::env::set_var("NO_PROXY", &phost);
235                let t = spawn(server(listener));
236
237                let r = crate::Request::get("http://whatever/bla");
238                let mut aw = r.exec().await?;
239
240                assert_eq!(aw.status_code(), 200, "wrong status");
241                assert_eq!(aw.text().await?, "abc", "wrong text");
242                assert!(t.await?, "not cool");
243                Ok(())
244            })
245            .unwrap();
246        }
247    }
248}
249
250#[cfg(any(
251    feature = "rustls_byoc",
252    feature = "hyper_native_tls",
253    feature = "async_native_tls"
254))]
255fn get_tls_connector() -> io::Result<TlsConnector> {
256    #[cfg(feature = "rustls_byoc")]
257    {
258        let mut root_store = RootCertStore::empty();
259        root_store.extend(TLS_SERVER_ROOTS.iter().cloned());
260
261        let mut config = ClientConfig::builder()
262            .with_root_certificates(root_store)
263            .with_no_client_auth();
264
265        #[cfg(all(feature = "use_hyper", feature = "http2"))]
266        config.alpn_protocols.push(b"h2".to_vec());
267        config.alpn_protocols.push(b"http/1.1".to_vec());
268
269        Ok(TlsConnector::from(Arc::new(config)))
270    }
271    #[cfg(any(feature = "async_native_tls", feature = "hyper_native_tls"))]
272    return Ok(TlsConnector::new());
273}
274
275impl Stream {
276    pub async fn connect(host: &str, port: u16, tls: bool) -> io::Result<Stream> {
277        #[cfg(feature = "proxies")]
278        let tcp = proxy::proxy().connect_w_proxy(host, port, tls).await?;
279        #[cfg(not(feature = "proxies"))]
280        let tcp = TcpStream::connect((host, port)).await?;
281        log::trace!("connected to {}:{}", host, port);
282
283        if tls {
284            #[cfg(any(
285                feature = "hyper_native_tls",
286                feature = "async_native_tls",
287                feature = "rustls_byoc"
288            ))]
289            {
290                #[cfg(feature = "rustls_byoc")]
291                let host = ServerName::try_from(host)
292                    .map_err(|_e| io::Error::new(io::ErrorKind::InvalidInput, "Invalid DNS name"))?
293                    .to_owned();
294                let tlsc = get_tls_connector()?;
295
296                let tls = tlsc.connect(host, tcp).await;
297                return match tls {
298                    Ok(stream) => {
299                        log::trace!("wrapped TLS");
300                        Ok(Stream {
301                            state: State::Tls(stream),
302                        })
303                    }
304                    Err(e) => {
305                        log::error!("TLS Handshake: {}", e);
306                        #[cfg(feature = "rustls_byoc")]
307                        {
308                            Err(e)
309                        }
310                        #[cfg(any(feature = "hyper_native_tls", feature = "async_native_tls"))]
311                        Err(io::Error::new(io::ErrorKind::InvalidInput, e))
312                    }
313                };
314            }
315            #[cfg(not(any(
316                feature = "rustls_byoc",
317                feature = "hyper_native_tls",
318                feature = "async_native_tls"
319            )))]
320            return Err(io::Error::new(
321                io::ErrorKind::InvalidInput,
322                "no TLS backend available",
323            ));
324        } else {
325            return Ok(Stream {
326                state: State::Plain(tcp),
327            });
328        }
329    }
330}
331
332#[cfg(feature = "use_hyper")]
333impl Stream {
334    pub fn get_proto(&self) -> hyper::Version {
335        #[cfg(feature = "rustls_byoc")]
336        if let State::Tls(ref t) = self.state {
337            let (_, s) = t.get_ref();
338            if Some(&b"h2"[..]) == s.alpn_protocol() {
339                return hyper::Version::HTTP_2;
340            }
341        }
342        hyper::Version::HTTP_11
343    }
344}
345
346impl Write for Stream {
347    fn poll_write(
348        self: Pin<&mut Self>,
349        cx: &mut Context<'_>,
350        buf: &[u8],
351    ) -> Poll<io::Result<usize>> {
352        let pin = self.get_mut();
353        match pin.state {
354            #[cfg(any(
355                feature = "rustls_byoc",
356                feature = "hyper_native_tls",
357                feature = "async_native_tls"
358            ))]
359            State::Tls(ref mut t) => Pin::new(t).poll_write(cx, buf),
360            State::Plain(ref mut t) => Pin::new(t).poll_write(cx, buf),
361        }
362    }
363
364    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
365        let pin = self.get_mut();
366        match pin.state {
367            #[cfg(any(
368                feature = "rustls_byoc",
369                feature = "hyper_native_tls",
370                feature = "async_native_tls"
371            ))]
372            State::Tls(ref mut t) => Pin::new(t).poll_flush(cx),
373            State::Plain(ref mut t) => Pin::new(t).poll_flush(cx),
374        }
375    }
376
377    #[cfg(feature = "use_async_h1")]
378    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
379        let pin = self.get_mut();
380        match pin.state {
381            #[cfg(any(
382                feature = "rustls_byoc",
383                feature = "hyper_native_tls",
384                feature = "async_native_tls"
385            ))]
386            State::Tls(ref mut t) => Pin::new(t).poll_close(cx),
387            State::Plain(ref mut t) => Pin::new(t).poll_close(cx),
388        }
389    }
390
391    #[cfg(feature = "use_hyper")]
392    fn poll_shutdown(
393        self: Pin<&mut Self>,
394        cx: &mut Context<'_>,
395    ) -> Poll<std::result::Result<(), std::io::Error>> {
396        let pin = self.get_mut();
397        match pin.state {
398            #[cfg(any(
399                feature = "rustls_byoc",
400                feature = "hyper_native_tls",
401                feature = "async_native_tls"
402            ))]
403            State::Tls(ref mut t) => Pin::new(t).poll_shutdown(cx),
404            State::Plain(ref mut t) => Pin::new(t).poll_shutdown(cx),
405        }
406    }
407}
408impl Read for Stream {
409    #[cfg(feature = "use_async_h1")]
410    fn poll_read(
411        self: Pin<&mut Self>,
412        cx: &mut Context<'_>,
413        buf: &mut [u8],
414    ) -> Poll<io::Result<usize>> {
415        let pin = self.get_mut();
416        match pin.state {
417            #[cfg(any(
418                feature = "rustls_byoc",
419                feature = "hyper_native_tls",
420                feature = "async_native_tls"
421            ))]
422            State::Tls(ref mut t) => Pin::new(t).poll_read(cx, buf),
423            State::Plain(ref mut t) => Pin::new(t).poll_read(cx, buf),
424        }
425    }
426    #[cfg(feature = "use_hyper")]
427    fn poll_read(
428        self: Pin<&mut Self>,
429        cx: &mut Context<'_>,
430        mut buf: ReadBufCursor<'_>,
431    ) -> Poll<io::Result<()>> {
432        let pin = self.get_mut();
433        let f = {
434            let mut tbuf = tokio::io::ReadBuf::uninit(unsafe { buf.as_mut() });
435            let p = match pin.state {
436                #[cfg(any(
437                    feature = "rustls_byoc",
438                    feature = "hyper_native_tls",
439                    feature = "async_native_tls"
440                ))]
441                State::Tls(ref mut t) => Pin::new(t).poll_read(cx, &mut tbuf),
442                State::Plain(ref mut t) => Pin::new(t).poll_read(cx, &mut tbuf),
443            };
444            match p {
445                Poll::Ready(Ok(())) => tbuf.filled().len(),
446                o => return o,
447            }
448        };
449        unsafe {
450            buf.advance(f);
451        }
452        Poll::Ready(Ok(()))
453    }
454}