use http::uri::Scheme;
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use std::str::FromStr;
use std::{
convert::Infallible,
fmt,
future::Future,
marker::PhantomData,
rc::Rc,
task::{Poll, Waker},
time::Duration,
};
use crate::proxy_wasm::types::{Bytes, Status};
use serde::de::StdError;
use crate::http_constants::{
DEFAULT_TIMEOUT, HEADER_AUTHORITY, HEADER_METHOD, HEADER_PATH, HEADER_SCHEME, HEADER_STATUS,
METHOD_DELETE, METHOD_GET, METHOD_OPTIONS, METHOD_POST, METHOD_PUT, USER_AGENT_HEADER,
};
use crate::user_agent::UserAgent;
use crate::{
extract::{Extract, FromContext},
host::Host,
reactor::root::{BoxedExtractor, RootReactor},
types::{Cid, RequestId},
};
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct HttpCallResponse {
pub request_id: RequestId,
pub num_headers: usize,
pub body_size: usize,
pub num_trailers: usize,
}
pub struct HttpClient {
reactor: Rc<RootReactor>,
host: Rc<dyn Host>,
user_agent: Rc<UserAgent>,
}
#[derive(thiserror::Error, Debug, Clone)]
pub enum HttpClientError {
#[error("Proxy status problem: {0:?}")]
Status(Status),
#[error("Request awaited on create context event")]
AwaitedOnCreateContext,
}
impl HttpClient {
pub(crate) fn new(
reactor: Rc<RootReactor>,
host: Rc<dyn Host>,
user_agent: Rc<UserAgent>,
) -> Self {
Self {
reactor,
host,
user_agent,
}
}
pub fn request<'a>(
&'a self,
service: &'a Service,
) -> RequestBuilder<'a, DefaultResponseExtractor> {
RequestBuilder::new(self, service, DefaultResponseExtractor)
}
}
impl<C> FromContext<C> for HttpClient
where
Rc<dyn Host>: FromContext<C, Error = Infallible>,
Rc<RootReactor>: FromContext<C, Error = Infallible>,
{
type Error = Infallible;
fn from_context(context: &C) -> Result<Self, Self::Error> {
let reactor = context.extract()?;
let host = context.extract()?;
let agent = context.extract()?;
Ok(Self::new(reactor, host, agent))
}
}
pub struct Request<T> {
reactor: Rc<RootReactor>,
request_id: RequestId,
cid_and_waker: Option<(Cid, Waker)>,
error: Option<HttpClientError>,
_response_type: PhantomData<T>,
}
pub trait ResponseBuffers {
fn status_code(&self) -> u32;
fn header(&self, name: &str) -> Option<String>;
fn headers(&self) -> Vec<(String, String)>;
fn body(&self, start: usize, max_size: usize) -> Option<Bytes>;
fn trailers(&self) -> Vec<(String, String)>;
}
impl ResponseBuffers for Rc<dyn Host> {
fn status_code(&self) -> u32 {
self.header(HEADER_STATUS)
.and_then(|status| status.parse::<u32>().ok())
.unwrap_or_default()
}
fn header(&self, name: &str) -> Option<String> {
self.get_http_call_response_header(name)
}
fn headers(&self) -> Vec<(String, String)> {
self.get_http_call_response_headers()
}
fn body(&self, start: usize, max_size: usize) -> Option<Bytes> {
self.get_http_call_response_body(start, max_size)
}
fn trailers(&self) -> Vec<(String, String)> {
self.get_http_call_response_trailers()
}
}
pub trait ResponseExtractor {
type Output;
fn extract(self, event: &HttpCallResponse, buffers: &dyn ResponseBuffers) -> Self::Output;
}
pub struct FnResponseExtractor<F> {
function: F,
}
impl<F, T> ResponseExtractor for FnResponseExtractor<F>
where
F: FnOnce(&HttpCallResponse, &dyn ResponseBuffers) -> T,
{
type Output = T;
fn extract(self, event: &HttpCallResponse, buffers: &dyn ResponseBuffers) -> Self::Output {
(self.function)(event, buffers)
}
}
impl<F, T> FnResponseExtractor<F>
where
F: FnOnce(&HttpCallResponse, &dyn ResponseBuffers) -> T,
{
pub fn from_fn(function: F) -> FnResponseExtractor<F>
where
F: FnOnce(&HttpCallResponse, &dyn ResponseBuffers) -> T,
{
FnResponseExtractor { function }
}
}
pub struct RequestBuilder<'a, E> {
client: &'a HttpClient,
extractor: E,
service: &'a Service,
path: Option<&'a str>,
headers: Option<Vec<(&'a str, &'a str)>>,
body: Option<&'a [u8]>,
trailers: Option<Vec<(&'a str, &'a str)>>,
timeout: Option<Duration>,
}
impl<'a, E> RequestBuilder<'a, E>
where
E: ResponseExtractor + 'static,
E::Output: 'static,
{
fn new(client: &'a HttpClient, service: &'a Service, extractor: E) -> Self {
Self {
client,
extractor,
service,
path: None,
headers: None,
body: None,
trailers: None,
timeout: None,
}
}
pub fn extractor<T>(self, extractor: T) -> RequestBuilder<'a, T>
where
T: ResponseExtractor,
{
RequestBuilder {
client: self.client,
extractor,
service: self.service,
path: self.path,
headers: self.headers,
body: self.body,
trailers: self.trailers,
timeout: self.timeout,
}
}
pub fn extract_with<F, T>(self, function: F) -> RequestBuilder<'a, FnResponseExtractor<F>>
where
F: FnOnce(&HttpCallResponse, &dyn ResponseBuffers) -> T,
{
self.extractor(FnResponseExtractor::from_fn(function))
}
pub fn path(mut self, path: &'a str) -> Self {
self.path = Some(path);
self
}
pub fn headers(mut self, headers: Vec<(&'a str, &'a str)>) -> Self {
self.headers = Some(headers);
self
}
pub fn body(mut self, body: &'a [u8]) -> Self {
self.body = Some(body);
self
}
pub fn trailers(mut self, trailers: Vec<(&'a str, &'a str)>) -> Self {
self.trailers = Some(trailers);
self
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
pub fn post(self) -> Request<E::Output> {
self.send(METHOD_POST)
}
pub fn put(self) -> Request<E::Output> {
self.send(METHOD_PUT)
}
pub fn get(self) -> Request<E::Output> {
self.send(METHOD_GET)
}
pub fn options(self) -> Request<E::Output> {
self.send(METHOD_OPTIONS)
}
pub fn delete(self) -> Request<E::Output> {
self.send(METHOD_DELETE)
}
#[must_use]
pub fn send(mut self, method: &str) -> Request<E::Output> {
let mut headers = self.headers.take().unwrap_or_default();
headers.push((HEADER_PATH, self.path.unwrap_or(self.service.uri().path())));
headers.push((HEADER_AUTHORITY, self.service.uri().authority()));
headers.push((HEADER_METHOD, method));
headers.push((USER_AGENT_HEADER, self.client.user_agent.value()));
headers.push((HEADER_SCHEME, self.service.uri().scheme()));
let body = self.body.take();
let trailers = self.trailers.take().unwrap_or_default();
let timeout = self.timeout.take().unwrap_or(DEFAULT_TIMEOUT);
match self.client.host.dispatch_http_call(
self.service.cluster_name(),
headers,
body,
trailers,
timeout,
) {
Ok(request_id) => {
let request_id: RequestId = request_id.into();
let extractor = boxed_extractor(self.client.host.clone(), self.extractor);
self.client.reactor.insert_extractor(request_id, extractor);
Request::new(self.client.reactor.clone(), request_id)
}
Err(err) => Request::error(self.client.reactor.clone(), HttpClientError::Status(err)),
}
}
}
impl<E: ResponseExtractor> ResponseExtractor for RequestBuilder<'_, E> {
type Output = E::Output;
fn extract(self, event: &HttpCallResponse, buffers: &dyn ResponseBuffers) -> Self::Output {
self.extractor.extract(event, buffers)
}
}
fn boxed_extractor<E>(buffers: Rc<dyn Host>, extractor: E) -> BoxedExtractor
where
E: ResponseExtractor + 'static,
E::Output: 'static,
{
Box::new(move |event| Box::new(extractor.extract(event, &buffers)))
}
pub struct EmptyResponseExtractor;
impl ResponseExtractor for EmptyResponseExtractor {
type Output = ();
fn extract(self, _event: &HttpCallResponse, _buffers: &dyn ResponseBuffers) -> Self::Output {}
}
impl<T> Request<T> {
fn new(reactor: Rc<RootReactor>, request_id: RequestId) -> Self {
Request {
reactor,
request_id,
error: None,
cid_and_waker: None,
_response_type: PhantomData,
}
}
fn error(reactor: Rc<RootReactor>, error: HttpClientError) -> Self {
Request {
reactor,
request_id: RequestId::from(0),
error: Some(error),
cid_and_waker: None,
_response_type: PhantomData,
}
}
pub fn id(&self) -> RequestId {
self.request_id
}
}
impl<T> Drop for Request<T> {
fn drop(&mut self) {
if self.error.is_none() {
let reactor = self.reactor.as_ref();
reactor.remove_extractor(self.request_id);
reactor.remove_response(self.request_id);
reactor.remove_client(self.request_id);
}
}
}
impl<T: Unpin + 'static> Future for Request<T> {
type Output = Result<T, HttpClientError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> Poll<Self::Output> {
if let Some(error) = self.error.clone() {
return Poll::Ready(Err(error));
}
if let Some((_event, content)) = self.reactor.remove_response(self.request_id) {
let content = content.expect("response content should have been extracted");
let content = content.downcast().expect("downcasting");
Poll::Ready(Ok(*content))
} else {
let this = &mut *self.as_mut();
match this.cid_and_waker.as_ref() {
None => {
let cid = this.reactor.active_cid();
this.reactor
.insert_client(this.request_id, cx.waker().clone());
this.reactor.set_paused(cid, true);
this.cid_and_waker = Some((cid, cx.waker().clone()));
}
Some((cid, waker)) if !waker.will_wake(cx.waker()) => {
let _ = this
.reactor
.remove_client(this.request_id)
.expect("stored extractor");
this.reactor
.insert_client(this.request_id, cx.waker().clone());
this.cid_and_waker = Some((*cid, cx.waker().clone()));
}
Some(_) => {}
}
Poll::Pending
}
}
}
pub struct DefaultResponseExtractor;
impl ResponseExtractor for DefaultResponseExtractor {
type Output = HttpClientResponse;
fn extract(self, event: &HttpCallResponse, buffers: &dyn ResponseBuffers) -> Self::Output {
let mut map = HashMap::new();
for (k, v) in buffers.headers().into_iter() {
match map.entry(k) {
Entry::Vacant(e) => {
e.insert(v);
}
Entry::Occupied(mut e) => {
e.insert(format!("{},{}", e.get(), v));
}
}
}
let body = buffers.body(0, event.body_size).unwrap_or_default();
HttpClientResponse::new(map, body)
}
}
#[derive(Debug)]
pub struct HttpClientResponse {
headers: HashMap<String, String>,
body: Bytes,
}
impl HttpClientResponse {
pub fn new(headers: HashMap<String, String>, body: Bytes) -> Self {
Self { headers, body }
}
pub fn status_code(&self) -> u32 {
self.header(HEADER_STATUS)
.and_then(|status| status.parse::<u32>().ok())
.unwrap_or_default()
}
pub fn headers(&self) -> &HashMap<String, String> {
&self.headers
}
pub fn header(&self, header: &str) -> Option<&String> {
self.headers
.iter()
.find_map(|(k, v)| k.eq_ignore_ascii_case(header).then_some(v))
}
pub fn body(&self) -> &[u8] {
self.body.as_slice()
}
pub fn as_utf8_lossy(&self) -> String {
String::from_utf8_lossy(&self.body).to_string()
}
}
pub struct InvalidUri(InvalidUriKind);
enum InvalidUriKind {
Delegate(http::uri::InvalidUri),
MissingAuthority,
InvalidSchema,
}
impl Display for InvalidUri {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
match &self.0 {
InvalidUriKind::Delegate(d) => Display::fmt(d, f),
InvalidUriKind::MissingAuthority => Display::fmt("authority missing", f),
InvalidUriKind::InvalidSchema => Display::fmt("scheme not supported", f),
}
}
}
impl Debug for InvalidUri {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&self, f)
}
}
impl StdError for InvalidUri {}
#[derive(Clone, Debug, Default)]
pub struct Uri {
delegate: http::Uri,
}
impl Uri {
pub fn path(&self) -> &str {
self.delegate
.path_and_query()
.map(|path_and_query| path_and_query.as_str())
.unwrap_or_else(|| self.delegate.path())
}
pub fn scheme(&self) -> &str {
self.delegate.scheme_str().unwrap_or_default()
}
pub fn authority(&self) -> &str {
self.delegate
.authority()
.map(|authority| authority.as_str())
.unwrap_or_default()
}
}
impl Display for Uri {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
f.write_fmt(format_args!("{}", self.delegate.to_string().as_str()))
}
}
impl FromStr for Uri {
type Err = InvalidUri;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.parse::<http::Uri>() {
Ok(delegate) => {
if delegate.authority().is_none() {
return Err(InvalidUri(InvalidUriKind::MissingAuthority));
}
if delegate
.scheme()
.map(|s| {
!s.eq(&Scheme::HTTP)
&& !s.eq(&Scheme::HTTPS)
&& !s.as_str().eq_ignore_ascii_case("h2")
})
.unwrap_or(true)
{
return Err(InvalidUri(InvalidUriKind::InvalidSchema));
}
Ok(Self { delegate })
}
Err(e) => Err(InvalidUri(InvalidUriKind::Delegate(e))),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct Service {
cluster_name: String,
uri: Uri,
}
impl Service {
pub fn from<'a>(name: &'a str, namespace: &'a str, uri: Uri) -> Service {
let cluster_name = format!("{name}.{namespace}.svc");
Service { cluster_name, uri }
}
pub fn new(cluster_name: &str, uri: Uri) -> Service {
Service {
cluster_name: cluster_name.to_string(),
uri,
}
}
pub fn cluster_name(&self) -> &str {
self.cluster_name.as_str()
}
pub fn uri(&self) -> &Uri {
&self.uri
}
}
#[cfg(test)]
mod test {
use super::Uri;
#[test]
fn successfully_parse_http() {
assert!("http://some.com/foo?some=val".parse::<Uri>().is_ok());
}
#[test]
fn successfully_parse_https() {
assert!("https://some.com/foo".parse::<Uri>().is_ok());
}
#[test]
fn successfully_parse_h2() {
assert!("h2://some.com/foo".parse::<Uri>().is_ok());
}
#[test]
fn error_invalid_scheme() {
assert!("ftp://some.com/foo".parse::<Uri>().is_err());
}
#[test]
fn error_on_missing_scheme() {
assert!("some.com/foo".parse::<Uri>().is_err());
}
#[test]
fn error_on_missing_host() {
assert!("/foo".parse::<Uri>().is_err());
}
}