use crate::{
agent::{self, AgentBuilder},
auth::{Authentication, Credentials},
config::internal::{ConfigurableBase, SetOpt},
config::*,
default_headers::DefaultHeadersInterceptor,
handler::{RequestHandler, ResponseBodyReader},
headers,
interceptor::{self, Interceptor, InterceptorObj},
Body, Error,
};
use futures_lite::{future::block_on, io::AsyncRead, pin};
use http::{
header::{HeaderMap, HeaderName, HeaderValue},
Request, Response,
};
use once_cell::sync::Lazy;
use std::{
convert::TryFrom,
fmt,
future::Future,
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use tracing_futures::Instrument;
static USER_AGENT: Lazy<String> = Lazy::new(|| format!(
"curl/{} isahc/{}",
curl::Version::get().version(),
env!("CARGO_PKG_VERSION")
));
pub struct HttpClientBuilder {
agent_builder: AgentBuilder,
defaults: http::Extensions,
interceptors: Vec<InterceptorObj>,
default_headers: HeaderMap<HeaderValue>,
error: Option<Error>,
#[cfg(feature = "cookies")]
cookie_jar: Option<crate::cookies::CookieJar>,
}
impl Default for HttpClientBuilder {
fn default() -> Self {
Self::new()
}
}
impl HttpClientBuilder {
pub fn new() -> Self {
let mut defaults = http::Extensions::new();
defaults.insert(VersionNegotiation::default());
defaults.insert(AutomaticDecompression(true));
defaults.insert(Authentication::default());
Self {
agent_builder: AgentBuilder::default(),
defaults,
interceptors: vec![
InterceptorObj::new(crate::redirect::RedirectInterceptor),
],
default_headers: HeaderMap::new(),
error: None,
#[cfg(feature = "cookies")]
cookie_jar: None,
}
}
#[cfg(feature = "cookies")]
pub fn cookies(self) -> Self {
self.cookie_jar(Default::default())
}
#[cfg(feature = "unstable-interceptors")]
#[inline]
pub fn interceptor(self, interceptor: impl Interceptor + 'static) -> Self {
self.interceptor_impl(interceptor)
}
#[allow(unused)]
pub(crate) fn interceptor_impl(mut self, interceptor: impl Interceptor + 'static) -> Self {
self.interceptors.push(InterceptorObj::new(interceptor));
self
}
pub fn connection_cache_ttl(mut self, ttl: Duration) -> Self {
self.defaults.insert(MaxAgeConn(ttl));
self
}
pub fn max_connections(mut self, max: usize) -> Self {
self.agent_builder = self.agent_builder.max_connections(max);
self
}
pub fn max_connections_per_host(mut self, max: usize) -> Self {
self.agent_builder = self.agent_builder.max_connections_per_host(max);
self
}
pub fn connection_cache_size(mut self, size: usize) -> Self {
self.agent_builder = self.agent_builder.connection_cache_size(size);
self.defaults.insert(CloseConnection(size == 0));
self
}
pub fn dns_cache(self, cache: impl Into<DnsCache>) -> Self {
self.configure(cache.into())
}
pub fn dns_resolve(self, map: ResolveMap) -> Self {
self.configure(map)
}
pub fn default_header<K, V>(mut self, key: K, value: V) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
{
match HeaderName::try_from(key) {
Ok(key) => match HeaderValue::try_from(value) {
Ok(value) => {
self.default_headers.append(key, value);
}
Err(e) => {
self.error = Some(e.into().into());
}
},
Err(e) => {
self.error = Some(e.into().into());
}
}
self
}
pub fn default_headers<K, V, I, P>(mut self, headers: I) -> Self
where
HeaderName: TryFrom<K>,
<HeaderName as TryFrom<K>>::Error: Into<http::Error>,
HeaderValue: TryFrom<V>,
<HeaderValue as TryFrom<V>>::Error: Into<http::Error>,
I: IntoIterator<Item = P>,
P: HeaderPair<K, V>,
{
self.default_headers.clear();
for (key, value) in headers.into_iter().map(HeaderPair::pair) {
self = self.default_header(key, value);
}
self
}
#[allow(unused_mut)]
#[tracing::instrument(level = "debug", skip(self))]
pub fn build(mut self) -> Result<HttpClient, Error> {
if let Some(err) = self.error {
return Err(err);
}
#[cfg(feature = "cookies")]
{
let jar = self.cookie_jar.clone();
self = self.interceptor_impl(crate::cookies::interceptor::CookieInterceptor::new(jar));
}
if !self.default_headers.is_empty() {
let default_headers = std::mem::take(&mut self.default_headers);
self = self.interceptor_impl(DefaultHeadersInterceptor::from(default_headers));
}
let inner = InnerHttpClient {
agent: self.agent_builder.spawn()?,
defaults: self.defaults,
interceptors: self.interceptors,
#[cfg(feature = "cookies")]
cookie_jar: self.cookie_jar,
};
Ok(HttpClient { inner: Arc::new(inner) })
}
}
impl Configurable for HttpClientBuilder {
#[cfg(feature = "cookies")]
fn cookie_jar(mut self, cookie_jar: crate::cookies::CookieJar) -> Self {
self.cookie_jar = Some(cookie_jar);
self
}
}
impl ConfigurableBase for HttpClientBuilder {
fn configure(mut self, option: impl Send + Sync + 'static) -> Self {
self.defaults.insert(option);
self
}
}
impl fmt::Debug for HttpClientBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpClientBuilder").finish()
}
}
pub trait HeaderPair<K, V> {
fn pair(self) -> (K, V);
}
impl<K, V> HeaderPair<K, V> for (K, V) {
fn pair(self) -> (K, V) {
self
}
}
impl<'a, K: Copy, V: Copy> HeaderPair<K, V> for &'a (K, V) {
fn pair(self) -> (K, V) {
(self.0, self.1)
}
}
#[derive(Clone)]
pub struct HttpClient {
inner: Arc<InnerHttpClient>,
}
struct InnerHttpClient {
agent: agent::Handle,
defaults: http::Extensions,
interceptors: Vec<InterceptorObj>,
#[cfg(feature = "cookies")]
cookie_jar: Option<crate::cookies::CookieJar>,
}
impl HttpClient {
#[tracing::instrument(level = "debug")]
pub fn new() -> Result<Self, Error> {
HttpClientBuilder::default().build()
}
#[tracing::instrument(level = "debug")]
pub(crate) fn shared() -> &'static Self {
static SHARED: Lazy<HttpClient> = Lazy::new(|| HttpClient::new()
.expect("shared client failed to initialize"));
&SHARED
}
pub fn builder() -> HttpClientBuilder {
HttpClientBuilder::default()
}
#[cfg(feature = "cookies")]
pub fn cookie_jar(&self) -> Option<&crate::cookies::CookieJar> {
self.inner.cookie_jar.as_ref()
}
#[inline]
pub fn get<U>(&self, uri: U) -> Result<Response<Body>, Error>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
block_on(self.get_async(uri))
}
pub fn get_async<U>(&self, uri: U) -> ResponseFuture<'_>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
self.send_builder_async(http::Request::get(uri), Body::empty())
}
#[inline]
pub fn head<U>(&self, uri: U) -> Result<Response<Body>, Error>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
block_on(self.head_async(uri))
}
pub fn head_async<U>(&self, uri: U) -> ResponseFuture<'_>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
self.send_builder_async(http::Request::head(uri), Body::empty())
}
#[inline]
pub fn post<U>(&self, uri: U, body: impl Into<Body>) -> Result<Response<Body>, Error>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
block_on(self.post_async(uri, body))
}
pub fn post_async<U>(&self, uri: U, body: impl Into<Body>) -> ResponseFuture<'_>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
self.send_builder_async(http::Request::post(uri), body.into())
}
#[inline]
pub fn put<U>(&self, uri: U, body: impl Into<Body>) -> Result<Response<Body>, Error>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
block_on(self.put_async(uri, body))
}
pub fn put_async<U>(&self, uri: U, body: impl Into<Body>) -> ResponseFuture<'_>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
self.send_builder_async(http::Request::put(uri), body.into())
}
#[inline]
pub fn delete<U>(&self, uri: U) -> Result<Response<Body>, Error>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
block_on(self.delete_async(uri))
}
pub fn delete_async<U>(&self, uri: U) -> ResponseFuture<'_>
where
http::Uri: TryFrom<U>,
<http::Uri as TryFrom<U>>::Error: Into<http::Error>,
{
self.send_builder_async(http::Request::delete(uri), Body::empty())
}
#[inline]
#[tracing::instrument(level = "debug", skip(self, request), err)]
pub fn send<B: Into<Body>>(&self, request: Request<B>) -> Result<Response<Body>, Error> {
block_on(self.send_async(request))
}
#[inline]
pub fn send_async<B: Into<Body>>(&self, request: Request<B>) -> ResponseFuture<'_> {
ResponseFuture::new(self.send_async_inner(request.map(Into::into)))
}
#[inline]
fn send_builder_async(
&self,
builder: http::request::Builder,
body: Body,
) -> ResponseFuture<'_> {
ResponseFuture::new(async move { self.send_async_inner(builder.body(body)?).await })
}
async fn send_async_inner(&self, mut request: Request<Body>) -> Result<Response<Body>, Error> {
let span = tracing::debug_span!(
"send_async",
method = ?request.method(),
uri = ?request.uri(),
);
if request.extensions().get::<RedirectPolicy>().is_none() {
if let Some(policy) = self.inner.defaults.get::<RedirectPolicy>().cloned() {
request.extensions_mut().insert(policy);
}
}
let ctx = interceptor::Context {
invoker: Arc::new(self),
interceptors: &self.inner.interceptors,
};
ctx.send(request)
.instrument(span)
.await
}
fn create_easy_handle(
&self,
mut request: Request<Body>,
) -> Result<
(
curl::easy::Easy2<RequestHandler>,
impl Future<Output = Result<Response<ResponseBodyReader>, Error>>,
),
Error,
> {
let body = std::mem::take(request.body_mut());
let has_body = !body.is_empty();
let body_length = body.len();
let (handler, future) = RequestHandler::new(body);
let mut easy = curl::easy::Easy2::new(handler);
easy.verbose(easy.get_ref().is_debug_enabled())?;
easy.signal(false)?;
macro_rules! set_opts {
($easy:expr, $extensions:expr, $defaults:expr, [$($option:ty,)*]) => {{
$(
if let Some(extension) = $extensions.get::<$option>().or_else(|| $defaults.get()) {
extension.set_opt($easy)?;
}
)*
}};
}
set_opts!(
&mut easy,
request.extensions(),
self.inner.defaults,
[
Timeout,
ConnectTimeout,
TcpKeepAlive,
TcpNoDelay,
NetworkInterface,
Dialer,
AutomaticDecompression,
Authentication,
Credentials,
MaxAgeConn,
MaxUploadSpeed,
MaxDownloadSpeed,
VersionNegotiation,
proxy::Proxy<Option<http::Uri>>,
proxy::Blacklist,
proxy::Proxy<Authentication>,
proxy::Proxy<Credentials>,
DnsCache,
dns::ResolveMap,
dns::Servers,
ssl::Ciphers,
ClientCertificate,
CaCertificate,
SslOption,
CloseConnection,
EnableMetrics,
IpVersion,
]
);
#[allow(indirect_structural_match)]
match (request.method(), has_body) {
(&http::Method::GET, false) => {
easy.get(true)?;
}
(&http::Method::HEAD, false) => {
easy.nobody(true)?;
}
(&http::Method::POST, _) => {
easy.post(true)?;
}
(&http::Method::PUT, _) => {
easy.upload(true)?;
}
(method, has_body) => {
easy.upload(has_body)?;
easy.custom_request(method.as_str())?;
}
}
easy.url(&uri_to_string(request.uri()))?;
if has_body {
let body_length = request
.headers()
.get("Content-Length")
.and_then(|value| value.to_str().ok())
.and_then(|value| value.parse().ok())
.or(body_length);
if let Some(len) = body_length {
if request.method() == http::Method::POST {
easy.post_field_size(len)?;
} else {
easy.in_filesize(len)?;
}
} else {
request.headers_mut().insert(
"Transfer-Encoding",
http::header::HeaderValue::from_static("chunked"),
);
}
}
let mut headers = curl::easy::List::new();
let title_case = request
.extensions()
.get::<TitleCaseHeaders>()
.or_else(|| self.inner.defaults.get())
.map(|v| v.0)
.unwrap_or(false);
for (name, value) in request.headers().iter() {
headers.append(&headers::to_curl_string(name, value, title_case))?;
}
easy.http_headers(headers)?;
Ok((easy, future))
}
}
impl crate::interceptor::Invoke for &HttpClient {
fn invoke<'a>(&'a self, mut request: Request<Body>) -> crate::interceptor::InterceptorFuture<'a, Error> {
Box::pin(
async move {
request
.headers_mut()
.entry(http::header::USER_AGENT)
.or_insert(USER_AGENT.parse().unwrap());
let (easy, future) = self.create_easy_handle(request)?;
self.inner.agent.submit_request(easy)?;
let response = future.await?;
let content_length = response
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|v| v.parse().ok());
Ok(response.map(|reader| {
let body = ResponseBody {
inner: reader,
_client: (*self).clone(),
};
if let Some(len) = content_length {
Body::from_reader_sized(body, len)
} else {
Body::from_reader(body)
}
}))
}
)
}
}
impl fmt::Debug for HttpClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpClient").finish()
}
}
pub struct ResponseFuture<'c>(Pin<Box<dyn Future<Output = Result<Response<Body>, Error>> + 'c + Send>>);
impl<'c> ResponseFuture<'c> {
fn new(future: impl Future<Output = Result<Response<Body>, Error>> + Send + 'c) -> Self {
ResponseFuture(Box::pin(future))
}
}
impl Future for ResponseFuture<'_> {
type Output = Result<Response<Body>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.0.as_mut().poll(cx)
}
}
impl<'c> fmt::Debug for ResponseFuture<'c> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ResponseFuture").finish()
}
}
struct ResponseBody {
inner: ResponseBodyReader,
_client: HttpClient,
}
impl AsyncRead for ResponseBody {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let inner = &mut self.inner;
pin!(inner);
inner.poll_read(cx, buf)
}
}
fn uri_to_string(uri: &http::Uri) -> String {
let mut s = String::new();
if let Some(scheme) = uri.scheme() {
s.push_str(scheme.as_str());
s.push_str("://");
}
if let Some(authority) = uri.authority() {
s.push_str(authority.as_str());
}
s.push_str(uri.path());
if let Some(query) = uri.query() {
s.push('?');
s.push_str(query);
}
s
}
#[cfg(test)]
mod tests {
use super::*;
static_assertions::assert_impl_all!(HttpClient: Send, Sync);
static_assertions::assert_impl_all!(HttpClientBuilder: Send);
#[test]
fn test_default_header() {
let client = HttpClientBuilder::new()
.default_header("some-key", "some-value")
.build();
match client {
Ok(_) => assert!(true),
Err(_) => assert!(false),
}
}
#[test]
fn test_default_headers_mut() {
let mut builder = HttpClientBuilder::new().default_header("some-key", "some-value");
let headers_map = &mut builder.default_headers;
assert!(headers_map.len() == 1);
let mut builder = HttpClientBuilder::new()
.default_header("some-key", "some-value1")
.default_header("some-key", "some-value2");
let headers_map = &mut builder.default_headers;
assert!(headers_map.len() == 2);
let mut builder = HttpClientBuilder::new();
let header_map = &mut builder.default_headers;
assert!(header_map.is_empty())
}
}