use std::fmt;
use base64::{Engine, engine::general_purpose::STANDARD as base64};
use crate::http::error::HttpError;
use crate::http::header::{self, HeaderName, HeaderValue};
use crate::service::{Identity, Middleware, Service, ServiceFactory, Stack, boxed};
use crate::{SharedCfg, time::Millis};
use super::error::{ClientBuilderError, ClientError};
use super::sender::Sender;
use super::service::{ServiceRequest, ServiceResponse};
use super::{Client, ClientConfig, Connector, cfg::ClientConfigInner};
#[derive(Debug)]
pub struct ClientBuilder<M = Identity> {
middleware: M,
allow_redirects: bool,
max_redirects: usize,
config: ClientConfigInner,
connector: Connector,
}
impl Default for ClientBuilder<Identity> {
fn default() -> Self {
Self::new()
}
}
impl ClientBuilder<Identity> {
#[must_use]
pub fn new() -> Self {
ClientBuilder {
middleware: Identity,
allow_redirects: true,
max_redirects: 10,
config: ClientConfigInner::default(),
connector: Connector::default(),
}
}
}
impl<M> ClientBuilder<M> {
#[must_use]
pub fn connector<T>(mut self, connector: Connector) -> Self {
self.connector = connector;
self
}
#[must_use]
pub fn response_timeout<T: Into<Millis>>(mut self, timeout: T) -> Self {
self.config.timeout = timeout.into();
self
}
#[must_use]
pub fn disable_timeout(mut self) -> Self {
self.config.timeout = Millis::ZERO;
self
}
#[must_use]
pub fn disable_redirects(mut self) -> Self {
self.allow_redirects = false;
self
}
#[must_use]
pub fn max_redirects(mut self, num: usize) -> Self {
self.max_redirects = num;
self
}
#[must_use]
pub fn no_default_headers(mut self) -> Self {
self.config.default_headers = false;
self
}
#[must_use]
pub fn response_payload_limit(mut self, limit: usize) -> Self {
self.config.pl_limit = limit;
self
}
#[must_use]
pub fn response_payload_timeout(mut self, timeout: Millis) -> Self {
self.config.pl_timeout = timeout;
self
}
#[must_use]
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: fmt::Debug + Into<HttpError>,
<HeaderValue as TryFrom<V>>::Error: fmt::Debug + Into<HttpError>,
{
match HeaderName::try_from(key) {
Ok(key) => match HeaderValue::try_from(value) {
Ok(value) => {
self.config.headers.append(key, value);
}
Err(e) => log::error!("Header value error: {e:?}"),
},
Err(e) => log::error!("Header name error: {e:?}"),
}
self
}
#[must_use]
pub fn basic_auth<U>(self, username: U, password: Option<&str>) -> Self
where
U: fmt::Display,
{
let auth = match password {
Some(password) => format!("{username}:{password}"),
None => format!("{username}:"),
};
self.header(
header::AUTHORIZATION,
format!("Basic {}", base64.encode(auth)),
)
}
#[must_use]
pub fn bearer_auth<T>(self, token: T) -> Self
where
T: fmt::Display,
{
self.header(header::AUTHORIZATION, format!("Bearer {token}"))
}
#[must_use]
pub fn middleware<U>(self, mw: U) -> ClientBuilder<Stack<U, M>> {
ClientBuilder {
middleware: Stack::new(mw, self.middleware),
allow_redirects: self.allow_redirects,
max_redirects: self.max_redirects,
config: self.config,
connector: self.connector,
}
}
pub async fn build<T>(mut self, cfg: T) -> Result<Client, ClientBuilderError>
where
T: Into<SharedCfg>,
M: Middleware<Sender, ClientConfig>,
M::Service: Service<ServiceRequest, Response = ServiceResponse, Error = ClientError>
+ 'static,
{
let cfg = cfg.into();
self.config.cfg = cfg.clone();
let config = ClientConfig::new(self.config);
let svc = self
.connector
.create(cfg)
.await
.map_err(|_| ClientBuilderError::ConnectorFailed)?;
let svc = boxed::service(
self.middleware
.create(Sender::new(svc, config.clone()), config.clone()),
);
Ok(Client::with_service(svc.into(), config))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[crate::rt_test]
async fn basics() {
let builder = ClientBuilder::new()
.disable_timeout()
.disable_redirects()
.max_redirects(10)
.no_default_headers();
assert!(!builder.allow_redirects);
assert!(!builder.config.default_headers);
assert_eq!(builder.max_redirects, 10);
}
#[crate::rt_test]
async fn response_payload_limit() {
let builder = ClientBuilder::default();
assert_eq!(builder.config.pl_limit, 262_144);
let builder = builder.response_payload_limit(10);
assert_eq!(builder.config.pl_limit, 10);
}
#[crate::rt_test]
async fn response_payload_timeout() {
let builder = ClientBuilder::default();
assert_eq!(builder.config.pl_timeout, Millis(10_000));
let builder = builder.response_payload_timeout(Millis(10));
assert_eq!(builder.config.pl_timeout, Millis(10));
}
#[crate::rt_test]
async fn valid_header_name() {
let builder = ClientBuilder::new().header("Content-Length", 1);
assert!(builder.config.headers.contains_key("Content-Length"));
}
#[crate::rt_test]
async fn invalid_header_name() {
let builder = ClientBuilder::new().header("no valid header name", 1);
assert!(!builder.config.headers.contains_key("no valid header name"));
}
#[crate::rt_test]
async fn valid_header_value() {
let valid_header_value = HeaderValue::from(1234);
let builder = ClientBuilder::new().header("Content-Length", &valid_header_value);
assert_eq!(
builder.config.headers.get("Content-Length"),
Some(&valid_header_value)
);
}
#[crate::rt_test]
async fn invalid_header_value() {
let builder = ClientBuilder::new().header("Content-Length", "\n");
assert!(!builder.config.headers.contains_key("Content-Length"));
}
#[crate::rt_test]
async fn client_basic_auth() {
let client = ClientBuilder::new().basic_auth("username", Some("password"));
assert_eq!(
client
.config
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Basic dXNlcm5hbWU6cGFzc3dvcmQ="
);
let client = ClientBuilder::new().basic_auth("username", None);
assert_eq!(
client
.config
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Basic dXNlcm5hbWU6"
);
}
#[crate::rt_test]
async fn client_bearer_auth() {
let client = ClientBuilder::new().bearer_auth("someS3cr3tAutht0k3n");
assert_eq!(
client
.config
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Bearer someS3cr3tAutht0k3n"
);
}
}