use std::sync::Arc;
use std::sync::Once;
use reqwest::Client;
use rustls::client::danger::{
HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier,
};
use rustls::crypto::{ring as rustls_ring, CryptoProvider};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
use super::transport::{Transport, TransportError, TransportRequest, TransportResponse};
const CONNECT_TIMEOUT_SECS: u64 = 10;
#[derive(Clone)]
pub struct HttpTransport {
inner: Arc<HttpTransportInner>,
}
struct HttpTransportInner {
strict_client: Client,
trust_all_client: Client,
}
impl HttpTransport {
pub fn new() -> Result<Self, TransportError> {
install_default_crypto_provider();
let strict_tls = strict_rustls_config();
let trust_all_tls = trust_all_rustls_config();
let strict_client = Client::builder()
.connect_timeout(std::time::Duration::from_secs(CONNECT_TIMEOUT_SECS))
.redirect(reqwest::redirect::Policy::none())
.use_preconfigured_tls(strict_tls)
.build()
.map_err(|e| TransportError::Other(format!("build strict client: {e}")))?;
let trust_all_client = Client::builder()
.connect_timeout(std::time::Duration::from_secs(CONNECT_TIMEOUT_SECS))
.redirect(reqwest::redirect::Policy::none())
.use_preconfigured_tls(trust_all_tls)
.build()
.map_err(|e| TransportError::Other(format!("build trust-all client: {e}")))?;
Ok(Self {
inner: Arc::new(HttpTransportInner {
strict_client,
trust_all_client,
}),
})
}
}
impl Transport for HttpTransport {
async fn send_request(
&self,
request: &TransportRequest,
) -> Result<TransportResponse, TransportError> {
let client = if request.accept_invalid_certs {
&self.inner.trust_all_client
} else {
&self.inner.strict_client
};
let method = reqwest::Method::from_bytes(request.method.as_bytes())
.map_err(|e| TransportError::Other(format!("invalid method: {e}")))?;
let mut builder = client.request(method, &request.url).timeout(request.timeout);
for (k, v) in &request.headers {
builder = builder.header(k, v);
}
if !request.body.is_empty() {
builder = builder.body(request.body.clone());
}
let response = match builder.send().await {
Ok(r) => r,
Err(e) if e.is_timeout() => return Err(TransportError::Timeout),
Err(e) if e.is_connect() => {
return Err(TransportError::ConnectionFailed(format_error_chain(&e)))
}
Err(e) => return Err(TransportError::SendFailed(format_error_chain(&e))),
};
let status_code = response.status().as_u16();
let headers: Vec<(String, String)> = response
.headers()
.iter()
.filter_map(|(k, v)| v.to_str().ok().map(|s| (k.to_string(), s.to_string())))
.collect();
let body = response
.bytes()
.await
.map_err(|e| TransportError::ReceiveFailed(format_error_chain(&e)))?
.to_vec();
Ok(TransportResponse {
status_code,
body,
headers,
})
}
}
fn install_default_crypto_provider() {
static ONCE: Once = Once::new();
ONCE.call_once(|| {
let _ = rustls_ring::default_provider().install_default();
});
}
fn strict_rustls_config() -> ClientConfig {
let mut roots = RootCertStore::empty();
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
ClientConfig::builder_with_provider(crypto_provider())
.with_safe_default_protocol_versions()
.expect("rustls default protocol versions")
.with_root_certificates(roots)
.with_no_client_auth()
}
fn trust_all_rustls_config() -> ClientConfig {
ClientConfig::builder_with_provider(crypto_provider())
.with_safe_default_protocol_versions()
.expect("rustls default protocol versions")
.dangerous()
.with_custom_certificate_verifier(Arc::new(AcceptAnyCertVerifier))
.with_no_client_auth()
}
fn crypto_provider() -> Arc<CryptoProvider> {
Arc::new(rustls_ring::default_provider())
}
fn format_error_chain(e: &dyn std::error::Error) -> String {
let mut out = e.to_string();
let mut cur = e.source();
while let Some(next) = cur {
out.push_str(" | ");
out.push_str(&next.to_string());
cur = next.source();
}
out
}
#[derive(Debug)]
struct AcceptAnyCertVerifier;
impl ServerCertVerifier for AcceptAnyCertVerifier {
fn verify_server_cert(
&self,
_end_entity: &CertificateDer<'_>,
_intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
]
}
}