use std::sync::Arc;
use base64::engine::general_purpose::STANDARD as BASE64_STANDARD;
use base64::Engine as _;
use reqwest::header::HeaderValue;
use crate::error::ClientError;
pub trait TransportConfig: Send + Sync {
fn build_client(&self) -> Result<reqwest::Client, ClientError>;
}
#[derive(Debug, Clone)]
pub struct DefaultTransport;
impl TransportConfig for DefaultTransport {
fn build_client(&self) -> Result<reqwest::Client, ClientError> {
default_reqwest_client()
}
}
#[derive(Debug, Clone)]
pub struct CustomCaTransport {
der_cert: Vec<u8>,
}
impl CustomCaTransport {
pub fn new(der_cert: Vec<u8>) -> Self {
Self { der_cert }
}
}
impl TransportConfig for CustomCaTransport {
fn build_client(&self) -> Result<reqwest::Client, ClientError> {
let cert = reqwest::Certificate::from_der(&self.der_cert)?;
let client = reqwest::ClientBuilder::new()
.connect_timeout(std::time::Duration::from_secs(10))
.add_root_certificate(cert)
.build()?;
Ok(client)
}
}
pub trait AuthProvider: Send + Sync {
fn auth_header(&self) -> Option<(&str, &str)>;
}
#[derive(Debug, Clone)]
pub struct NoneAuth;
impl AuthProvider for NoneAuth {
fn auth_header(&self) -> Option<(&str, &str)> {
None
}
}
#[derive(Clone)]
pub struct BearerAuth {
header_string: String,
}
impl BearerAuth {
pub fn new(token: &str) -> Result<Self, ClientError> {
if token.is_empty() || token.chars().any(|c| c.is_ascii_whitespace()) {
return Err(ClientError::InvalidArgument(
"BearerAuth token may not be empty or contain whitespace (RFC 6750 §2.1)".into(),
));
}
let header_string = format!("Bearer {token}");
HeaderValue::from_str(&header_string)?;
Ok(Self { header_string })
}
}
impl std::fmt::Debug for BearerAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BearerAuth")
.field("token", &"[REDACTED]")
.finish()
}
}
impl AuthProvider for BearerAuth {
fn auth_header(&self) -> Option<(&str, &str)> {
Some(("authorization", &self.header_string))
}
}
#[derive(Clone)]
pub struct BasicAuth {
header_string: String,
}
impl BasicAuth {
pub fn new(username: &str, password: &str) -> Result<Self, ClientError> {
if username.contains(':') {
return Err(ClientError::InvalidArgument(
"BasicAuth username may not contain ':'".into(),
));
}
let encoded = BASE64_STANDARD.encode(format!("{username}:{password}").as_bytes());
let header_string = format!("Basic {encoded}");
HeaderValue::from_str(&header_string)?;
Ok(Self { header_string })
}
}
impl std::fmt::Debug for BasicAuth {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BasicAuth")
.field("credentials", &"[REDACTED]")
.finish()
}
}
impl AuthProvider for BasicAuth {
fn auth_header(&self) -> Option<(&str, &str)> {
Some(("authorization", &self.header_string))
}
}
fn default_reqwest_client() -> Result<reqwest::Client, ClientError> {
reqwest::ClientBuilder::new()
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.map_err(ClientError::Http)
}
impl TransportConfig for Box<dyn TransportConfig> {
fn build_client(&self) -> Result<reqwest::Client, ClientError> {
(**self).build_client()
}
}
impl AuthProvider for Arc<dyn AuthProvider> {
fn auth_header(&self) -> Option<(&str, &str)> {
(**self).auth_header()
}
}
impl AuthProvider for Box<dyn AuthProvider> {
fn auth_header(&self) -> Option<(&str, &str)> {
(**self).auth_header()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn none_auth_no_header() {
assert!(NoneAuth.auth_header().is_none());
}
#[test]
fn bearer_auth_valid_constructs() {
assert!(BearerAuth::new("tok123").is_ok());
}
#[test]
fn bearer_auth_header() {
let auth = BearerAuth::new("tok123").expect("valid ASCII token must construct");
let (name, value) = auth.auth_header().expect("BearerAuth must return a header");
assert_eq!(name, "authorization");
assert_eq!(value, "Bearer tok123");
}
#[test]
fn bearer_auth_invalid_token_rejected() {
let result = BearerAuth::new("tok\x01abc");
assert!(
result.is_err(),
"token with C0 control character must be rejected by constructor"
);
}
#[test]
fn basic_auth_valid_constructs() {
assert!(BasicAuth::new("alice", "s3cr3t").is_ok());
}
#[test]
fn basic_auth_colon_in_username_rejected() {
let result = BasicAuth::new("ali:ce", "s3cr3t");
match result {
Ok(_) => panic!("username with colon must be rejected by constructor"),
Err(e) => {
let err_msg = e.to_string();
assert!(
err_msg.contains("username"),
"error message should mention 'username', got: {err_msg}"
);
}
}
}
#[test]
fn basic_auth_header() {
let auth = BasicAuth::new("alice", "s3cr3t").expect("valid credentials must construct");
let (name, value) = auth.auth_header().expect("BasicAuth must return a header");
assert_eq!(name, "authorization");
assert_eq!(value, "Basic YWxpY2U6czNjcjN0");
}
#[test]
fn custom_ca_transport_no_build_with_empty_cert() {
let transport = CustomCaTransport::new(vec![]);
assert!(transport.build_client().is_err(), "empty DER must fail");
}
#[test]
fn bearer_auth_empty_token_rejected() {
let result = BearerAuth::new("");
match result {
Ok(_) => panic!("empty token must be rejected by constructor"),
Err(ClientError::InvalidArgument(msg)) => {
assert!(
msg.contains("empty"),
"error message should mention 'empty', got: {msg}"
);
}
Err(e) => panic!("expected InvalidArgument, got: {e}"),
}
}
#[test]
fn bearer_auth_whitespace_only_token_rejected() {
let result = BearerAuth::new(" ");
match result {
Ok(_) => panic!("whitespace-only token must be rejected by constructor"),
Err(ClientError::InvalidArgument(msg)) => {
assert!(
msg.contains("whitespace"),
"error message should mention 'whitespace', got: {msg}"
);
}
Err(e) => panic!("expected InvalidArgument, got: {e}"),
}
}
#[tokio::test]
async fn default_transport_builds_client() {
DefaultTransport
.build_client()
.expect("DefaultTransport::build_client must succeed");
}
}