use std::{fmt, rc::Rc};
use base64::{engine::general_purpose::STANDARD as base64, Engine};
use crate::http::error::HttpError;
use crate::http::header::{self, HeaderMap, HeaderName, HeaderValue};
use crate::{service::Service, time::Millis};
use super::connect::ConnectorWrapper;
use super::error::ConnectError;
use super::{Client, ClientConfig, Connect, Connection, Connector};
#[derive(Debug)]
pub struct ClientBuilder {
config: ClientConfig,
default_headers: bool,
allow_redirects: bool,
max_redirects: usize,
}
impl Default for ClientBuilder {
fn default() -> Self {
Self::new()
}
}
impl ClientBuilder {
pub fn new() -> Self {
ClientBuilder {
default_headers: true,
allow_redirects: true,
max_redirects: 10,
config: ClientConfig {
headers: HeaderMap::new(),
timeout: Millis(5_000),
response_pl_limit: 262_144,
response_pl_timeout: Millis(10_000),
connector: Box::new(ConnectorWrapper(Connector::default().finish().into())),
},
}
}
pub fn connector<T>(mut self, connector: T) -> Self
where
T: Service<Connect, Response = Connection, Error = ConnectError>
+ fmt::Debug
+ 'static,
{
self.config.connector = Box::new(ConnectorWrapper(connector.into()));
self
}
#[doc(hidden)]
pub fn connection(
mut self,
connection: impl super::connect::Connect + 'static,
) -> Self {
self.config.connector = Box::new(connection);
self
}
pub fn timeout<T: Into<Millis>>(mut self, timeout: T) -> Self {
self.config.timeout = timeout.into();
self
}
pub fn disable_timeout(mut self) -> Self {
self.config.timeout = Millis::ZERO;
self
}
pub fn disable_redirects(mut self) -> Self {
self.allow_redirects = false;
self
}
pub fn max_redirects(mut self, num: usize) -> Self {
self.max_redirects = num;
self
}
pub fn no_default_headers(mut self) -> Self {
self.default_headers = false;
self
}
pub fn response_payload_limit(mut self, limit: usize) -> Self {
self.config.response_pl_limit = limit;
self
}
pub fn response_payload_timeout(mut self, timeout: Millis) -> Self {
self.config.response_pl_timeout = timeout;
self
}
pub fn header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
HeaderValue: TryFrom<V>,
<HeaderName as TryFrom<K>>::Error: fmt::Debug + Into<HttpError>,
<HeaderValue as TryFrom<V>>::Error: fmt::Debug + Into<HttpError>,
{
match HeaderName::try_from(key) {
Ok(key) => match HeaderValue::try_from(value) {
Ok(value) => {
self.config.headers.append(key, value);
}
Err(e) => log::error!("Header value error: {e:?}"),
},
Err(e) => log::error!("Header name error: {e:?}"),
}
self
}
pub fn basic_auth<U>(self, username: U, password: Option<&str>) -> Self
where
U: fmt::Display,
{
let auth = match password {
Some(password) => format!("{username}:{password}"),
None => format!("{username}:"),
};
self.header(
header::AUTHORIZATION,
format!("Basic {}", base64.encode(auth)),
)
}
pub fn bearer_auth<T>(self, token: T) -> Self
where
T: fmt::Display,
{
self.header(header::AUTHORIZATION, format!("Bearer {token}"))
}
pub fn finish(self) -> Client {
Client(Rc::new(self.config))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[crate::rt_test]
async fn basics() {
let builder = ClientBuilder::new()
.disable_timeout()
.disable_redirects()
.max_redirects(10)
.no_default_headers();
assert!(!builder.allow_redirects);
assert!(!builder.default_headers);
assert_eq!(builder.max_redirects, 10);
}
#[crate::rt_test]
async fn response_payload_limit() {
let builder = ClientBuilder::default();
assert_eq!(builder.config.response_pl_limit, 262_144);
let builder = builder.response_payload_limit(10);
assert_eq!(builder.config.response_pl_limit, 10);
}
#[crate::rt_test]
async fn response_payload_timeout() {
let builder = ClientBuilder::default();
assert_eq!(builder.config.response_pl_timeout, Millis(10_000));
let builder = builder.response_payload_timeout(Millis(10));
assert_eq!(builder.config.response_pl_timeout, Millis(10));
}
#[crate::rt_test]
async fn valid_header_name() {
let builder = ClientBuilder::new().header("Content-Length", 1);
assert!(builder.config.headers.contains_key("Content-Length"));
}
#[crate::rt_test]
async fn invalid_header_name() {
let builder = ClientBuilder::new().header("no valid header name", 1);
assert!(!builder.config.headers.contains_key("no valid header name"));
}
#[crate::rt_test]
async fn valid_header_value() {
let valid_header_value = HeaderValue::from(1234);
let builder = ClientBuilder::new().header("Content-Length", &valid_header_value);
assert_eq!(
builder.config.headers.get("Content-Length"),
Some(&valid_header_value)
);
}
#[crate::rt_test]
async fn invalid_header_value() {
let builder = ClientBuilder::new().header("Content-Length", "\n");
assert!(!builder.config.headers.contains_key("Content-Length"));
}
#[crate::rt_test]
async fn client_basic_auth() {
let client = ClientBuilder::new().basic_auth("username", Some("password"));
assert_eq!(
client
.config
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Basic dXNlcm5hbWU6cGFzc3dvcmQ="
);
let client = ClientBuilder::new().basic_auth("username", None);
assert_eq!(
client
.config
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Basic dXNlcm5hbWU6"
);
}
#[crate::rt_test]
async fn client_bearer_auth() {
let client = ClientBuilder::new().bearer_auth("someS3cr3tAutht0k3n");
assert_eq!(
client
.config
.headers
.get(header::AUTHORIZATION)
.unwrap()
.to_str()
.unwrap(),
"Bearer someS3cr3tAutht0k3n"
);
}
}