use crate::agent;
use crate::config::*;
use crate::handler::{RequestHandler, RequestHandlerFuture};
use crate::middleware::Middleware;
use crate::{Body, Error};
use http::{Request, Response};
use lazy_static::lazy_static;
use std::fmt;
use std::future::Future;
use std::iter::FromIterator;
use std::net::SocketAddr;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
lazy_static! {
static ref USER_AGENT: String = format!(
"curl/{} chttp/{}",
curl::Version::get().version(),
env!("CARGO_PKG_VERSION")
);
}
#[derive(Default)]
pub struct HttpClientBuilder {
defaults: http::Extensions,
middleware: Vec<Box<dyn Middleware>>,
}
impl HttpClientBuilder {
pub fn new() -> Self {
Self::default()
}
#[cfg(feature = "cookies")]
pub fn cookies(self) -> Self {
self.middleware_impl(crate::cookies::CookieJar::default())
}
#[cfg(feature = "middleware-api")]
pub fn middleware(self, middleware: impl Middleware) -> Self {
self.middleware_impl(middleware)
}
#[allow(unused)]
fn middleware_impl(mut self, middleware: impl Middleware) -> Self {
self.middleware.push(Box::new(middleware));
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.defaults.insert(Timeout(timeout));
self
}
pub fn connect_timeout(mut self, timeout: Duration) -> Self {
self.defaults.insert(ConnectTimeout(timeout));
self
}
pub fn redirect_policy(mut self, policy: RedirectPolicy) -> Self {
self.defaults.insert(policy);
self
}
pub fn auto_referer(mut self) -> Self {
self.defaults.insert(AutoReferer);
self
}
pub fn preferred_http_version(mut self, version: http::Version) -> Self {
self.defaults.insert(PreferredHttpVersion(version));
self
}
pub fn tcp_keepalive(mut self, interval: Duration) -> Self {
self.defaults.insert(TcpKeepAlive(interval));
self
}
pub fn tcp_nodelay(mut self) -> Self {
self.defaults.insert(TcpNoDelay);
self
}
pub fn proxy(mut self, proxy: http::Uri) -> Self {
self.defaults.insert(Proxy(proxy));
self
}
pub fn max_upload_speed(mut self, max: u64) -> Self {
self.defaults.insert(MaxUploadSpeed(max));
self
}
pub fn max_download_speed(mut self, max: u64) -> Self {
self.defaults.insert(MaxDownloadSpeed(max));
self
}
pub fn dns_servers(mut self, servers: impl IntoIterator<Item = SocketAddr>) -> Self {
self.defaults.insert(DnsServers::from_iter(servers));
self
}
pub fn ssl_ciphers(mut self, servers: impl IntoIterator<Item = String>) -> Self {
self.defaults.insert(SslCiphers::from_iter(servers));
self
}
pub fn ssl_client_certificate(mut self, certificate: ClientCertificate) -> Self {
self.defaults.insert(certificate);
self
}
pub fn build(self) -> Result<HttpClient, Error> {
Ok(HttpClient {
agent: agent::new()?,
defaults: self.defaults,
middleware: self.middleware,
})
}
}
impl fmt::Debug for HttpClientBuilder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpClientBuilder").finish()
}
}
pub struct HttpClient {
agent: agent::Handle,
defaults: http::Extensions,
middleware: Vec<Box<dyn Middleware>>,
}
impl Default for HttpClient {
fn default() -> Self {
HttpClientBuilder::default()
.build()
.expect("client failed to initialize")
}
}
impl HttpClient {
pub fn new() -> Self {
Self::default()
}
pub(crate) fn shared() -> &'static Self {
lazy_static! {
static ref SHARED: HttpClient = HttpClient::new();
}
&SHARED
}
pub fn builder() -> HttpClientBuilder {
HttpClientBuilder::default()
}
#[inline]
pub fn get<U>(&self, uri: U) -> Result<Response<Body>, Error>
where
http::Uri: http::HttpTryFrom<U>,
{
self.get_async(uri).join()
}
pub fn get_async<U>(&self, uri: U) -> ResponseFuture<'_>
where
http::Uri: http::HttpTryFrom<U>,
{
self.send_builder_async(http::Request::get(uri), Body::empty())
}
#[inline]
pub fn head<U>(&self, uri: U) -> Result<Response<Body>, Error>
where
http::Uri: http::HttpTryFrom<U>,
{
self.head_async(uri).join()
}
pub fn head_async<U>(&self, uri: U) -> ResponseFuture<'_>
where
http::Uri: http::HttpTryFrom<U>,
{
self.send_builder_async(http::Request::head(uri), Body::empty())
}
#[inline]
pub fn post<U>(&self, uri: U, body: impl Into<Body>) -> Result<Response<Body>, Error>
where
http::Uri: http::HttpTryFrom<U>,
{
self.post_async(uri, body).join()
}
pub fn post_async<U>(&self, uri: U, body: impl Into<Body>) -> ResponseFuture<'_>
where
http::Uri: http::HttpTryFrom<U>,
{
self.send_builder_async(http::Request::post(uri), body)
}
#[inline]
pub fn put<U>(&self, uri: U, body: impl Into<Body>) -> Result<Response<Body>, Error>
where
http::Uri: http::HttpTryFrom<U>,
{
self.put_async(uri, body).join()
}
pub fn put_async<U>(&self, uri: U, body: impl Into<Body>) -> ResponseFuture<'_>
where
http::Uri: http::HttpTryFrom<U>,
{
self.send_builder_async(http::Request::put(uri), body)
}
#[inline]
pub fn delete<U>(&self, uri: U) -> Result<Response<Body>, Error>
where
http::Uri: http::HttpTryFrom<U>,
{
self.delete_async(uri).join()
}
pub fn delete_async<U>(&self, uri: U) -> ResponseFuture<'_>
where
http::Uri: http::HttpTryFrom<U>,
{
self.send_builder_async(http::Request::delete(uri), Body::empty())
}
#[inline]
pub fn send<B: Into<Body>>(&self, request: Request<B>) -> Result<Response<Body>, Error> {
self.send_async(request).join()
}
pub fn send_async<B: Into<Body>>(&self, request: Request<B>) -> ResponseFuture<'_> {
let mut request = request.map(Into::into);
request
.headers_mut()
.entry(http::header::USER_AGENT)
.unwrap()
.or_insert(USER_AGENT.parse().unwrap());
for middleware in self.middleware.iter().rev() {
request = middleware.filter_request(request);
}
ResponseFuture {
client: self,
error: None,
request: Some(request),
inner: None,
}
}
fn send_builder_async(
&self,
mut builder: http::request::Builder,
body: impl Into<Body>,
) -> ResponseFuture<'_> {
match builder.body(body.into()) {
Ok(request) => self.send_async(request),
Err(e) => ResponseFuture {
client: self,
error: Some(e.into()),
request: None,
inner: None,
},
}
}
fn create_easy_handle(
&self,
request: Request<Body>,
) -> Result<(curl::easy::Easy2<RequestHandler>, RequestHandlerFuture), Error> {
let (parts, body) = request.into_parts();
let has_body = !body.is_empty();
let body_size = body.len();
let (handler, future) = RequestHandler::new(body);
macro_rules! extension {
($first:expr) => {
$first.get()
};
($first:expr, $($rest:expr),+) => {
$first.get().or_else(|| extension!($($rest),*))
};
}
let mut easy = curl::easy::Easy2::new(handler);
easy.verbose(log::log_enabled!(log::Level::Trace))?;
easy.signal(false)?;
if let Some(Timeout(timeout)) = extension!(parts.extensions, self.defaults) {
easy.timeout(*timeout)?;
}
if let Some(ConnectTimeout(timeout)) = extension!(parts.extensions, self.defaults) {
easy.connect_timeout(*timeout)?;
}
if let Some(TcpKeepAlive(interval)) = extension!(parts.extensions, self.defaults) {
easy.tcp_keepalive(true)?;
easy.tcp_keepintvl(*interval)?;
}
if let Some(TcpNoDelay) = extension!(parts.extensions, self.defaults) {
easy.tcp_nodelay(true)?;
}
if let Some(redirect_policy) = extension!(parts.extensions, self.defaults) {
match redirect_policy {
RedirectPolicy::Follow => {
easy.follow_location(true)?;
}
RedirectPolicy::Limit(max) => {
easy.follow_location(true)?;
easy.max_redirections(*max)?;
}
RedirectPolicy::None => {
easy.follow_location(false)?;
}
}
}
if let Some(MaxUploadSpeed(limit)) = extension!(parts.extensions, self.defaults) {
easy.max_send_speed(*limit)?;
}
if let Some(MaxDownloadSpeed(limit)) = extension!(parts.extensions, self.defaults) {
easy.max_recv_speed(*limit)?;
}
easy.http_version(match extension!(parts.extensions, self.defaults) {
Some(PreferredHttpVersion(http::Version::HTTP_10)) => curl::easy::HttpVersion::V10,
Some(PreferredHttpVersion(http::Version::HTTP_11)) => curl::easy::HttpVersion::V11,
Some(PreferredHttpVersion(http::Version::HTTP_2)) => curl::easy::HttpVersion::V2,
_ => curl::easy::HttpVersion::Any,
})?;
if let Some(Proxy(proxy)) = extension!(parts.extensions, self.defaults) {
easy.proxy(&format!("{}", proxy))?;
}
if let Some(DnsServers(addrs)) = extension!(parts.extensions, self.defaults) {
let dns_string = addrs
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.join(",");
if let Err(e) = easy.dns_servers(&dns_string) {
log::warn!("DNS servers could not be configured: {}", e);
}
}
if let Some(SslCiphers(ciphers)) = extension!(parts.extensions, self.defaults) {
easy.ssl_cipher_list(&ciphers.join(":"))?;
}
if let Some(cert) = extension!(parts.extensions, self.defaults) {
easy.ssl_client_certificate(cert)?;
}
easy.accept_encoding("")?;
match (parts.method, has_body) {
(http::Method::GET, false) => {
easy.get(true)?;
}
(http::Method::HEAD, has_body) => {
easy.custom_request("HEAD")?;
easy.nobody(true)?;
easy.upload(has_body)?;
}
(http::Method::POST, _) => {
easy.post(true)?;
}
(http::Method::PUT, _) => {
easy.upload(true)?;
}
(method, has_body) => {
easy.custom_request(method.as_str())?;
easy.upload(has_body)?;
}
}
easy.url(&parts.uri.to_string())?;
let mut headers = curl::easy::List::new();
for (name, value) in parts.headers.iter() {
let header = format!("{}: {}", name.as_str(), value.to_str().unwrap());
headers.append(&header)?;
}
easy.http_headers(headers)?;
if let Some(len) = body_size {
easy.in_filesize(len as u64)?;
}
Ok((easy, future))
}
}
impl fmt::Debug for HttpClient {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpClient").finish()
}
}
trait EasyExt {
fn easy(&mut self) -> &mut curl::easy::Easy2<RequestHandler>;
fn ssl_client_certificate(&mut self, cert: &ClientCertificate) -> Result<(), curl::Error> {
match cert {
ClientCertificate::PEM { path, private_key } => {
self.easy().ssl_cert(path)?;
self.easy().ssl_cert_type("PEM")?;
if let Some(key) = private_key {
self.ssl_private_key(key)?;
}
}
ClientCertificate::DER { path, private_key } => {
self.easy().ssl_cert(path)?;
self.easy().ssl_cert_type("DER")?;
if let Some(key) = private_key {
self.ssl_private_key(key)?;
}
}
ClientCertificate::P12 { path, password } => {
self.easy().ssl_cert(path)?;
self.easy().ssl_cert_type("P12")?;
if let Some(password) = password {
self.easy().key_password(password)?;
}
}
}
Ok(())
}
fn ssl_private_key(&mut self, key: &PrivateKey) -> Result<(), curl::Error> {
match key {
PrivateKey::PEM { path, password } => {
self.easy().ssl_key(path)?;
self.easy().ssl_key_type("PEM")?;
if let Some(password) = password {
self.easy().key_password(password)?;
}
}
PrivateKey::DER { path, password } => {
self.easy().ssl_key(path)?;
self.easy().ssl_key_type("DER")?;
if let Some(password) = password {
self.easy().key_password(password)?;
}
}
}
Ok(())
}
}
impl EasyExt for curl::easy::Easy2<RequestHandler> {
fn easy(&mut self) -> &mut Self {
self
}
}
#[derive(Debug)]
pub struct ResponseFuture<'c> {
client: &'c HttpClient,
error: Option<Error>,
request: Option<Request<Body>>,
inner: Option<RequestHandlerFuture>,
}
impl<'c> ResponseFuture<'c> {
fn maybe_initialize(&mut self) -> Result<(), Error> {
if let Some(e) = self.error.take() {
return Err(e);
}
if let Some(request) = self.request.take() {
let (easy, future) = self.client.create_easy_handle(request)?;
self.client.agent.submit_request(easy)?;
self.inner = Some(future);
}
Ok(())
}
fn complete(&self, output: <Self as Future>::Output) -> <Self as Future>::Output {
output.map(|mut response| {
for middleware in self.client.middleware.iter() {
response = middleware.filter_response(response);
}
response
})
}
fn join(mut self) -> Result<Response<Body>, Error> {
self.maybe_initialize()?;
if let Some(inner) = self.inner.take() {
self.complete(inner.join())
} else {
panic!("join called after poll");
}
}
}
impl Future for ResponseFuture<'_> {
type Output = Result<Response<Body>, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.maybe_initialize()?;
if let Some(inner) = self.inner.as_mut() {
Pin::new(inner).poll(cx).map(|result| self.complete(result))
} else {
Poll::Pending
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn is_send<T: Send>() {}
fn is_sync<T: Sync>() {}
#[test]
fn traits() {
is_send::<HttpClient>();
is_sync::<HttpClient>();
is_send::<HttpClientBuilder>();
}
}