async_web_client/
lib.rs

1mod http;
2pub mod prelude;
3#[cfg(feature = "websocket")]
4mod ws;
5
6use std::{
7    io,
8    net::IpAddr,
9    pin::Pin,
10    sync::Arc,
11    task::{Context, Poll},
12};
13
14pub use crate::http::*;
15use async_net::TcpStream;
16use futures::{AsyncRead, AsyncWrite};
17use futures_rustls::{client::TlsStream, rustls::ClientConfig, TlsConnector};
18use rustls_pki_types::{InvalidDnsNameError, ServerName};
19#[cfg(feature = "websocket")]
20pub use ws::*;
21
22pub enum Transport {
23    Tcp(TcpStream),
24    Tls(TlsStream<TcpStream>),
25}
26
27impl Transport {
28    async fn connect(tls: Option<Arc<ClientConfig>>, host: &str, port: u16) -> Result<Self, TransportError> {
29        let server = ServerName::try_from(host)
30            .map_err(|err| TransportError::InvalidDnsName(Arc::new(err)))?
31            .to_owned();
32        let tcp = match &server {
33            ServerName::DnsName(name) => TcpStream::connect((name.as_ref(), port)).await,
34            ServerName::IpAddress(ip) => TcpStream::connect((IpAddr::from(*ip), port)).await,
35            _ => unreachable!(),
36        }
37        .map_err(|err| TransportError::TcpConnect(Arc::new(err)))?;
38        let transport = match tls {
39            None => Transport::Tcp(tcp),
40            Some(client_config) => {
41                let tls = TlsConnector::from(client_config)
42                    .connect(server, tcp)
43                    .await
44                    .map_err(|err| TransportError::TlsConnect(Arc::new(err)))?;
45                Transport::Tls(tls)
46            }
47        };
48        Ok(transport)
49    }
50}
51
52impl Unpin for Transport {}
53
54impl AsyncRead for Transport {
55    fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
56        match self.get_mut() {
57            Transport::Tcp(tcp) => Pin::new(tcp).poll_read(cx, buf),
58            Transport::Tls(tls) => Pin::new(tls).poll_read(cx, buf),
59        }
60    }
61}
62
63impl AsyncWrite for Transport {
64    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
65        match self.get_mut() {
66            Transport::Tcp(tcp) => Pin::new(tcp).poll_write(cx, buf),
67            Transport::Tls(tls) => Pin::new(tls).poll_write(cx, buf),
68        }
69    }
70
71    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72        match self.get_mut() {
73            Transport::Tcp(tcp) => Pin::new(tcp).poll_flush(cx),
74            Transport::Tls(tls) => Pin::new(tls).poll_flush(cx),
75        }
76    }
77
78    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
79        match self.get_mut() {
80            Transport::Tcp(tcp) => Pin::new(tcp).poll_close(cx),
81            Transport::Tls(tls) => Pin::new(tls).poll_close(cx),
82        }
83    }
84}
85
86use thiserror::Error;
87
88#[derive(Error, Debug, Clone)]
89pub enum TransportError {
90    #[error("invalid host name: {0:?}")]
91    InvalidDnsName(Arc<InvalidDnsNameError>),
92    #[error("tcp connect error: {0:?}")]
93    TcpConnect(Arc<io::Error>),
94    #[error("tls connect error: {0:?}")]
95    TlsConnect(Arc<io::Error>),
96}
97
98#[cfg(any(feature = "ring", feature = "aws-lc-rs"))]
99lazy_static::lazy_static! {
100    pub (crate) static ref DEFAULT_CLIENT_CONFIG: Arc<ClientConfig> = {
101        let roots = webpki_roots::TLS_SERVER_ROOTS
102        .iter()
103        .map(|t| {
104            let t = t.to_owned();
105            rustls_pki_types::TrustAnchor {
106                subject: t.subject.into(),
107                subject_public_key_info: t.subject_public_key_info.into(),
108                name_constraints: t.name_constraints.map(Into::into),
109            }
110        });
111        let mut root_store = futures_rustls::rustls::RootCertStore::empty();
112        root_store.extend(roots);
113        #[cfg(all(feature = "ring", not(feature = "aws-lc-rs")))]
114        let provider = futures_rustls::rustls::crypto::ring::default_provider();
115        #[cfg(feature = "aws-lc-rs")]
116        let provider = futures_rustls::rustls::crypto::aws_lc_rs::default_provider();
117
118        let mut config = ClientConfig::builder_with_provider(Arc::new(provider))
119            .with_safe_default_protocol_versions()
120            .expect("could not enable default TLS versions")
121            .with_root_certificates(root_store)
122            .with_no_client_auth();
123        config.alpn_protocols.push(b"http/1.1".to_vec());
124        Arc::new(config)
125    };
126}