use super::conn::BodyBuf;
use super::connect;
use super::cookies::Cookies;
use super::Connection;
use crate::async_impl::AsyncRuntime;
use crate::params::resolve_hreq_params;
use crate::params::HReqParams;
use crate::params::QueryParams;
use crate::uri_ext::UriExt;
use crate::Body;
use crate::Error;
use crate::ResponseExt;
use cookie::Cookie;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
#[derive(Default)]
pub struct Agent {
connections: Vec<Connection>,
cookies: Option<Cookies>,
redirects: i8,
retries: i8,
pooling: bool,
use_cookies: bool,
}
impl Agent {
pub fn new() -> Self {
Agent {
connections: vec![],
cookies: None,
redirects: 5,
retries: 5,
pooling: true,
use_cookies: true,
}
}
pub fn redirects(&mut self, amount: u8) {
self.redirects = amount as i8;
}
pub fn retries(&mut self, amount: u8) {
self.retries = amount as i8;
}
pub fn pooling(&mut self, enabled: bool) {
self.pooling = enabled;
if !enabled {
self.connections.clear();
}
}
pub fn cookies(&mut self, enabled: bool) {
self.use_cookies = enabled;
if !enabled {
self.cookies = None;
}
}
pub fn get_cookies(&self, uri: &http::Uri) -> Vec<&Cookie<'static>> {
if let Some(cookies) = &self.cookies {
cookies.get(uri)
} else {
vec![]
}
}
fn reuse_from_pool(&mut self, uri: &http::Uri) -> Result<Option<&mut Connection>, Error> {
if !self.pooling {
return Ok(None);
}
let host_port = uri.host_port()?;
let ret = self
.connections
.iter_mut()
.find(|c| {
c.host_port() == &host_port && (c.is_http2() || c.unfinished_requests() == 0)
});
if ret.is_some() {
debug!("Reuse from pool: {}", uri);
}
let ret = None;
Ok(ret)
}
pub(crate) fn send_future<'a>(mut self, req: http::Request<Body>) -> ResponseFuture {
let do_fut = async move { self.send(req).await };
ResponseFuture::new(do_fut)
}
pub async fn send<B: Into<Body>>(
&mut self,
req: http::Request<B>,
) -> Result<http::Response<Body>, Error> {
let (parts, body) = req.into_parts();
let body = body.into();
let parts = resolve_hreq_params(parts);
let params = parts.extensions.get::<HReqParams>().unwrap().clone();
let mut body_buffer = BodyBuf::new(params.redirect_body_buffer);
let deadline = params.deadline();
let mut cookies = self.cookies.take();
let ret = deadline
.race(self.do_send(parts, body, params, &mut cookies, &mut body_buffer))
.await;
self.cookies = cookies;
ret
}
async fn do_send(
&mut self,
parts: http::request::Parts,
body: Body,
params: HReqParams,
cookies: &mut Option<Cookies>,
body_buffer: &mut BodyBuf,
) -> Result<http::Response<Body>, Error> {
trace!("Agent {} {}", parts.method, parts.uri);
let mut retries = self.retries;
let mut backoff_millis: u64 = 125;
let mut redirects = self.redirects;
let pooling = self.pooling;
let mut unpooled: Option<Connection> = None;
let use_cookies = self.use_cookies;
let orig_hostport = parts.uri.host_port()?.to_owned();
let mut next_req = http::Request::from_parts(parts, body);
loop {
let mut req = next_req;
let uri = req.uri().clone();
if self.use_cookies {
if let Some(cookies) = cookies {
let cookies = cookies.get(&uri);
for cookie in cookies {
let no_param = Cookie::new(cookie.name(), cookie.value());
let cval = no_param.encoded().to_string();
let val = http::header::HeaderValue::from_str(&cval)
.expect("Cookie header value");
req.headers_mut().append("cookie", val);
}
}
}
let is_idempotent = req.method().is_idempotent();
next_req = clone_to_empty_body(&req);
let conn = match self.reuse_from_pool(&uri)? {
Some(conn) => conn,
None => {
let hostport_uri = uri.host_port()?;
let mut conn: Option<Connection> = None;
let HReqParams {
force_http2,
tls_disable_verify,
..
} = params;
if orig_hostport == hostport_uri {
if let Some(arc) = params.with_override.clone() {
let hostport = &*arc;
debug!("Connect new: {} with override: {}", uri, hostport);
conn = Some(connect(hostport, force_http2, tls_disable_verify).await?);
}
}
let conn = match conn {
Some(conn) => conn,
None => {
debug!("Connect new: {}", hostport_uri);
connect(&hostport_uri, force_http2, tls_disable_verify).await?
}
};
if pooling {
self.connections.push(conn);
let idx = self.connections.len() - 1;
self.connections.get_mut(idx).unwrap()
} else {
unpooled.replace(conn);
unpooled.as_mut().unwrap()
}
}
};
debug!("{} {}", req.method(), req.uri());
match conn.send_request(req, body_buffer).await {
Ok(mut res) => {
let mut retain = true;
if use_cookies {
for cookie_head in res.headers().get_all("set-cookie") {
if let Ok(v) = cookie_head.to_str() {
if let Ok(cookie) = Cookie::parse_encoded(v.to_string()) {
if cookies.is_none() {
*cookies = Some(Cookies::new());
}
cookies.as_mut().unwrap().add(&uri, cookie);
} else {
info!("Failed to parse cookie: {}", v);
}
} else {
info!("Failed to read cookie value: {:?}", cookie_head);
}
}
}
fn is_handled_redirect(status: http::StatusCode) -> bool {
match status.as_u16() {
301 | 302 | 307 | 308 => true,
_ => false,
}
}
if is_handled_redirect(res.status()) {
redirects -= 1;
if redirects < 0 {
trace!("Not following more redirections");
break Ok(res);
}
let location = res.header("location").ok_or_else(|| {
Error::Proto("Redirect without Location header".into())
})?;
trace!("Redirect to: {}", location);
let (mut parts, body) = next_req.into_parts();
parts.uri = parts.uri.parse_relative(location)?;
next_req = http::Request::from_parts(parts, body);
let code = res.status_code();
let is_307ish = code > 303;
if let Some(body) = body_buffer.reset(is_307ish) {
let (parts, _) = next_req.into_parts();
next_req = http::Request::from_parts(parts, body);
}
if is_307ish
&& !conn.is_http2()
&& conn.host_port() == &next_req.uri().host_port()?
{
retain = false;
}
if res.body_mut().read_and_discard().await.is_err() {
retain = false;
}
if !retain {
let conn_id = conn.id();
debug!("Remove from pool: {}", conn.host_port());
self.connections.retain(|c| c.id() != conn_id);
}
continue;
}
break Ok(res);
}
Err(err) => {
let conn_id = conn.id();
self.connections.retain(|c| c.id() != conn_id);
retries -= 1;
if retries == 0 || !is_idempotent || !err.is_retryable() {
trace!("Abort with error, {}", err);
break Err(err);
}
trace!("Retrying on error, {}", err);
}
}
trace!("Retry backoff: {}ms", backoff_millis);
AsyncRuntime::timeout(Duration::from_millis(backoff_millis)).await;
backoff_millis = (backoff_millis * 2).min(10_000);
}
}
}
fn clone_to_empty_body(from: &http::Request<Body>) -> http::Request<Body> {
let req = http::Request::builder()
.method(from.method().clone())
.uri(from.uri().clone())
.version(from.version().clone())
.body(Body::empty())
.unwrap();
let (mut parts, body) = req.into_parts();
parts.headers = from.headers().clone();
if let Some(params) = from.extensions().get::<HReqParams>() {
parts.extensions.insert(params.clone());
}
if let Some(params) = from.extensions().get::<QueryParams>() {
parts.extensions.insert(params.clone());
}
http::Request::from_parts(parts, body)
}
impl fmt::Debug for Agent {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Agent")
}
}
pub struct ResponseFuture {
req: Box<dyn Future<Output = Result<http::Response<Body>, Error>> + Send>,
}
impl ResponseFuture {
pub(crate) fn new(
t: impl Future<Output = Result<http::Response<Body>, Error>> + Send + 'static,
) -> Self {
ResponseFuture { req: Box::new(t) }
}
}
impl Future for ResponseFuture {
type Output = Result<http::Response<Body>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.get_mut();
unsafe { Pin::new_unchecked(&mut *this.req) }.poll(cx)
}
}
impl fmt::Debug for ResponseFuture {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "ResponseFuture")
}
}