use std::collections::HashMap;
use std::error::Error as StdError;
use std::fmt;
#[cfg(feature = "rustls")]
use std::fs;
#[cfg(feature = "rustls")]
use std::path::Path;
use std::time::Duration;
use bytes::Bytes;
use http_body_util::BodyExt;
use hyper::{body::Incoming, Method, Request, Response, StatusCode, Uri};
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
#[derive(Default, Clone)]
pub struct RequestOptions {
pub headers: Option<HashMap<String, String>>,
pub timeout: Option<Duration>,
}
impl RequestOptions {
pub fn new() -> Self {
RequestOptions::default()
}
pub fn with_headers(mut self, headers: HashMap<String, String>) -> Self {
self.headers = Some(headers);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
#[derive(Debug, Clone)]
pub struct HttpResponse {
pub status: StatusCode,
pub headers: HashMap<String, String>,
pub body: Bytes,
}
#[derive(Debug)]
pub enum ClientError {
WasmNotImplemented,
HttpError(hyper::Error),
HttpBuildError(hyper::http::Error),
HttpClientError(hyper_util::client::legacy::Error),
InvalidUri(hyper::http::uri::InvalidUri),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
TlsError(String),
IoError(std::io::Error),
}
impl fmt::Display for ClientError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ClientError::WasmNotImplemented => write!(
f,
"Not implemented on WebAssembly (browser restricts programmatic CA trust)"
),
ClientError::HttpError(err) => write!(f, "HTTP error: {}", err),
ClientError::HttpBuildError(err) => write!(f, "HTTP build error: {}", err),
ClientError::HttpClientError(err) => write!(f, "HTTP client error: {}", err),
ClientError::InvalidUri(err) => write!(f, "Invalid URI: {}", err),
#[cfg(any(feature = "native-tls", feature = "rustls"))]
ClientError::TlsError(err) => write!(f, "TLS error: {}", err),
ClientError::IoError(err) => write!(f, "IO error: {}", err),
}
}
}
impl StdError for ClientError {}
impl From<hyper::Error> for ClientError {
fn from(err: hyper::Error) -> Self {
ClientError::HttpError(err)
}
}
impl From<hyper::http::uri::InvalidUri> for ClientError {
fn from(err: hyper::http::uri::InvalidUri) -> Self {
ClientError::InvalidUri(err)
}
}
impl From<std::io::Error> for ClientError {
fn from(err: std::io::Error) -> Self {
ClientError::IoError(err)
}
}
impl From<hyper::http::Error> for ClientError {
fn from(err: hyper::http::Error) -> Self {
ClientError::HttpBuildError(err)
}
}
impl From<hyper_util::client::legacy::Error> for ClientError {
fn from(err: hyper_util::client::legacy::Error) -> Self {
ClientError::HttpClientError(err)
}
}
pub struct HttpClient {
timeout: Duration,
default_headers: HashMap<String, String>,
#[cfg(feature = "insecure-dangerous")]
accept_invalid_certs: bool,
root_ca_pem: Option<Vec<u8>>,
#[cfg(feature = "rustls")]
pinned_cert_sha256: Option<Vec<[u8; 32]>>,
}
impl HttpClient {
pub fn new() -> Self {
HttpClientBuilder::new().build()
}
pub fn builder() -> HttpClientBuilder {
HttpClientBuilder::new()
}
#[cfg(feature = "insecure-dangerous")]
pub fn with_self_signed_certs() -> Self {
HttpClient::builder()
.insecure_accept_invalid_certs(true)
.build()
}
}
#[cfg(not(target_arch = "wasm32"))]
impl HttpClient {
#[deprecated(since = "0.4.0", note = "Use request(url, Some(options)) instead")]
pub async fn request(&self, url: &str) -> Result<HttpResponse, ClientError> {
self.request_with_options(url, None).await
}
pub async fn request_with_options(
&self,
url: &str,
options: Option<RequestOptions>,
) -> Result<HttpResponse, ClientError> {
let uri: Uri = url.parse()?;
let req = Request::builder().method(Method::GET).uri(uri);
let mut req = req;
for (key, value) in &self.default_headers {
req = req.header(key, value);
}
if let Some(options) = &options {
if let Some(headers) = &options.headers {
for (key, value) in headers {
req = req.header(key, value);
}
}
}
let req = req.body(http_body_util::Empty::<Bytes>::new())?;
if let Some(opts) = &options {
if let Some(timeout) = opts.timeout {
let client_copy = HttpClient {
timeout,
default_headers: self.default_headers.clone(),
#[cfg(feature = "insecure-dangerous")]
accept_invalid_certs: self.accept_invalid_certs,
root_ca_pem: self.root_ca_pem.clone(),
#[cfg(feature = "rustls")]
pinned_cert_sha256: self.pinned_cert_sha256.clone(),
};
client_copy.perform_request(req).await
} else {
self.perform_request(req).await
}
} else {
self.perform_request(req).await
}
}
#[deprecated(
since = "0.4.0",
note = "Use post_with_options(url, body, Some(options)) instead"
)]
pub async fn post<B: AsRef<[u8]>>(
&self,
url: &str,
body: B,
) -> Result<HttpResponse, ClientError> {
self.post_with_options(url, body, None).await
}
pub async fn post_with_options<B: AsRef<[u8]>>(
&self,
url: &str,
body: B,
options: Option<RequestOptions>,
) -> Result<HttpResponse, ClientError> {
let uri: Uri = url.parse()?;
let req = Request::builder().method(Method::POST).uri(uri);
let mut req = req;
for (key, value) in &self.default_headers {
req = req.header(key, value);
}
if let Some(options) = &options {
if let Some(headers) = &options.headers {
for (key, value) in headers {
req = req.header(key, value);
}
}
}
let body_bytes = Bytes::copy_from_slice(body.as_ref());
let req = req.body(http_body_util::Full::new(body_bytes))?;
if let Some(opts) = &options {
if let Some(timeout) = opts.timeout {
let client_copy = HttpClient {
timeout,
default_headers: self.default_headers.clone(),
#[cfg(feature = "insecure-dangerous")]
accept_invalid_certs: self.accept_invalid_certs,
root_ca_pem: self.root_ca_pem.clone(),
#[cfg(feature = "rustls")]
pinned_cert_sha256: self.pinned_cert_sha256.clone(),
};
client_copy.perform_request(req).await
} else {
self.perform_request(req).await
}
} else {
self.perform_request(req).await
}
}
async fn perform_request<B>(&self, req: Request<B>) -> Result<HttpResponse, ClientError>
where
B: hyper::body::Body + Send + 'static + Unpin,
B::Data: Send,
B::Error: Into<Box<dyn std::error::Error + Send + Sync>>,
{
#[cfg(feature = "native-tls")]
{
#[cfg(feature = "insecure-dangerous")]
if self.accept_invalid_certs {
let mut http_connector = hyper_util::client::legacy::connect::HttpConnector::new();
http_connector.enforce_http(false);
let mut tls_builder = native_tls::TlsConnector::builder();
tls_builder.danger_accept_invalid_certs(true);
let tls_connector = tls_builder.build().map_err(|e| {
ClientError::TlsError(format!("Failed to build TLS connector: {}", e))
})?;
let tokio_connector = tokio_native_tls::TlsConnector::from(tls_connector);
let connector = hyper_tls::HttpsConnector::from((http_connector, tokio_connector));
let client = Client::builder(TokioExecutor::new()).build(connector);
let resp = tokio::time::timeout(self.timeout, client.request(req))
.await
.map_err(|_| ClientError::TlsError("Request timed out".to_string()))??;
return self.build_response(resp).await;
}
let connector = hyper_tls::HttpsConnector::new();
let client = Client::builder(TokioExecutor::new()).build(connector);
let resp = tokio::time::timeout(self.timeout, client.request(req))
.await
.map_err(|_| ClientError::TlsError("Request timed out".to_string()))??;
self.build_response(resp).await
}
#[cfg(all(feature = "rustls", not(feature = "native-tls")))]
{
let mut root_cert_store = rustls::RootCertStore::empty();
let native_certs = rustls_native_certs::load_native_certs();
for cert in &native_certs.certs {
if let Err(e) = root_cert_store.add(cert.clone()) {
return Err(ClientError::TlsError(format!(
"Failed to add native cert to root store: {}",
e
)));
}
}
if let Some(ref pem_bytes) = self.root_ca_pem {
let mut reader = std::io::Cursor::new(pem_bytes);
for cert_result in rustls_pemfile::certs(&mut reader) {
match cert_result {
Ok(cert) => {
root_cert_store.add(cert).map_err(|e| {
ClientError::TlsError(format!(
"Failed to add custom cert to root store: {}",
e
))
})?;
}
Err(e) => {
return Err(ClientError::TlsError(format!(
"Failed to parse PEM cert: {}",
e
)));
}
}
}
}
let mut config_builder =
rustls::ClientConfig::builder().with_root_certificates(root_cert_store);
let rustls_config = config_builder.with_no_client_auth();
#[cfg(feature = "insecure-dangerous")]
let rustls_config = if self.accept_invalid_certs {
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified};
use rustls::pki_types::UnixTime;
use rustls::DigitallySignedStruct;
use rustls::SignatureScheme;
use std::sync::Arc;
#[derive(Debug)]
struct NoCertificateVerification {}
impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification {
fn verify_server_cert(
&self,
_end_entity: &rustls::pki_types::CertificateDer<'_>,
_intermediates: &[rustls::pki_types::CertificateDer<'_>],
_server_name: &rustls::pki_types::ServerName<'_>,
_ocsp_response: &[u8],
_now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
Ok(ServerCertVerified::assertion())
}
fn verify_tls12_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn verify_tls13_signature(
&self,
_message: &[u8],
_cert: &rustls::pki_types::CertificateDer<'_>,
_dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
Ok(HandshakeSignatureValid::assertion())
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
vec![
SignatureScheme::RSA_PKCS1_SHA1,
SignatureScheme::ECDSA_SHA1_Legacy,
SignatureScheme::RSA_PKCS1_SHA256,
SignatureScheme::ECDSA_NISTP256_SHA256,
SignatureScheme::RSA_PKCS1_SHA384,
SignatureScheme::ECDSA_NISTP384_SHA384,
SignatureScheme::RSA_PKCS1_SHA512,
SignatureScheme::ECDSA_NISTP521_SHA512,
SignatureScheme::RSA_PSS_SHA256,
SignatureScheme::RSA_PSS_SHA384,
SignatureScheme::RSA_PSS_SHA512,
SignatureScheme::ED25519,
SignatureScheme::ED448,
]
}
}
let mut config = rustls_config.clone();
config
.dangerous()
.set_certificate_verifier(Arc::new(NoCertificateVerification {}));
config
} else {
rustls_config
};
#[cfg(feature = "rustls")]
let rustls_config = if let Some(ref pins) = self.pinned_cert_sha256 {
use rustls::client::danger::{
HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier,
};
use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
use rustls::DigitallySignedStruct;
use rustls::SignatureScheme;
use std::sync::Arc;
struct CertificatePinner {
pins: Vec<[u8; 32]>,
inner: Arc<dyn ServerCertVerifier>,
}
impl ServerCertVerifier for CertificatePinner {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> Result<ServerCertVerified, rustls::Error> {
self.inner.verify_server_cert(
end_entity,
intermediates,
server_name,
ocsp_response,
now,
)?;
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(end_entity.as_ref());
let cert_hash = hasher.finalize();
for pin in &self.pins {
if pin[..] == cert_hash[..] {
return Ok(ServerCertVerified::assertion());
}
}
Err(rustls::Error::General(
"Certificate pin verification failed".into(),
))
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
let mut config = rustls_config.clone();
let default_verifier = rustls::client::WebPkiServerVerifier::builder()
.with_root_certificates(root_cert_store.clone())
.build()
.map_err(|e| {
ClientError::TlsError(format!(
"Failed to build certificate verifier: {}",
e
))
})?;
let cert_pinner = Arc::new(CertificatePinner {
pins: pins.clone(),
inner: default_verifier,
});
config.dangerous().set_certificate_verifier(cert_pinner);
config
} else {
rustls_config
};
let mut http_connector = hyper_util::client::legacy::connect::HttpConnector::new();
http_connector.enforce_http(false);
let https_connector = hyper_rustls::HttpsConnectorBuilder::new()
.with_tls_config(rustls_config)
.https_or_http()
.enable_http1()
.build();
let client = Client::builder(TokioExecutor::new()).build(https_connector);
let resp = tokio::time::timeout(self.timeout, client.request(req))
.await
.map_err(|_| ClientError::TlsError("Request timed out".to_string()))??;
self.build_response(resp).await
}
#[cfg(not(any(feature = "native-tls", feature = "rustls")))]
{
let connector = hyper_util::client::legacy::connect::HttpConnector::new();
let client = Client::builder(TokioExecutor::new()).build(connector);
let resp = tokio::time::timeout(self.timeout, client.request(req))
.await
.map_err(|_| ClientError::TlsError("Request timed out".to_string()))??;
self.build_response(resp).await
}
}
async fn build_response(&self, resp: Response<Incoming>) -> Result<HttpResponse, ClientError> {
let status = resp.status();
let mut headers = HashMap::new();
for (name, value) in resp.headers() {
if let Ok(value_str) = value.to_str() {
headers.insert(name.to_string(), value_str.to_string());
}
}
let body_bytes = resp.into_body().collect().await?.to_bytes();
Ok(HttpResponse {
status,
headers,
body: body_bytes,
})
}
}
#[cfg(target_arch = "wasm32")]
impl HttpClient {
#[deprecated(
since = "0.4.0",
note = "Use request_with_options(url, Some(options)) instead"
)]
pub fn request(&self, _url: &str) -> Result<(), ClientError> {
Err(ClientError::WasmNotImplemented)
}
pub fn request_with_options(
&self,
_url: &str,
_options: Option<RequestOptions>,
) -> Result<(), ClientError> {
Err(ClientError::WasmNotImplemented)
}
#[deprecated(
since = "0.4.0",
note = "Use post_with_options(url, body, Some(options)) instead"
)]
pub fn post<B: AsRef<[u8]>>(&self, _url: &str, _body: B) -> Result<(), ClientError> {
Err(ClientError::WasmNotImplemented)
}
pub fn post_with_options<B: AsRef<[u8]>>(
&self,
_url: &str,
_body: B,
_options: Option<RequestOptions>,
) -> Result<(), ClientError> {
Err(ClientError::WasmNotImplemented)
}
}
pub struct HttpClientBuilder {
timeout: Duration,
default_headers: HashMap<String, String>,
#[cfg(feature = "insecure-dangerous")]
accept_invalid_certs: bool,
root_ca_pem: Option<Vec<u8>>,
#[cfg(feature = "rustls")]
pinned_cert_sha256: Option<Vec<[u8; 32]>>,
}
impl HttpClientBuilder {
pub fn new() -> Self {
Self {
timeout: Duration::from_secs(30),
default_headers: HashMap::new(),
#[cfg(feature = "insecure-dangerous")]
accept_invalid_certs: false,
root_ca_pem: None,
#[cfg(feature = "rustls")]
pinned_cert_sha256: None,
}
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_default_headers(mut self, headers: HashMap<String, String>) -> Self {
self.default_headers = headers;
self
}
#[cfg(feature = "insecure-dangerous")]
pub fn insecure_accept_invalid_certs(mut self, accept: bool) -> Self {
self.accept_invalid_certs = accept;
self
}
#[cfg(feature = "rustls")]
pub fn with_root_ca_pem(mut self, pem_bytes: &[u8]) -> Self {
self.root_ca_pem = Some(pem_bytes.to_vec());
self
}
#[cfg(feature = "rustls")]
pub fn with_root_ca_file<P: AsRef<Path>>(mut self, path: P) -> Self {
let pem_bytes = fs::read(path.as_ref()).unwrap_or_else(|e| {
panic!(
"Failed to read CA certificate file '{}': {}",
path.as_ref().display(),
e
)
});
self.root_ca_pem = Some(pem_bytes);
self
}
#[cfg(feature = "rustls")]
pub fn with_pinned_cert_sha256(mut self, pins: Vec<[u8; 32]>) -> Self {
self.pinned_cert_sha256 = Some(pins);
self
}
pub fn build(self) -> HttpClient {
HttpClient {
timeout: self.timeout,
default_headers: self.default_headers,
#[cfg(feature = "insecure-dangerous")]
accept_invalid_certs: self.accept_invalid_certs,
root_ca_pem: self.root_ca_pem,
#[cfg(feature = "rustls")]
pinned_cert_sha256: self.pinned_cert_sha256,
}
}
}
impl Default for HttpClient {
fn default() -> Self {
Self::new()
}
}
impl Default for HttpClientBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn builder_default_builds() {
let _client = HttpClient::builder().build();
}
#[test]
fn builder_allows_timeout_and_headers() {
let mut headers = HashMap::new();
headers.insert("x-test".into(), "1".into());
let builder = HttpClient::builder()
.with_timeout(Duration::from_secs(5))
.with_default_headers(headers);
#[cfg(feature = "rustls")]
let builder = builder.with_root_ca_pem(b"-----BEGIN CERTIFICATE-----\n...");
let _client = builder.build();
}
#[cfg(feature = "insecure-dangerous")]
#[test]
fn builder_allows_insecure_when_feature_enabled() {
let _client = HttpClient::builder()
.insecure_accept_invalid_certs(true)
.build();
let _client2 = HttpClient::with_self_signed_certs();
}
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn request_returns_ok_on_native() {
let client = HttpClient::builder().build();
let _client = client; }
#[cfg(not(target_arch = "wasm32"))]
#[tokio::test]
async fn post_returns_ok_on_native() {
let client = HttpClient::builder().build();
let _client = client; }
#[cfg(all(feature = "rustls", not(target_arch = "wasm32")))]
#[test]
fn builder_allows_root_ca_file() {
use std::fs;
use std::io::Write;
let temp_dir = std::env::temp_dir();
let cert_file = temp_dir.join("test-ca.pem");
let test_cert = b"-----BEGIN CERTIFICATE-----
MIICxjCCAa4CAQAwDQYJKoZIhvcNAQELBQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAe
Fw0yNTA4MTQwMDAwMDBaFw0yNjA4MTQwMDAwMDBaMBIxEDAOBgNVBAMMB1Rlc3Qg
Q0EwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDTest...
-----END CERTIFICATE-----";
{
let mut file = fs::File::create(&cert_file).expect("Failed to create temp cert file");
file.write_all(test_cert)
.expect("Failed to write cert to temp file");
}
let client = HttpClient::builder().with_root_ca_file(&cert_file).build();
assert!(client.root_ca_pem.is_some());
assert_eq!(client.root_ca_pem.as_ref().unwrap(), test_cert);
let _ = fs::remove_file(cert_file);
}
}