use crate::{Error, Result};
use futures_util::{
io::{AsyncRead, AsyncWrite},
stream::{self, Stream},
StreamExt, TryFutureExt,
};
use hyper::{
body::Bytes,
client::{Client, HttpConnector},
header, Body, Method, Request, Response, StatusCode,
};
#[cfg(feature = "tls")]
use hyper_openssl::HttpsConnector;
#[cfg(unix)]
use hyperlocal::UnixConnector;
#[cfg(unix)]
use hyperlocal::Uri as DomainUri;
use pin_project::pin_project;
use serde::{Deserialize, Serialize};
use url::Url;
use std::{
io,
iter::IntoIterator,
path::PathBuf,
pin::Pin,
task::{Context, Poll},
};
#[derive(Debug, Default, Clone)]
pub struct Headers(Vec<(&'static str, String)>);
impl Headers {
pub fn none() -> Option<Headers> {
None
}
pub fn add<V>(&mut self, key: &'static str, val: V)
where
V: Into<String>,
{
self.0.push((key, val.into()))
}
pub fn single<V>(key: &'static str, val: V) -> Self
where
V: Into<String>,
{
let mut h = Self::default();
h.add(key, val);
h
}
}
impl IntoIterator for Headers {
type Item = (&'static str, String);
type IntoIter = std::vec::IntoIter<Self::Item>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
pub enum Payload<B: Into<Body>> {
None,
Text(B),
Json(B),
XTar(B),
Tar(B),
}
impl Payload<Body> {
pub fn empty() -> Self {
Payload::None
}
}
impl<B: Into<Body>> Payload<B> {
pub fn into_inner(self) -> Option<B> {
match self {
Self::None => None,
Self::Text(b) => Some(b),
Self::Json(b) => Some(b),
Self::XTar(b) => Some(b),
Self::Tar(b) => Some(b),
}
}
pub fn mime_type(&self) -> Option<mime::Mime> {
match &self {
Self::None => None,
Self::Text(_) => None,
Self::Json(_) => Some(mime::APPLICATION_JSON),
Self::XTar(_) => Some("application/x-tar".parse().expect("parsed mime")),
Self::Tar(_) => Some("application/tar".parse().expect("parsed mime")),
}
}
pub fn is_none(&self) -> bool {
matches!(self, Self::None)
}
}
pub async fn body_to_string(body: Body) -> Result<String> {
let bytes = hyper::body::to_bytes(body).await?;
String::from_utf8(bytes.to_vec()).map_err(Error::from)
}
#[derive(Clone, Debug)]
pub enum Transport {
Tcp {
client: Client<HttpConnector>,
host: Url,
},
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
EncryptedTcp {
client: Client<HttpsConnector<HttpConnector>>,
host: Url,
},
#[cfg(unix)]
Unix {
client: Client<UnixConnector>,
path: PathBuf,
},
}
impl Transport {
pub fn remote_addr(&self) -> &str {
match &self {
Self::Tcp { ref host, .. } => host.as_ref(),
#[cfg(feature = "tls")]
Self::EncryptedTcp { ref host, .. } => host.as_ref(),
#[cfg(unix)]
Self::Unix { ref path, .. } => path.to_str().unwrap_or_default(),
}
}
pub async fn request<B>(
&self,
method: Method,
endpoint: impl AsRef<str>,
body: Payload<B>,
headers: Option<Headers>,
) -> Result<Response<Body>>
where
B: Into<Body>,
{
let req = self.build_request(method, endpoint, body, headers, Request::builder())?;
self.send_request(req).await
}
pub async fn request_string<B>(
&self,
method: Method,
endpoint: impl AsRef<str>,
body: Payload<B>,
headers: Option<Headers>,
) -> Result<String>
where
B: Into<Body>,
{
let body = self.get_body(method, endpoint, body, headers).await?;
body_to_string(body).await
}
pub fn stream_chunks<'transport, B>(
&'transport self,
method: Method,
endpoint: impl AsRef<str> + 'transport,
body: Payload<B>,
headers: Option<Headers>,
) -> impl Stream<Item = Result<Bytes>> + 'transport
where
B: Into<Body> + 'transport,
{
self.get_chunk_stream(method, endpoint, body, headers)
.try_flatten_stream()
}
pub fn stream_json_chunks<'transport, B>(
&'transport self,
method: Method,
endpoint: impl AsRef<str> + 'transport,
body: Payload<B>,
headers: Option<Headers>,
) -> impl Stream<Item = Result<Bytes>> + 'transport
where
B: Into<Body> + 'transport,
{
self.get_json_chunk_stream(method, endpoint, body, headers)
.try_flatten_stream()
}
pub async fn stream_upgrade<B>(
&self,
method: Method,
endpoint: impl AsRef<str>,
body: Payload<B>,
) -> Result<impl AsyncRead + AsyncWrite>
where
B: Into<Body>,
{
self.stream_upgrade_tokio(method, endpoint, body)
.await
.map(Compat::new)
.map_err(Error::from)
}
pub async fn get_body<B>(
&self,
method: Method,
endpoint: impl AsRef<str>,
body: Payload<B>,
headers: Option<Headers>,
) -> Result<Body>
where
B: Into<Body>,
{
let response = self.request(method, endpoint, body, headers).await?;
log::trace!(
"got response {} {:?}",
response.status(),
response.headers()
);
let status = response.status();
match status {
StatusCode::OK
| StatusCode::CREATED
| StatusCode::SWITCHING_PROTOCOLS
| StatusCode::NO_CONTENT => Ok(response.into_body()),
_ => {
let bytes = hyper::body::to_bytes(response.into_body()).await?;
let message_body = String::from_utf8(bytes.to_vec())?;
Err(Error::Fault {
code: status,
message: Self::get_error_message(&message_body).unwrap_or_else(|| {
status
.canonical_reason()
.unwrap_or("unknown error code")
.to_owned()
}),
})
}
}
}
async fn get_chunk_stream<B>(
&self,
method: Method,
endpoint: impl AsRef<str>,
body: Payload<B>,
headers: Option<Headers>,
) -> Result<impl Stream<Item = Result<Bytes>>>
where
B: Into<Body>,
{
self.get_body(method, endpoint, body, headers)
.await
.map(stream_body)
}
async fn get_json_chunk_stream<B>(
&self,
method: Method,
endpoint: impl AsRef<str>,
body: Payload<B>,
headers: Option<Headers>,
) -> Result<impl Stream<Item = Result<Bytes>>>
where
B: Into<Body>,
{
self.get_body(method, endpoint, body, headers)
.await
.map(stream_json_body)
}
fn build_request<B>(
&self,
method: Method,
endpoint: impl AsRef<str>,
body: Payload<B>,
headers: Option<Headers>,
builder: hyper::http::request::Builder,
) -> Result<Request<Body>>
where
B: Into<Body>,
{
let ep = endpoint.as_ref();
let uri: hyper::Uri = match self {
Transport::Tcp { host, .. } => format!("{}{}", host, ep)
.parse()
.map_err(Error::InvalidUri)?,
#[cfg(feature = "tls")]
Transport::EncryptedTcp { host, .. } => format!("{}{}", host, ep)
.parse()
.map_err(Error::InvalidUri)?,
#[cfg(unix)]
Transport::Unix { path, .. } => DomainUri::new(&path, ep).into(),
};
let req = builder.method(method).uri(&uri);
let mut req = req.header(header::HOST, "");
if let Some(h) = headers {
for (k, v) in h.into_iter() {
req = req.header(k, v);
}
}
if body.is_none() {
return Ok(req.body(Body::empty())?);
}
let mime = body.mime_type();
if let Some(c) = mime {
req = req.header(header::CONTENT_TYPE, &c.to_string());
}
req.body(body.into_inner().unwrap().into())
.map_err(Error::from)
}
async fn send_request(&self, req: Request<Body>) -> Result<Response<Body>> {
log::trace!("sending request {} {}", req.method(), req.uri());
match self {
Transport::Tcp { ref client, .. } => client.request(req),
#[cfg(feature = "tls")]
Transport::EncryptedTcp { ref client, .. } => client.request(req),
#[cfg(unix)]
Transport::Unix { ref client, .. } => client.request(req),
}
.await
.map_err(Error::from)
}
async fn stream_upgrade_tokio<B>(
&self,
method: Method,
endpoint: impl AsRef<str>,
body: Payload<B>,
) -> Result<hyper::upgrade::Upgraded>
where
B: Into<Body>,
{
let req = self.build_request(
method,
endpoint,
body,
Headers::none(),
Request::builder()
.header(header::CONNECTION, "Upgrade")
.header(header::UPGRADE, "tcp"),
)?;
let response = self.send_request(req).await?;
match response.status() {
StatusCode::SWITCHING_PROTOCOLS => Ok(hyper::upgrade::on(response).await?),
_ => Err(Error::ConnectionNotUpgraded),
}
}
fn get_error_message(body: &str) -> Option<String> {
serde_json::from_str::<ErrorResponse>(body)
.map(|e| e.message)
.ok()
}
}
#[pin_project]
struct Compat<S> {
#[pin]
tokio_multiplexer: S,
}
impl<S> Compat<S> {
fn new(tokio_multiplexer: S) -> Self {
Self { tokio_multiplexer }
}
}
impl<S> AsyncRead for Compat<S>
where
S: tokio::io::AsyncRead,
{
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut readbuf = tokio::io::ReadBuf::new(buf);
match self.project().tokio_multiplexer.poll_read(cx, &mut readbuf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok(())) => Poll::Ready(Ok(readbuf.filled().len())),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
}
}
}
impl<S> AsyncWrite for Compat<S>
where
S: tokio::io::AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.project().tokio_multiplexer.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().tokio_multiplexer.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.project().tokio_multiplexer.poll_shutdown(cx)
}
}
#[derive(Serialize, Deserialize)]
struct ErrorResponse {
message: String,
}
fn stream_body(body: Body) -> impl Stream<Item = Result<Bytes>> {
async fn unfold(mut body: Body) -> Option<(Result<Bytes>, Body)> {
body.next()
.await
.map(|chunk| (chunk.map_err(Error::from), body))
}
stream::unfold(body, unfold)
}
static JSON_WHITESPACE: &[u8] = b"\r\n";
fn stream_json_body(body: Body) -> impl Stream<Item = Result<Bytes>> {
async fn unfold(mut body: Body) -> Option<(Result<Bytes>, Body)> {
let mut chunk = Vec::new();
while let Some(chnk) = body.next().await {
match chnk {
Ok(chnk) => {
chunk.extend(chnk.to_vec());
if chnk.ends_with(JSON_WHITESPACE) {
break;
}
}
Err(e) => {
return Some((Err(Error::from(e)), body));
}
}
}
if chunk.is_empty() {
return None;
}
Some((Ok(Bytes::from(chunk)), body))
}
stream::unfold(body, unfold)
}