use std::env;
use std::fmt::Debug;
use std::fmt::Formatter;
use std::io;
use std::str::FromStr;
use std::sync::Arc;
use futures::TryStreamExt;
use http::Request;
use http::Response;
use log::debug;
use reqwest::redirect::Policy;
use reqwest::ClientBuilder;
use reqwest::Url;
use super::body::IncomingAsyncBody;
use super::dns::*;
use super::parse_content_length;
use super::AsyncBody;
use super::Body;
use crate::Error;
use crate::ErrorKind;
use crate::Result;
#[derive(Clone)]
pub struct HttpClient {
async_client: reqwest::Client,
sync_client: ureq::Agent,
}
impl Debug for HttpClient {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HttpClient").finish()
}
}
impl HttpClient {
pub fn new() -> Result<Self> {
let async_client = {
let mut builder = ClientBuilder::new();
builder = builder.no_gzip();
builder = builder.no_brotli();
builder = builder.no_deflate();
builder = builder.redirect(Policy::none());
#[cfg(feature = "trust-dns")]
let builder = builder.dns_resolver(Arc::new(AsyncTrustDnsResolver::new().unwrap()));
#[cfg(not(feature = "trust-dns"))]
let builder = builder.dns_resolver(Arc::new(AsyncStdDnsResolver::default()));
builder.build().map_err(|err| {
Error::new(ErrorKind::Unexpected, "async client build failed").set_source(err)
})?
};
let sync_client = {
let mut builder = ureq::AgentBuilder::new();
for key in ["http_proxy", "https_proxy", "HTTP_PROXY", "HTTPS_PROXY"] {
if let Ok(proxy) = env::var(key) {
if let Ok(proxy) = ureq::Proxy::new(proxy) {
debug!("sync client: set proxy to {proxy:?}");
builder = builder.proxy(proxy);
}
}
}
let builder = builder.resolver(StdDnsResolver::default());
builder.build()
};
Ok(HttpClient {
async_client,
sync_client,
})
}
pub fn with_client(async_client: reqwest::Client, sync_client: ureq::Agent) -> Self {
Self {
async_client,
sync_client,
}
}
pub fn async_client(&self) -> reqwest::Client {
self.async_client.clone()
}
pub fn sync_client(&self) -> ureq::Agent {
self.sync_client.clone()
}
pub fn send(&self, req: Request<Body>) -> Result<Response<Body>> {
let (parts, body) = req.into_parts();
let mut ur = self
.sync_client
.request(parts.method.as_str(), &parts.uri.to_string());
for (k, v) in parts.headers.iter() {
ur = ur.set(k.as_str(), v.to_str().expect("must be valid header"));
}
let resp = match ur.send(body) {
Ok(resp) => resp,
Err(err_resp) => match err_resp {
ureq::Error::Status(_code, resp) => resp,
ureq::Error::Transport(transport) => {
let is_temperary = matches!(
transport.kind(),
ureq::ErrorKind::Dns
| ureq::ErrorKind::ConnectionFailed
| ureq::ErrorKind::Io
);
let mut err = Error::new(ErrorKind::Unexpected, "send blocking request")
.with_operation("http_util::Client::send")
.set_source(transport);
if is_temperary {
err = err.set_temporary();
}
return Err(err);
}
},
};
let mut hr = Response::builder().status(resp.status());
for name in resp.headers_names() {
if let Some(value) = resp.header(&name) {
hr = hr.header(name, value);
}
}
let resp = hr
.body(Body::Reader(Box::new(resp.into_reader())))
.expect("response must build succeed");
Ok(resp)
}
pub async fn send_async(&self, req: Request<AsyncBody>) -> Result<Response<IncomingAsyncBody>> {
let is_head = req.method() == http::Method::HEAD;
let (parts, body) = req.into_parts();
let mut req_builder = self
.async_client
.request(
parts.method,
Url::from_str(&parts.uri.to_string()).expect("input request url must be valid"),
)
.version(parts.version)
.headers(parts.headers);
req_builder = if let AsyncBody::Multipart(field, r) = body {
let mut form = reqwest::multipart::Form::new();
let part = reqwest::multipart::Part::stream(AsyncBody::Reader(r));
form = form.part(field, part);
req_builder.multipart(form)
} else {
req_builder.body(body)
};
let resp = req_builder.send().await.map_err(|err| {
let is_temperary = !(
err.is_builder() ||
err.is_redirect() ||
err.is_status()
);
let mut oerr = Error::new(ErrorKind::Unexpected, "send async request")
.with_operation("http_util::Client::send_async")
.set_source(err);
if is_temperary {
oerr = oerr.set_temporary();
}
oerr
})?;
let content_length = if is_head {
None
} else {
parse_content_length(resp.headers()).expect("response content length must be valid")
};
let mut hr = Response::builder()
.version(resp.version())
.status(resp.status());
for (k, v) in resp.headers().iter() {
hr = hr.header(k, v);
}
let stream = resp.bytes_stream().map_err(|err| {
io::Error::new(
if err.is_body() {
io::ErrorKind::Interrupted
} else {
io::ErrorKind::Other
},
err,
)
});
let body = IncomingAsyncBody::new(Box::new(stream), content_length);
let resp = hr.body(body).expect("response must build succeed");
Ok(resp)
}
}