use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use http::{HeaderMap, HeaderName, HeaderValue, Request, Response};
use tower::Layer;
use tower::util::BoxCloneService;
use super::auth::{AuthLayer, TokenProvider};
use super::circuit_breaker::{CircuitBreaker, CircuitBreakerLayer};
#[cfg(feature = "client-connectrpc")]
use super::connect::ConnectTransport;
use super::graphql::GraphqlClient;
use super::rest::{ClientError, RestClient};
use super::retry::{RetryLayer, RetryPolicy};
use super::transport::ReqwestService;
use crate::middleware::cache::CacheConfig;
#[cfg(feature = "cache")]
use crate::client::cache::ClientCacheLayer;
use crate::BoxError;
#[derive(Debug, thiserror::Error)]
pub enum ClientBuilderError {
#[error("invalid URL: {0}")]
InvalidUrl(String),
#[error("invalid header name: {0}")]
InvalidHeaderName(String),
#[error("invalid header value: {0}")]
InvalidHeaderValue(String),
#[error("failed to build HTTP client: {0}")]
Build(#[from] reqwest::Error),
#[error(transparent)]
Other(#[from] BoxError),
}
#[derive(Clone, Default)]
pub struct ClientBuilder {
base_url: String,
timeout: Option<Duration>,
connect_timeout: Option<Duration>,
retry: Option<RetryPolicy>,
circuit_breaker: Option<Arc<CircuitBreaker>>,
cache: Option<CacheConfig>,
default_headers: HeaderMap,
auth: Option<Arc<dyn TokenProvider>>,
root_certs: Vec<reqwest::Certificate>,
identity: Option<reqwest::Identity>,
danger_accept_invalid_certs: bool,
}
impl std::fmt::Debug for ClientBuilder {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientBuilder")
.field("base_url", &self.base_url)
.field("timeout", &self.timeout)
.field("connect_timeout", &self.connect_timeout)
.field("retry", &self.retry)
.field("circuit_breaker", &self.circuit_breaker)
.field("cache", &self.cache)
.field("default_headers", &self.default_headers)
.field("auth", &self.auth.as_ref().map(|_| "TokenProvider"))
.field("root_certs", &self.root_certs.len())
.field("identity", &self.identity.as_ref().map(|_| "Identity"))
.field("danger_accept_invalid_certs", &self.danger_accept_invalid_certs)
.finish()
}
}
impl ClientBuilder {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
base_url: base_url.into(),
..Default::default()
}
}
#[must_use]
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into();
self
}
#[must_use]
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn connect_timeout(mut self, connect_timeout: Duration) -> Self {
self.connect_timeout = Some(connect_timeout);
self
}
#[must_use]
pub fn retry(mut self, policy: RetryPolicy) -> Self {
self.retry = Some(policy);
self
}
#[must_use]
pub fn circuit_breaker(mut self, breaker: Arc<CircuitBreaker>) -> Self {
self.circuit_breaker = Some(breaker);
self
}
#[must_use]
pub fn cache(mut self, config: CacheConfig) -> Self {
self.cache = Some(config);
self
}
#[must_use]
pub fn no_cache(mut self) -> Self {
self.cache = None;
self
}
pub fn default_header<N, V>(mut self, name: N, value: V) -> Result<Self, ClientBuilderError>
where
N: TryInto<HeaderName>,
N::Error: std::fmt::Display,
V: TryInto<HeaderValue>,
V::Error: std::fmt::Display,
{
let name = name
.try_into()
.map_err(|e| ClientBuilderError::InvalidHeaderName(e.to_string()))?;
let value = value
.try_into()
.map_err(|e| ClientBuilderError::InvalidHeaderValue(e.to_string()))?;
self.default_headers.append(name, value);
Ok(self)
}
#[must_use]
pub fn auth<P>(mut self, provider: P) -> Self
where
P: TokenProvider,
{
self.auth = Some(Arc::new(provider));
self
}
#[must_use]
pub fn add_root_certificate(mut self, cert: reqwest::Certificate) -> Self {
self.root_certs.push(cert);
self
}
#[must_use]
pub fn identity(mut self, identity: reqwest::Identity) -> Self {
self.identity = Some(identity);
self
}
#[must_use]
pub fn danger_accept_invalid_certs(mut self) -> Self {
self.danger_accept_invalid_certs = true;
self
}
pub fn build(self) -> Result<Client, ClientBuilderError> {
let base_url =
reqwest::Url::parse(&self.base_url).map_err(|e| ClientBuilderError::InvalidUrl(e.to_string()))?;
let mut reqwest_builder = reqwest::Client::builder();
if let Some(timeout) = self.timeout {
reqwest_builder = reqwest_builder.timeout(timeout);
}
if let Some(connect_timeout) = self.connect_timeout {
reqwest_builder = reqwest_builder.connect_timeout(connect_timeout);
}
for cert in self.root_certs {
reqwest_builder = reqwest_builder.add_root_certificate(cert);
}
if let Some(identity) = self.identity {
reqwest_builder = reqwest_builder.identity(identity);
}
if self.danger_accept_invalid_certs {
reqwest_builder = reqwest_builder.danger_accept_invalid_certs(true);
}
let reqwest_client = reqwest_builder.build()?;
let service = BoxCloneService::new(ReqwestService::new(reqwest_client));
let service = if let Some(provider) = self.auth {
BoxCloneService::new(AuthLayer::new(provider).layer(service))
} else {
service
};
#[cfg(feature = "cache")]
let service = if let Some(config) = self.cache {
BoxCloneService::new(ClientCacheLayer::new(config).layer(service))
} else {
service
};
let service = if let Some(policy) = self.retry {
BoxCloneService::new(RetryLayer::new(policy).layer(service))
} else {
service
};
let service = if let Some(breaker) = self.circuit_breaker {
BoxCloneService::new(CircuitBreakerLayer::new(breaker).layer(service))
} else {
service
};
let boxed = service;
Ok(Client {
service: std::sync::Arc::new(std::sync::Mutex::new(boxed)),
base_url,
default_headers: self.default_headers,
})
}
}
pub type BoxedClientService = BoxCloneService<Request<Bytes>, Response<Bytes>, BoxError>;
#[derive(Clone)]
pub struct Client {
service: std::sync::Arc<std::sync::Mutex<BoxedClientService>>,
base_url: reqwest::Url,
default_headers: HeaderMap,
}
impl std::fmt::Debug for Client {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Client")
.field("base_url", &self.base_url)
.finish()
}
}
impl Client {
pub fn rest(&self) -> RestClient {
RestClient::new(self.clone(), self.base_url.clone(), self.default_headers.clone())
}
pub fn graphql(&self) -> GraphqlClient {
GraphqlClient::new(self.rest())
}
#[cfg(feature = "client-connectrpc")]
pub fn connectrpc(&self, base_uri: http::Uri) -> ConnectTransport {
ConnectTransport::new(self.clone(), base_uri)
}
pub async fn execute(&self, req: Request<Bytes>) -> Result<Response<Bytes>, ClientError> {
use tower::ServiceExt;
let service = self
.service
.lock()
.expect("client service lock poisoned")
.clone();
service.oneshot(req).await.map_err(ClientError::Transport)
}
pub fn base_url(&self) -> &reqwest::Url {
&self.base_url
}
pub fn default_headers(&self) -> &HeaderMap {
&self.default_headers
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_builder_defaults() {
let client = ClientBuilder::new("http://example.com").build().unwrap();
assert_eq!(client.base_url().as_str(), "http://example.com/");
}
#[test]
fn test_client_builder_invalid_url() {
let err = ClientBuilder::new("://not-a-url").build().unwrap_err();
assert!(matches!(err, ClientBuilderError::InvalidUrl(_)));
}
#[test]
fn test_client_builder_default_header() {
let client = ClientBuilder::new("http://example.com")
.default_header("x-custom", "value")
.unwrap()
.build()
.unwrap();
assert_eq!(
client.default_headers().get("x-custom").unwrap(),
"value"
);
}
#[test]
fn test_client_builder_invalid_header_name() {
let err = ClientBuilder::new("http://example.com")
.default_header("\0", "value")
.unwrap_err();
assert!(matches!(err, ClientBuilderError::InvalidHeaderName(_)));
}
#[test]
fn test_client_is_clone_and_sync() {
fn assert_sync<T: Sync + Clone>(_t: T) {}
let client = ClientBuilder::new("http://example.com").build().unwrap();
assert_sync(client);
}
#[test]
fn test_client_builder_debug() {
let builder = ClientBuilder::new("http://example.com");
let debug = format!("{builder:?}");
assert!(debug.contains("ClientBuilder"));
assert!(debug.contains("http://example.com"));
}
#[test]
fn test_client_builder_full_chain() {
use std::num::NonZeroUsize;
use std::sync::Arc;
use std::time::Duration;
use crate::client::cache::CacheConfig;
use crate::client::circuit_breaker::{CircuitBreaker, CircuitBreakerConfig};
use crate::client::retry::RetryPolicy;
let cb = Arc::new(CircuitBreaker::new(CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 1,
open_duration: Duration::from_millis(100),
half_open_max_calls: 1,
}));
let cache = CacheConfig {
capacity: NonZeroUsize::new(10).unwrap(),
ttl: Some(Duration::from_secs(10)),
max_body_size: 1024,
};
let retry = RetryPolicy {
max_attempts: 2,
initial_backoff: Duration::from_millis(10),
max_backoff: Duration::from_millis(100),
jitter: false,
};
let client = ClientBuilder::new("http://example.com")
.base_url("http://example.com/api")
.timeout(Duration::from_secs(5))
.connect_timeout(Duration::from_secs(1))
.retry(retry)
.circuit_breaker(Arc::clone(&cb))
.cache(cache)
.no_cache()
.default_header("x-foo", "bar")
.unwrap()
.auth(crate::client::BearerToken::new("tok"))
.danger_accept_invalid_certs()
.build()
.unwrap();
assert_eq!(client.base_url().as_str(), "http://example.com/api");
assert_eq!(
client.default_headers().get("x-foo").and_then(|v| v.to_str().ok()),
Some("bar")
);
let debug = format!("{client:?}");
assert!(debug.starts_with("Client { base_url: Url {"));
assert!(debug.contains("example.com"));
}
#[test]
fn test_client_builder_invalid_header_value() {
let err = ClientBuilder::new("http://example.com")
.default_header("x-custom", "\0")
.unwrap_err();
assert!(matches!(err, ClientBuilderError::InvalidHeaderValue(_)));
}
#[tokio::test]
async fn test_client_execute_via_mock_service() {
use bytes::Bytes;
use http::{Request, Response};
let service = tower::service_fn(|req: Request<Bytes>| async move {
assert_eq!(req.uri().path(), "/hello");
Ok::<_, crate::BoxError>(Response::new(Bytes::from_static(b"world")))
});
let client = Client::from_service(
super::BoxedClientService::new(service),
reqwest::Url::parse("http://example.com").unwrap(),
http::HeaderMap::new(),
);
let req = Request::get("http://example.com/hello")
.body(Bytes::new())
.unwrap();
let resp = client.execute(req).await.unwrap();
assert_eq!(resp.body().as_ref(), b"world");
}
}
#[cfg(test)]
impl Client {
pub fn from_service(
service: BoxedClientService,
base_url: reqwest::Url,
default_headers: http::HeaderMap,
) -> Self {
Self {
service: std::sync::Arc::new(std::sync::Mutex::new(service)),
base_url,
default_headers,
}
}
}