use std::sync::Arc;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine as _;
use reqwest::header::HeaderValue;
use zeroize::Zeroizing;
use crate::error::ClientError;
#[non_exhaustive]
pub struct HttpClient(reqwest::Client);
impl HttpClient {
pub fn new(client: reqwest::Client) -> Self {
Self(client)
}
pub(crate) fn into_inner(self) -> reqwest::Client {
self.0
}
}
impl std::fmt::Debug for HttpClient {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("HttpClient").finish()
}
}
pub trait TransportConfig: Send + Sync {
fn build_client(&self) -> Result<HttpClient, ClientError>;
}
#[derive(Debug, Clone)]
pub struct DefaultTransport;
impl TransportConfig for DefaultTransport {
fn build_client(&self) -> Result<HttpClient, ClientError> {
default_reqwest_client().map(HttpClient::new)
}
}
#[derive(Clone)]
pub struct CustomCaTransport {
der_cert: Vec<u8>,
}
impl CustomCaTransport {
pub fn new(der_cert: Vec<u8>) -> Self {
Self { der_cert }
}
pub fn from_pem_bytes(pem_bytes: &[u8]) -> Result<Self, ClientError> {
let cert_bytes = parse_first_pem_cert(pem_bytes).ok_or_else(|| {
ClientError::InvalidArgument(
"CustomCaTransport::from_pem_bytes: no PEM-framed certificate found in input"
.into(),
)
})?;
Ok(Self {
der_cert: cert_bytes,
})
}
}
fn parse_first_pem_cert(input: &[u8]) -> Option<Vec<u8>> {
use base64::Engine as _;
let text = std::str::from_utf8(input).ok()?;
let begin_idx = text.find("-----BEGIN ")?;
let after_begin = &text[begin_idx + "-----BEGIN ".len()..];
let begin_eol = after_begin.find('\n')?;
let label = after_begin[..begin_eol].trim().trim_end_matches('-').trim();
let end_marker = format!("-----END {label}-----");
let body_start = begin_idx + "-----BEGIN ".len() + begin_eol + 1;
let end_offset = text[body_start..].find(end_marker.as_str())?;
let body = &text[body_start..body_start + end_offset];
let body_no_ws: String = body.chars().filter(|c| !c.is_whitespace()).collect();
base64::engine::general_purpose::STANDARD
.decode(body_no_ws)
.ok()
}
fn parse_all_pem_certs(input: &[u8]) -> Vec<Vec<u8>> {
use base64::Engine as _;
let Ok(text) = std::str::from_utf8(input) else {
return Vec::new();
};
let mut out = Vec::new();
let mut rest = text;
while let Some(begin_idx) = rest.find("-----BEGIN ") {
let after_begin = &rest[begin_idx + "-----BEGIN ".len()..];
let Some(begin_eol) = after_begin.find('\n') else {
break;
};
let label = after_begin[..begin_eol].trim().trim_end_matches('-').trim();
let end_marker = format!("-----END {label}-----");
let body_start = begin_idx + "-----BEGIN ".len() + begin_eol + 1;
let Some(end_offset) = rest[body_start..].find(end_marker.as_str()) else {
break;
};
let body = &rest[body_start..body_start + end_offset];
let body_no_ws: String = body.chars().filter(|c| !c.is_whitespace()).collect();
if let Ok(der) = base64::engine::general_purpose::STANDARD.decode(body_no_ws) {
out.push(der);
}
let consumed = body_start + end_offset + end_marker.len();
rest = &rest[consumed..];
}
out
}
impl std::fmt::Debug for CustomCaTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CustomCaTransport")
.field("der_cert", &format_args!("<{} bytes>", self.der_cert.len()))
.finish()
}
}
impl TransportConfig for CustomCaTransport {
fn build_client(&self) -> Result<HttpClient, ClientError> {
let cert =
reqwest::Certificate::from_der(&self.der_cert).map_err(ClientError::from_reqwest)?;
let client = reqwest::ClientBuilder::new()
.connect_timeout(std::time::Duration::from_secs(10))
.tls_built_in_root_certs(false)
.add_root_certificate(cert)
.build()
.map_err(ClientError::from_reqwest)?;
Ok(HttpClient::new(client))
}
}
#[derive(Default)]
pub struct CustomTransportBuilder {
roots_der: Vec<Vec<u8>>,
client_identity: Option<(Vec<u8>, Vec<u8>)>,
}
impl CustomTransportBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn add_root_der(mut self, der: Vec<u8>) -> Self {
self.roots_der.push(der);
self
}
pub fn add_root_pem(self, pem: &[u8]) -> Result<Self, ClientError> {
let der = parse_first_pem_cert(pem).ok_or_else(|| {
ClientError::InvalidArgument(
"CustomTransportBuilder::add_root_pem: no PEM-framed certificate found in input"
.into(),
)
})?;
Ok(self.add_root_der(der))
}
pub fn add_roots_pem_bundle(mut self, pem_bundle: &[u8]) -> Result<Self, ClientError> {
let ders = parse_all_pem_certs(pem_bundle);
if ders.is_empty() {
return Err(ClientError::InvalidArgument(
"CustomTransportBuilder::add_roots_pem_bundle: no PEM-framed \
certificates found in bundle"
.into(),
));
}
self.roots_der.extend(ders);
Ok(self)
}
pub fn with_client_cert(mut self, cert_pem: Vec<u8>, key_pem: Vec<u8>) -> Self {
self.client_identity = Some((cert_pem, key_pem));
self
}
pub fn build(self) -> BuilderTransport {
BuilderTransport {
roots_der: self.roots_der,
client_identity: self.client_identity,
}
}
}
#[derive(Clone)]
pub struct BuilderTransport {
roots_der: Vec<Vec<u8>>,
client_identity: Option<(Vec<u8>, Vec<u8>)>,
}
impl std::fmt::Debug for BuilderTransport {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BuilderTransport")
.field(
"roots_der",
&format_args!("<{} root cert(s)>", self.roots_der.len()),
)
.field(
"client_identity",
&format_args!(
"<{}>",
if self.client_identity.is_some() {
"client cert configured"
} else {
"no client cert"
}
),
)
.finish()
}
}
impl TransportConfig for BuilderTransport {
fn build_client(&self) -> Result<HttpClient, ClientError> {
let mut builder = reqwest::ClientBuilder::new()
.connect_timeout(std::time::Duration::from_secs(10))
.tls_built_in_root_certs(false);
for der in &self.roots_der {
let cert = reqwest::Certificate::from_der(der).map_err(ClientError::from_reqwest)?;
builder = builder.add_root_certificate(cert);
}
if let Some((cert_pem, key_pem)) = &self.client_identity {
let mut bundle = Vec::with_capacity(cert_pem.len() + key_pem.len() + 1);
bundle.extend_from_slice(cert_pem);
if !cert_pem.ends_with(b"\n") {
bundle.push(b'\n');
}
bundle.extend_from_slice(key_pem);
let identity =
reqwest::Identity::from_pem(&bundle).map_err(ClientError::from_reqwest)?;
builder = builder.identity(identity);
}
let client = builder.build().map_err(ClientError::from_reqwest)?;
Ok(HttpClient::new(client))
}
}
#[non_exhaustive]
#[derive(Clone, Copy)]
pub struct AuthHeader<'a> {
name: &'a str,
value: &'a str,
}
impl<'a> AuthHeader<'a> {
pub fn new(name: &'a str, value: &'a str) -> Self {
Self { name, value }
}
pub fn name(&self) -> &'a str {
self.name
}
pub fn expose_value(&self) -> &'a str {
self.value
}
}
impl std::fmt::Debug for AuthHeader<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AuthHeader")
.field("name", &self.name)
.field("value", &"[REDACTED]")
.finish()
}
}
pub trait AuthProvider: Send + Sync {
fn auth_header(&self) -> Option<AuthHeader<'_>>;
}
#[derive(Debug, Clone)]
pub struct NoneAuth;
impl AuthProvider for NoneAuth {
fn auth_header(&self) -> Option<AuthHeader<'_>> {
None
}
}
#[derive(Clone)]
pub struct BearerAuth {
header_string: Zeroizing<String>,
}
impl BearerAuth {
pub fn new(token: &str) -> Result<Self, ClientError> {
if token.is_empty() || token.chars().any(|c| c.is_ascii_whitespace()) {
return Err(ClientError::InvalidArgument(
"BearerAuth token may not be empty or contain whitespace (RFC 6750 §2.1)".into(),
));
}
let header_string = Zeroizing::new(format!("Bearer {token}"));
HeaderValue::from_str(&header_string).map_err(ClientError::from_invalid_header)?;
Ok(Self { header_string })
}
}
impl std::fmt::Debug for BearerAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BearerAuth")
.field("token", &"[REDACTED]")
.finish()
}
}
impl AuthProvider for BearerAuth {
fn auth_header(&self) -> Option<AuthHeader<'_>> {
Some(AuthHeader::new("authorization", &self.header_string))
}
}
#[derive(Clone)]
pub struct BasicAuth {
header_string: Zeroizing<String>,
}
impl BasicAuth {
pub fn new(username: &str, password: &str) -> Result<Self, ClientError> {
if username.contains(':') {
return Err(ClientError::InvalidArgument(
"BasicAuth username may not contain ':'".into(),
));
}
let plaintext = Zeroizing::new(format!("{username}:{password}"));
let encoded = BASE64_STANDARD.encode(plaintext.as_bytes());
let header_string = Zeroizing::new(format!("Basic {encoded}"));
HeaderValue::from_str(&header_string).map_err(ClientError::from_invalid_header)?;
Ok(Self { header_string })
}
}
impl std::fmt::Debug for BasicAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BasicAuth")
.field("credentials", &"[REDACTED]")
.finish()
}
}
impl AuthProvider for BasicAuth {
fn auth_header(&self) -> Option<AuthHeader<'_>> {
Some(AuthHeader::new("authorization", &self.header_string))
}
}
fn default_reqwest_client() -> Result<reqwest::Client, ClientError> {
reqwest::ClientBuilder::new()
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.map_err(ClientError::from_reqwest)
}
impl TransportConfig for Box<dyn TransportConfig> {
fn build_client(&self) -> Result<HttpClient, ClientError> {
(**self).build_client()
}
}
impl AuthProvider for Arc<dyn AuthProvider> {
fn auth_header(&self) -> Option<AuthHeader<'_>> {
(**self).auth_header()
}
}
impl AuthProvider for Box<dyn AuthProvider> {
fn auth_header(&self) -> Option<AuthHeader<'_>> {
(**self).auth_header()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn none_auth_no_header() {
assert!(NoneAuth.auth_header().is_none());
}
#[test]
fn bearer_auth_valid_constructs() {
assert!(BearerAuth::new("tok123").is_ok());
}
#[test]
fn bearer_auth_header() {
let auth = BearerAuth::new("tok123").expect("valid ASCII token must construct");
let header = auth.auth_header().expect("BearerAuth must return a header");
assert_eq!(header.name(), "authorization");
assert_eq!(header.expose_value(), "Bearer tok123");
}
#[test]
fn bearer_auth_invalid_token_rejected() {
let result = BearerAuth::new("tok\x01abc");
assert!(
result.is_err(),
"token with C0 control character must be rejected by constructor"
);
}
#[test]
fn basic_auth_valid_constructs() {
assert!(BasicAuth::new("alice", "s3cr3t").is_ok());
}
#[test]
fn basic_auth_colon_in_username_rejected() {
let result = BasicAuth::new("ali:ce", "s3cr3t");
match result {
Ok(_) => panic!("username with colon must be rejected by constructor"),
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("username"),
"error message should mention 'username', got: {err_msg}"
);
}
}
}
#[test]
fn basic_auth_header() {
let auth = BasicAuth::new("alice", "s3cr3t").expect("valid credentials must construct");
let header = auth.auth_header().expect("BasicAuth must return a header");
assert_eq!(header.name(), "authorization");
assert_eq!(header.expose_value(), "Basic YWxpY2U6czNjcjN0");
}
#[test]
fn custom_ca_transport_no_build_with_empty_cert() {
let transport = CustomCaTransport::new(vec![]);
assert!(transport.build_client().is_err(), "empty DER must fail");
}
#[test]
fn parse_all_pem_certs_handles_multi_cert_bundle() {
let single = std::fs::read("tests/fixtures/tls/test-ca.pem")
.expect("test-ca.pem fixture must exist");
let mut bundle = b"# Comment that the parser must ignore\n".to_vec();
bundle.extend_from_slice(&single);
bundle.extend_from_slice(b"\n# Another comment between frames\n");
bundle.extend_from_slice(&single);
bundle.extend_from_slice(b"\n# Trailing comment\n");
let ders = parse_all_pem_certs(&bundle);
assert_eq!(ders.len(), 2, "two-cert bundle must produce two DER blobs");
assert!(!ders[0].is_empty(), "first DER must be non-empty");
assert!(!ders[1].is_empty(), "second DER must be non-empty");
assert_eq!(
ders[0], ders[1],
"duplicate-input bundle must produce identical DER blobs"
);
}
#[test]
fn custom_transport_builder_single_pem_root_builds() {
let pem = std::fs::read("tests/fixtures/tls/test-ca.pem")
.expect("test-ca.pem fixture must exist");
let transport = CustomTransportBuilder::new()
.add_root_pem(&pem)
.expect("PEM fixture must parse")
.build();
transport
.build_client()
.expect("single-root build_client must succeed");
}
#[test]
fn custom_transport_builder_multi_root_bundle_builds() {
let single = std::fs::read("tests/fixtures/tls/test-ca.pem")
.expect("test-ca.pem fixture must exist");
let mut bundle = single.clone();
bundle.extend_from_slice(b"\n");
bundle.extend_from_slice(&single);
let transport = CustomTransportBuilder::new()
.add_roots_pem_bundle(&bundle)
.expect("two-cert PEM bundle must parse")
.build();
transport
.build_client()
.expect("multi-root build_client must succeed");
}
#[test]
fn custom_transport_builder_add_root_pem_invalid_returns_invalid_argument() {
let result = CustomTransportBuilder::new().add_root_pem(b"not a pem");
match result {
Ok(_) => panic!("garbage input must not produce a valid builder"),
Err(ClientError::InvalidArgument(msg)) => {
assert!(
msg.contains("CustomTransportBuilder::add_root_pem"),
"error must identify the offending method: {msg}"
);
}
Err(other) => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn custom_transport_builder_empty_bundle_returns_invalid_argument() {
let result = CustomTransportBuilder::new().add_roots_pem_bundle(b"plain text");
match result {
Ok(_) => panic!("input without PEM frames must not produce a valid builder"),
Err(ClientError::InvalidArgument(msg)) => {
assert!(
msg.contains("CustomTransportBuilder::add_roots_pem_bundle"),
"error must identify the offending method: {msg}"
);
}
Err(other) => panic!("expected InvalidArgument, got {other:?}"),
}
}
#[test]
fn custom_transport_builder_with_client_cert_invalid_fails_at_build() {
let pem = std::fs::read("tests/fixtures/tls/test-ca.pem")
.expect("test-ca.pem fixture must exist");
let transport = CustomTransportBuilder::new()
.add_root_pem(&pem)
.expect("PEM fixture must parse")
.with_client_cert(b"not a cert PEM".to_vec(), b"not a key PEM".to_vec())
.build();
let result = transport.build_client();
assert!(
matches!(result, Err(ClientError::Http(_))),
"invalid client identity must surface as ClientError::Http, got {result:?}"
);
}
#[test]
fn builder_transport_debug_does_not_leak_cert_bytes() {
let canary = vec![0xCA_u8; 32];
let transport = CustomTransportBuilder::new().add_root_der(canary).build();
let dbg = format!("{transport:?}");
assert!(
!dbg.contains("cacacacacacacacacacacacacacacacacacacacacacacacacacacacacacacaca"),
"BuilderTransport Debug must not contain lowercase-hex DER bytes; got: {dbg}"
);
assert!(
dbg.contains("1 root cert"),
"BuilderTransport Debug must surface the root count for diagnostics; got: {dbg}"
);
}
#[test]
fn bearer_auth_empty_token_rejected() {
let result = BearerAuth::new("");
match result {
Ok(_) => panic!("empty token must be rejected by constructor"),
Err(ClientError::InvalidArgument(msg)) => {
assert!(
msg.contains("empty"),
"error message should mention 'empty', got: {msg}"
);
}
Err(e) => panic!("expected InvalidArgument, got: {e}"),
}
}
#[test]
fn bearer_auth_whitespace_only_token_rejected() {
let result = BearerAuth::new(" ");
match result {
Ok(_) => panic!("whitespace-only token must be rejected by constructor"),
Err(ClientError::InvalidArgument(msg)) => {
assert!(
msg.contains("whitespace"),
"error message should mention 'whitespace', got: {msg}"
);
}
Err(e) => panic!("expected InvalidArgument, got: {e}"),
}
}
#[tokio::test]
async fn default_transport_builds_client() {
DefaultTransport
.build_client()
.expect("DefaultTransport::build_client must succeed");
}
#[tokio::test]
async fn build_client_returns_opaque_http_client() {
let result: Result<HttpClient, ClientError> = DefaultTransport.build_client();
let http = result.expect("DefaultTransport::build_client must succeed");
let dbg = format!("{http:?}");
assert_eq!(
dbg, "HttpClient",
"HttpClient Debug must be opaque; the wrapper is the only public surface"
);
}
#[test]
fn http_client_new_is_callable_from_custom_transport_impl() {
struct StubTransport;
impl TransportConfig for StubTransport {
fn build_client(&self) -> Result<HttpClient, ClientError> {
let client = reqwest::ClientBuilder::new()
.build()
.map_err(ClientError::from_reqwest)?;
Ok(HttpClient::new(client))
}
}
StubTransport
.build_client()
.expect("custom transport must build the opaque HttpClient");
}
#[test]
fn auth_header_debug_redacts_value() {
const CANARY: &str = "CANARY-AUTH-VALUE-DO-NOT-LEAK-456";
let header = AuthHeader::new("authorization", CANARY);
let dbg = format!("{header:?}");
assert!(
!dbg.contains(CANARY),
"AuthHeader Debug must not contain the canary value: {dbg}"
);
assert!(
dbg.contains("[REDACTED]"),
"AuthHeader Debug must render '[REDACTED]' for the value field: {dbg}"
);
assert!(
dbg.contains("authorization"),
"AuthHeader Debug should include the header name for diagnostic value: {dbg}"
);
}
#[test]
fn auth_header_expose_value_returns_credential_bytes() {
const VALUE: &str = "Bearer some-token-123";
let header = AuthHeader::new("authorization", VALUE);
assert_eq!(header.name(), "authorization");
assert_eq!(header.expose_value(), VALUE);
}
#[test]
fn bearer_auth_debug_does_not_leak_token() {
const CANARY: &str = "CANARY-TOKEN-DO-NOT-LEAK-123";
let auth = BearerAuth::new(CANARY).expect("valid ASCII token must construct");
let dbg = format!("{auth:?}");
assert!(
!dbg.contains(CANARY),
"BearerAuth Debug must not contain the raw token; got: {dbg}"
);
}
#[test]
fn basic_auth_debug_does_not_leak_credentials() {
const CANARY_USER: &str = "CANARY-USER-DO-NOT-LEAK";
const CANARY_PASS: &str = "CANARY-PASS-DO-NOT-LEAK";
let auth =
BasicAuth::new(CANARY_USER, CANARY_PASS).expect("valid credentials must construct");
let dbg = format!("{auth:?}");
assert!(
!dbg.contains(CANARY_USER),
"BasicAuth Debug must not contain the raw username; got: {dbg}"
);
assert!(
!dbg.contains(CANARY_PASS),
"BasicAuth Debug must not contain the raw password; got: {dbg}"
);
let base64_pair = BASE64_STANDARD.encode(format!("{CANARY_USER}:{CANARY_PASS}"));
assert!(
!dbg.contains(&base64_pair),
"BasicAuth Debug must not contain the base64-encoded credentials; got: {dbg}"
);
}
#[test]
fn custom_ca_transport_debug_does_not_leak_der_bytes() {
let canary_der = vec![0xCA_u8; 32];
let transport = CustomCaTransport::new(canary_der);
let dbg = format!("{transport:?}");
assert!(
!dbg.contains("cacacacacacacacacacacacacacacacacacacacacacacacacacacacacacacaca"),
"CustomCaTransport Debug must not contain lowercase-hex DER bytes; got: {dbg}"
);
assert!(
!dbg.contains("CACACACACACACACACACACACACACACACACACACACACACACACACACACACACACACACA"),
"CustomCaTransport Debug must not contain uppercase-hex DER bytes; got: {dbg}"
);
assert!(
!dbg.contains("202, 202, 202"),
"CustomCaTransport Debug must not contain decimal-byte DER bytes; got: {dbg}"
);
assert!(
dbg.contains("32 bytes"),
"CustomCaTransport Debug should record the DER byte length; got: {dbg}"
);
}
const TEST_CA_PEM: &[u8] = include_bytes!("../tests/fixtures/tls/test-ca.pem");
const TEST_CA_DER: &[u8] = include_bytes!("../tests/fixtures/tls/test-ca.der");
#[test]
fn from_pem_bytes_extracts_der_matching_openssl_oracle() {
let transport = CustomCaTransport::from_pem_bytes(TEST_CA_PEM)
.expect("test-ca.pem fixture must parse as a valid CA");
assert_eq!(
transport.der_cert.as_slice(),
TEST_CA_DER,
"PEM-decoded DER must match the openssl-produced reference DER fixture"
);
}
#[test]
fn from_pem_bytes_rejects_empty_input() {
let err = CustomCaTransport::from_pem_bytes(b"").expect_err("empty input must be rejected");
assert!(
matches!(err, ClientError::InvalidArgument(_)),
"empty input must surface as InvalidArgument; got {err:?}"
);
}
#[test]
fn from_pem_bytes_rejects_input_with_no_pem_framing() {
let err = CustomCaTransport::from_pem_bytes(b"this is not a PEM file")
.expect_err("non-PEM input must be rejected");
assert!(
matches!(err, ClientError::InvalidArgument(_)),
"non-PEM input must surface as InvalidArgument; got {err:?}"
);
}
#[test]
fn from_pem_bytes_rejects_pem_with_invalid_base64() {
let bad =
b"-----BEGIN CERTIFICATE-----\nNOT VALID BASE64 @#$%\n-----END CERTIFICATE-----\n";
let err =
CustomCaTransport::from_pem_bytes(bad).expect_err("invalid base64 must be rejected");
assert!(
matches!(err, ClientError::InvalidArgument(_)),
"invalid-base64 PEM must surface as InvalidArgument; got {err:?}"
);
}
#[test]
fn from_pem_bytes_accepts_garbage_der_payload_deferring_validation_to_build() {
use base64::Engine as _;
let garbage_der = [0u8; 16];
let body = base64::engine::general_purpose::STANDARD.encode(garbage_der);
let pem = format!("-----BEGIN CERTIFICATE-----\n{body}\n-----END CERTIFICATE-----\n");
let transport = CustomCaTransport::from_pem_bytes(pem.as_bytes())
.expect("PEM framing OK + base64 OK = constructor accepts");
assert_eq!(
transport.der_cert.as_slice(),
&garbage_der,
"PEM helper must extract the exact base64-decoded bytes"
);
}
}