use core::marker::PhantomData;
use core::future::Future;
use core::fmt;
use std::path::Path;
use crate::header;
pub mod config;
pub mod request;
pub mod response;
pub use request::Request;
pub use response::Response;
pub struct Client<C=config::DefaultCfg> where C: config::Config + 'static {
inner: hyper::Client<C::Connector>,
_config: PhantomData<C>
}
impl Default for Client {
fn default() -> Self {
Client::<config::DefaultCfg>::new()
}
}
impl<C: config::Config> fmt::Debug for Client<C> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "Yukikaze {{ HyperClient={:?} }}", self.inner)
}
}
pub type RequestResult = Result<response::Response, hyper::Error>;
use tokio::io::{AsyncRead, AsyncWrite};
impl<C: config::Config> Client<C> where <C::Connector as hyper::service::Service<hyper::Uri>>::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
<C::Connector as hyper::service::Service<hyper::Uri>>::Future: Send + Unpin,
<C::Connector as hyper::service::Service<hyper::Uri>>::Response: AsyncRead + AsyncWrite + hyper::client::connect::Connection + Unpin + Send
{
pub fn new() -> Client<C> {
let inner = C::config_hyper(&mut hyper::Client::builder()).build(C::Connector::default());
Self {
inner,
_config: PhantomData
}
}
fn apply_headers(request: &mut request::Request) {
C::default_headers(request);
#[cfg(feature = "compu")]
{
const DEFAULT_COMPRESS: &'static str = "br, gzip, deflate";
if C::decompress() {
let headers = request.headers_mut();
if !headers.contains_key(header::ACCEPT_ENCODING) && headers.contains_key(header::RANGE) {
headers.insert(header::ACCEPT_ENCODING, header::HeaderValue::from_static(DEFAULT_COMPRESS));
}
}
}
}
pub async fn request(&self, mut req: request::Request) -> RequestResult {
Self::apply_headers(&mut req);
#[cfg(feature = "carry_extensions")]
let mut extensions = req.extract_extensions();
let ongoing = self.inner.request(req.into());
let ongoing = matsu!(ongoing).map(|res| response::Response::new(res));
#[cfg(feature = "carry_extensions")]
{
ongoing.map(move |resp| resp.replace_extensions(&mut extensions))
}
#[cfg(not(feature = "carry_extensions"))]
{
ongoing
}
}
pub async fn send(&self, mut req: request::Request) -> Result<RequestResult, async_timer::Expired<impl Future<Output=RequestResult>, C::Timer>> {
Self::apply_headers(&mut req);
#[cfg(feature = "carry_extensions")]
let mut extensions = req.extract_extensions();
let ongoing = self.inner.request(req.into());
let ongoing = async {
let res = matsu!(ongoing);
res.map(|resp| response::Response::new(resp))
};
let timeout = C::timeout();
match timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
#[cfg(not(feature = "carry_extensions"))]
true => Ok(matsu!(ongoing)),
#[cfg(feature = "carry_extensions")]
true => Ok(matsu!(ongoing).map(move |resp| resp.replace_extensions(&mut extensions))),
false => {
let job = unsafe { async_timer::Timed::<_, C::Timer>::new_unchecked(ongoing, timeout) };
#[cfg(not(feature = "carry_extensions"))]
{
matsu!(job)
}
#[cfg(feature = "carry_extensions")]
{
matsu!(job).map(move |res| res.map(move |resp| resp.replace_extensions(&mut extensions)))
}
}
}
}
pub async fn send_redirect(&'static self, req: request::Request) -> Result<RequestResult, async_timer::Expired<impl Future<Output=RequestResult> + 'static, C::Timer>> {
let timeout = C::timeout();
match timeout.as_secs() == 0 && timeout.subsec_nanos() == 0 {
true => Ok(matsu!(self.redirect_request(req))),
false => {
let ongoing = self.redirect_request(req);
let job = unsafe { async_timer::Timed::<_, C::Timer>::new_unchecked(ongoing, timeout) };
matsu!(job)
}
}
}
pub async fn redirect_request(&self, mut req: request::Request) -> RequestResult {
use http::{Method, StatusCode};
Self::apply_headers(&mut req);
let mut rem_redirect = C::max_redirect_num();
let mut method = req.parts.method.clone();
let uri = req.parts.uri.clone();
let mut headers = req.parts.headers.clone();
let mut body = req.body.clone();
#[cfg(feature = "carry_extensions")]
let mut extensions = req.extract_extensions();
loop {
let ongoing = self.inner.request(req.into());
let res = matsu!(ongoing).map(|resp| response::Response::new(resp))?;
match res.status() {
StatusCode::SEE_OTHER => {
rem_redirect -= 1;
match rem_redirect {
#[cfg(feature = "carry_extensions")]
0 => return Ok(res.replace_extensions(&mut extensions)),
#[cfg(not(feature = "carry_extensions"))]
0 => return Ok(res),
_ => {
body = None;
method = Method::GET;
}
}
},
StatusCode::MOVED_PERMANENTLY | StatusCode::FOUND | StatusCode::TEMPORARY_REDIRECT | StatusCode::PERMANENT_REDIRECT => {
rem_redirect -= 1;
match rem_redirect {
#[cfg(feature = "carry_extensions")]
0 => return Ok(res.replace_extensions(&mut extensions)),
#[cfg(not(feature = "carry_extensions"))]
0 => return Ok(res),
_ => (),
}
}
#[cfg(feature = "carry_extensions")]
_ => return Ok(res.replace_extensions(&mut extensions)),
#[cfg(not(feature = "carry_extensions"))]
_ => return Ok(res),
}
let location = match res.headers().get(header::LOCATION).and_then(|loc| loc.to_str().ok()).and_then(|loc| loc.parse::<hyper::Uri>().ok()) {
Some(loc) => match loc.scheme().is_some() {
true => {
if let Some(prev_host) = uri.authority().map(|part| part.host()) {
match loc.authority().map(|part| part.host() == prev_host).unwrap_or(false) {
true => (),
false => {
headers.remove("authorization");
headers.remove("cookie");
headers.remove("cookie2");
headers.remove("www-authenticate");
}
}
}
loc
},
false => {
let current = Path::new(uri.path());
let loc = Path::new(loc.path());
let loc = current.join(loc);
let loc = loc.to_str().expect("Valid UTF-8 path").parse::<hyper::Uri>().expect("Valid URI");
let mut loc_parts = loc.into_parts();
loc_parts.scheme = uri.scheme().cloned();
loc_parts.authority = uri.authority().cloned();
hyper::Uri::from_parts(loc_parts).expect("Create redirect URI")
},
},
#[cfg(feature = "carry_extensions")]
None => return Ok(res.replace_extensions(&mut extensions)),
#[cfg(not(feature = "carry_extensions"))]
None => return Ok(res),
};
let (mut parts, _) = hyper::Request::<()>::new(()).into_parts();
parts.method = method.clone();
parts.uri = location;
parts.headers = headers.clone();
req = request::Request {
parts,
body: body.clone()
};
}
}
}