1mod http;
2pub mod prelude;
3#[cfg(feature = "websocket")]
4mod ws;
5
6use std::{
7 io,
8 net::IpAddr,
9 pin::Pin,
10 sync::Arc,
11 task::{Context, Poll},
12};
13
14pub use crate::http::*;
15use async_net::TcpStream;
16use futures::{AsyncRead, AsyncWrite};
17use futures_rustls::{client::TlsStream, rustls::ClientConfig, TlsConnector};
18use rustls_pki_types::{InvalidDnsNameError, ServerName};
19#[cfg(feature = "websocket")]
20pub use ws::*;
21
22pub enum Transport {
23 Tcp(TcpStream),
24 Tls(TlsStream<TcpStream>),
25}
26
27impl Transport {
28 async fn connect(tls: Option<Arc<ClientConfig>>, host: &str, port: u16) -> Result<Self, TransportError> {
29 let server = ServerName::try_from(host)
30 .map_err(|err| TransportError::InvalidDnsName(Arc::new(err)))?
31 .to_owned();
32 let tcp = match &server {
33 ServerName::DnsName(name) => TcpStream::connect((name.as_ref(), port)).await,
34 ServerName::IpAddress(ip) => TcpStream::connect((IpAddr::from(*ip), port)).await,
35 _ => unreachable!(),
36 }
37 .map_err(|err| TransportError::TcpConnect(Arc::new(err)))?;
38 let transport = match tls {
39 None => Transport::Tcp(tcp),
40 Some(client_config) => {
41 let tls = TlsConnector::from(client_config)
42 .connect(server, tcp)
43 .await
44 .map_err(|err| TransportError::TlsConnect(Arc::new(err)))?;
45 Transport::Tls(tls)
46 }
47 };
48 Ok(transport)
49 }
50}
51
52impl Unpin for Transport {}
53
54impl AsyncRead for Transport {
55 fn poll_read(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8]) -> Poll<io::Result<usize>> {
56 match self.get_mut() {
57 Transport::Tcp(tcp) => Pin::new(tcp).poll_read(cx, buf),
58 Transport::Tls(tls) => Pin::new(tls).poll_read(cx, buf),
59 }
60 }
61}
62
63impl AsyncWrite for Transport {
64 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
65 match self.get_mut() {
66 Transport::Tcp(tcp) => Pin::new(tcp).poll_write(cx, buf),
67 Transport::Tls(tls) => Pin::new(tls).poll_write(cx, buf),
68 }
69 }
70
71 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
72 match self.get_mut() {
73 Transport::Tcp(tcp) => Pin::new(tcp).poll_flush(cx),
74 Transport::Tls(tls) => Pin::new(tls).poll_flush(cx),
75 }
76 }
77
78 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
79 match self.get_mut() {
80 Transport::Tcp(tcp) => Pin::new(tcp).poll_close(cx),
81 Transport::Tls(tls) => Pin::new(tls).poll_close(cx),
82 }
83 }
84}
85
86use thiserror::Error;
87
88#[derive(Error, Debug, Clone)]
89pub enum TransportError {
90 #[error("invalid host name: {0:?}")]
91 InvalidDnsName(Arc<InvalidDnsNameError>),
92 #[error("tcp connect error: {0:?}")]
93 TcpConnect(Arc<io::Error>),
94 #[error("tls connect error: {0:?}")]
95 TlsConnect(Arc<io::Error>),
96}
97
98#[cfg(any(feature = "ring", feature = "aws-lc-rs"))]
99lazy_static::lazy_static! {
100 pub (crate) static ref DEFAULT_CLIENT_CONFIG: Arc<ClientConfig> = {
101 let roots = webpki_roots::TLS_SERVER_ROOTS
102 .iter()
103 .map(|t| {
104 let t = t.to_owned();
105 rustls_pki_types::TrustAnchor {
106 subject: t.subject.into(),
107 subject_public_key_info: t.subject_public_key_info.into(),
108 name_constraints: t.name_constraints.map(Into::into),
109 }
110 });
111 let mut root_store = futures_rustls::rustls::RootCertStore::empty();
112 root_store.extend(roots);
113 #[cfg(all(feature = "ring", not(feature = "aws-lc-rs")))]
114 let provider = futures_rustls::rustls::crypto::ring::default_provider();
115 #[cfg(feature = "aws-lc-rs")]
116 let provider = futures_rustls::rustls::crypto::aws_lc_rs::default_provider();
117
118 let mut config = ClientConfig::builder_with_provider(Arc::new(provider))
119 .with_safe_default_protocol_versions()
120 .expect("could not enable default TLS versions")
121 .with_root_certificates(root_store)
122 .with_no_client_auth();
123 config.alpn_protocols.push(b"http/1.1".to_vec());
124 Arc::new(config)
125 };
126}